From 6102b31e86c0e445e28a76fdcd98f1e7f4e85c22 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 24 Sep 2018 15:20:12 -0400
Subject: [PATCH] fix

---
 mhcflurry/allele_encoding.py            | 22 +++++++++++++---------
 mhcflurry/allele_encoding_transforms.py |  6 ++++--
 mhcflurry/class1_affinity_predictor.py  | 18 +++++++++++++-----
 test/test_allele_encoding.py            |  3 ++-
 4 files changed, 32 insertions(+), 17 deletions(-)

diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py
index e7d47cf0..35a124db 100644
--- a/mhcflurry/allele_encoding.py
+++ b/mhcflurry/allele_encoding.py
@@ -74,13 +74,16 @@ class AlleleEncoding(object):
             "allele_representations",
             encoding_name)
         if cache_key not in self.encoding_cache:
-            if callable(encoding_name):
-                vector_encoded = encoding_name(self)
-                assert len(vector_encoded)== len(self.allele_to_sequence)
-            elif ":" in encoding_name:
-                # Apply transform
-                (transform_name, rest) = encoding_name.split(":", 2)
-                preliminary_encoded = self.allele_representations(rest)
+            if ":" in encoding_name:
+                # Transform
+                pieces = encoding_name.split(":", 3)
+                if pieces[0] != "transform":
+                    raise RuntimeError(
+                        "Expected 'transform' but saw: %s" % pieces[0])
+                if len(pieces) == 1:
+                    raise RuntimeError("Expected: 'transform:<name>[:argument]")
+                transform_name = pieces[1]
+                argument = None if len(pieces) == 2 else pieces[2]
                 try:
                     transform = self.transforms[transform_name]
                 except KeyError:
@@ -88,8 +91,9 @@ class AlleleEncoding(object):
                         "Unsupported transform: %s. Supported transforms: %s" % (
                             transform_name,
                             " ".join(self.transforms) if self.transforms else "(none)"))
-
-                vector_encoded = transform.transform(preliminary_encoded)
+                vector_encoded = (
+                    transform.transform(self) if argument is None
+                    else transform.transform(self, argument))
             else:
                 # No transform.
                 index_encoded_matrix = amino_acid.index_encoding(
diff --git a/mhcflurry/allele_encoding_transforms.py b/mhcflurry/allele_encoding_transforms.py
index 88b361f5..018d831e 100644
--- a/mhcflurry/allele_encoding_transforms.py
+++ b/mhcflurry/allele_encoding_transforms.py
@@ -5,7 +5,7 @@ import sklearn.decomposition
 
 
 class AlleleEncodingTransform(object):
-    def transform(self, data):
+    def transform(self, allele_encoding, argument=None):
         raise NotImplementedError()
 
     def get_fit(self):
@@ -61,7 +61,9 @@ class PCATransform(AlleleEncodingTransform):
         self.model.mean_ = fit["mean"]
         self.model.components_ = fit["components"]
 
-    def transform(self, allele_representations):
+    def transform(self, allele_encoding, underlying_representation):
+        allele_representations = allele_encoding.allele_representations(
+            underlying_representation)
         if not self.is_fit():
             self.fit(allele_representations)
         flattened = allele_representations.reshape(
diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 1e9ad71f..87c00eda 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -816,7 +816,8 @@ class Class1AffinityPredictor(object):
             alleles=None,
             allele=None,
             throw=True,
-            centrality_measure=DEFAULT_CENTRALITY_MEASURE):
+            centrality_measure=DEFAULT_CENTRALITY_MEASURE,
+            model_kwargs={}):
         """
         Predict nM binding affinities.
         
@@ -840,6 +841,8 @@ class Class1AffinityPredictor(object):
         centrality_measure : string or callable
             Measure of central tendency to use to combine predictions in the
             ensemble. Options include: mean, median, robust_mean.
+        model_kwargs : dict
+            Additional keyword arguments to pass to Class1NeuralNetwork.predict
 
         Returns
         -------
@@ -853,6 +856,7 @@ class Class1AffinityPredictor(object):
             include_percentile_ranks=False,
             include_confidence_intervals=False,
             centrality_measure=centrality_measure,
+            model_kwargs=model_kwargs
         )
         return df.prediction.values
 
@@ -865,7 +869,8 @@ class Class1AffinityPredictor(object):
             include_individual_model_predictions=False,
             include_percentile_ranks=True,
             include_confidence_intervals=True,
-            centrality_measure=DEFAULT_CENTRALITY_MEASURE):
+            centrality_measure=DEFAULT_CENTRALITY_MEASURE,
+            model_kwargs={}):
         """
         Predict nM binding affinities. Gives more detailed output than `predict`
         method, including 5-95% prediction intervals.
@@ -897,6 +902,8 @@ class Class1AffinityPredictor(object):
         centrality_measure : string or callable
             Measure of central tendency to use to combine predictions in the
             ensemble. Options include: mean, median, robust_mean.
+        model_kwargs : dict
+            Additional keyword arguments to pass to Class1NeuralNetwork.predict
 
         Returns
         -------
@@ -1002,7 +1009,8 @@ class Class1AffinityPredictor(object):
                 for (i, model) in enumerate(self.class1_pan_allele_models):
                     predictions_array[mask, i] = model.predict(
                         masked_peptides,
-                        allele_encoding=masked_allele_encoding)
+                        allele_encoding=masked_allele_encoding,
+                        **model_kwargs)
 
         if self.allele_to_allele_specific_models:
             unsupported_alleles = [
@@ -1031,7 +1039,7 @@ class Class1AffinityPredictor(object):
                     # Common case optimization
                     for (i, model) in enumerate(models):
                         predictions_array[:, num_pan_models + i] = (
-                            model.predict(peptides))
+                            model.predict(peptides, **model_kwargs))
                 elif mask.sum() > 0:
                     peptides_for_allele = EncodableSequences.create(
                         df.ix[mask].peptide.values)
@@ -1039,7 +1047,7 @@ class Class1AffinityPredictor(object):
                         predictions_array[
                             mask,
                             num_pan_models + i,
-                        ] = model.predict(peptides_for_allele)
+                        ] = model.predict(peptides_for_allele, **model_kwargs)
 
         if callable(centrality_measure):
             centrality_function = centrality_measure
diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py
index ef526e2d..a6c1a972 100644
--- a/test/test_allele_encoding.py
+++ b/test/test_allele_encoding.py
@@ -47,7 +47,8 @@ def test_pca():
             "A*02:03": "AE",
         }
     )
-    encoded1 = encoding.fixed_length_vector_encoded_sequences("pca:BLOSUM62")
+    encoded1 = encoding.fixed_length_vector_encoded_sequences(
+        "transform:pca:BLOSUM62")
 
     numpy.testing.assert_array_equal(encoded1[0], encoded1[2])
     assert not numpy.array_equal(encoded1[0], encoded1[1])
-- 
GitLab