Skip to content
Snippets Groups Projects
Commit 57edfb49 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 24985286
No related branches found
No related tags found
No related merge requests found
......@@ -35,8 +35,6 @@ install:
env:
global:
- PYTHONHASHSEED=0
- CUDA_VISIBLE_DEVICES="" # for tensorflow
matrix:
- KERAS_BACKEND=tensorflow
script:
# download data and models, then run tests
......
......@@ -418,7 +418,7 @@ class Class1AffinityPredictor(object):
logging.info("Wrote: %s", percent_ranks_path)
@staticmethod
def load(models_dir=None, max_models=None):
def load(models_dir=None, max_models=None, optimization_level=None):
"""
Deserialize a predictor from a directory on disk.
......@@ -431,12 +431,18 @@ class Class1AffinityPredictor(object):
max_models : int, optional
Maximum number of `Class1NeuralNetwork` instances to load
optimization_level : int
If >0, model optimization will be attempted. Defaults to value of
environment variable MHCFLURRY_OPTIMIZATION_LEVEL.
Returns
-------
`Class1AffinityPredictor` instance
"""
if models_dir is None:
models_dir = get_default_class1_models_dir()
if optimization_level is None:
optimization_level = OPTIMIZATION_LEVEL
manifest_path = join(models_dir, "manifest.csv")
manifest_df = pandas.read_csv(manifest_path, nrows=max_models)
......@@ -497,11 +503,11 @@ class Class1AffinityPredictor(object):
manifest_df=manifest_df,
allele_to_percent_rank_transform=allele_to_percent_rank_transform,
)
if OPTIMIZATION_LEVEL >= 1:
logging.info("Optimizing models")
if optimization_level >= 1:
optimized = result.optimize()
logging.info(
"Optimization %s", ("succeeded" if optimized else "failed"))
"Model optimization %s",
"succeeded" if optimized else "not supported for these models")
return result
def optimize(self):
......
......@@ -14,13 +14,16 @@ from mhcflurry.common import random_peptides
from mhcflurry.downloads import get_path
ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
get_path("models_class1", "models"))
get_path("models_class1", "models"), optimization_level=0)
PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
get_path("models_class1_pan", "models.with_mass_spec"))
get_path("models_class1_pan", "models.with_mass_spec"),
optimization_level=0)
def test_merge():
assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) > 1
peptides = random_peptides(100, length=9)
peptides.extend(random_peptides(100, length=10))
peptides = pandas.Series(peptides).sample(frac=1.0)
......@@ -40,3 +43,4 @@ def test_merge():
)
predictions2 = merged_predictor.predict(peptides=peptides, alleles=alleles)
numpy.testing.assert_allclose(predictions1, predictions2, atol=0.1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment