From dec7c6dc51f2ec2495c27893394730f8f2c72edc Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 15 Feb 2017 16:24:56 -0500
Subject: [PATCH] more tests

---
 .../presentation_component_model.py           |  9 +--
 .../presentation_model.py                     | 14 ++--
 mhcflurry/common.py                           |  4 +-
 test/test_antigen_presentation.py             | 67 ++++++++++++++++---
 4 files changed, 73 insertions(+), 21 deletions(-)

diff --git a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
index 52fe54a2..eb20fbd5 100644
--- a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
+++ b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
@@ -22,10 +22,11 @@ def cache_dict_for_policy(policy):
 
 class PresentationComponentModel(object):
     '''
-    Base class for inputs to a "final model". By "final model" we mean
-    something like a logistic regression model that takes as input expression,
-    mhc binding affinity, cleavage, etc. By "final model input" we mean
-    the predictors for expression, mhc binding affinity, etc.
+    Base class for component models to a presentation model.
+
+    The component models are things like mhc binding affinity and cleavage,
+    and the presentation model is typically a logistic regression model
+    over these.
     '''
     def __init__(
             self, fit_cache_policy="weak", predictions_cache_policy="weak"):
diff --git a/mhcflurry/antigen_presentation/presentation_model.py b/mhcflurry/antigen_presentation/presentation_model.py
index fe8664e9..73c1f688 100644
--- a/mhcflurry/antigen_presentation/presentation_model.py
+++ b/mhcflurry/antigen_presentation/presentation_model.py
@@ -52,7 +52,7 @@ def build_presentation_models(term_dict, formulas, **kwargs):
             (term_inputs, term_expressions) = term_dict[name]
             inputs.extend(term_inputs)
             expressions.extend(term_expressions)
-        assert len(set(expressions)) == len(expressions)
+        assert len(set(expressions)) == len(expressions), expressions
         presentation_model = PresentationModel(
             inputs,
             expressions,
@@ -301,7 +301,11 @@ class PresentationModel(object):
         for expression in self.feature_expressions:
             # We use numpy module as globals here so math functions
             # like log, log1p, exp, are in scope.
-            values = eval(expression, numpy.__dict__, input_df)
+            try:
+                values = eval(expression, numpy.__dict__, input_df)
+            except SyntaxError:
+                logging.error("Syntax error in expression: %s" % expression)
+                raise
             assert len(values) == len(input_df), expression
             if hasattr(values, 'values'):
                 values = values.values
@@ -419,7 +423,6 @@ class PresentationModel(object):
         assert 'hit' in peptides_df.columns
 
         peptides_df["prediction"] = self.predict(peptides_df)
-        # print(sorted(peptides_df.prediction[peptides_df.hit].values))
         top_n = float(peptides_df.hit.sum())
 
         if not include_hit_indices:
@@ -503,9 +506,8 @@ class PresentationModel(object):
                 "those of this PresentationModel: '%s'" % (
                     feature_expressions, self.feature_expressions))
         assert not fit, "Unhandled data in fit: %s" % fit
-        assert len(model_input_fits) == (
-            2 if self.component_models_require_fitting else 1), (
-            "Wrong length: %s" % model_input_fits)
+        assert (
+            len(model_input_fits) == len(self.presentation_models_predictors))
 
         self.trained_component_models = []
         for model_input_fits_for_fold in model_input_fits:
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index 887d73b3..1a8c5a71 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -201,10 +201,10 @@ def describe_nulls(df, related_df_with_same_index_to_describe=None):
 
 def raise_or_debug(exception):
     """
-    Raise the exception unless the NEON_DEBUG environment variable is set,
+    Raise the exception unless the MHCFLURRY_DEBUG environment variable is set,
     in which case drop into ipython debugger (ipdb).
     """
-    if environ.get("NEON_DEBUG"):
+    if environ.get("MHCFLURRY_DEBUG"):
         import ipdb
         ipdb.set_trace()
     raise exception
diff --git a/test/test_antigen_presentation.py b/test/test_antigen_presentation.py
index df597226..9d2e0674 100644
--- a/test/test_antigen_presentation.py
+++ b/test/test_antigen_presentation.py
@@ -1,7 +1,9 @@
+import pickle
+
 from nose.tools import eq_, assert_less
 
 import numpy
