Skip to content
Snippets Groups Projects
Commit 054eee27 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

pan allele model training checkpointing

parent b5010513
No related merge requests found
......@@ -312,11 +312,6 @@ class Class1AffinityPredictor(object):
str(self.class1_pan_allele_models),
str(self.allele_to_allele_specific_models)))
def save_metadata_df(self, models_dir, name):
df = self.metadata_dataframes[name]
metadata_df_path = join(models_dir, "%s.csv.bz2" % name)
df.to_csv(metadata_df_path, index=False, compression="bz2")
def save(self, models_dir, model_names_to_write=None, write_metadata=True):
"""
Serialize the predictor to a directory on disk. If the directory does
......@@ -380,8 +375,9 @@ class Class1AffinityPredictor(object):
info_path, sep="\t", header=False, index=False)
if self.metadata_dataframes:
for name in self.metadata_dataframes:
self.save_metadata_df(models_dir, name)
for (name, df) in self.metadata_dataframes.items():
metadata_df_path = join(models_dir, "%s.csv.bz2" % name)
df.to_csv(metadata_df_path, index=False, compression="bz2")
# Save allele sequences
if self.allele_to_sequence is not None:
......
......@@ -3,6 +3,7 @@ Train Class1 pan-allele models.
"""
import argparse
import os
from os.path import join
import signal
import sys
import time
......@@ -11,14 +12,12 @@ import random
import pprint
import hashlib
import pickle
import subprocess
import uuid
from functools import partial
import numpy
import pandas
import yaml
from mhcnames import normalize_allele_name
import tqdm # progress bar
tqdm.monitor_interval = 0 # see https://github.com/tqdm/tqdm/issues/481
......@@ -41,7 +40,7 @@ from .encodable_sequences import EncodableSequences
# stored here before creating the thread pool will be inherited to the child
# processes upon fork() call, allowing us to share large data with the workers
# via shared memory.
GLOBAL_DATA = None
GLOBAL_DATA = {}
# Note on parallelization:
# It seems essential currently (tensorflow==1.4.1) that no processes are forked
......@@ -51,16 +50,9 @@ GLOBAL_DATA = None
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
"--action",
nargs="+",
default=["setup", "train", "finalize"],
choices=["setup", "train", "finalize"],
help="Actions to run")
parser.add_argument(
"--data",
metavar="FILE.csv",
required=True,
help=(
"Training data CSV. Expected columns: "
"allele, peptide, measurement_value"))
......@@ -78,7 +70,6 @@ parser.add_argument(
parser.add_argument(
"--hyperparameters",
metavar="FILE.json",
required=True,
help="JSON or YAML of hyperparameters")
parser.add_argument(
"--held-out-measurements-per-allele-fraction-and-max",
......@@ -96,7 +87,6 @@ parser.add_argument(
"--ensemble-size",
type=int,
metavar="N",
required=True,
help="Ensemble size, i.e. how many models to retain the final predictor. "
"In the current implementation, this is also the number of training folds.")
parser.add_argument(
......@@ -125,6 +115,12 @@ parser.add_argument(
action="store_true",
default=False,
help="Launch python debugger on error")
parser.add_argument(
"--continue-incomplete",
action="store_true",
default=False,
help="Continue training models from an incomplete training run. If this is "
"specified then the only required argument is --out-models-dir")
add_local_parallelism_args(parser)
add_cluster_parallelism_args(parser)
......@@ -212,18 +208,6 @@ def pretrain_data_iterator(
yield (allele_encoding, encodable_peptides, df.stack().values)
def mark_work_item_complete(predictor, work_item_name):
if "training_work_items" not in predictor.metadata_dataframes:
predictor.metadata_dataframes["training_work_items"] = pandas.DataFrame({
'work_item_name': [work_item_name],
'complete': [True],
}).set_index("work_item_name", drop=False)
else:
predictor.metadata_dataframes["training_work_items"].loc[
work_item_name, "complete"
] = True
def run(argv=sys.argv[1:]):
# On sigusr1 print stack trace
print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
......@@ -243,18 +227,29 @@ def run(argv=sys.argv[1:]):
def main(args):
global GLOBAL_DATA
print("Arguments:")
print(args)
args.out_models_dir = os.path.abspath(args.out_models_dir)
configure_logging(verbose=args.verbosity > 1)
def setup_action(args):
if not args.continue_incomplete:
initialize_training(args)
train_models(args)
def initialize_training(args):
required_arguments = [
"data",
"out_models_dir",
"hyperparameters",
"ensemble_size",
]
for arg in required_arguments:
if getattr(args, arg) is None:
parser.error("Missing required arg: %s" % arg)
print("Beginning setup action.")
print("Initializing training.")
hyperparameters_lst = yaml.load(open(args.hyperparameters))
assert isinstance(hyperparameters_lst, list)
print("Loaded hyperparameters list:")
......@@ -312,11 +307,6 @@ def setup_action(args):
alleles=allele_sequences_in_use.index.values,
allele_to_sequence=allele_sequences_in_use.to_dict())
global_data = {}
global_data["train_data"] = df
global_data["folds_df"] = folds_df
global_data["allele_encoding"] = allele_encoding
if not os.path.exists(args.out_models_dir):
print("Attempting to create directory: %s" % args.out_models_dir)
os.mkdir(args.out_models_dir)
......@@ -359,39 +349,46 @@ def setup_action(args):
}
work_items.append(work_dict)
predictor.metadata_dataframes["training_work_items"] = pandas.DataFrame(
work_dict,
).set_index("work_item_name", drop=False)
predictor.metadata_dataframes["training_work_items"]["complete"] = False
with open(os.path.join(args.out_models_dir, "global_data.pkl"), "wb") as fd:
pickle.dump(global_data, fd, protocol=pickle.HIGHEST_PROTOCOL)
training_init_info = {}
training_init_info["train_data"] = df
training_init_info["folds_df"] = folds_df
training_init_info["allele_encoding"] = allele_encoding
training_init_info["full_allele_encoding"] = full_allele_encoding
training_init_info["work_items"] = work_items
# Save empty predictor (for metadata)
predictor.save(args.out_models_dir)
def train_action(args):
# Write training_init_info.
with open(join(args.out_models_dir, "training_init_info.pkl"), "wb") as fd:
pickle.dump(training_init_info, fd, protocol=pickle.HIGHEST_PROTOCOL)
print("Done initializing training.")
def train_models(args):
global GLOBAL_DATA
print("Beginning train action.")
print("Beginning training.")
predictor = Class1AffinityPredictor.load(args.out_models_dir)
print("Loaded predictor with %d networks" % len(predictor.neural_networks))
with open(os.path.join(args.out_models_dir, "global_data.pkl"), "rb") as fd:
GLOBAL_DATA = pickle.load(fd)
print("Loaded global data:")
with open(join(args.out_models_dir, "training_init_info.pkl"), "rb") as fd:
GLOBAL_DATA.update(pickle.load(fd))
print("Loaded training init info:")
print(GLOBAL_DATA)
work_items_df = predictor.metadata_dataframes[
"training_work_items"
].set_index("work_item_name", drop=False)
print("Loaded work items:")
print(work_items_df)
print("Work items complete:")
print(work_items_df.complete.value_counts())
work_items_df = work_items_df.loc[
work_items_df.complete
].copy()
del work_items_df["complete"]
work_items = work_items_df.to_dict('records')
print("Will process %d work items" % len(work_items))
all_work_items = GLOBAL_DATA["work_items"]
complete_work_item_names = [
network.fit_info[-1]["training_info"]["work_item_name"] for network in
predictor.neural_networks
]
work_items = [
item for item in all_work_items
if item["work_item_name"] not in complete_work_item_names
]
print("Found %d work items, of which %d are incomplete and will run now." % (
len(all_work_items), len(work_items)))
serial_run = args.num_jobs == 0
......@@ -445,19 +442,14 @@ def train_action(args):
results_generator = None
if results_generator:
for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
#for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
for new_predictor in results_generator:
save_start = time.time()
(work_item_name,) = new_predictor.metadata_dataframes[
"training_work_items"
].work_item_name.values
(new_model_name,) = predictor.merge_in_place([new_predictor])
mark_work_item_complete(predictor, work_item_name)
predictor.save(
args.out_models_dir,
model_names_to_write=[new_model_name],
write_metadata=False)
predictor.save_metadata_df(
args.out_models_dir, "training_work_items")
print(
"Saved predictor (%d models total) with 1 new models"
"in %0.2f sec to %s" % (
......@@ -468,7 +460,8 @@ def train_action(args):
print("Saving final predictor to: %s" % args.out_models_dir)
# We want the final predictor to support all alleles with sequences, not
# just those we actually used for model training.
predictor.allele_to_sequence = full_allele_encoding.allele_to_sequence
predictor.allele_to_sequence = (
GLOBAL_DATA['full_allele_encoding'].allele_to_sequence)
predictor.clear_cache()
predictor.save(args.out_models_dir)
print("Done.")
......@@ -642,7 +635,6 @@ def train_model(
numpy.testing.assert_equal(
predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))
mark_work_item_complete(predictor, work_item_name)
predictor.add_pan_allele_model(model, models_dir_for_save=save_to)
numpy.testing.assert_equal(
predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment