Skip to content
Snippets Groups Projects
Commit ba5620ea authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Use all data in imputation

parent b80ab1a5
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,7 @@ cp $SCRIPT_DIR/models.py .
python models.py > models.json
time mhcflurry-class1-allele-specific-ensemble-train \
--ensemble-size 8 \
--ensemble-size 16 \
--model-architectures models.json \
--train-data "$(mhcflurry-downloads path data_combined_iedb_kim2014)/combined_human_class1_dataset.csv" \
--min-samples-per-allele 200 \
......
......@@ -21,13 +21,8 @@ To generate this download we run:
```
To debug locally:
```
./GENERATE.sh \
--parallel-backend local-threads \
--target-tasks 1 \
--backend kubernetes \
--storage-prefix gs://kubeface-tim \
--worker-image hammerlab/mhcflurry-misc:latest \
--kubernetes-task-resources-memory-mb 10000 \
--worker-path-prefix venv-py3/bin \
--max-simultaneous-tasks 200 \
--target-tasks 1
```
......@@ -17,8 +17,8 @@ models = HYPERPARAMETER_DEFAULTS.models_grid(
# Arguments specific to imputation method (mice)
{"n_burn_in": 5, "n_imputations": 50, "n_nearest_columns": 25}
],
impute_min_observations_per_peptide=[5],
impute_min_observations_per_allele=[100])
impute_min_observations_per_peptide=[1],
impute_min_observations_per_allele=[1])
sys.stderr.write("Models: %d\n" % len(models))
print(json.dumps(models, indent=4))
......@@ -307,7 +307,7 @@ class Class1EnsembleMultiAllelePredictor(object):
assert len(splits) == self.ensemble_size, len(splits)
alleles = measurement_collection.df.allele.unique()
alleles = set(measurement_collection.df.allele.unique())
total_work = (
len(alleles) *
......@@ -335,8 +335,18 @@ class Class1EnsembleMultiAllelePredictor(object):
task_alleles.clear()
task_models.clear()
assert all(
allele in set(train_split.df.allele.unique())
for allele in alleles), (
"%s not in %s" % (
alleles, set(train_split.df.allele.unique())))
assert all(
allele in set(test_split.df.allele.unique())
for allele in alleles), (
"%s not in %s" % (
alleles, set(test_split.df.allele.unique())))
for allele in alleles:
task_alleles.append(allele)
for model in self.hyperparameters_to_search:
task_models.append(model)
if len(task_alleles) * len(task_models) > work_per_task:
......
......@@ -68,7 +68,8 @@ class MeasurementCollection(object):
"""
assert isinstance(allele, str), type(allele)
assert len(self.df) > 0
assert allele in self.df.allele.unique()
alleles = set(self.df.allele.unique())
assert allele in alleles, "%s not in %s" % (allele, alleles)
return MeasurementCollection(
self.df.ix[self.df.allele == allele],
check=False)
......@@ -100,6 +101,8 @@ class MeasurementCollection(object):
None if random_state is None
else random_state + len(results)))
stratification_groups = self.df.allele + self.df.measurement_type
assert len(stratification_groups.unique()) > 1, (
stratification_groups.unique())
(indices1, indices2) = next(
cv.split(self.df.values, stratification_groups))
assert len(indices1) > 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment