Skip to content
Snippets Groups Projects
Commit 402d3176 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fixes

parent b3be67c6
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,7 @@ base_hyperparameters = {
##########################################
# TRAINING Data
##########################################
"train_data": {"subset": "all", "pretrain_min_points": 10000},
"train_data": {"subset": "all", "pretrain_min_points": 1000},
}
grid = []
......
......@@ -46,7 +46,7 @@ time mhcflurry-class1-train-allele-specific-models \
--out-models-dir models \
--percent-rank-calibration-num-peptides-per-length 0 \
--min-measurements-per-allele 75 \
--num-jobs $(expr $PROCESSORS \* 2) --gpus $GPUS --max-workers-per-gpu 2 --max-tasks-per-worker 20
--num-jobs $(expr $PROCESSORS \* 2) --gpus $GPUS --max-workers-per-gpu 2 --max-tasks-per-worker 50
cp $SCRIPT_ABSOLUTE_PATH .
bzip2 LOG.txt
......
......@@ -510,14 +510,15 @@ class Class1AffinityPredictor(object):
]
if train_rounds is not None:
for round in range(1, train_rounds.max() + 1):
round_mask = train_rounds >= round
sub_encodable_peptides = EncodableSequences.create(
encodable_peptides.sequences[round_mask])
peptides_affinities_inequalities_per_round.append((
sub_encodable_peptides,
affinities[round_mask],
None if inequalities is None else inequalities[round_mask]))
for round in sorted(set(train_rounds)):
round_mask = train_rounds > round
if round_mask.any():
sub_encodable_peptides = EncodableSequences.create(
encodable_peptides.sequences[round_mask])
peptides_affinities_inequalities_per_round.append((
sub_encodable_peptides,
affinities[round_mask],
None if inequalities is None else inequalities[round_mask]))
n_rounds = len(peptides_affinities_inequalities_per_round)
n_architectures = len(architecture_hyperparameters_list)
......
......@@ -758,6 +758,7 @@ class Class1NeuralNetwork(object):
numpy.array of nM affinity predictions
"""
use_cache = (
self.prediction_cache is not None and
allele_encoding is None and
isinstance(peptides, EncodableSequences))
if use_cache and peptides in self.prediction_cache:
......
import traceback
import sys
from multiprocessing import Pool, Queue, cpu_count
from six.moves import queue
from multiprocessing.util import Finalize
......@@ -98,3 +100,28 @@ def worker_init_entry_point(
print("Initializing worker: %s" % str(kwargs))
init_function(**kwargs)
# Solution suggested in https://bugs.python.org/issue13831
class WrapException(Exception):
"""
Add traceback info to exception so exceptions raised in worker processes
can still show traceback info when re-raised in the parent.
"""
def __init__(self):
exc_type, exc_value, exc_tb = sys.exc_info()
self.exception = exc_value
self.formatted = ''.join(traceback.format_exception(exc_type, exc_value, exc_tb))
def __str__(self):
return '%s\nOriginal traceback:\n%s' % (Exception.__str__(self), self.formatted)
def call_wrapped(function, *args, **kwargs):
try:
return function(*args, **kwargs)
except:
raise WrapException()
def call_wrapped_kwargs(function, kwargs):
return call_wrapped(function, **kwargs)
\ No newline at end of file
......@@ -20,7 +20,8 @@ tqdm.monitor_interval = 0 # see https://github.com/tqdm/tqdm/issues/481
from .class1_affinity_predictor import Class1AffinityPredictor
from .common import configure_logging, set_keras_backend
from .parallelism import make_worker_pool, cpu_count
from .parallelism import (
make_worker_pool, cpu_count, call_wrapped, call_wrapped_kwargs)
from .hyperparameters import HyperparameterDefaults
from .allele_encoding import AlleleEncoding
......@@ -327,7 +328,9 @@ def run(argv=sys.argv[1:]):
random.shuffle(work_items)
results_generator = worker_pool.imap_unordered(
train_model_entry_point, work_items, chunksize=1)
partial(call_wrapped_kwargs, train_model),
work_items,
chunksize=1)
unsaved_predictors = []
last_save_time = time.time()
......@@ -361,7 +364,7 @@ def run(argv=sys.argv[1:]):
# as it goes so no saving is required at the end.
for _ in tqdm.trange(len(work_items)):
item = work_items.pop(0) # want to keep freeing up memory
work_predictor = train_model_entry_point(item)
work_predictor = train_model(**item)
assert work_predictor is predictor
assert not work_items
......@@ -418,14 +421,13 @@ def run(argv=sys.argv[1:]):
results = worker_pool.imap_unordered(
partial(
calibrate_percentile_ranks,
partial(call_wrapped, calibrate_percentile_ranks),
predictor=predictor),
alleles,
chunksize=1)
for result in tqdm.tqdm(results, total=len(alleles)):
predictor.allele_to_percent_rank_transform.update(result)
print("Done calibrating %d additional alleles." % len(alleles))
predictor.save(args.out_models_dir, model_names_to_write=[])
......@@ -440,10 +442,6 @@ def run(argv=sys.argv[1:]):
print("Predictor written to: %s" % args.out_models_dir)
def train_model_entry_point(item):
return train_model(**item)
def alleles_by_similarity(allele):
global GLOBAL_DATA
allele_similarity = GLOBAL_DATA['allele_similarity_matrix']
......
......@@ -2,10 +2,10 @@ import json
import os
import shutil
import tempfile
import subprocess
from numpy.testing import assert_array_less, assert_equal
from mhcflurry import train_allele_specific_models_command
from mhcflurry import Class1AffinityPredictor
from mhcflurry.downloads import get_path
......@@ -57,6 +57,7 @@ def run_and_check(n_jobs=0):
json.dump(HYPERPARAMETERS, fd)
args = [
"mhcflurry-class1-train-allele-specific-models",
"--data", get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"),
"--hyperparameters", hyperparameters_filename,
"--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
......@@ -66,7 +67,7 @@ def run_and_check(n_jobs=0):
"--ignore-inequalities",
]
print("Running with args: %s" % args)
train_allele_specific_models_command.run(args)
subprocess.check_call(args)
result = Class1AffinityPredictor.load(models_dir)
predictions = result.predict(
......@@ -84,8 +85,9 @@ def run_and_check(n_jobs=0):
shutil.rmtree(models_dir)
def Xtest_run_parallel():
run_and_check(n_jobs=3)
if os.environ.get("KERAS_BACKEND") != "theano":
def test_run_parallel():
run_and_check(n_jobs=3)
def test_run_serial():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment