Skip to content
Snippets Groups Projects
Commit eeabbafd authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 8f797a74
No related merge requests found
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
set -e set -e
set -x set -x
DOWNLOAD_NAME=models_class1_unselected DOWNLOAD_NAME=models_class1_pan
SCRATCH_DIR=${TMPDIR-/tmp}/mhcflurry-downloads-generation SCRATCH_DIR=${TMPDIR-/tmp}/mhcflurry-downloads-generation
SCRIPT_ABSOLUTE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")" SCRIPT_ABSOLUTE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")"
SCRIPT_DIR=$(dirname "$SCRIPT_ABSOLUTE_PATH") SCRIPT_DIR=$(dirname "$SCRIPT_ABSOLUTE_PATH")
......
...@@ -39,7 +39,11 @@ base_hyperparameters = { ...@@ -39,7 +39,11 @@ base_hyperparameters = {
'random_negative_distribution_smoothing': 0.0, 'random_negative_distribution_smoothing': 0.0,
'random_negative_match_distribution': True, 'random_negative_match_distribution': True,
'random_negative_rate': 0.2, 'random_negative_rate': 0.2,
'train_data': {}, 'train_data': {
'pretrain': True,
'pretrain_peptides_per_epoch': 1024,
'pretrain_patience': 10,
},
'validation_split': 0.1, 'validation_split': 0.1,
} }
......
...@@ -291,6 +291,10 @@ def main(args): ...@@ -291,6 +291,10 @@ def main(args):
if args.max_epochs: if args.max_epochs:
hyperparameters['max_epochs'] = args.max_epochs hyperparameters['max_epochs'] = args.max_epochs
if hyperparameters.get("train_data", {}).get("pretrain", False):
if not args.pretrain_data:
raise ValueError("--pretrain-data is required")
for fold in range(args.ensemble_size): for fold in range(args.ensemble_size):
for replicate in range(args.num_replicates): for replicate in range(args.num_replicates):
work_dict = { work_dict = {
...@@ -386,11 +390,10 @@ def train_model( ...@@ -386,11 +390,10 @@ def train_model(
num_replicates, num_replicates,
hyperparameters, hyperparameters,
pretrain_data_filename, pretrain_data_filename,
pretrain_patience=1, verbose,
verbose=None, progress_print_interval,
progress_print_interval=None, predictor,
predictor=None, save_to):
save_to=None):
if predictor is None: if predictor is None:
predictor = Class1AffinityPredictor() predictor = Class1AffinityPredictor()
...@@ -424,7 +427,7 @@ def train_model( ...@@ -424,7 +427,7 @@ def train_model(
replicate_num + 1, replicate_num + 1,
num_replicates)) num_replicates))
if pretrain_data_filename: if hyperparameters.get("train_data", {}).get("pretrain", False):
iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding) iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
original_hyperparameters = dict(model.hyperparameters) original_hyperparameters = dict(model.hyperparameters)
model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100) model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100)
...@@ -432,6 +435,7 @@ def train_model( ...@@ -432,6 +435,7 @@ def train_model(
model.hyperparameters['validation_split'] = 0.0 model.hyperparameters['validation_split'] = 0.0
model.hyperparameters['random_negative_rate'] = 0.0 model.hyperparameters['random_negative_rate'] = 0.0
model.hyperparameters['random_negative_constant'] = 0 model.hyperparameters['random_negative_constant'] = 0
pretrain_patience = hyperparameters["train_data"]["pretrain_patience"]
scores = [] scores = []
best_score = float('inf') best_score = float('inf')
best_score_epoch = 0 best_score_epoch = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment