diff --git a/downloads-generation/models_class1_pan_unselected/GENERATE.sh b/downloads-generation/models_class1_pan_unselected/GENERATE.sh index 8475d01da065f943399f0964ad3540d8e55f8dd2..8724ed3f90d5667dab62e4f4ce9ced88c49619bf 100755 --- a/downloads-generation/models_class1_pan_unselected/GENERATE.sh +++ b/downloads-generation/models_class1_pan_unselected/GENERATE.sh @@ -5,7 +5,7 @@ set -e set -x -DOWNLOAD_NAME=models_class1_unselected +DOWNLOAD_NAME=models_class1_pan SCRATCH_DIR=${TMPDIR-/tmp}/mhcflurry-downloads-generation SCRIPT_ABSOLUTE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")" SCRIPT_DIR=$(dirname "$SCRIPT_ABSOLUTE_PATH") diff --git a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py index 6a9c2782873be030b070fbffc777e34927ccb9c1..bfbd44323aff111b1184a6d3a902ced76d93ebc8 100644 --- a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py +++ b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py @@ -39,7 +39,11 @@ base_hyperparameters = { 'random_negative_distribution_smoothing': 0.0, 'random_negative_match_distribution': True, 'random_negative_rate': 0.2, - 'train_data': {}, + 'train_data': { + 'pretrain': True, + 'pretrain_peptides_per_epoch': 1024, + 'pretrain_patience': 10, + }, 'validation_split': 0.1, } diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index 1520502a5fef4084be3b03c76c61c49c0c68e854..3bdd9963ff154d76518cd6d66b24708a25040582 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -291,6 +291,10 @@ def main(args): if args.max_epochs: hyperparameters['max_epochs'] = args.max_epochs + if hyperparameters.get("train_data", {}).get("pretrain", False): + if not args.pretrain_data: + raise ValueError("--pretrain-data is required") + for fold in range(args.ensemble_size): for replicate in range(args.num_replicates): work_dict = { @@ -386,11 +390,10 @@ def train_model( num_replicates, hyperparameters, pretrain_data_filename, - pretrain_patience=1, - verbose=None, - progress_print_interval=None, - predictor=None, - save_to=None): + verbose, + progress_print_interval, + predictor, + save_to): if predictor is None: predictor = Class1AffinityPredictor() @@ -424,7 +427,7 @@ def train_model( replicate_num + 1, num_replicates)) - if pretrain_data_filename: + if hyperparameters.get("train_data", {}).get("pretrain", False): iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding) original_hyperparameters = dict(model.hyperparameters) model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100) @@ -432,6 +435,7 @@ def train_model( model.hyperparameters['validation_split'] = 0.0 model.hyperparameters['random_negative_rate'] = 0.0 model.hyperparameters['random_negative_constant'] = 0 + pretrain_patience = hyperparameters["train_data"]["pretrain_patience"] scores = [] best_score = float('inf') best_score_epoch = 0