diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py index 86bddf273934cf7ca2d03d9e71018d68f0f0e977..26742d45a6251566a150e8e6d7df1c0f16061cbf 100644 --- a/mhcflurry/class1_presentation_neural_network.py +++ b/mhcflurry/class1_presentation_neural_network.py @@ -182,12 +182,15 @@ class Class1PresentationNeuralNetwork(object): # Apply allele mask: zero out all outputs corresponding to alleles # with the special index 0. + def alleles_to_mask(x): + import keras.backend as K + return K.cast(K.expand_dims(K.not_equal(x, 0.0)), "float32") + + allele_mask = Lambda(alleles_to_mask, name="allele_mask")(input_alleles) + affinity_predictor_matrix_output = Multiply( name="affinity_matrix_output")([ - Lambda( - lambda x: K.cast( - K.expand_dims(K.not_equal(x, 0.0)), - "float32"))(input_alleles), + allele_mask, pre_mask_affinity_predictor_matrix_output ]) @@ -241,10 +244,7 @@ class Class1PresentationNeuralNetwork(object): # Apply allele mask: zero out all outputs corresponding to alleles # with the special index 0. presentation_output = Multiply(name="presentation_output")([ - Lambda( - lambda x: K.cast( - K.expand_dims(K.not_equal(x, 0.0)), - "float32"))(input_alleles), + allele_mask, pre_mask_presentation_output ]) @@ -711,14 +711,8 @@ class Class1PresentationNeuralNetwork(object): dict """ - result = dict(self.__dict__) - result['network'] = None - result['network_json'] = None - result['network_weights'] = None - - if self.network is not None: - result['network_json'] = self.network.to_json() - result['network_weights'] = self.network.get_weights() + result = self.get_config() + result['network_weights'] = self.get_weights() return result def __setstate__(self, state): @@ -734,6 +728,19 @@ class Class1PresentationNeuralNetwork(object): if network_weights is not None: self.network.set_weights(network_weights) + def get_weights(self): + """ + Get the network weights + + Returns + ------- + list of numpy.array giving weights for each layer or None if there is no + network + """ + if self.network is None: + return None + return self.network.get_weights() + def get_config(self): """ serialize to a dict all attributes except model weights @@ -743,11 +750,9 @@ class Class1PresentationNeuralNetwork(object): dict """ result = dict(self.__dict__) - result['network'] = None - result['network_weights'] = None + del result['network'] result['network_json'] = None if self.network: - result['network_weights'] = self.network.get_weights() result['network_json'] = self.network.to_json() return result @@ -771,12 +776,11 @@ class Class1PresentationNeuralNetwork(object): config = dict(config) instance = cls(**config.pop('hyperparameters')) network_json = config.pop('network_json') - network_weights = config.pop('network_weights') instance.__dict__.update(config) assert instance.network is None if network_json is not None: import keras.models instance.network = keras.models.model_from_json(network_json) - if network_weights is not None: - instance.network.set_weights(network_weights) + if weights is not None: + instance.network.set_weights(weights) return instance \ No newline at end of file diff --git a/mhcflurry/class1_presentation_predictor.py b/mhcflurry/class1_presentation_predictor.py index a2c19879184b0151f05ffbc7db0680880178b75b..42c317e9ef636f59edf127ac45ca5ad91a95fb7c 100644 --- a/mhcflurry/class1_presentation_predictor.py +++ b/mhcflurry/class1_presentation_predictor.py @@ -61,9 +61,10 @@ class Class1PresentationPredictor(object): if self._manifest_df is None: rows = [] for (i, model) in enumerate(self.models): + model_config = model.get_config() rows.append(( self.model_name(i), - json.dumps(model.get_config()), + json.dumps(model_config), model )) self._manifest_df = pandas.DataFrame( @@ -244,8 +245,7 @@ class Class1PresentationPredictor(object): updated_network_config_jsons.append( json.dumps(row.model.get_config())) weights_path = self.weights_path(models_dir, row.model_name) - self.save_weights( - row.model.get_weights(), weights_path) + save_weights(row.model.get_weights(), weights_path) logging.info("Wrote: %s", weights_path) sub_manifest_df["config_json"] = updated_network_config_jsons self.manifest_df.loc[ diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py index f3f654706c729fadeb024275c7f2aa63b8a45500..9c04c618116342846cba659ab379b97e1a52c281 100644 --- a/test/test_class1_presentation_predictor.py +++ b/test/test_class1_presentation_predictor.py @@ -23,6 +23,7 @@ import argparse import sys import copy import os +import tempfile from numpy.testing import assert_, assert_equal, assert_allclose, assert_array_equal from nose.tools import assert_greater, assert_less @@ -95,6 +96,7 @@ def test_basic(): for affinity_network in affinity_predictor.class1_pan_allele_models: presentation_network = Class1PresentationNeuralNetwork() presentation_network.load_from_class1_neural_network(affinity_network) + print(presentation_network.network.get_config()) models.append(presentation_network) predictor = Class1PresentationPredictor( @@ -116,10 +118,25 @@ def test_basic(): merged_df = pandas.merge( df, df2.set_index("peptide"), left_index=True, right_index=True) - assert_array_equal(merged_df["tightest_affinity"], merged_df["affinity"]) - assert_array_equal(merged_df["tightest_affinity"], to_ic50(merged_df["score"])) + #import ipdb ; ipdb.set_trace() + + assert_allclose( + merged_df["tightest_affinity"], merged_df["affinity"], rtol=1e-5) + assert_allclose( + merged_df["tightest_affinity"], to_ic50(merged_df["score"]), rtol=1e-5) assert_array_equal(merged_df["tightest_allele"], merged_df["allele"]) + models_dir = tempfile.mkdtemp("_models") + print(models_dir) + predictor.save(models_dir) + predictor2 = Class1PresentationPredictor.load(models_dir) + + df3 = predictor2.predict_to_dataframe( + peptides=df.index.values, + alleles=alleles) + + assert_array_equal(df2.values, df3.values) + # TODO: test fitting, saving, and loading