From e062613c0c34c2860b0364c99b072c5ab50ccbaf Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 28 Nov 2017 14:59:55 -0500
Subject: [PATCH] better progress printing

---
 .../class1_affinity_predictor.py                          | 6 ++++--
 .../class1_affinity_prediction/class1_neural_network.py   | 3 ---
 .../train_allele_specific_models_command.py               | 8 +++++++-
 3 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py b/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py
index d1e0f02c..ada2e9e5 100644
--- a/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py
@@ -437,7 +437,7 @@ class Class1AffinityPredictor(object):
             if models_dir_for_save:
                 self.save(
                     models_dir_for_save, model_names_to_write=[model_name])
-        return models
+        return models_list
 
     def fit_class1_pan_allele_models(
             self,
@@ -498,8 +498,10 @@ class Class1AffinityPredictor(object):
             verbose=verbose,
             progress_preamble=progress_preamble)
 
+        models_list = []
         for (i, model) in enumerate(models):
             model_name = self.model_name("pan-class1", i)
+            models_list.append(model)  # models is a generator
             self.class1_pan_allele_models.append(model)
             row = pandas.Series(collections.OrderedDict([
                 ("model_name", model_name),
@@ -512,7 +514,7 @@ class Class1AffinityPredictor(object):
             if models_dir_for_save:
                 self.save(
                     models_dir_for_save, model_names_to_write=[model_name])
-        return models
+        return models_list
 
     def _fit_predictors(
             self,
diff --git a/mhcflurry/class1_affinity_prediction/class1_neural_network.py b/mhcflurry/class1_affinity_prediction/class1_neural_network.py
index c52a07e3..bf385c96 100644
--- a/mhcflurry/class1_affinity_prediction/class1_neural_network.py
+++ b/mhcflurry/class1_affinity_prediction/class1_neural_network.py
@@ -643,7 +643,4 @@ class Class1NeuralNetwork(object):
             inputs=inputs,
             outputs=[output],
             name="predictor")
-
-        print("*** ARCHITECTURE ***")
-        model.summary()
         return model
diff --git a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
index cc9f82a6..c763b92d 100644
--- a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
+++ b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
@@ -241,7 +241,7 @@ def process_work(
             allele))
 
     train_data = sub_df.dropna().sample(frac=1.0)
-    predictor.fit_allele_specific_predictors(
+    (model,) = predictor.fit_allele_specific_predictors(
         n_models=1,
         architecture_hyperparameters=hyperparameters,
         allele=allele,
@@ -251,6 +251,12 @@ def process_work(
         progress_preamble=progress_preamble,
         verbose=verbose)
 
+    if allele_num == 0 and model_group == 0:
+        # For the first model for the first allele, print the architecture.
+        print("*** ARCHITECTURE FOR HYPERPARAMETER SET %d***" %
+              (hyperparameter_set_num + 1))
+        model.network(borrow=True).summary()
+
     return predictor
 
 
-- 
GitLab