From 6102b31e86c0e445e28a76fdcd98f1e7f4e85c22 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Mon, 24 Sep 2018 15:20:12 -0400 Subject: [PATCH] fix --- mhcflurry/allele_encoding.py | 22 +++++++++++++--------- mhcflurry/allele_encoding_transforms.py | 6 ++++-- mhcflurry/class1_affinity_predictor.py | 18 +++++++++++++----- test/test_allele_encoding.py | 3 ++- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py index e7d47cf0..35a124db 100644 --- a/mhcflurry/allele_encoding.py +++ b/mhcflurry/allele_encoding.py @@ -74,13 +74,16 @@ class AlleleEncoding(object): "allele_representations", encoding_name) if cache_key not in self.encoding_cache: - if callable(encoding_name): - vector_encoded = encoding_name(self) - assert len(vector_encoded)== len(self.allele_to_sequence) - elif ":" in encoding_name: - # Apply transform - (transform_name, rest) = encoding_name.split(":", 2) - preliminary_encoded = self.allele_representations(rest) + if ":" in encoding_name: + # Transform + pieces = encoding_name.split(":", 3) + if pieces[0] != "transform": + raise RuntimeError( + "Expected 'transform' but saw: %s" % pieces[0]) + if len(pieces) == 1: + raise RuntimeError("Expected: 'transform:<name>[:argument]") + transform_name = pieces[1] + argument = None if len(pieces) == 2 else pieces[2] try: transform = self.transforms[transform_name] except KeyError: @@ -88,8 +91,9 @@ class AlleleEncoding(object): "Unsupported transform: %s. Supported transforms: %s" % ( transform_name, " ".join(self.transforms) if self.transforms else "(none)")) - - vector_encoded = transform.transform(preliminary_encoded) + vector_encoded = ( + transform.transform(self) if argument is None + else transform.transform(self, argument)) else: # No transform. index_encoded_matrix = amino_acid.index_encoding( diff --git a/mhcflurry/allele_encoding_transforms.py b/mhcflurry/allele_encoding_transforms.py index 88b361f5..018d831e 100644 --- a/mhcflurry/allele_encoding_transforms.py +++ b/mhcflurry/allele_encoding_transforms.py @@ -5,7 +5,7 @@ import sklearn.decomposition class AlleleEncodingTransform(object): - def transform(self, data): + def transform(self, allele_encoding, argument=None): raise NotImplementedError() def get_fit(self): @@ -61,7 +61,9 @@ class PCATransform(AlleleEncodingTransform): self.model.mean_ = fit["mean"] self.model.components_ = fit["components"] - def transform(self, allele_representations): + def transform(self, allele_encoding, underlying_representation): + allele_representations = allele_encoding.allele_representations( + underlying_representation) if not self.is_fit(): self.fit(allele_representations) flattened = allele_representations.reshape( diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index 1e9ad71f..87c00eda 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -816,7 +816,8 @@ class Class1AffinityPredictor(object): alleles=None, allele=None, throw=True, - centrality_measure=DEFAULT_CENTRALITY_MEASURE): + centrality_measure=DEFAULT_CENTRALITY_MEASURE, + model_kwargs={}): """ Predict nM binding affinities. @@ -840,6 +841,8 @@ class Class1AffinityPredictor(object): centrality_measure : string or callable Measure of central tendency to use to combine predictions in the ensemble. Options include: mean, median, robust_mean. + model_kwargs : dict + Additional keyword arguments to pass to Class1NeuralNetwork.predict Returns ------- @@ -853,6 +856,7 @@ class Class1AffinityPredictor(object): include_percentile_ranks=False, include_confidence_intervals=False, centrality_measure=centrality_measure, + model_kwargs=model_kwargs ) return df.prediction.values @@ -865,7 +869,8 @@ class Class1AffinityPredictor(object): include_individual_model_predictions=False, include_percentile_ranks=True, include_confidence_intervals=True, - centrality_measure=DEFAULT_CENTRALITY_MEASURE): + centrality_measure=DEFAULT_CENTRALITY_MEASURE, + model_kwargs={}): """ Predict nM binding affinities. Gives more detailed output than `predict` method, including 5-95% prediction intervals. @@ -897,6 +902,8 @@ class Class1AffinityPredictor(object): centrality_measure : string or callable Measure of central tendency to use to combine predictions in the ensemble. Options include: mean, median, robust_mean. + model_kwargs : dict + Additional keyword arguments to pass to Class1NeuralNetwork.predict Returns ------- @@ -1002,7 +1009,8 @@ class Class1AffinityPredictor(object): for (i, model) in enumerate(self.class1_pan_allele_models): predictions_array[mask, i] = model.predict( masked_peptides, - allele_encoding=masked_allele_encoding) + allele_encoding=masked_allele_encoding, + **model_kwargs) if self.allele_to_allele_specific_models: unsupported_alleles = [ @@ -1031,7 +1039,7 @@ class Class1AffinityPredictor(object): # Common case optimization for (i, model) in enumerate(models): predictions_array[:, num_pan_models + i] = ( - model.predict(peptides)) + model.predict(peptides, **model_kwargs)) elif mask.sum() > 0: peptides_for_allele = EncodableSequences.create( df.ix[mask].peptide.values) @@ -1039,7 +1047,7 @@ class Class1AffinityPredictor(object): predictions_array[ mask, num_pan_models + i, - ] = model.predict(peptides_for_allele) + ] = model.predict(peptides_for_allele, **model_kwargs) if callable(centrality_measure): centrality_function = centrality_measure diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py index ef526e2d..a6c1a972 100644 --- a/test/test_allele_encoding.py +++ b/test/test_allele_encoding.py @@ -47,7 +47,8 @@ def test_pca(): "A*02:03": "AE", } ) - encoded1 = encoding.fixed_length_vector_encoded_sequences("pca:BLOSUM62") + encoded1 = encoding.fixed_length_vector_encoded_sequences( + "transform:pca:BLOSUM62") numpy.testing.assert_array_equal(encoded1[0], encoded1[2]) assert not numpy.array_equal(encoded1[0], encoded1[1]) -- GitLab