Skip to content
Snippets Groups Projects
Commit 9d197d67 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

tests

parent 30cc4451
No related merge requests found
...@@ -291,8 +291,9 @@ class MultiallelicMassSpecLoss(Loss): ...@@ -291,8 +291,9 @@ class MultiallelicMassSpecLoss(Loss):
pos_max = tf.reduce_max(pos, axis=1) pos_max = tf.reduce_max(pos, axis=1)
neg = tf.boolean_mask(y_pred, tf.math.equal(y_true, 0.0)) neg = tf.boolean_mask(y_pred, tf.math.equal(y_true, 0.0))
term = tf.reshape(neg, (-1, 1)) - pos_max + self.delta term = tf.reshape(neg, (-1, 1)) - pos_max + self.delta
result = tf.reduce_sum(tf.maximum(0.0, term) ** 2) / tf.cast( result = tf.math.divide_no_nan(
tf.size(term), tf.float32) * self.multiplier tf.reduce_sum(tf.maximum(0.0, term) ** 2),
tf.cast(tf.size(term), tf.float32)) * self.multiplier
return result return result
......
...@@ -91,7 +91,7 @@ def teardown(): ...@@ -91,7 +91,7 @@ def teardown():
cleanup() cleanup()
def test_basic(): def Xtest_basic():
affinity_predictor = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC affinity_predictor = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
models = [] models = []
for affinity_network in affinity_predictor.class1_pan_allele_models: for affinity_network in affinity_predictor.class1_pan_allele_models:
...@@ -485,9 +485,11 @@ def Xtest_real_data_multiallelic_refinement(max_epochs=10): ...@@ -485,9 +485,11 @@ def Xtest_real_data_multiallelic_refinement(max_epochs=10):
import ipdb ; ipdb.set_trace() import ipdb ; ipdb.set_trace()
def test_synthetic_allele_refinement_with_affinity_data():
test_synthetic_allele_refinement(include_affinities=True)
def Xtest_synthetic_allele_refinement_with_affinity_data(
max_epochs=10, include_affinities=False): def test_synthetic_allele_refinement(max_epochs=10, include_affinities=False):
refine_allele = "HLA-C*01:02" refine_allele = "HLA-C*01:02"
alleles = [ alleles = [
"HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01", "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
...@@ -508,11 +510,10 @@ def Xtest_synthetic_allele_refinement_with_affinity_data( ...@@ -508,11 +510,10 @@ def Xtest_synthetic_allele_refinement_with_affinity_data(
"curated_training_data.no_mass_spec.csv.bz2")) "curated_training_data.no_mass_spec.csv.bz2"))
def filter_df(df): def filter_df(df):
df = df.loc[ return df.loc[
(df.allele.isin(alleles)) & (df.allele.isin(alleles)) &
(df.peptide.str.len() == length) (df.peptide.str.len() == length)
] ]
return df
train_with_ms = filter_df(train_with_ms) train_with_ms = filter_df(train_with_ms)
train_no_ms = filter_df(train_no_ms) train_no_ms = filter_df(train_no_ms)
...@@ -560,12 +561,13 @@ def Xtest_synthetic_allele_refinement_with_affinity_data( ...@@ -560,12 +561,13 @@ def Xtest_synthetic_allele_refinement_with_affinity_data(
(affinity_model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models (affinity_model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models
presentation_model = Class1PresentationNeuralNetwork( presentation_model = Class1PresentationNeuralNetwork(
auxiliary_input_features=["gene"], auxiliary_input_features=["gene"],
batch_generator_batch_size=1024,
max_epochs=max_epochs, max_epochs=max_epochs,
learning_rate=0.0001, learning_rate=0.001,
patience=5, patience=5,
min_delta=0.0, min_delta=0.0,
random_negative_rate=0.0, random_negative_rate=1.0,
random_negative_constant=0) # WHY DOES THIS BREAK WITH RANDOM NEG? random_negative_constant=25)
presentation_model.load_from_class1_neural_network(affinity_model) presentation_model.load_from_class1_neural_network(affinity_model)
presentation_model = pickle.loads(pickle.dumps(presentation_model)) presentation_model = pickle.loads(pickle.dumps(presentation_model))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment