diff --git a/mhcflurry/__init__.py b/mhcflurry/__init__.py index 5d7ceb0fedefd35efb370b702acaf6df0bf799cc..525ce6a1c21feaab8ede3099a8764c1ab11623c6 100644 --- a/mhcflurry/__init__.py +++ b/mhcflurry/__init__.py @@ -4,10 +4,15 @@ Class I MHC ligand prediction package from .class1_affinity_predictor import Class1AffinityPredictor from .class1_neural_network import Class1NeuralNetwork +from .class1_presentation_predictor import Class1PresentationPredictor +from .class1_presentation_neural_network import Class1PresentationNeuralNetwork + from .version import __version__ __all__ = [ "__version__", "Class1AffinityPredictor", "Class1NeuralNetwork", + "Class1PresentationPredictor", + "Class1PresentationNeuralNetwork", ] diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py index a1c495fa2e96b8eb0f5c2d5e0e833f590bb84d8c..58be681341adab8f42461b7920e535433e02a8f0 100644 --- a/mhcflurry/class1_presentation_neural_network.py +++ b/mhcflurry/class1_presentation_neural_network.py @@ -554,6 +554,26 @@ class Class1PresentationNeuralNetwork(object): "peptide" ][:num_random_negatives] = random_negative_peptides_encoding + if i == 0: + (train_generator, test_generator) = ( + batch_generator.get_train_and_test_generators( + x_dict=x_dict_with_random_negatives, + y_list=[encoded_y1, encoded_y2], + epochs=1)) + pairs = [ + ("train", train_generator, batch_generator.num_train_batches), + ("test", test_generator, batch_generator.num_test_batches), + ] + for (kind, generator, steps) in pairs: + self.assert_allele_representations_hash( + allele_representations_hash) + metrics = self.network.evaluate_generator( + generator=generator, + steps=steps, + workers=0, + use_multiprocessing=False) + for (key, val) in zip(self.network.metrics_names, metrics): + fit_info["pre_fit_%s_%s" % (kind, key)] = val (train_generator, test_generator) = ( batch_generator.get_train_and_test_generators( x_dict=x_dict_with_random_negatives, @@ -631,7 +651,7 @@ class Class1PresentationNeuralNetwork(object): return { 'batch_generator': batch_generator, 'last_x': x_dict_with_random_negatives, - 'last_y': [encoded_y1, encoded_y2, encoded_y2], + 'last_y': [encoded_y1, encoded_y2], 'fit_info': fit_info, } diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py index 6cec8a8b549b560dceda8135f7197971e2985dff..53574a9605ca94ec7da7a3d40d5871eb57b28e19 100644 --- a/mhcflurry/custom_loss.py +++ b/mhcflurry/custom_loss.py @@ -183,10 +183,14 @@ class MSEWithInequalities(Loss): diff3 *= K.cast(y_true >= 4.0, "float32") diff3 *= K.cast(diff3 > 0.0, "float32") + denominator = K.maximum( + K.sum(K.cast(K.not_equal(y_true, 2.0), "float32"), 0), + 1.0) + result = ( K.sum(K.square(diff1)) + K.sum(K.square(diff2)) + - K.sum(K.square(diff3))) / K.cast(K.shape(y_pred)[0], "float32") + K.sum(K.square(diff3))) / denominator return result diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml index 5899f29a2d60db7783ce99e13fef89dcc8421a24..5de672d9b7e7d7fff85aac79fbca36ebd57c15ea 100644 --- a/mhcflurry/downloads.yml +++ b/mhcflurry/downloads.yml @@ -30,7 +30,7 @@ releases: default: false - name: models_class1_pan_refined - url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191210.tar.bz2 + url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191211.tar.bz2 default: false - name: models_class1_pan_variants diff --git a/test/test_class1_presentation_neural_network.py b/test/test_class1_presentation_neural_network.py index 91799922d5e807fb836a6f700ab65f23ca13292e..75cd751d69b322700f0180e147b4e6a7d8c3c994 100644 --- a/test/test_class1_presentation_neural_network.py +++ b/test/test_class1_presentation_neural_network.py @@ -91,7 +91,7 @@ def make_motif(presentation_predictor, allele, peptides, frac=0.01): # TESTS ################################################### -def test_synthetic_allele_refinement(include_affinities=False): +def test_synthetic_allele_refinement(include_affinities=True): """ Test that in a synthetic example the model is able to learn that HLA-C*01:02 prefers P at position 3.