From 0a45437f67f71f32ebdef7f150ed976e7fa39178 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Wed, 11 Dec 2019 14:15:52 -0500 Subject: [PATCH] fix --- mhcflurry/__init__.py | 5 +++++ .../class1_presentation_neural_network.py | 22 ++++++++++++++++++- mhcflurry/custom_loss.py | 6 ++++- mhcflurry/downloads.yml | 2 +- ...test_class1_presentation_neural_network.py | 2 +- 5 files changed, 33 insertions(+), 4 deletions(-) diff --git a/mhcflurry/__init__.py b/mhcflurry/__init__.py index 5d7ceb0f..525ce6a1 100644 --- a/mhcflurry/__init__.py +++ b/mhcflurry/__init__.py @@ -4,10 +4,15 @@ Class I MHC ligand prediction package from .class1_affinity_predictor import Class1AffinityPredictor from .class1_neural_network import Class1NeuralNetwork +from .class1_presentation_predictor import Class1PresentationPredictor +from .class1_presentation_neural_network import Class1PresentationNeuralNetwork + from .version import __version__ __all__ = [ "__version__", "Class1AffinityPredictor", "Class1NeuralNetwork", + "Class1PresentationPredictor", + "Class1PresentationNeuralNetwork", ] diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py index a1c495fa..58be6813 100644 --- a/mhcflurry/class1_presentation_neural_network.py +++ b/mhcflurry/class1_presentation_neural_network.py @@ -554,6 +554,26 @@ class Class1PresentationNeuralNetwork(object): "peptide" ][:num_random_negatives] = random_negative_peptides_encoding + if i == 0: + (train_generator, test_generator) = ( + batch_generator.get_train_and_test_generators( + x_dict=x_dict_with_random_negatives, + y_list=[encoded_y1, encoded_y2], + epochs=1)) + pairs = [ + ("train", train_generator, batch_generator.num_train_batches), + ("test", test_generator, batch_generator.num_test_batches), + ] + for (kind, generator, steps) in pairs: + self.assert_allele_representations_hash( + allele_representations_hash) + metrics = self.network.evaluate_generator( + generator=generator, + steps=steps, + workers=0, + use_multiprocessing=False) + for (key, val) in zip(self.network.metrics_names, metrics): + fit_info["pre_fit_%s_%s" % (kind, key)] = val (train_generator, test_generator) = ( batch_generator.get_train_and_test_generators( x_dict=x_dict_with_random_negatives, @@ -631,7 +651,7 @@ class Class1PresentationNeuralNetwork(object): return { 'batch_generator': batch_generator, 'last_x': x_dict_with_random_negatives, - 'last_y': [encoded_y1, encoded_y2, encoded_y2], + 'last_y': [encoded_y1, encoded_y2], 'fit_info': fit_info, } diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py index 6cec8a8b..53574a96 100644 --- a/mhcflurry/custom_loss.py +++ b/mhcflurry/custom_loss.py @@ -183,10 +183,14 @@ class MSEWithInequalities(Loss): diff3 *= K.cast(y_true >= 4.0, "float32") diff3 *= K.cast(diff3 > 0.0, "float32") + denominator = K.maximum( + K.sum(K.cast(K.not_equal(y_true, 2.0), "float32"), 0), + 1.0) + result = ( K.sum(K.square(diff1)) + K.sum(K.square(diff2)) + - K.sum(K.square(diff3))) / K.cast(K.shape(y_pred)[0], "float32") + K.sum(K.square(diff3))) / denominator return result diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml index 5899f29a..5de672d9 100644 --- a/mhcflurry/downloads.yml +++ b/mhcflurry/downloads.yml @@ -30,7 +30,7 @@ releases: default: false - name: models_class1_pan_refined - url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191210.tar.bz2 + url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191211.tar.bz2 default: false - name: models_class1_pan_variants diff --git a/test/test_class1_presentation_neural_network.py b/test/test_class1_presentation_neural_network.py index 91799922..75cd751d 100644 --- a/test/test_class1_presentation_neural_network.py +++ b/test/test_class1_presentation_neural_network.py @@ -91,7 +91,7 @@ def make_motif(presentation_predictor, allele, peptides, frac=0.01): # TESTS ################################################### -def test_synthetic_allele_refinement(include_affinities=False): +def test_synthetic_allele_refinement(include_affinities=True): """ Test that in a synthetic example the model is able to learn that HLA-C*01:02 prefers P at position 3. -- GitLab