From eeabbafdf6a7593f3aaba38a4e73ce112aaa6393 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Thu, 16 May 2019 11:45:06 -0400 Subject: [PATCH] fix --- .../models_class1_pan_unselected/GENERATE.sh | 2 +- .../generate_hyperparameters.py | 6 +++++- mhcflurry/train_pan_allele_models_command.py | 16 ++++++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/downloads-generation/models_class1_pan_unselected/GENERATE.sh b/downloads-generation/models_class1_pan_unselected/GENERATE.sh index 8475d01d..8724ed3f 100755 --- a/downloads-generation/models_class1_pan_unselected/GENERATE.sh +++ b/downloads-generation/models_class1_pan_unselected/GENERATE.sh @@ -5,7 +5,7 @@ set -e set -x -DOWNLOAD_NAME=models_class1_unselected +DOWNLOAD_NAME=models_class1_pan SCRATCH_DIR=${TMPDIR-/tmp}/mhcflurry-downloads-generation SCRIPT_ABSOLUTE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")" SCRIPT_DIR=$(dirname "$SCRIPT_ABSOLUTE_PATH") diff --git a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py index 6a9c2782..bfbd4432 100644 --- a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py +++ b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py @@ -39,7 +39,11 @@ base_hyperparameters = { 'random_negative_distribution_smoothing': 0.0, 'random_negative_match_distribution': True, 'random_negative_rate': 0.2, - 'train_data': {}, + 'train_data': { + 'pretrain': True, + 'pretrain_peptides_per_epoch': 1024, + 'pretrain_patience': 10, + }, 'validation_split': 0.1, } diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index 1520502a..3bdd9963 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -291,6 +291,10 @@ def main(args): if args.max_epochs: hyperparameters['max_epochs'] = args.max_epochs + if hyperparameters.get("train_data", {}).get("pretrain", False): + if not args.pretrain_data: + raise ValueError("--pretrain-data is required") + for fold in range(args.ensemble_size): for replicate in range(args.num_replicates): work_dict = { @@ -386,11 +390,10 @@ def train_model( num_replicates, hyperparameters, pretrain_data_filename, - pretrain_patience=1, - verbose=None, - progress_print_interval=None, - predictor=None, - save_to=None): + verbose, + progress_print_interval, + predictor, + save_to): if predictor is None: predictor = Class1AffinityPredictor() @@ -424,7 +427,7 @@ def train_model( replicate_num + 1, num_replicates)) - if pretrain_data_filename: + if hyperparameters.get("train_data", {}).get("pretrain", False): iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding) original_hyperparameters = dict(model.hyperparameters) model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100) @@ -432,6 +435,7 @@ def train_model( model.hyperparameters['validation_split'] = 0.0 model.hyperparameters['random_negative_rate'] = 0.0 model.hyperparameters['random_negative_constant'] = 0 + pretrain_patience = hyperparameters["train_data"]["pretrain_patience"] scores = [] best_score = float('inf') best_score_epoch = 0 -- GitLab