diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py index 12b2c669245e280431278842af856957c74497e9..3660ae329101ed0c9bfa196413f141d5a6529cfc 100644 --- a/mhcflurry/custom_loss.py +++ b/mhcflurry/custom_loss.py @@ -291,8 +291,9 @@ class MultiallelicMassSpecLoss(Loss): pos_max = tf.reduce_max(pos, axis=1) neg = tf.boolean_mask(y_pred, tf.math.equal(y_true, 0.0)) term = tf.reshape(neg, (-1, 1)) - pos_max + self.delta - result = tf.reduce_sum(tf.maximum(0.0, term) ** 2) / tf.cast( - tf.size(term), tf.float32) * self.multiplier + result = tf.math.divide_no_nan( + tf.reduce_sum(tf.maximum(0.0, term) ** 2), + tf.cast(tf.size(term), tf.float32)) * self.multiplier return result diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py index eeb574d35fdd8f7df687ba0e7dd5bb66ce73ffb3..37837571a502aca7191e7459a476e6eb229c8a1b 100644 --- a/test/test_class1_presentation_predictor.py +++ b/test/test_class1_presentation_predictor.py @@ -91,7 +91,7 @@ def teardown(): cleanup() -def test_basic(): +def Xtest_basic(): affinity_predictor = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC models = [] for affinity_network in affinity_predictor.class1_pan_allele_models: @@ -485,9 +485,11 @@ def Xtest_real_data_multiallelic_refinement(max_epochs=10): import ipdb ; ipdb.set_trace() +def test_synthetic_allele_refinement_with_affinity_data(): + test_synthetic_allele_refinement(include_affinities=True) -def Xtest_synthetic_allele_refinement_with_affinity_data( - max_epochs=10, include_affinities=False): + +def test_synthetic_allele_refinement(max_epochs=10, include_affinities=False): refine_allele = "HLA-C*01:02" alleles = [ "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01", @@ -508,11 +510,10 @@ def Xtest_synthetic_allele_refinement_with_affinity_data( "curated_training_data.no_mass_spec.csv.bz2")) def filter_df(df): - df = df.loc[ + return df.loc[ (df.allele.isin(alleles)) & (df.peptide.str.len() == length) ] - return df train_with_ms = filter_df(train_with_ms) train_no_ms = filter_df(train_no_ms) @@ -560,12 +561,13 @@ def Xtest_synthetic_allele_refinement_with_affinity_data( (affinity_model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models presentation_model = Class1PresentationNeuralNetwork( auxiliary_input_features=["gene"], + batch_generator_batch_size=1024, max_epochs=max_epochs, - learning_rate=0.0001, + learning_rate=0.001, patience=5, min_delta=0.0, - random_negative_rate=0.0, - random_negative_constant=0) # WHY DOES THIS BREAK WITH RANDOM NEG? + random_negative_rate=1.0, + random_negative_constant=25) presentation_model.load_from_class1_neural_network(affinity_model) presentation_model = pickle.loads(pickle.dumps(presentation_model))