From 89816bf517f9be22d8ae72559af10dca972963a8 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Wed, 15 May 2019 17:36:51 -0400 Subject: [PATCH] fix --- mhcflurry/train_pan_allele_models_command.py | 21 ++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index 0a202170..994d7c0c 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -406,7 +406,7 @@ def train_model( if pretrain_data_filename: iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding) original_hyperparameters = dict(model.hyperparameters) - model.hyperparameters['minibatch_size'] = len(next(iterator)[-1]) + model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100) model.hyperparameters['max_epochs'] = 1 model.hyperparameters['validation_split'] = 0.0 model.hyperparameters['random_negative_rate'] = 0.0 @@ -443,7 +443,7 @@ def train_model( print( progress_preamble, "PRETRAIN epoch %d [%d values, %0.2f sec]. " - "Score [%0.2f sec.]: %f" % ( + "Score [%0.2f sec.]: %10f" % ( epoch, len(affinities), fit_time, score_time, score)) scores.append(score) @@ -464,8 +464,8 @@ def train_model( model.fit( - train_peptides, - train_data.measurement_value, + peptides=train_peptides, + affinities=train_data.measurement_value.values, allele_encoding=train_alleles, inequalities=( train_data.measurement_inequality.values @@ -476,9 +476,18 @@ def train_model( predictor.class1_pan_allele_models.append(model) predictor.clear_cache() - return predictor + if save_to: + predictor.save(save_to) + + return predictor if __name__ == '__main__': - run() + try: + run() + except Exception as e: + print(e) + import ipdb ; ipdb.set_trace() + raise + -- GitLab