From e29cacc766bc1bf75034b737ffb7dd1a535702c5 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 15 May 2019 16:50:21 -0400
Subject: [PATCH] First cut on train_pan_allele_models_command

---
 downloads-generation/data_curated/GENERATE.sh |   2 -
 downloads-generation/data_curated/curate.py   |  37 +-
 downloads-generation/data_iedb/GENERATE.sh    |   8 +-
 mhcflurry/allele_encoding.py                  |   4 +-
 mhcflurry/class1_neural_network.py            |  15 +-
 mhcflurry/downloads.yml                       |  14 +-
 mhcflurry/train_pan_allele_models_command.py  | 479 ++++++++++++++++++
 setup.py                                      |   2 +
 8 files changed, 512 insertions(+), 49 deletions(-)
 create mode 100644 mhcflurry/train_pan_allele_models_command.py

diff --git a/downloads-generation/data_curated/GENERATE.sh b/downloads-generation/data_curated/GENERATE.sh
index e2ba9187..30a17bd0 100755
--- a/downloads-generation/data_curated/GENERATE.sh
+++ b/downloads-generation/data_curated/GENERATE.sh
@@ -46,8 +46,6 @@ time python curate.py \
         "$(mhcflurry-downloads path data_published)/bdata.20130222.mhci.public.1.txt" \
     --data-systemhc-atlas \
         "$(mhcflurry-downloads path data_systemhcatlas)/data.csv.bz2" \
-    --data-abelin-mass-spec \
-        "$(mhcflurry-downloads path data_published)/abelin2017.hits.csv.bz2" \
     --include-iedb-mass-spec \
     --out-csv curated_training_data.with_mass_spec.csv
 
diff --git a/downloads-generation/data_curated/curate.py b/downloads-generation/data_curated/curate.py
index bb35f489..d61719f7 100755
--- a/downloads-generation/data_curated/curate.py
+++ b/downloads-generation/data_curated/curate.py
@@ -34,11 +34,6 @@ parser.add_argument(
     action="append",
     default=[],
     help="Path to systemhc-atlas-style mass-spec data")
