diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index 0a202170acebea98fdccc2cfce2bcb55def2c6e8..994d7c0cfa2304ed51252f46e8f965372ff2500a 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -406,7 +406,7 @@ def train_model( if pretrain_data_filename: iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding) original_hyperparameters = dict(model.hyperparameters) - model.hyperparameters['minibatch_size'] = len(next(iterator)[-1]) + model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100) model.hyperparameters['max_epochs'] = 1 model.hyperparameters['validation_split'] = 0.0 model.hyperparameters['random_negative_rate'] = 0.0 @@ -443,7 +443,7 @@ def train_model( print( progress_preamble, "PRETRAIN epoch %d [%d values, %0.2f sec]. " - "Score [%0.2f sec.]: %f" % ( + "Score [%0.2f sec.]: %10f" % ( epoch, len(affinities), fit_time, score_time, score)) scores.append(score) @@ -464,8 +464,8 @@ def train_model( model.fit( - train_peptides, - train_data.measurement_value, + peptides=train_peptides, + affinities=train_data.measurement_value.values, allele_encoding=train_alleles, inequalities=( train_data.measurement_inequality.values @@ -476,9 +476,18 @@ def train_model( predictor.class1_pan_allele_models.append(model) predictor.clear_cache() - return predictor + if save_to: + predictor.save(save_to) + + return predictor if __name__ == '__main__': - run() + try: + run() + except Exception as e: + print(e) + import ipdb ; ipdb.set_trace() + raise +