From b5dca88e94bf5dbf25c929835e2419a5f9e5b209 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 5 Sep 2019 17:11:17 -0400
Subject: [PATCH] docs

---
 mhcflurry/cluster_parallelism.py              |  84 +++++++++++++
 mhcflurry/common.py                           |  13 ++
 mhcflurry/custom_loss.py                      |  67 ++++++++++-
 .../data_dependent_weights_initialization.py  |  31 +++++
 mhcflurry/downloads.yml                       |   4 +-
 mhcflurry/encodable_sequences.py              | 113 +++++++++++++-----
 mhcflurry/local_parallelism.py                |  73 +++++++++++
 mhcflurry/percent_rank_transform.py           |   1 +
 mhcflurry/predict_command.py                  |   9 +-
 mhcflurry/select_pan_allele_models_command.py |  50 ++++++--
 ... expensive_verify_pretrain_optimizable.py} |   6 +-
 11 files changed, 397 insertions(+), 54 deletions(-)
 rename test/{expensive_test_pretrain_optimizable.py => expensive_verify_pretrain_optimizable.py} (97%)

diff --git a/mhcflurry/cluster_parallelism.py b/mhcflurry/cluster_parallelism.py
index 31af13c1..976f53a2 100644
--- a/mhcflurry/cluster_parallelism.py
+++ b/mhcflurry/cluster_parallelism.py
@@ -1,3 +1,8 @@
+"""
+Simple, relatively naive parallel map implementation for HPC clusters.
+
+Used for training MHCflurry models.
+"""
 import traceback
 import sys
 import os
@@ -18,6 +23,14 @@ except ImportError:
 
 
 def add_cluster_parallelism_args(parser):
+    """
+    Add commandline arguments controlling cluster parallelism to an argparse
+    ArgumentParser.
+
+    Parameters
+    ----------
+    parser : argparse.ArgumentParser
+    """
     group = parser.add_argument_group("Cluster parallelism")
     group.add_argument(
         "--cluster-parallelism",
@@ -45,6 +58,27 @@ def cluster_results_from_args(
         constant_data=None,
         result_serialization_method="pickle",
         clear_constant_data=False):
+    """
+    Parallel map configurable using commandline arguments. See the
+    cluster_results() function for docs.
+
+    The `args` parameter should be an argparse.Namespace from an argparse parser
+    generated using the add_cluster_parallelism_args() function.
+
+
+    Parameters
+    ----------
+    args
+    work_function
+    work_items
+    constant_data
+    result_serialization_method
+    clear_constant_data
+
+    Returns
+    -------
+    generator
+    """
     return cluster_results(
         work_function=work_function,
         work_items=work_items,
@@ -67,6 +101,49 @@ def cluster_results(
         result_serialization_method="pickle",
         max_retries=3,
         clear_constant_data=False):
+    """
+    Parallel map on an HPC cluster.
+
+    Returns [work_function(item) for item in work_items] where each invocation
+    of work_function is performed as a separate HPC cluster job. Order is
+    preserved.
+
+    Optionally, "constant data" can be specified, which will be passed to
+    each work_function() invocation as a keyword argument called constant_data.
+    This data is serialized once and all workers read it from the same source,
+    which is more efficient than serializing it separately for each worker.
+
+    Each worker's input is serialized to a shared NFS directory and the
+    submit_command is used to launch a job to process that input. The shared
+    filesystem is polled occasionally to watch for results, which are fed back
+    to the user.
+
+    Parameters
+    ----------
+    work_function : A -> B
+    work_items : list of A
+    constant_data : object
+    submit_command : string
+        For running on LSF, we use "bsub" here.
+    results_workdir : string
+        Path to NFS shared directory where inputs and results can be written
+    script_prefix_path : string
+        Path to script that will be invoked to run each worker. A line calling
+        the _mhcflurry-cluster-worker-entry-point command will be appended to
+        the contents of this file.
+    result_serialization_method : string, one of "pickle" or "save_predictor"
+        The "save_predictor" works only when the return type of work_function
+        is Class1AffinityPredictor
+    max_retries : int
+        How many times to attempt to re-launch a failed worker
+    clear_constant_data : bool
+        If True, the constant data dict is cleared on the launching host after
+        it is serialized to disk.
+
+    Returns
+    -------
+    generator of B
+    """
 
     constant_payload = {
         'constant_data': constant_data,
@@ -231,6 +308,13 @@ parser.add_argument(
 
 
 def worker_entry_point(argv=sys.argv[1:]):
+    """
+    Entry point for the worker command.
+
+    Parameters
+    ----------
+    argv : list of string
+    """
     # On sigusr1 print stack trace
     print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
     signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index 68a51084..7c8e1628 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -150,6 +150,19 @@ def random_peptides(num, length=9, distribution=None):
 
 
 def positional_frequency_matrix(peptides):
+    """
+    Given a set of peptides, calculate a length x amino acids frequency matrix.
+
+    Parameters
+    ----------
+    peptides : list of string
+        All of same length
+
+    Returns
+    -------
+    pandas.DataFrame
+        Index is position, columns are amino acids
+    """
     length = len(peptides[0])
     assert all(len(peptide) == length for peptide in peptides)
     counts = pandas.DataFrame(
diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index fcb970bd..47d51cb4 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -14,6 +14,17 @@ CUSTOM_LOSSES = {}
 
 
 def get_loss(name):
+    """
+    Get a custom_loss.Loss instance by name.
+
+    Parameters
+    ----------
+    name : string
+
+    Returns
+    -------
+    custom_loss.Loss
+    """
     if name.startswith("custom:"):
         try:
             custom_loss = CUSTOM_LOSSES[name.replace("custom:", "")]
@@ -29,6 +40,17 @@ def get_loss(name):
 
 
 class Loss(object):
+    """
+    Thin wrapper to keep track of neural network loss functions, which could
+    be custom or baked into Keras.
+
+    Each subclass or instance should define these properties/methods:
+    - name : string
+    - loss : string or function
+        This is what gets passed to keras.fit()
+    - encode_y : numpy.ndarray -> numpy.ndarray
+        Transformation to apply to regression target before fitting
+    """
     def __init__(self, name=None):
         self.name = name if name else self.name  # use name from class instance
 
@@ -37,6 +59,9 @@ class Loss(object):
 
 
 class StandardKerasLoss(Loss):
+    """
+    A loss function supported by Keras, such as MSE.
+    """
     supports_inequalities = False
     supports_multiple_outputs = False
 
@@ -51,7 +76,7 @@ class StandardKerasLoss(Loss):
 
 class MSEWithInequalities(Loss):
     """
-    Supports training a regressor on data that includes inequalities
+    Supports training a regression model on data that includes inequalities
     (e.g. x < 100). Mean square error is used as the loss for elements with
     an (=) inequality. For elements with e.g. a (> 0.5) inequality, then the loss
     for that element is (y - 0.5)^2 (standard MSE) if y < 500 and 0 otherwise.
@@ -63,15 +88,15 @@ class MSEWithInequalities(Loss):
     y_true is interpreted as follows:
 
     between 0 - 1
-       Regular MSE loss is used. Penality (y_pred - y_true)**2 is applied if
+       Regular MSE loss is used. Penalty (y_pred - y_true)**2 is applied if
        y_pred is greater or less than y_true.
 
     between 2 - 3:
-       Treated as a "<" inequality. Penality (y_pred - (y_true - 2))**2 is
+       Treated as a "<" inequality. Penalty (y_pred - (y_true - 2))**2 is
        applied only if y_pred is greater than y_true - 2.
 
     between 4 - 5:
-       Treated as a ">" inequality. Penality (y_pred - (y_true - 4))**2 is
+       Treated as a ">" inequality. Penalty (y_pred - (y_true - 4))**2 is
        applied only if y_pred is less than y_true - 4.
     """
     name = "mse_with_inequalities"
@@ -104,8 +129,8 @@ class MSEWithInequalities(Loss):
 
     @staticmethod
     def loss(y_true, y_pred):
-        # We always delay import of Keras so that mhcflurry can be imported initially
-        # without tensorflow debug output, etc.
+        # We always delay import of Keras so that mhcflurry can be imported
+        # initially without tensorflow debug output, etc.
         from keras import backend as K
 
         # Handle (=) inequalities
@@ -133,6 +158,26 @@ class MSEWithInequalities(Loss):
 
 
 class MSEWithInequalitiesAndMultipleOutputs(Loss):
+    """
+    Loss supporting inequalities and multiple outputs.
+
+    This loss assumes that the normal range for y_true and y_pred is 0 - 1. As a
+    hack, the implementation uses other intervals for y_pred to encode the
+    inequality and output-index information.
+
+    Inequalities are encoded into the regression target as in
+    the MSEWithInequalities loss.
+
+    Multiple outputs are encoded by mapping each regression target x (after
+    transforming for inequalities) using the rule x -> x + i * 10 where i is
+    the output index.
+
+    The reason for explicitly encoding multiple outputs this way (rather than
+    just making the regression target a matrix instead of a vector) is that
+    in our use cases we frequently have missing data in the regression target.
+    This encoding gives a simple way to penalize only on (data point, output
+    index) pairs that have labels.
+    """
     name = "mse_with_inequalities_and_multiple_outputs"
     supports_inequalities = True
     supports_multiple_outputs = True
@@ -189,6 +234,16 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
 
 
 def check_shape(name, arr, expected_shape):
+    """
+    Raise ValueError if arr.shape != expected_shape.
+
+    Parameters
+    ----------
+    name : string
+        Included in error message to aid debugging
+    arr : numpy.ndarray
+    expected_shape : tuple of int
+    """
     if arr.shape != expected_shape:
         raise ValueError("Expected %s to have shape %s not %s" % (
             name, str(expected_shape), str(arr.shape)))
diff --git a/mhcflurry/data_dependent_weights_initialization.py b/mhcflurry/data_dependent_weights_initialization.py
index c6cfd6c8..8cc118ce 100644
--- a/mhcflurry/data_dependent_weights_initialization.py
+++ b/mhcflurry/data_dependent_weights_initialization.py
@@ -1,3 +1,11 @@
+"""
+Layer-sequential unit-variance initialization for neural networks.
+
+See:
+    Mishkin and Matas, "All you need is a good init". 2016.
+    https://arxiv.org/abs/1511.06422
+"""
+#
 # LSUV initialization code in this file is adapted from:
 #   https://github.com/ducha-aiki/LSUV-keras/blob/master/lsuv_init.py
 # by Dmytro Mishkin
@@ -58,6 +66,29 @@ def get_activations(model, layer, X_batch):
 
 
 def lsuv_init(model, batch, verbose=True, margin=0.1, max_iter=100):
+    """
+    Initialize neural network weights using layer-sequential unit-variance
+    initialization.
+
+    See:
+        Mishkin and Matas, "All you need is a good init". 2016.
+        https://arxiv.org/abs/1511.06422
+
+    Parameters
+    ----------
+    model : keras.Model
+    batch : dict
+        Training data, as would be passed keras.Model.fit()
+    verbose : boolean
+        Whether to print progress to stdout
+    margin : float
+    max_iter : int
+
+    Returns
+    -------
+    keras.Model
+        Same as what was passed in.
+    """
     from keras.layers import Dense, Convolution2D
     needed_variance = 1.0
     layers_inintialized = 0
diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml
index 6d22b610..338b8e72 100644
--- a/mhcflurry/downloads.yml
+++ b/mhcflurry/downloads.yml
@@ -8,7 +8,7 @@
 # by name, the downloads with "default=true" are downloaded.
 
 # This should usually be the latest release.
-current-release: 2.0.0
+current-release: 1.3.0
 
 # An integer indicating what models the current MHCflurry code base is compatible
 # with. Increment this integer when changes are made to MHCflurry that would break
@@ -17,7 +17,7 @@ current-compatibility-version: 2
 
 # Add new releases here as they are made.
 releases:
-    2.0.0:
+    1.3.0:
         compatibility-version: 2
         downloads:
             - name: models_class1_pan
diff --git a/mhcflurry/encodable_sequences.py b/mhcflurry/encodable_sequences.py
index a124d867..f6322835 100644
--- a/mhcflurry/encodable_sequences.py
+++ b/mhcflurry/encodable_sequences.py
@@ -14,6 +14,9 @@ from . import amino_acid
 
 
 class EncodingError(ValueError):
+    """
+    Exception raised when peptides cannot be encoded
+    """
     def __init__(self, message, supported_peptide_lengths):
         self.supported_peptide_lengths = supported_peptide_lengths
         ValueError.__init__(
@@ -65,21 +68,27 @@ class EncodableSequences(object):
             right_edge=4,
             max_length=15):
         """
-        Encode variable-length sequences using a fixed-length encoding designed
-        for preserving the anchor positions of class I peptides.
-        
-        The sequences must be of length at least left_edge + right_edge, and at
-        most max_length.
+        Encode variable-length sequences to a fixed-size index-encoded (integer)
+        matrix.
+
+        See `sequences_to_fixed_length_index_encoded_array` for details.
         
         Parameters
         ----------
+        alignment_method : string
+            One of "pad_middle" or "left_pad_right_pad"
         left_edge : int, size of fixed-position left side
+            Only relevant for pad_middle alignment method
         right_edge : int, size of the fixed-position right side
-        max_length : sequence length of the resulting encoding
+            Only relevant for pad_middle alignment method
+        max_length : maximum supported peptide length
 
         Returns
         -------
-        numpy.array of integers with shape (num sequences, max_length)
+        numpy.array of integers with shape (num sequences, encoded length)
+
+        For pad_middle, the encoded length is max_length. For left_pad_right_pad,
+        it's 3 * max_length.
         """
 
         cache_key = (
@@ -108,11 +117,12 @@ class EncodableSequences(object):
             right_edge=4,
             max_length=15):
         """
-        Encode variable-length sequences using a fixed-length encoding designed
-        for preserving the anchor positions of class I peptides.
+        Encode variable-length sequences to a fixed-size matrix. Amino acids
+        are encoded as specified by the vector_encoding_name argument.
 
-        The sequences must be of length at least left_edge + right_edge, and at
-        most max_length.
+        See `sequences_to_fixed_length_index_encoded_array` for details.
+
+        See also: variable_length_to_fixed_length_categorical.
 
         Parameters
         ----------
@@ -120,14 +130,22 @@ class EncodableSequences(object):
             How to represent amino acids.
             One of "BLOSUM62", "one-hot", etc. Full list of supported vector
             encodings is given by available_vector_encodings().
+        alignment_method : string
+            One of "pad_middle" or "left_pad_right_pad"
         left_edge : int, size of fixed-position left side
+            Only relevant for pad_middle alignment method
         right_edge : int, size of the fixed-position right side
-        max_length : sequence length of the resulting encoding
+            Only relevant for pad_middle alignment method
+        max_length : maximum supported peptide length
 
         Returns
         -------
-        numpy.array with shape (num sequences, max_length, m) where m is
-        vector_encoding_length(vector_encoding_name)
+        numpy.array with shape (num sequences, encoded length, m)
+
+        where
+            - m is the vector encoding length (usually 21).
+            - encoded length is max_length if alignment_method is pad_middle;
+              3 * max_length if it's left_pad_right_pad.
         """
         cache_key = (
             "fixed_length_vector_encoding",
@@ -160,32 +178,63 @@ class EncodableSequences(object):
             right_edge=4,
             max_length=15):
         """
-        Transform a sequence of strings, where each string is of length at least
-        left_edge + right_edge and at most max_length into strings of length
-        max_length using a scheme designed to preserve the anchor positions of
-        class I peptides.
+        Encode variable-length sequences to a fixed-size index-encoded (integer)
+        matrix.
+
+        How variable length sequences get mapped to fixed length is set by the
+        "alignment_method" argument. Supported alignment methods are:
+
+            pad_middle
+                Encoding designed for preserving the anchor positions of class
+                I peptides. This is what is used in allele-specific models.
+                
+                Each string must be of length at least left_edge + right_edge
+                and at most max_length. The first left_edge characters in the
+                input always map to the first left_edge characters in the
+                output. Similarly for the last right_edge characters. The
+                middle characters are filled in based on the length, with the
+                X character filling in the blanks.
 
-        The first left_edge characters in the input always map to the first
-        left_edge characters in the output. Similarly for the last right_edge
-        characters. The middle characters are filled in based on the length,
-        with the X character filling in the blanks.
+                Example:
 
-        For example, using defaults:
+                AAAACDDDD -> AAAAXXXCXXXDDDD
 
-        AAAACDDDD -> AAAAXXXCXXXDDDD
+            left_pad_centered_right_pad
+                Encoding that makes no assumptions on anchor positions but is
+                3x larger than pad_middle, since it duplicates the peptide
+                (left aligned + centered + right aligned). This is what is used
+                for the pan-allele models.
 
-        The strings are also converted to int categorical amino acid indices.
+                Example:
+
+                AAAACDDDD -> AAAACDDDDXXXXXXXXXAAAACDDDDXXXXXXXXXAAAACDDDD
+
+            left_pad_right_pad
+                Same as left_pad_centered_right_pad but only includes left-
+                and right-padded peptide.
+
+                Example:
+
+                AAAACDDDD -> AAAACDDDDXXXXXXXXXXXXAAAACDDDD
 
         Parameters
         ----------
-        sequence : string
-        left_edge : int
-        right_edge : int
-        max_length : int
+        sequences : list of string
+        alignment_method : string
+            One of "pad_middle" or "left_pad_right_pad"
+        left_edge : int, size of fixed-position left side
+            Only relevant for pad_middle alignment method
+        right_edge : int, size of the fixed-position right side
+            Only relevant for pad_middle alignment method
+        max_length : maximum supported peptide length
 
         Returns
         -------
-        numpy array of shape (len(sequences), max_length) and dtype int
+        numpy.array of integers with shape (num sequences, encoded length)
+
+        For pad_middle, the encoded length is max_length. For left_pad_right_pad,
+        it's 2 * max_length. For left_pad_centered_right_pad, it's
+        3 * max_length.
         """
         result = None
         if alignment_method == 'pad_middle':
@@ -213,8 +262,8 @@ class EncodableSequences(object):
                             len(sub_df)), supported_peptide_lengths=(
                                 min_length, max_length))
 
-                # Array of shape (num peptides, length) giving fixed-length amino
-                # acid encoding each peptide of the current length.
+                # Array of shape (num peptides, length) giving fixed-length
+                # amino acid encoding each peptide of the current length.
                 fixed_length_sequences = numpy.stack(
                     sub_df.peptide.map(
                         lambda s: numpy.array([
diff --git a/mhcflurry/local_parallelism.py b/mhcflurry/local_parallelism.py
index ac3facaa..100a270c 100644
--- a/mhcflurry/local_parallelism.py
+++ b/mhcflurry/local_parallelism.py
@@ -1,3 +1,8 @@
+"""
+Infrastructure for "local" parallelism, i.e. multiprocess parallelism on one
+compute node.
+"""
+
 import traceback
 import sys
 import os
@@ -14,6 +19,13 @@ from .common import set_keras_backend
 
 
 def add_local_parallelism_args(parser):
+    """
+    Add local parallelism arguments to the given argparse.ArgumentParser.
+
+    Parameters
+    ----------
+    parser : argparse.ArgumentParser
+    """
     group = parser.add_argument_group("Local parallelism")
 
     group.add_argument(
@@ -54,6 +66,20 @@ def add_local_parallelism_args(parser):
 
 
 def worker_pool_with_gpu_assignments_from_args(args):
+    """
+    Create a multiprocessing.Pool where each worker uses its own GPU.
+
+    Uses commandline arguments. See `worker_pool_with_gpu_assignments`.
+
+    Parameters
+    ----------
+    args : argparse.ArgumentParser
+
+    Returns
+    -------
+    multiprocessing.Pool
+    """
+
     return worker_pool_with_gpu_assignments(
         num_jobs=args.num_jobs,
         num_gpus=args.gpus,
@@ -71,6 +97,23 @@ def worker_pool_with_gpu_assignments(
         max_workers_per_gpu=1,
         max_tasks_per_worker=None,
         worker_log_dir=None):
+    """
+    Create a multiprocessing.Pool where each worker uses its own GPU.
+
+    Parameters
+    ----------
+    num_jobs : int
+        Number of worker processes.
+    num_gpus : int
+    backend : string
+    max_workers_per_gpu : int
+    max_tasks_per_worker : int
+    worker_log_dir : string
+
+    Returns
+    -------
+    multiprocessing.Pool
+    """
 
     if num_jobs == 0:
         if backend:
@@ -247,6 +290,20 @@ class WrapException(Exception):
 
 
 def call_wrapped(function, *args, **kwargs):
+    """
+    Run function on args and kwargs and return result, wrapping any exception
+    raised in a WrapException.
+
+    Parameters
+    ----------
+    function : arbitrary function
+
+    Any other arguments provided are passed to the function.
+
+    Returns
+    -------
+    object
+    """
     try:
         return function(*args, **kwargs)
     except:
@@ -254,4 +311,20 @@ def call_wrapped(function, *args, **kwargs):
 
 
 def call_wrapped_kwargs(function, kwargs):
+    """
+    Invoke function on given kwargs and return result, wrapping any exception
+    raised in a WrapException.
+
+    Parameters
+    ----------
+    function : arbitrary function
+    kwargs : dict
+
+    Returns
+    -------
+    object
+
+    result of calling function(**kwargs)
+
+    """
     return call_wrapped(function, **kwargs)
\ No newline at end of file
diff --git a/mhcflurry/percent_rank_transform.py b/mhcflurry/percent_rank_transform.py
index 6f42477d..a9098bc2 100644
--- a/mhcflurry/percent_rank_transform.py
+++ b/mhcflurry/percent_rank_transform.py
@@ -1,6 +1,7 @@
 import numpy
 import pandas
 
+
 class PercentRankTransform(object):
     """
     Transform arbitrary values into percent ranks.
diff --git a/mhcflurry/predict_command.py b/mhcflurry/predict_command.py
index 6d92a501..dea1ddf4 100644
--- a/mhcflurry/predict_command.py
+++ b/mhcflurry/predict_command.py
@@ -163,9 +163,9 @@ def run(argv=sys.argv[1:]):
 
     models_dir = args.models
     if models_dir is None:
-        # The reason we set the default here instead of in the argument parser is that
-        # we want to test_exists at this point, so the user gets a message instructing
-        # them to download the models if needed.
+        # The reason we set the default here instead of in the argument parser
+        # is that we want to test_exists at this point, so the user gets a
+        # message instructing them to download the models if needed.
         models_dir = get_default_class1_models_dir(test_exists=True)
     predictor = Class1AffinityPredictor.load(models_dir)
 
@@ -224,7 +224,8 @@ def run(argv=sys.argv[1:]):
     predictions = predictor.predict_to_dataframe(
         peptides=df[args.peptide_column].values,
         alleles=df[args.allele_column].values,
-        include_individual_model_predictions=args.include_individual_model_predictions,
+        include_individual_model_predictions=(
+            args.include_individual_model_predictions),
         throw=not args.no_throw)
 
     for col in predictions.columns:
diff --git a/mhcflurry/select_pan_allele_models_command.py b/mhcflurry/select_pan_allele_models_command.py
index 3c3fc4d6..0032f8fc 100644
--- a/mhcflurry/select_pan_allele_models_command.py
+++ b/mhcflurry/select_pan_allele_models_command.py
@@ -1,5 +1,10 @@
 """
-Model select class1 pan allele models.
+Model select class1 pan-allele models.
+
+APPROACH: For each training fold, we select at least min and at most max models
+(where min and max are set by the --{min/max}-models-per-fold argument) using a
+step-up (forward) selection procedure. The final ensemble is the union of all
+selected models across all folds.
 """
 import argparse
 import os
@@ -94,6 +99,22 @@ def mse(
         actual,
         inequalities=None,
         affinities_are_already_01_transformed=False):
+    """
+    Mean squared error of predictions vs. actual
+
+    Parameters
+    ----------
+    predictions : list of float
+    actual : list of float
+    inequalities : list of string (">", "<", or "=")
+    affinities_are_already_01_transformed : boolean
+        Predictions and actual are taken to be nanomolar affinities if
+        affinities_are_already_01_transformed is False, otherwise 0-1 values.
+
+    Returns
+    -------
+    float
+    """
     if not affinities_are_already_01_transformed:
         predictions = from_ic50(predictions)
         actual = from_ic50(actual)
@@ -286,14 +307,29 @@ def run(argv=sys.argv[1:]):
     print("Predictor written to: %s" % args.out_models_dir)
 
 
-def do_model_select_task(item):
-    return model_select(**item)
+def do_model_select_task(item, constant_data=GLOBAL_DATA):
+    return model_select(constant_data=constant_data, **item)
 
 
-def model_select(fold_num, models, min_models, max_models):
-    global GLOBAL_DATA
-    full_data = GLOBAL_DATA["data"]
-    input_predictor = GLOBAL_DATA["input_predictor"]
+def model_select(
+        fold_num, models, min_models, max_models, constant_data=GLOBAL_DATA):
+    """
+    Model select for a fold.
+
+    Parameters
+    ----------
+    fold_num : int
+    models : list of Class1NeuralNetwork
+    min_models : int
+    max_models : int
+    constant_data : dict
+
+    Returns
+    -------
+    dict with keys 'fold_num', 'selected_indices', 'summary'
+    """
+    full_data = constant_data["data"]
+    input_predictor = constant_data["input_predictor"]
     df = full_data.loc[
         full_data["fold_%d" % fold_num] == 0
     ]
diff --git a/test/expensive_test_pretrain_optimizable.py b/test/expensive_verify_pretrain_optimizable.py
similarity index 97%
rename from test/expensive_test_pretrain_optimizable.py
rename to test/expensive_verify_pretrain_optimizable.py
index eec444dc..618ca829 100644
--- a/test/expensive_test_pretrain_optimizable.py
+++ b/test/expensive_verify_pretrain_optimizable.py
@@ -1,4 +1,4 @@
-# Expensive test - not run by default.
+# Expensive test - not run by nose.
 
 from mhcflurry import train_pan_allele_models_command
 from mhcflurry.downloads import get_path
@@ -60,7 +60,7 @@ HYPERPARAMTERS = {
 }
 
 
-def test_optimizable():
+def verify_optimizable():
     predictor = train_pan_allele_models_command.train_model(
         work_item_name="work-item0",
         work_item_num=0,
@@ -94,4 +94,4 @@ def test_optimizable():
 
 
 if __name__ == "__main__":
-    test_optimizable()
+    verify_optimizable()
-- 
GitLab