Skip to content
Snippets Groups Projects
Commit 0a45437f authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent a9fc0466
No related merge requests found
...@@ -4,10 +4,15 @@ Class I MHC ligand prediction package ...@@ -4,10 +4,15 @@ Class I MHC ligand prediction package
from .class1_affinity_predictor import Class1AffinityPredictor from .class1_affinity_predictor import Class1AffinityPredictor
from .class1_neural_network import Class1NeuralNetwork from .class1_neural_network import Class1NeuralNetwork
from .class1_presentation_predictor import Class1PresentationPredictor
from .class1_presentation_neural_network import Class1PresentationNeuralNetwork
from .version import __version__ from .version import __version__
__all__ = [ __all__ = [
"__version__", "__version__",
"Class1AffinityPredictor", "Class1AffinityPredictor",
"Class1NeuralNetwork", "Class1NeuralNetwork",
"Class1PresentationPredictor",
"Class1PresentationNeuralNetwork",
] ]
...@@ -554,6 +554,26 @@ class Class1PresentationNeuralNetwork(object): ...@@ -554,6 +554,26 @@ class Class1PresentationNeuralNetwork(object):
"peptide" "peptide"
][:num_random_negatives] = random_negative_peptides_encoding ][:num_random_negatives] = random_negative_peptides_encoding
if i == 0:
(train_generator, test_generator) = (
batch_generator.get_train_and_test_generators(
x_dict=x_dict_with_random_negatives,
y_list=[encoded_y1, encoded_y2],
epochs=1))
pairs = [
("train", train_generator, batch_generator.num_train_batches),
("test", test_generator, batch_generator.num_test_batches),
]
for (kind, generator, steps) in pairs:
self.assert_allele_representations_hash(
allele_representations_hash)
metrics = self.network.evaluate_generator(
generator=generator,
steps=steps,
workers=0,
use_multiprocessing=False)
for (key, val) in zip(self.network.metrics_names, metrics):
fit_info["pre_fit_%s_%s" % (kind, key)] = val
(train_generator, test_generator) = ( (train_generator, test_generator) = (
batch_generator.get_train_and_test_generators( batch_generator.get_train_and_test_generators(
x_dict=x_dict_with_random_negatives, x_dict=x_dict_with_random_negatives,
...@@ -631,7 +651,7 @@ class Class1PresentationNeuralNetwork(object): ...@@ -631,7 +651,7 @@ class Class1PresentationNeuralNetwork(object):
return { return {
'batch_generator': batch_generator, 'batch_generator': batch_generator,
'last_x': x_dict_with_random_negatives, 'last_x': x_dict_with_random_negatives,
'last_y': [encoded_y1, encoded_y2, encoded_y2], 'last_y': [encoded_y1, encoded_y2],
'fit_info': fit_info, 'fit_info': fit_info,
} }
......
...@@ -183,10 +183,14 @@ class MSEWithInequalities(Loss): ...@@ -183,10 +183,14 @@ class MSEWithInequalities(Loss):
diff3 *= K.cast(y_true >= 4.0, "float32") diff3 *= K.cast(y_true >= 4.0, "float32")
diff3 *= K.cast(diff3 > 0.0, "float32") diff3 *= K.cast(diff3 > 0.0, "float32")
denominator = K.maximum(
K.sum(K.cast(K.not_equal(y_true, 2.0), "float32"), 0),
1.0)
result = ( result = (
K.sum(K.square(diff1)) + K.sum(K.square(diff1)) +
K.sum(K.square(diff2)) + K.sum(K.square(diff2)) +
K.sum(K.square(diff3))) / K.cast(K.shape(y_pred)[0], "float32") K.sum(K.square(diff3))) / denominator
return result return result
......
...@@ -30,7 +30,7 @@ releases: ...@@ -30,7 +30,7 @@ releases:
default: false default: false
- name: models_class1_pan_refined - name: models_class1_pan_refined
url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191210.tar.bz2 url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191211.tar.bz2
default: false default: false
- name: models_class1_pan_variants - name: models_class1_pan_variants
......
...@@ -91,7 +91,7 @@ def make_motif(presentation_predictor, allele, peptides, frac=0.01): ...@@ -91,7 +91,7 @@ def make_motif(presentation_predictor, allele, peptides, frac=0.01):
# TESTS # TESTS
################################################### ###################################################
def test_synthetic_allele_refinement(include_affinities=False): def test_synthetic_allele_refinement(include_affinities=True):
""" """
Test that in a synthetic example the model is able to learn that HLA-C*01:02 Test that in a synthetic example the model is able to learn that HLA-C*01:02
prefers P at position 3. prefers P at position 3.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment