Skip to content
Snippets Groups Projects
Commit eeabbafd authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 8f797a74
No related merge requests found
......@@ -5,7 +5,7 @@
set -e
set -x
DOWNLOAD_NAME=models_class1_unselected
DOWNLOAD_NAME=models_class1_pan
SCRATCH_DIR=${TMPDIR-/tmp}/mhcflurry-downloads-generation
SCRIPT_ABSOLUTE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")"
SCRIPT_DIR=$(dirname "$SCRIPT_ABSOLUTE_PATH")
......
......@@ -39,7 +39,11 @@ base_hyperparameters = {
'random_negative_distribution_smoothing': 0.0,
'random_negative_match_distribution': True,
'random_negative_rate': 0.2,
'train_data': {},
'train_data': {
'pretrain': True,
'pretrain_peptides_per_epoch': 1024,
'pretrain_patience': 10,
},
'validation_split': 0.1,
}
......
......@@ -291,6 +291,10 @@ def main(args):
if args.max_epochs:
hyperparameters['max_epochs'] = args.max_epochs
if hyperparameters.get("train_data", {}).get("pretrain", False):
if not args.pretrain_data:
raise ValueError("--pretrain-data is required")
for fold in range(args.ensemble_size):
for replicate in range(args.num_replicates):
work_dict = {
......@@ -386,11 +390,10 @@ def train_model(
num_replicates,
hyperparameters,
pretrain_data_filename,
pretrain_patience=1,
verbose=None,
progress_print_interval=None,
predictor=None,
save_to=None):
verbose,
progress_print_interval,
predictor,
save_to):
if predictor is None:
predictor = Class1AffinityPredictor()
......@@ -424,7 +427,7 @@ def train_model(
replicate_num + 1,
num_replicates))
if pretrain_data_filename:
if hyperparameters.get("train_data", {}).get("pretrain", False):
iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
original_hyperparameters = dict(model.hyperparameters)
model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100)
......@@ -432,6 +435,7 @@ def train_model(
model.hyperparameters['validation_split'] = 0.0
model.hyperparameters['random_negative_rate'] = 0.0
model.hyperparameters['random_negative_constant'] = 0
pretrain_patience = hyperparameters["train_data"]["pretrain_patience"]
scores = []
best_score = float('inf')
best_score_epoch = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment