diff --git a/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py b/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py index d1e0f02cc210ccdc5ab4fe56d63334cad8b5c2bb..ada2e9e54cbf875a019db00f56c92cc2506e78bf 100644 --- a/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py @@ -437,7 +437,7 @@ class Class1AffinityPredictor(object): if models_dir_for_save: self.save( models_dir_for_save, model_names_to_write=[model_name]) - return models + return models_list def fit_class1_pan_allele_models( self, @@ -498,8 +498,10 @@ class Class1AffinityPredictor(object): verbose=verbose, progress_preamble=progress_preamble) + models_list = [] for (i, model) in enumerate(models): model_name = self.model_name("pan-class1", i) + models_list.append(model) # models is a generator self.class1_pan_allele_models.append(model) row = pandas.Series(collections.OrderedDict([ ("model_name", model_name), @@ -512,7 +514,7 @@ class Class1AffinityPredictor(object): if models_dir_for_save: self.save( models_dir_for_save, model_names_to_write=[model_name]) - return models + return models_list def _fit_predictors( self, diff --git a/mhcflurry/class1_affinity_prediction/class1_neural_network.py b/mhcflurry/class1_affinity_prediction/class1_neural_network.py index c52a07e325fe00962b33e451ef06f94a9276b806..bf385c96906ea2619ab1b4db9b1387a1fb57e529 100644 --- a/mhcflurry/class1_affinity_prediction/class1_neural_network.py +++ b/mhcflurry/class1_affinity_prediction/class1_neural_network.py @@ -643,7 +643,4 @@ class Class1NeuralNetwork(object): inputs=inputs, outputs=[output], name="predictor") - - print("*** ARCHITECTURE ***") - model.summary() return model diff --git a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py index cc9f82a68ac2c4244232ac88bcdc222d903952cf..c763b92dc08566113bc594dc04b0e01f2a23fe56 100644 --- a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py +++ b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py @@ -241,7 +241,7 @@ def process_work( allele)) train_data = sub_df.dropna().sample(frac=1.0) - predictor.fit_allele_specific_predictors( + (model,) = predictor.fit_allele_specific_predictors( n_models=1, architecture_hyperparameters=hyperparameters, allele=allele, @@ -251,6 +251,12 @@ def process_work( progress_preamble=progress_preamble, verbose=verbose) + if allele_num == 0 and model_group == 0: + # For the first model for the first allele, print the architecture. + print("*** ARCHITECTURE FOR HYPERPARAMETER SET %d***" % + (hyperparameter_set_num + 1)) + model.network(borrow=True).summary() + return predictor