From 34e2343e28d090a82c8d6a6b57b361aaf8d280a3 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 12 Jun 2019 15:38:02 -0400
Subject: [PATCH] fix

---
 mhcflurry/class1_affinity_predictor.py       | 16 ++++++++++++++--
 mhcflurry/train_pan_allele_models_command.py |  2 +-
 2 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 59e9b7ac..d6410b49 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -1032,8 +1032,20 @@ class Class1AffinityPredictor(object):
                 logging.warning(msg)
                 if throw:
                     raise ValueError(msg)
-            mask = df.supported_peptide_length
-            if mask.sum() > 0:
+            mask = df.supported_peptide_length & (
+                ~df.normalized_allele.isin(unsupported_alleles))
+            if mask is None or mask.all():
+                # Common case optimization
+                allele_encoding = AlleleEncoding(
+                    df.normalized_allele,
+                    borrow_from=master_allele_encoding)
+                for (i, model) in enumerate(self.class1_pan_allele_models):
+                    predictions_array[:, i] = (
+                        model.predict(
+                            peptides,
+                            allele_encoding=allele_encoding,
+                            **model_kwargs))
+            elif mask.sum() > 0:
                 masked_allele_encoding = AlleleEncoding(
                     df.loc[mask].normalized_allele,
                     borrow_from=master_allele_encoding)
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index 2c16f263..88ca3e2a 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -417,7 +417,7 @@ def train_model(
         save_to):
 
     import keras.backend as K
-    K.clear_session()
+    K.clear_session()  # release memory
 
     df = GLOBAL_DATA["train_data"]
     folds_df = GLOBAL_DATA["folds_df"]
-- 
GitLab