-parser.add_argument(
-    "--data-abelin-mass-spec",
-    action="append",
-    default=[],
-    help="Path to Abelin Immunity 2017 mass-spec hits")
 parser.add_argument(
     "--include-iedb-mass-spec",
     action="store_true",
@@ -120,29 +115,6 @@ def load_data_systemhc_atlas(filename, min_probability=0.99):
     return df
 
 
-def load_data_abelin_mass_spec(filename):
-    df = pandas.read_csv(filename)
-    print("Loaded Abelin mass-spec data: %s" % str(df.shape))
-
-    df["measurement_source"] = "abelin-mass-spec"
-    df["measurement_value"] = QUALITATIVE_TO_AFFINITY["Positive"]
-    df["measurement_inequality"] = "<"
-    df["measurement_type"] = "qualitative"
-    df["original_allele"] = df.allele
-    df["allele"] = df.original_allele.map(normalize_allele_name)
-
-    print("Dropping un-parseable alleles: %s" % ", ".join(
-        str(x) for x in df.ix[df.allele == "UNKNOWN"]["allele"].unique()))
-    df = df.loc[df.allele != "UNKNOWN"]
-    print("Abelin mass-spec data now: %s" % str(df.shape))
-
-    print("Removing duplicates")
-    df = df.drop_duplicates(["allele", "peptide"])
-    print("Abelin mass-spec data now: %s" % str(df.shape))
-
-    return df
-
-
 def load_data_iedb(iedb_csv, include_qualitative=True, include_mass_spec=False):
     iedb_df = pandas.read_csv(iedb_csv, skiprows=1, low_memory=False)
     print("Loaded iedb data: %s" % str(iedb_df.shape))
@@ -171,10 +143,12 @@ def load_data_iedb(iedb_csv, include_qualitative=True, include_mass_spec=False):
 
     quantitative = iedb_df.ix[iedb_df["Units"] == "nM"].copy()
     quantitative["measurement_type"] = "quantitative"
-    quantitative["measurement_inequality"] = "="
+    quantitative["measurement_inequality"] = quantitative[
+        "Measurement Inequality"
+    ].fillna("=")
     print("Quantitative measurements: %d" % len(quantitative))
 
-    qualitative = iedb_df.ix[iedb_df["Units"] != "nM"].copy()
+    qualitative = iedb_df.ix[iedb_df["Units"].isnull()].copy()
     qualitative["measurement_type"] = "qualitative"
     print("Qualitative measurements: %d" % len(qualitative))
     if not include_mass_spec:
@@ -256,9 +230,6 @@ def run():
     for filename in args.data_systemhc_atlas:
         df = load_data_systemhc_atlas(filename)
         dfs.append(df)
-    for filename in args.data_abelin_mass_spec:
-        df = load_data_abelin_mass_spec(filename)
-        dfs.append(df)
 
     df = pandas.concat(dfs, ignore_index=True)
     print("Combined df: %s" % (str(df.shape)))
diff --git a/downloads-generation/data_iedb/GENERATE.sh b/downloads-generation/data_iedb/GENERATE.sh
index a6067a36..7165476b 100755
--- a/downloads-generation/data_iedb/GENERATE.sh
+++ b/downloads-generation/data_iedb/GENERATE.sh
@@ -22,11 +22,17 @@ date
 
 cd $SCRATCH_DIR/$DOWNLOAD_NAME
 
-wget --quiet http://www.iedb.org/doc/mhc_ligand_full.zip
+wget -q http://www.iedb.org/doc/mhc_ligand_full.zip
+wget -q http://www.iedb.org/downloader.php?file_name=doc/tcell_full_v3.zip -O tcell_full_v3.zip
+
 unzip mhc_ligand_full.zip
 rm mhc_ligand_full.zip
 bzip2 mhc_ligand_full.csv
 
+unzip tcell_full_v3.zip
+rm tcell_full_v3.zip
+bzip2 tcell_full_v3.csv
+
 cp $SCRIPT_ABSOLUTE_PATH .
 bzip2 LOG.txt
 tar -cjf "../${DOWNLOAD_NAME}.tar.bz2" *
diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py
index 7a5cc479..20350da0 100644
--- a/mhcflurry/allele_encoding.py
+++ b/mhcflurry/allele_encoding.py
@@ -59,8 +59,8 @@ class AlleleEncoding(object):
         if alleles is not None:
             assert all(
                 allele in self.allele_to_index for allele in alleles),\
-                "Missing alleles: " + " ".join([
-                    a for a in alleles if a not in self.allele_to_index])
+                "Missing alleles: " + " ".join(set(
+                    a for a in alleles if a not in self.allele_to_index))
             self.indices = alleles.map(self.allele_to_index)
             assert not self.indices.isnull().any()
         else:
diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index a9c1edb5..7667da15 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -567,6 +567,8 @@ class Class1NeuralNetwork(object):
                 allele_representations=allele_representations,
                 **self.network_hyperparameter_defaults.subselect(
                     self.hyperparameters))
+            if verbose > 0:
+                self.network().summary()
 
         if allele_representations is not None:
             self.set_allele_representations(allele_representations)
@@ -852,10 +854,6 @@ class Class1NeuralNetwork(object):
             current_layer = BatchNormalization(name="batch_norm_early")(
                 current_layer)
 
-        if dropout_probability:
-            current_layer = Dropout(dropout_probability, name="dropout_early")(
-                current_layer)
-
         if allele_representations is not None:
             allele_input = Input(
                 shape=(1,),
@@ -877,6 +875,8 @@ class Class1NeuralNetwork(object):
                     kernel_regularizer=kernel_regularizer,
                     activation=activation)(allele_layer)
 
+            allele_layer = Flatten(name="allele_flat")(allele_layer)
+
             if peptide_allele_merge_method == 'concatenate':
                 current_layer = keras.layers.concatenate([
                     current_layer, allele_layer
@@ -904,12 +904,13 @@ class Class1NeuralNetwork(object):
                 name="dense_%d" % i)(current_layer)
 
             if batch_normalization:
-                current_layer = BatchNormalization(name="batch_norm_%d" % i)\
-                    (current_layer)
+                current_layer = BatchNormalization(
+                    name="batch_norm_%d" % i)(current_layer)
 
             if dropout_probability > 0:
                 current_layer = Dropout(
-                    dropout_probability, name="dropout_%d" % i)(current_layer)
+                    rate=1 - dropout_probability,
+                    name="dropout_%d" % i)(current_layer)
 
         output = Dense(
             1,
diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml
index b8a4aaa7..f1f3e980 100644
--- a/mhcflurry/downloads.yml
+++ b/mhcflurry/downloads.yml
@@ -36,6 +36,14 @@ releases:
               url: https://github.com/openvax/mhcflurry/releases/download/pan-dev1/random_peptide_predictions.20190506.tar.bz2
               default: false
 
+            - name: data_published
+              url: https://github.com/openvax/mhcflurry/releases/download/pan-dev1/data_published.tar.bz2
+              default: false
+
+            - name: data_curated
+              url: https://github.com/openvax/mhcflurry/releases/download/pan-dev1/data_curated.20190514.tar.bz2
+              default: true
+
             # Older downloads
             - name: models_class1
               url: https://github.com/openvax/mhcflurry/releases/download/pre-1.2/models_class1.20180225.tar.bz2
@@ -43,7 +51,7 @@ releases:
 
             - name: models_class1_selected_no_mass_spec
               url: https://github.com/openvax/mhcflurry/releases/download/pre-1.2/models_class1_selected_no_mass_spec.20180225.tar.bz2
-              default: true
+              default: false
 
             - name: models_class1_unselected
               url: https://github.com/openvax/mhcflurry/releases/download/pre-1.2/models_class1_unselected.20180221.tar.bz2
@@ -61,9 +69,7 @@ releases:
               url: https://github.com/openvax/mhcflurry/releases/download/pre-1.2/models_class1_minimal.20180226.tar.bz2
               default: false
 
-            - name: data_published
-              url: http://github.com/openvax/mhcflurry/releases/download/pre-1.1/data_published.tar.bz2
-              default: false
+
 
     1.2.0:
         compatibility-version: 2
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
new file mode 100644
index 00000000..6844d8ca
--- /dev/null
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -0,0 +1,479 @@
+"""
+Train Class1 pan-allele models.
+"""
+import argparse
+import os
+import signal
+import sys
+import time
+import traceback
+import random
+from functools import partial
+
+import numpy
+import pandas
+import yaml
+from sklearn.metrics.pairwise import cosine_similarity
+from sklearn.model_selection import StratifiedKFold
+from mhcnames import normalize_allele_name
+import tqdm  # progress bar
+tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
+
+from .class1_affinity_predictor import Class1AffinityPredictor
+from .class1_neural_network import Class1NeuralNetwork
+from .common import configure_logging, set_keras_backend
+from .parallelism import (
+    add_worker_pool_args,
+    worker_pool_with_gpu_assignments_from_args,
+    call_wrapped_kwargs)
+from .hyperparameters import HyperparameterDefaults
+from .allele_encoding import AlleleEncoding
+from .encodable_sequences import EncodableSequences
+from .regression_target import to_ic50, from_ic50
+
+
+# To avoid pickling large matrices to send to child processes when running in
+# parallel, we use this global variable as a place to store data. Data that is
+# stored here before creating the thread pool will be inherited to the child
+# processes upon fork() call, allowing us to share large data with the workers
+# via shared memory.
+GLOBAL_DATA = {}
+
+# Note on parallelization:
+# It seems essential currently (tensorflow==1.4.1) that no processes are forked
+# after tensorflow has been used at all, which includes merely importing
+# keras.backend. So we must make sure not to use tensorflow in the main process
+# if we are running in parallel.
+
+parser = argparse.ArgumentParser(usage=__doc__)
+
+parser.add_argument(
+    "--data",
+    metavar="FILE.csv",
+    required=True,
+    help=(
+        "Training data CSV. Expected columns: "
+        "allele, peptide, measurement_value"))
+parser.add_argument(
+    "--pretrain-data",
+    metavar="FILE.csv",
+    help=(
+        "Pre-training data CSV. Expected columns: "
+        "allele, peptide, measurement_value"))
+parser.add_argument(
+    "--out-models-dir",
+    metavar="DIR",
+    required=True,
+    help="Directory to write models and manifest")
+parser.add_argument(
+    "--hyperparameters",
+    metavar="FILE.json",
+    required=True,
+    help="JSON or YAML of hyperparameters")
+parser.add_argument(
+    "--held-out-measurements-per-allele-fraction-and-max",
+    type=float,
+    metavar="X",
+    nargs=2,
+    default=[0.25, 100],
+    help="Fraction of measurements per allele to hold out, and maximum number")
+parser.add_argument(
+    "--ignore-inequalities",
+    action="store_true",
+    default=False,
+    help="Do not use affinity value inequalities even when present in data")
+parser.add_argument(
+    "--ensemble-size",
+    type=int,
+    metavar="N",
+    required=True,
+    help="Ensemble size, i.e. how many models to retain the final predictor. "
+    "In the current implementation, this is also the number of training folds.")
+parser.add_argument(
+    "--num-replicates",
+    type=int,
+    metavar="N",
+    default=1,
+    help="Number of replicates per (architecture, fold) pair to train.")
+parser.add_argument(
+    "--max-epochs",
+    type=int,
+    metavar="N",
+    help="Max training epochs. If specified here it overrides any 'max_epochs' "
+    "specified in the hyperparameters.")
+parser.add_argument(
+    "--allele-sequences",
+    metavar="FILE.csv",
+    help="Allele sequences file.")
+parser.add_argument(
+    "--save-interval",
+    type=float,
+    metavar="N",
+    default=60,
+    help="Write models to disk every N seconds. Only affects parallel runs; "
+    "serial runs write each model to disk as it is trained.")
+parser.add_argument(
+    "--verbosity",
+    type=int,
+    help="Keras verbosity. Default: %(default)s",
+    default=0)
+
+add_worker_pool_args(parser)
+
+
+def assign_folds(df, num_folds, held_out_fraction, held_out_max):
+    result_df = pandas.DataFrame(index=df.index)
+    for fold in range(num_folds):
+        result_df["fold_%d" % fold] = True
+        for (allele, sub_df) in df.groupby("allele"):
+            medians = sub_df.groupby("peptide").measurement_value.median()
+
+            low_peptides = medians[medians < medians.median()].index.values
+            high_peptides = medians[medians >= medians.median()].index.values
+
+            held_out_count = int(
+                min(len(medians) * held_out_fraction, held_out_max))
+            held_out_low_count = min(
+                len(low_peptides),
+                int(held_out_count / 2))
+            held_out_high_count = min(
+                len(high_peptides),
+                held_out_count - held_out_low_count)
+
+            held_out_low = pandas.Series(low_peptides).sample(n=held_out_low_count)
+            held_out_high = pandas.Series(high_peptides).sample(n=held_out_high_count)
+            held_out_peptides = set(held_out_low).union(set(held_out_high))
+
+            result_df.loc[
+                sub_df.index[sub_df.peptide.isin(held_out_peptides)],
+                "fold_%d" % fold
+            ] = False
+
+    print("Training points per fold")
+    print(result_df.sum())
+
+    print("Test points per fold")
+    print((~result_df).sum())
+
+    return result_df
+
+
+def pretrain_data_iterator(
+        filename,
+        master_allele_encoding,
+        peptides_per_chunk=1024):
+    empty = pandas.read_csv(filename, index_col=0, nrows=0)
+    usable_alleles = [
+        c for c in empty.columns
+        if c in master_allele_encoding.allele_to_sequence
+    ]
+    print("Using %d / %d alleles" % (len(usable_alleles), len(empty.columns)))
+    print("Skipped alleles: ", [
+        c for c in empty.columns
+        if c not in master_allele_encoding.allele_to_sequence
+    ])
+
+    allele_encoding = AlleleEncoding(
+        numpy.tile(usable_alleles, peptides_per_chunk),
+        borrow_from=master_allele_encoding)
+
+    synthetic_iter = pandas.read_csv(
+        filename, index_col=0, chunksize=peptides_per_chunk)
+    for (k, df) in enumerate(synthetic_iter):
+        if len(df) != peptides_per_chunk:
+            continue
+
+        df = df[usable_alleles]
+        encodable_peptides = EncodableSequences(
+            numpy.repeat(
+                df.index.values,
+                len(usable_alleles)))
+
+        yield (allele_encoding, encodable_peptides, df.stack().values)
+
+def run(argv=sys.argv[1:]):
+    global GLOBAL_DATA
+
+    # On sigusr1 print stack trace
+    print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
+    signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())
+
+    args = parser.parse_args(argv)
+
+    args.out_models_dir = os.path.abspath(args.out_models_dir)
+
+    configure_logging(verbose=args.verbosity > 1)
+
+    hyperparameters_lst = yaml.load(open(args.hyperparameters))
+    assert isinstance(hyperparameters_lst, list)
+    print("Loaded hyperparameters list: %s" % str(hyperparameters_lst))
+
+    allele_sequences = pandas.read_csv(
+        args.allele_sequences, index_col=0).sequence
+
+    df = pandas.read_csv(args.data)
+    print("Loaded training data: %s" % (str(df.shape)))
+    df = df.loc[
+        (df.peptide.str.len() >= 8) & (df.peptide.str.len() <= 15)
+    ]
+    print("Subselected to 8-15mers: %s" % (str(df.shape)))
+
+    df = df.loc[df.allele.isin(allele_sequences.index)]
+    print("Subselected to alleles with sequences: %s" % (str(df.shape)))
+
+    if args.ignore_inequalities and "measurement_inequality" in df.columns:
+        print("Dropping measurement_inequality column")
+        del df["measurement_inequality"]
+    # Allele names in data are assumed to be already normalized.
+    print("Training data: %s" % (str(df.shape)))
+
+    (held_out_fraction, held_out_max) = (
+        args.held_out_measurements_per_allele_fraction_and_max)
+
+    folds_df = assign_folds(
+        df=df,
+        num_folds=args.ensemble_size,
+        held_out_fraction=held_out_fraction,
+        held_out_max=held_out_max)
+
+    allele_encoding = AlleleEncoding(
+        alleles=allele_sequences.index.values,
+        allele_to_sequence=allele_sequences.to_dict())
+
+    GLOBAL_DATA["train_data"] = df
+    GLOBAL_DATA["folds_df"] = folds_df
+    GLOBAL_DATA["allele_encoding"] = allele_encoding
+    GLOBAL_DATA["args"] = args
+
+    if not os.path.exists(args.out_models_dir):
+        print("Attempting to create directory: %s" % args.out_models_dir)
+        os.mkdir(args.out_models_dir)
+        print("Done.")
+
+    predictor = Class1AffinityPredictor(
+        metadata_dataframes={
+            'train_data': df,
+            'training_folds': folds_df,
+        })
+    serial_run = args.num_jobs == 1
+
+    work_items = []
+    for (h, hyperparameters) in enumerate(hyperparameters_lst):
+        if 'n_models' in hyperparameters:
+            raise ValueError("n_models is unsupported")
+
+        if args.max_epochs:
+            hyperparameters['max_epochs'] = args.max_epochs
+
+        for fold in range(args.ensemble_size):
+            for replicate in range(args.num_replicates):
+                work_dict = {
+                    'architecture_num': h,
+                    'num_architectures': len(hyperparameters_lst),
+                    'fold_num': fold,
+                    'num_folds': args.ensemble_size,
+                    'replicate_num': replicate,
+                    'num_replicates': args.num_replicates,
+                    'hyperparameters': hyperparameters,
+                    'pretrain_data_filename': args.pretrain_data,
+                    'verbose': args.verbosity,
+                    'progress_print_interval': None if not serial_run else 5.0,
+                    'predictor': predictor if serial_run else None,
+                    'save_to': args.out_models_dir if serial_run else None,
+                }
+                work_items.append(work_dict)
+
+    start = time.time()
+
+    worker_pool = worker_pool_with_gpu_assignments_from_args(args)
+
+    if worker_pool:
+        print("Processing %d work items in parallel." % len(work_items))
+
+        # The estimated time to completion is more accurate if we randomize
+        # the order of the work.
+        random.shuffle(work_items)
+
+        results_generator = worker_pool.imap_unordered(
+            partial(call_wrapped_kwargs, train_model),
+            work_items,
+            chunksize=1)
+
+        unsaved_predictors = []
+        last_save_time = time.time()
+        for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
+            unsaved_predictors.append(new_predictor)
+
+            if time.time() > last_save_time + args.save_interval:
+                # Save current predictor.
+                save_start = time.time()
+                new_model_names = predictor.merge_in_place(unsaved_predictors)
+                predictor.save(
+                    args.out_models_dir,
+                    model_names_to_write=new_model_names,
+                    write_metadata=False)
+                print(
+                    "Saved predictor (%d models total) including %d new models "
+                    "in %0.2f sec to %s" % (
+                        len(predictor.neural_networks),
+                        len(new_model_names),
+                        time.time() - save_start,
+                        args.out_models_dir))
+                unsaved_predictors = []
+                last_save_time = time.time()
+
+        predictor.merge_in_place(unsaved_predictors)
+
+    else:
+        # Run in serial. In this case, every worker is passed the same predictor,
+        # which it adds models to, so no merging is required. It also saves
+        # as it goes so no saving is required at the end.
+        for _ in tqdm.trange(len(work_items)):
+            item = work_items.pop(0)  # want to keep freeing up memory
+            work_predictor = train_model(**item)
+            assert work_predictor is predictor
+        assert not work_items
+
+    print("Saving final predictor to: %s" % args.out_models_dir)
+    predictor.save(args.out_models_dir)  # write all models just to be sure
+    print("Done.")
+
+    print("*" * 30)
+    training_time = time.time() - start
+    print("Trained affinity predictor with %d networks in %0.2f min." % (
+        len(predictor.neural_networks), training_time / 60.0))
+    print("*" * 30)
+
+    if worker_pool:
+        worker_pool.close()
+        worker_pool.join()
+
+    print("Predictor written to: %s" % args.out_models_dir)
+
+
+def train_model(
+        architecture_num,
+        num_architectures,
+        fold_num,
+        num_folds,
+        replicate_num,
+        num_replicates,
+        hyperparameters,
+        pretrain_data_filename,
+        pretrain_patience=20,
+        verbose=None,
+        progress_print_interval=None,
+        predictor=None,
+        save_to=None):
+
+    if predictor is None:
+        predictor = Class1AffinityPredictor()
+
+    df = GLOBAL_DATA["train_data"]
+    folds_df = GLOBAL_DATA["folds_df"]
+    allele_encoding = GLOBAL_DATA["allele_encoding"]
+    args = GLOBAL_DATA["args"]
+
+    numpy.testing.assert_equal(len(df), len(folds_df))
+
+    train_data = df.loc[
+        folds_df["fold_%d" % fold_num]
+    ].sample(frac=1.0)
+
+    train_peptides = EncodableSequences(train_data.peptide.values)
+    train_alleles = AlleleEncoding(
+        train_data.allele.values, borrow_from=allele_encoding)
+    train_target = from_ic50(train_data.measurement_value)
+
+    model = Class1NeuralNetwork(**hyperparameters)
+
+    progress_preamble = (
+        "[%2d / %2d folds] "
+        "[%2d / %2d architectures] "
+        "[%4d / %4d replicates] " % (
+            fold_num + 1,
+            num_folds,
+            architecture_num + 1,
+            num_architectures,
+            replicate_num + 1,
+            num_replicates))
+
+    if pretrain_data_filename:
+        iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
+        original_hyperparameters = dict(model.hyperparameters)
+        model.hyperparameters['minibatch_size'] = len(next(iterator)[-1])
+        model.hyperparameters['max_epochs'] = 1
+        model.hyperparameters['validation_split'] = 0.0
+        model.hyperparameters['random_negative_rate'] = 0.0
+        model.hyperparameters['random_negative_constant'] = 0
+        scores = []
+        best_score = float('inf')
+        best_score_epoch = 0
+        for (epoch, (alleles, peptides, affinities)) in enumerate(iterator):
+            # Fit one epoch.
+            start = time.time()
+            model.fit(
+                peptides=peptides,
+                affinities=affinities,
+                allele_encoding=alleles)
+            fit_time = time.time() - start
+            start = time.time()
+            predictions = model.predict(
+                train_peptides,
+                allele_encoding=train_alleles)
+            assert len(predictions) == len(train_data)
+
+            print("Prediction histogram:")
+            print(
+                pandas.Series(
+                    dict([k, v] for (v, k) in zip(*numpy.histogram(predictions)))))
+
+            for (inequality, func) in [(">", numpy.minimum), ("<", numpy.maximum)]:
+                mask = train_data.measurement_inequality == inequality
+                predictions[mask.values] = func(
+                    predictions[mask.values],
+                    train_data.loc[mask].measurement_value.values)
+            score = numpy.mean((from_ic50(predictions) - train_target)**2)
+            score_time = time.time() - start
+            print(
+                progress_preamble,
+                "PRETRAIN epoch %d [%d values, %0.2f sec]. "
+                "Score [%0.2f sec.]: %f" % (
+                    epoch, len(affinities), fit_time, score_time, score))
+            scores.append(score)
+
+            if score < best_score:
+                print("New best score", score)
+                best_score = score
+                best_score_epoch = epoch
+
+            if epoch - best_score_epoch > pretrain_patience:
+                print("Stopping pretraining")
+                break
+
+        model.hyperparameters = original_hyperparameters
+        if model.hyperparameters['learning_rate']:
+            model.hyperparameters['learning_rate'] /= 10
+        else:
+            model.hyperparameters['learning_rate'] = 0.0001
+
+
+    model.fit(
+        train_peptides,
+        train_data.measurement_value,
+        inequalities=(
+            train_data.measurement_inequality.values
+            if "measurement_inequality" in train_data.columns else None),
+        models_dir_for_save=save_to,
+        progress_preamble=progress_preamble,
+        progress_print_interval=progress_print_interval,
+        verbose=verbose)
+
+    predictor.class1_pan_allele_models.append(model)
+    predictor.clear_cache()
+    return predictor
+
+
+
+if __name__ == '__main__':
+    run()
diff --git a/setup.py b/setup.py
index 97224549..e8e4e079 100644
--- a/setup.py
+++ b/setup.py
@@ -79,6 +79,8 @@ if __name__ == '__main__':
                 'mhcflurry-predict = mhcflurry.predict_command:run',
                 'mhcflurry-class1-train-allele-specific-models = '
                     'mhcflurry.train_allele_specific_models_command:run',
+                'mhcflurry-class1-train-pan-allele-models = '
+                    'mhcflurry.train_pan_allele_models_command:run',
                 'mhcflurry-class1-select-allele-specific-models = '
                     'mhcflurry.select_allele_specific_models_command:run',
                 'mhcflurry-calibrate-percentile-ranks = '
-- 
GitLab