From 9d197d67e47adb9ef691ea951e5bc0efe884ece9 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 5 Dec 2019 19:33:51 -0500
Subject: [PATCH] tests

---
 mhcflurry/custom_loss.py                   |  5 +++--
 test/test_class1_presentation_predictor.py | 18 ++++++++++--------
 2 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index 12b2c669..3660ae32 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -291,8 +291,9 @@ class MultiallelicMassSpecLoss(Loss):
         pos_max = tf.reduce_max(pos, axis=1)
         neg = tf.boolean_mask(y_pred, tf.math.equal(y_true, 0.0))
         term = tf.reshape(neg, (-1, 1)) - pos_max + self.delta
-        result = tf.reduce_sum(tf.maximum(0.0, term) ** 2) / tf.cast(
-            tf.size(term), tf.float32) * self.multiplier
+        result = tf.math.divide_no_nan(
+            tf.reduce_sum(tf.maximum(0.0, term) ** 2),
+            tf.cast(tf.size(term), tf.float32)) * self.multiplier
         return result
 
 
diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py
index eeb574d3..37837571 100644
--- a/test/test_class1_presentation_predictor.py
+++ b/test/test_class1_presentation_predictor.py
@@ -91,7 +91,7 @@ def teardown():
     cleanup()
 
 
-def test_basic():
+def Xtest_basic():
     affinity_predictor = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
     models = []
     for affinity_network in affinity_predictor.class1_pan_allele_models:
@@ -485,9 +485,11 @@ def Xtest_real_data_multiallelic_refinement(max_epochs=10):
 
     import ipdb ; ipdb.set_trace()
 
+def test_synthetic_allele_refinement_with_affinity_data():
+    test_synthetic_allele_refinement(include_affinities=True)
 
-def Xtest_synthetic_allele_refinement_with_affinity_data(
-        max_epochs=10, include_affinities=False):
+
+def test_synthetic_allele_refinement(max_epochs=10, include_affinities=False):
     refine_allele = "HLA-C*01:02"
     alleles = [
         "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
@@ -508,11 +510,10 @@ def Xtest_synthetic_allele_refinement_with_affinity_data(
         "curated_training_data.no_mass_spec.csv.bz2"))
 
     def filter_df(df):
-        df = df.loc[
+        return df.loc[
             (df.allele.isin(alleles)) &
             (df.peptide.str.len() == length)
         ]
-        return df
 
     train_with_ms = filter_df(train_with_ms)
     train_no_ms = filter_df(train_no_ms)
@@ -560,12 +561,13 @@ def Xtest_synthetic_allele_refinement_with_affinity_data(
     (affinity_model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models
     presentation_model = Class1PresentationNeuralNetwork(
         auxiliary_input_features=["gene"],
+        batch_generator_batch_size=1024,
         max_epochs=max_epochs,
-        learning_rate=0.0001,
+        learning_rate=0.001,
         patience=5,
         min_delta=0.0,
-        random_negative_rate=0.0,
-        random_negative_constant=0)  # WHY DOES THIS BREAK WITH RANDOM NEG?
+        random_negative_rate=1.0,
+        random_negative_constant=25)
     presentation_model.load_from_class1_neural_network(affinity_model)
 
     presentation_model = pickle.loads(pickle.dumps(presentation_model))
-- 
GitLab