-from numpy.testing import assert_allclose
+from numpy.testing import assert_allclose, assert_array_equal
 import pandas
 from mhcflurry import amino_acid
 from mhcflurry.antigen_presentation import (
@@ -10,6 +12,8 @@ from mhcflurry.antigen_presentation import (
     presentation_component_models,
     presentation_model)
 
+from mhcflurry.amino_acid import common_amino_acid_letters
+
 
 ######################
 # Helper functions
@@ -31,7 +35,8 @@ def hit_criterion(experiment_name, peptide):
 ######################
 # Small test dataset
 
-PEPTIDES = make_random_peptides(100, 9)
+PEPTIDES = make_random_peptides(1000, 9)
+OTHER_PEPTIDES = make_random_peptides(1000, 9)
 
 TRANSCRIPTS = [
     "transcript-%d" % i
@@ -66,6 +71,15 @@ PEPTIDES_DF["hit"] = [
 ]
 print("Hit rate: %0.3f" % PEPTIDES_DF.hit.mean())
 
+AA_COMPOSITION_DF = pandas.DataFrame({
+    'peptide': sorted(set(PEPTIDES).union(set(OTHER_PEPTIDES))),
+})
+for aa in sorted(common_amino_acid_letters):
+    AA_COMPOSITION_DF[aa] = AA_COMPOSITION_DF.peptide.str.count(aa)
+
+AA_COMPOSITION_DF.index = AA_COMPOSITION_DF.peptide
+del AA_COMPOSITION_DF['peptide']
+
 HITS_DF = PEPTIDES_DF.ix[PEPTIDES_DF.hit].reset_index().copy()
 del HITS_DF["hit"]
 
@@ -89,7 +103,7 @@ def test_mhcflurry_trained_on_hits():
         experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP,
         transcripts=TRANSCIPTS_DF,
         peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF,
-        random_peptides_for_percent_rank=make_random_peptides(10000, 9),
+        random_peptides_for_percent_rank=OTHER_PEPTIDES,
     )
     mhcflurry_model.fit(HITS_DF)
 
@@ -112,32 +126,67 @@ def test_presentation_model():
         experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP,
         transcripts=TRANSCIPTS_DF,
         peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF,
-        random_peptides_for_percent_rank=make_random_peptides(1000, 9),
+        random_peptides_for_percent_rank=OTHER_PEPTIDES,
     )
 
+    aa_content_model = (
+        presentation_component_models.FixedPerPeptideQuantity(
+            "aa composition",
+            numpy.log1p(AA_COMPOSITION_DF)))
+
     decoys = decoy_strategies.UniformRandom(
-        make_random_peptides(1000, 9),
+        OTHER_PEPTIDES,
         decoys_per_hit=50)
 
     terms = {
         'A_ms': (
             [mhcflurry_model],
             ["log1p(mhcflurry_basic_affinity)"]),
+        'P': (
+            [aa_content_model],
+            list(AA_COMPOSITION_DF.columns)),
     }
 
-    for kwargs in [{}, {'ensemble_size': 6}]:
+    for kwargs in [{}, {'ensemble_size': 3}]:
         models = presentation_model.build_presentation_models(
             terms,
-            ["A_ms"],
+            ["A_ms", "A_ms + P"],
             decoy_strategy=decoys,
             **kwargs)
-        eq_(len(models), 1)
+        eq_(len(models), 2)
 
-        model = models["A_ms"]
+        unfit_model = models["A_ms"]
+        model = unfit_model.clone()
         model.fit(HITS_DF.ix[HITS_DF.experiment_name == "exp1"])
 
         peptides = PEPTIDES_DF.copy()
         peptides["prediction"] = model.predict(peptides)
+        print(peptides)
+        print("Hit mean", peptides.prediction[peptides.hit].mean())
+        print("Decoy mean", peptides.prediction[~peptides.hit].mean())
+
         assert_less(
             peptides.prediction[~peptides.hit].mean(),
             peptides.prediction[peptides.hit].mean())
+
+        model2 = pickle.loads(pickle.dumps(model))
+        assert_array_equal(
+            model.predict(peptides), model2.predict(peptides))
+
+        model3 = unfit_model.clone()
+        assert not model3.has_been_fit
+        model3.restore_fit(model2.get_fit())
+        assert_array_equal(
+            model.predict(peptides), model3.predict(peptides))
+
+        better_unfit_model = models["A_ms + P"]
+        model = better_unfit_model.clone()
+        model.fit(HITS_DF.ix[HITS_DF.experiment_name == "exp1"])
+        peptides["prediction_better"] = model.predict(peptides)
+
+        assert_less(
+            peptides.prediction_better[~peptides.hit].mean(),
+            peptides.prediction[~peptides.hit].mean())
+        assert_less(
+            peptides.prediction[peptides.hit].mean(),
+            peptides.prediction_better[peptides.hit].mean())
-- 
GitLab