diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py index 685f2083ade3af8734e2bdd71a94515724943b1f..2c32ad9dbf7aa48153e72b604edac506f91f66e1 100644 --- a/mhcflurry/class1_ligandome_predictor.py +++ b/mhcflurry/class1_ligandome_predictor.py @@ -150,7 +150,7 @@ class Class1LigandomePredictor(object): return network @staticmethod - def loss(y_true, y_pred, delta=0.2, alpha=None): + def loss(y_true, y_pred, sample_weight=None, delta=0.2, alpha=None): """ Loss function for ligandome prediction. """ diff --git a/test/test_network_merging.py b/test/test_network_merging.py index 31c63377ff1b41a2e439f5fbd427a8dfe2094aa6..591d420f086184ff6eb4d889591dde95be3ce0d7 100644 --- a/test/test_network_merging.py +++ b/test/test_network_merging.py @@ -10,7 +10,6 @@ from mhcflurry.downloads import get_path from mhcflurry.testing_utils import cleanup, startup - PAN_ALLELE_PREDICTOR = None diff --git a/test/test_speed.py b/test/test_speed.py index 3f75579b4c1a04b85749a9284b8254c8dbe7e2e3..037ae61c637a34f70d1e33eb5f382587107537c0 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -22,11 +22,11 @@ from mhcflurry.testing_utils import cleanup, startup ALLELE_SPECIFIC_PREDICTOR = None -PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None +PAN_ALLELE_PREDICTOR = None def setup(): - global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC + global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR startup() ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load( get_path("models_class1", "models")) @@ -36,7 +36,7 @@ def setup(): def teardown(): - global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC + global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR ALLELE_SPECIFIC_PREDICTOR = None PAN_ALLELE_PREDICTOR = None cleanup() @@ -97,7 +97,7 @@ def test_speed_allele_specific(profile=False, num=DEFAULT_NUM_PREDICTIONS): def test_speed_pan_allele(profile=False, num=DEFAULT_NUM_PREDICTIONS): - global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC + global PAN_ALLELE_PREDICTOR starts = collections.OrderedDict() timings = collections.OrderedDict() profilers = collections.OrderedDict()