Skip to content
Snippets Groups Projects
Commit 4a85f542 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fixes

parent 3d5ba47e
No related branches found
No related tags found
No related merge requests found
......@@ -182,12 +182,15 @@ class Class1PresentationNeuralNetwork(object):
# Apply allele mask: zero out all outputs corresponding to alleles
# with the special index 0.
def alleles_to_mask(x):
import keras.backend as K
return K.cast(K.expand_dims(K.not_equal(x, 0.0)), "float32")
allele_mask = Lambda(alleles_to_mask, name="allele_mask")(input_alleles)
affinity_predictor_matrix_output = Multiply(
name="affinity_matrix_output")([
Lambda(
lambda x: K.cast(
K.expand_dims(K.not_equal(x, 0.0)),
"float32"))(input_alleles),
allele_mask,
pre_mask_affinity_predictor_matrix_output
])
......@@ -241,10 +244,7 @@ class Class1PresentationNeuralNetwork(object):
# Apply allele mask: zero out all outputs corresponding to alleles
# with the special index 0.
presentation_output = Multiply(name="presentation_output")([
Lambda(
lambda x: K.cast(
K.expand_dims(K.not_equal(x, 0.0)),
"float32"))(input_alleles),
allele_mask,
pre_mask_presentation_output
])
......@@ -711,14 +711,8 @@ class Class1PresentationNeuralNetwork(object):
dict
"""
result = dict(self.__dict__)
result['network'] = None
result['network_json'] = None
result['network_weights'] = None
if self.network is not None:
result['network_json'] = self.network.to_json()
result['network_weights'] = self.network.get_weights()
result = self.get_config()
result['network_weights'] = self.get_weights()
return result
def __setstate__(self, state):
......@@ -734,6 +728,19 @@ class Class1PresentationNeuralNetwork(object):
if network_weights is not None:
self.network.set_weights(network_weights)
def get_weights(self):
"""
Get the network weights
Returns
-------
list of numpy.array giving weights for each layer or None if there is no
network
"""
if self.network is None:
return None
return self.network.get_weights()
def get_config(self):
"""
serialize to a dict all attributes except model weights
......@@ -743,11 +750,9 @@ class Class1PresentationNeuralNetwork(object):
dict
"""
result = dict(self.__dict__)
result['network'] = None
result['network_weights'] = None
del result['network']
result['network_json'] = None
if self.network:
result['network_weights'] = self.network.get_weights()
result['network_json'] = self.network.to_json()
return result
......@@ -771,12 +776,11 @@ class Class1PresentationNeuralNetwork(object):
config = dict(config)
instance = cls(**config.pop('hyperparameters'))
network_json = config.pop('network_json')
network_weights = config.pop('network_weights')
instance.__dict__.update(config)
assert instance.network is None
if network_json is not None:
import keras.models
instance.network = keras.models.model_from_json(network_json)
if network_weights is not None:
instance.network.set_weights(network_weights)
if weights is not None:
instance.network.set_weights(weights)
return instance
\ No newline at end of file
......@@ -61,9 +61,10 @@ class Class1PresentationPredictor(object):
if self._manifest_df is None:
rows = []
for (i, model) in enumerate(self.models):
model_config = model.get_config()
rows.append((
self.model_name(i),
json.dumps(model.get_config()),
json.dumps(model_config),
model
))
self._manifest_df = pandas.DataFrame(
......@@ -244,8 +245,7 @@ class Class1PresentationPredictor(object):
updated_network_config_jsons.append(
json.dumps(row.model.get_config()))
weights_path = self.weights_path(models_dir, row.model_name)
self.save_weights(
row.model.get_weights(), weights_path)
save_weights(row.model.get_weights(), weights_path)
logging.info("Wrote: %s", weights_path)
sub_manifest_df["config_json"] = updated_network_config_jsons
self.manifest_df.loc[
......
......@@ -23,6 +23,7 @@ import argparse
import sys
import copy
import os
import tempfile
from numpy.testing import assert_, assert_equal, assert_allclose, assert_array_equal
from nose.tools import assert_greater, assert_less
......@@ -95,6 +96,7 @@ def test_basic():
for affinity_network in affinity_predictor.class1_pan_allele_models:
presentation_network = Class1PresentationNeuralNetwork()
presentation_network.load_from_class1_neural_network(affinity_network)
print(presentation_network.network.get_config())
models.append(presentation_network)
predictor = Class1PresentationPredictor(
......@@ -116,10 +118,25 @@ def test_basic():
merged_df = pandas.merge(
df, df2.set_index("peptide"), left_index=True, right_index=True)
assert_array_equal(merged_df["tightest_affinity"], merged_df["affinity"])
assert_array_equal(merged_df["tightest_affinity"], to_ic50(merged_df["score"]))
#import ipdb ; ipdb.set_trace()
assert_allclose(
merged_df["tightest_affinity"], merged_df["affinity"], rtol=1e-5)
assert_allclose(
merged_df["tightest_affinity"], to_ic50(merged_df["score"]), rtol=1e-5)
assert_array_equal(merged_df["tightest_allele"], merged_df["allele"])
models_dir = tempfile.mkdtemp("_models")
print(models_dir)
predictor.save(models_dir)
predictor2 = Class1PresentationPredictor.load(models_dir)
df3 = predictor2.predict_to_dataframe(
peptides=df.index.values,
alleles=alleles)
assert_array_equal(df2.values, df3.values)
# TODO: test fitting, saving, and loading
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment