Skip to content
Snippets Groups Projects
Commit a9bac0c6 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fixes

parent ba5620ea
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ IMPUTE_HYPERPARAMETER_DEFAULTS = HyperparameterDefaults(
impute_method='mice',
impute_min_observations_per_peptide=1,
impute_min_observations_per_allele=1,
imputer_args=[{"n_burn_in": 5, "n_imputations": 25}])
imputer_args={"n_burn_in": 5, "n_imputations": 25})
HYPERPARAMETER_DEFAULTS = (
HyperparameterDefaults(
......@@ -64,14 +64,13 @@ def fit_and_test(
fold_num,
train_measurement_collection_broadcast,
test_measurement_collection_broadcast,
alleles,
hyperparameters_list):
allele_and_hyperparameter_pairs):
assert len(train_measurement_collection_broadcast.value.df) > 0
assert len(test_measurement_collection_broadcast.value.df) > 0
results = []
for all_hyperparameters in hyperparameters_list:
for (allele, all_hyperparameters) in allele_and_hyperparameter_pairs:
measurement_collection_hyperparameters = (
MEASUREMENT_COLLECTION_HYPERPARAMETER_DEFAULTS.subselect(
all_hyperparameters))
......@@ -79,34 +78,33 @@ def fit_and_test(
Class1BindingPredictor.hyperparameter_defaults.subselect(
all_hyperparameters))
for allele in alleles:
train_dataset = (
train_measurement_collection_broadcast
.value
.select_allele(allele)
.to_dataset(**measurement_collection_hyperparameters))
test_dataset = (
test_measurement_collection_broadcast
.value
.select_allele(allele)
.to_dataset(**measurement_collection_hyperparameters))
assert len(train_dataset) > 0
assert len(test_dataset) > 0
model = Class1BindingPredictor(**model_hyperparameters)
model.fit_dataset(train_dataset)
predictions = model.predict(test_dataset.peptides)
scores = scoring.make_scores(
test_dataset.affinities, predictions)
results.append({
'fold_num': fold_num,
'allele': allele,
'hyperparameters': all_hyperparameters,
'model': model,
'scores': scores
})
train_dataset = (
train_measurement_collection_broadcast
.value
.select_allele(allele)
.to_dataset(**measurement_collection_hyperparameters))
test_dataset = (
test_measurement_collection_broadcast
.value
.select_allele(allele)
.to_dataset(**measurement_collection_hyperparameters))
assert len(train_dataset) > 0
assert len(test_dataset) > 0
model = Class1BindingPredictor(**model_hyperparameters)
model.fit_dataset(train_dataset)
predictions = model.predict(test_dataset.peptides)
scores = scoring.make_scores(
test_dataset.affinities, predictions)
results.append({
'fold_num': fold_num,
'allele': allele,
'hyperparameters': all_hyperparameters,
'model': model,
'scores': scores
})
return results
......@@ -321,19 +319,16 @@ class Class1EnsembleMultiAllelePredictor(object):
train_broadcast = parallel_backend.broadcast(train_split)
test_broadcast = parallel_backend.broadcast(test_split)
task_alleles = []
task_models = []
task_allele_model_pairs = []
def make_task():
if task_alleles and task_models:
if task_allele_model_pairs:
tasks.append((
fold_num,
train_broadcast,
test_broadcast,
list(task_alleles),
list(task_models)))
task_alleles.clear()
task_models.clear()
list(task_allele_model_pairs)))
task_allele_model_pairs.clear()
assert all(
allele in set(train_split.df.allele.unique())
......@@ -348,12 +343,11 @@ class Class1EnsembleMultiAllelePredictor(object):
for allele in alleles:
for model in self.hyperparameters_to_search:
task_models.append(model)
if len(task_alleles) * len(task_models) > work_per_task:
task_allele_model_pairs.append((allele, model))
if len(task_allele_model_pairs) > work_per_task:
make_task()
make_task()
assert not task_alleles
assert not task_models
assert not task_allele_model_pairs
logging.info(
"Training and scoring models: %d tasks (target was %d), "
......@@ -365,6 +359,7 @@ class Class1EnsembleMultiAllelePredictor(object):
len(self.hyperparameters_to_search),
total_work))
assert len(tasks) > 0
results = parallel_backend.map(call_fit_and_test, tasks)
# fold number -> allele -> best model
......@@ -404,6 +399,7 @@ class Class1EnsembleMultiAllelePredictor(object):
manifest_entry["%s_%s" % (key, sub_key)] = value
manifest_rows.append(manifest_entry)
assert len(manifest_rows) > 0
manifest_df = pandas.DataFrame(manifest_rows)
manifest_df.index = manifest_df.model_name
del manifest_df["model_name"]
......
......@@ -21,12 +21,6 @@ from mhcflurry \
Class1EnsembleMultiAllelePredictor,
HYPERPARAMETER_DEFAULTS)
try:
import kubeface
KUBEFACE_INSTALLED = True
except ImportError:
KUBEFACE_INSTALLED = False
def test_basic():
model_hyperparameters = HYPERPARAMETER_DEFAULTS.models_grid(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment