From 0994b6d515a5da2e817feed5cb60c44df77fa675 Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Fri, 2 Oct 2015 13:34:42 -0400
Subject: [PATCH] added model selection script

---
 experiments/nips2015-model-selection.py | 380 ++++++++++++++++++++++++
 experiments/results.csv                 | 120 ++++++++
 mhcflurry/data_helpers.py               |  69 ++++-
 mhcflurry/feedforward.py                |   3 +
 mhcflurry/mhc1_binding_predictor.py     |   1 +
 5 files changed, 570 insertions(+), 3 deletions(-)
 create mode 100755 experiments/nips2015-model-selection.py
 create mode 100644 experiments/results.csv

diff --git a/experiments/nips2015-model-selection.py b/experiments/nips2015-model-selection.py
new file mode 100755
index 00000000..b021536f
--- /dev/null
+++ b/experiments/nips2015-model-selection.py
@@ -0,0 +1,380 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2015. Mount Sinai School of Medicine
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import (
+    print_function,
+    division,
+    absolute_import,
+    unicode_literals
+)
+from collections import namedtuple, OrderedDict
+from os.path import join
+import argparse
+
+import numpy as np
+import pandas as pd
+import sklearn
+import sklearn.metrics
+import sklearn.cross_validation
+from sklearn.cross_validation import KFold
+
+from mhcflurry.common import normalize_allele_name
+from mhcflurry.feedforward import make_embedding_network, make_hotshot_network
+from mhcflurry.data_helpers import load_data, indices_to_hotshot_encoding
+from mhcflurry.paths import (
+    CLASS1_DATA_DIRECTORY
+)
+
+PETERS2009_CSV_FILENAME = "bdata.2009.mhci.public.1.txt"
+PETERS2009_CSV_PATH = join(CLASS1_DATA_DIRECTORY, PETERS2009_CSV_FILENAME)
+
+PETERS2013_CSV_FILENAME = "bdata.20130222.mhci.public.1.txt"
+PETERS2013_CSV_PATH = join(CLASS1_DATA_DIRECTORY, PETERS2013_CSV_FILENAME)
+
+COMBINED_CSV_FILENAME = "combined_human_class1_dataset.csv"
+COMBINED_CSV_PATH = join(CLASS1_DATA_DIRECTORY, COMBINED_CSV_FILENAME)
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+    "--binding-data-csv-path",
+    default=PETERS2009_CSV_PATH,
+    help="CSV file with 'mhc', 'peptide', 'peptide_length', 'meas' columns")
+
+parser.add_argument(
+    "--min-samples-per-allele",
+    default=5,
+    help="Don't train predictors for alleles with fewer samples than this",
+    type=int)
+
+parser.add_argument(
+    "--results-filename",
+    required=True,
+    help="Write all hyperparameter/allele results to this filename")
+
+parser.add_argument(
+    "--cv-folds",
+    default=5,
+    type=int,
+    help="Number cross-validation folds")
+
+parser.add_argument(
+    "--training-epochs",
+    default=125,
+    type=int,
+    help="Number of passes over the dataset to perform during model fitting")
+
+
+parser.add_argument(
+    "--max-dropout",
+    default=0.25,
+    type=float,
+    help="Degree of dropout regularization to try in hyperparameter search")
+
+
+ModelConfig = namedtuple(
+    "ModelConfig",
+    [
+        "embedding_size",
+        "hidden_layer_size",
+        "activation",
+        "loss",
+        "init",
+        "n_pretrain_epochs",
+        "n_epochs",
+        "dropout_probability",
+        "max_ic50",
+    ])
+
+HIDDEN1_LAYER_SIZES = [
+    64,
+    128,
+    256,
+    512,
+]
+
+INITILIZATION_METHODS = [
+    "glorot_uniform",
+    "glorot_normal",
+    "uniform",
+]
+
+ACTIVATIONS = [
+    "relu",
+    "tanh",
+    "prelu",
+]
+
+
+def generate_all_model_configs(
+        embedding_sizes=[0, 16, 64],
+        n_training_epochs=125,
+        max_dropout=0.25):
+    configurations = []
+    for activation in ACTIVATIONS:
+        for loss in ["mse"]:
+            for init in INITILIZATION_METHODS:
+                for n_pretrain_epochs in [0, 10]:
+                    for hidden_layer_size in HIDDEN1_LAYER_SIZES:
+                        for embedding_size in embedding_sizes:
+                            for dropout in [0, max_dropout]:
+                                for max_ic50 in [5000, 50000]:
+                                    config = ModelConfig(
+                                        embedding_size=embedding_size,
+                                        hidden_layer_size=hidden_layer_size,
+                                        activation=activation,
+                                        init=init,
+                                        loss=loss,
+                                        dropout_probability=dropout,
+                                        n_pretrain_epochs=n_pretrain_epochs,
+                                        n_epochs=n_training_epochs,
+                                        max_ic50=max_ic50)
+                                    print(config)
+                                    configurations.append(config)
+    return configurations
+
+
+def kfold_cross_validation_for_single_allele(
+        allele_name, model, X, Y, ic50,
+        n_training_epochs=100,
+        cv_folds=5):
+    """
+    Estimate the per-allele AUC score of a model via k-fold cross-validation.
+    Returns the per-fold AUC scores and accuracies.
+    """
+    n_samples = len(Y)
+    initial_weights = [w.copy() for w in model.get_weights()]
+    fold_aucs = []
+    fold_accuracies = []
+    for cv_iter, (train_idx, test_idx) in enumerate(KFold(
+            n=n_samples,
+            n_folds=cv_folds,
+            shuffle=True,
+            random_state=0)):
+        X_train, Y_train = X[train_idx, :], Y[train_idx]
+        X_test = X[test_idx, :]
+        ic50_test = ic50[test_idx]
+        label_test = ic50_test <= 500
+        if label_test.all() or not label_test.any():
+            print(
+                "Skipping CV iter %d of %s since all outputs are the same" % (
+                    cv_iter, allele_name))
+            continue
+        model.set_weights(initial_weights)
+
+        history = model.fit(
+            X_train,
+            Y_train,
+            nb_epoch=n_training_epochs,
+            verbose=0)
+        losses = history.history["loss"]
+        print("-- CV iter #%d for %s: first=%0.4f, min=%0.4f, last=%0.4f" % (
+            cv_iter + 1,
+            allele_name,
+            losses[0],
+            min(losses),
+            losses[-1]))
+
+        pred = model.predict(X_test)
+        auc = sklearn.metrics.roc_auc_score(label_test, pred)
+        ic50_pred = 5000 ** (1.0 - pred)
+        accuracy = np.mean(label_test == (ic50_pred <= 500))
+        fold_aucs.append(auc)
+        fold_accuracies.append(accuracy)
+    return fold_aucs, fold_accuracies
+
+
+def leave_out_allele_cross_validation(
+        model,
+        binary_encoding=False,
+        n_pretrain_epochs=0,
+        min_samples_per_allele=5,
+        cv_folds=5):
+    """
+    Fit the model for every allele in the dataset and return a DataFrame
+    with the following columns:
+            allele_name
+            dataset_size
+            auc_mean
+            auc_median
+            auc_std
+            auc_min
+            auc_max
+            accuracy_mean
+            accuracy_median
+            accuracy_std
+            accuracy_min
+            accuracy_max
+    """
+    result_dict = OrderedDict([
+        ("allele_name", []),
+        ("dataset_size", []),
+        ("auc_mean", []),
+        ("auc_median", []),
+        ("auc_std", []),
+        ("auc_min", []),
+        ("auc_max", []),
+        ("accuracy_mean", []),
+        ("accuracy_median", []),
+        ("accuracy_std", []),
+        ("accuracy_min", []),
+        ("accuracy_max", [])
+    ])
+    initial_weights = [w.copy() for w in model.get_weights()]
+    for allele_name, dataset in allele_datasets.items():
+        allele_name = normalize_allele_name(allele_name)
+        if allele_name.isdigit() or len(allele_name) < 4:
+            print("Skipping allele %s" % (allele_name,))
+            continue
+        X_allele = dataset.X
+        n_samples_allele = X_allele.shape[0]
+        if n_samples_allele < min_samples_per_allele:
+            print("Skipping allele %s due to too few samples: %d" % (
+                allele_name, n_samples_allele))
+            continue
+        if binary_encoding:
+            X_allele = indices_to_hotshot_encoding(X_allele, n_indices=20)
+        Y_allele = dataset.Y
+        ic50_allele = dataset.ic50
+        model.set_weights(initial_weights)
+        if n_pretrain_epochs > 0:
+            X_other_alleles = np.vstack([
+                other_dataset.X
+                for (other_allele, other_dataset) in allele_datasets.items()
+                if normalize_allele_name(other_allele) != allele_name])
+            if binary_encoding:
+                X_other_alleles = indices_to_hotshot_encoding(
+                    X_other_alleles, n_indices=20)
+            Y_other_alleles = np.concatenate([
+                other_allele.Y for (other_allele, other_dataset)
+                in allele_datasets.items()
+                if normalize_allele_name(other_allele) != allele_name])
+            print("Pre-training X shape: %s" % (X_other_alleles.shape,))
+            print("Pre-training Y shape: %s" % (Y_other_alleles.shape,))
+            model.fit(
+                X_other_alleles,
+                Y_other_alleles,
+                nb_epoch=n_pretrain_epochs)
+        aucs, accuracies = kfold_cross_validation_for_single_allele(
+            allele_name=allele_name,
+            model=model,
+            X=X_allele,
+            Y=Y_allele,
+            ic50=ic50_allele,
+            n_training_epochs=config.n_epochs,
+            cv_folds=cv_folds)
+        if len(aucs) == 0:
+            print("Skipping allele %s" % allele_name)
+            continue
+        result_dict["allele_name"].append(allele_name)
+        result_dict["dataset_size"].append(len(ic50_allele))
+        for (name, values) in [("auc", aucs), ("accuracy", accuracies)]:
+            result_dict["%s_mean" % name].append(np.mean(values))
+            result_dict["%s_median" % name].append(np.median(values))
+            result_dict["%s_std" % name].append(np.std(values))
+            result_dict["%s_min" % name].append(np.min(values))
+            result_dict["%s_max" % name].append(np.max(values))
+    return pd.DataFrame(result_dict)
+
+
+def evaluate_model_config(
+        config,
+        allele_datasets,
+        min_samples_per_allele=5,
+        cv_folds=5):
+    print("===")
+    print(config)
+    if config.embedding_size:
+        model = make_embedding_network(
+            peptide_length=9,
+            embedding_input_dim=20,
+            embedding_output_dim=config.embedding_size,
+            layer_sizes=[config.hidden_layer_size],
+            activation=config.activation,
+            init=config.init,
+            loss=config.loss,
+            dropout_probability=config.dropout_probability)
+    else:
+        model = make_hotshot_network(
+            peptide_length=9,
+            layer_sizes=[config.hidden_layer_size],
+            activation=config.activation,
+            init=config.init,
+            loss=config.loss,
+            dropout_probability=config.dropout_probability)
+    return leave_out_allele_cross_validation(
+        model,
+        binary_encoding=config.embedding_size == 0,
+        n_pretrain_epochs=config.n_pretrain_epochs,
+        min_samples_per_allele=min_samples_per_allele,
+        cv_folds=cv_folds)
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    configs = generate_all_model_configs(
+        max_dropout=args.max_dropout,
+        n_training_epochs=args.training_epochs)
+    print("Total # configurations = %d" % len(configs))
+
+    datasets_by_max_ic50 = {}
+
+    all_dataframes = []
+    for i, config in enumerate(configs):
+        if config.max_ic50 not in datasets_by_max_ic50:
+            allele_datasets, _ = load_data(
+                args.binding_data_csv_path,
+                peptide_length=9,
+                binary_encoding=False)
+            datasets_by_max_ic50[config.max_ic50] = allele_datasets
+        else:
+            allele_datasets = datasets_by_max_ic50[config.max_ic50]
+
+        result_df = evaluate_model_config(
+            config,
+            allele_datasets,
+            min_samples_per_allele=args.min_samples_per_allele,
+            cv_folds=args.cv_folds)
+        n_rows = len(result_df)
+        result_df["config_idx"] = [i] * n_rows
+        for hyperparameter_name in config._fields:
+            value = getattr(config, hyperparameter_name)
+            result_df[hyperparameter_name] = [value] * n_rows
+        # overwrite existing files for first config
+        file_mode = "a" if i > 0 else "w"
+        # append results to CSV
+        with open(args.results_filename, file_mode) as f:
+            result_df.to_csv(f, index=False)
+        all_dataframes.append(result_df)
+    combined_df = pd.concat(all_dataframes)
+
+    print("\n=== Hyperparameters ===")
+    for hyperparameter_name in config._fields:
+        print("\n%s" % hyperparameter_name)
+        groups = combined_df.groupby(hyperparameter_name)
+        for hyperparameter_value, group in groups:
+            aucs = group["auc_mean"]
+            accuracies = group["accuracy_mean"]
+            unique_configs = group["config_idx"].unique()
+            print(
+                "-- %s (%d): AUC=%0.4f/%0.4f/%0.4f, Acc=%0.4f/%0.4f/%0.4f" % (
+                    hyperparameter_value,
+                    len(unique_configs),
+                    np.percentile(aucs, 25.0),
+                    np.percentile(aucs, 50.0),
+                    np.percentile(aucs, 75.0),
+                    np.percentile(accuracies, 25.0),
+                    np.percentile(accuracies, 50.0),
+                    np.percentile(accuracies, 75.0)))
diff --git a/experiments/results.csv b/experiments/results.csv
new file mode 100644
index 00000000..d4112bd5
--- /dev/null
+++ b/experiments/results.csv
@@ -0,0 +1,120 @@
+allele_name,dataset_size,auc_mean,auc_median,auc_std,auc_min,auc_max,accuracy_mean,accuracy_median,accuracy_std,accuracy_min,accuracy_max,config_idx,embedding_size,hidden_layer_size,activation,loss,init,n_pretrain_epochs,n_epochs,dropout_probability,max_ic50
+A0211,1038,0.4752274242081633,0.5058909394223752,0.06112243314759529,0.3985399915361828,0.5472972972972973,0.3477657004830918,0.34782608695652173,0.016934936574622965,0.3285024154589372,0.375,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B1509,346,0.5419579181000932,0.5651041666666667,0.1530827102576647,0.2626262626262626,0.7058823529411764,0.04612836438923396,0.043478260869565216,0.024531636703871547,0.014492753623188406,0.08571428571428572,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0201,6961,0.6786953335000259,0.6792339114712382,0.013382075367814387,0.6585076468564242,0.6956627680311891,0.4720830667775885,0.4811732065001982,0.020700102722806486,0.4356130977933458,0.4958548024838156,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0205,36,0.45634920634920634,0.5,0.1253616587216771,0.2857142857142857,0.5833333333333333,0.7678571428571428,0.8571428571428571,0.1390871600660467,0.5714285714285714,0.875,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4601,1406,0.485459229164695,0.4951415293620617,0.07410775263975425,0.361407249466951,0.5916976456009913,0.06473334847681785,0.06405693950177936,0.017230888454962318,0.042704626334519574,0.08896797153024912,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4001,2326,0.5970401165610543,0.6043756967670011,0.030145610810665364,0.5452646768436243,0.6384767043685795,0.12616488160559994,0.12108220603537981,0.015369994428072544,0.10755000578101515,0.14963579604578564,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4002,474,0.5627075362132403,0.5731707317073171,0.02172643276290265,0.527179487179487,0.5892857142857143,0.3458902575587906,0.3368421052631579,0.0461267452661712,0.2978723404255319,0.43157894736842106,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B0803,217,0.6615607865380979,0.6428571428571428,0.05022619382473292,0.6071428571428571,0.7441860465116279,0.04154334038054969,0.045454545454545456,0.017464588945508147,0.022727272727272728,0.06976744186046512,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A6802,3188,0.5539368687287517,0.5408994548758328,0.03384487389562976,0.5178667866786679,0.6081171294189045,0.22591803992886592,0.2240887962972062,0.0063726253858333885,0.21735242381658984,0.23571785917603366,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2402,1985,0.49021637587886246,0.4840541721275666,0.030748660773272447,0.4488422892092617,0.5380667236954663,0.16272040302267005,0.15869017632241814,0.011443644021763769,0.15113350125944586,0.17632241813602015,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B5701,1719,0.5660086658990283,0.5810810810810811,0.04068213383409081,0.494983552631579,0.6124422026061369,0.1254869466909069,0.11661807580174927,0.012394950458606994,0.11337209302325581,0.14163061114115738,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B1801,1661,0.6018579105101394,0.6281693880260188,0.07573601505708398,0.4852961868580273,0.6911458333333333,0.11246944796443999,0.10542168674698796,0.030406977148720576,0.07187182464798955,0.16216216216216217,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2602,202,0.5474068368413196,0.5238095238095238,0.05474899468692869,0.49425287356321834,0.648,0.33182926829268294,0.34146341463414637,0.041430043960080536,0.275,0.375,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0219,1203,0.4486568621820883,0.45662205814490076,0.048211124621219556,0.38876560332871013,0.5173519163763065,0.16955739972337483,0.17083333333333334,0.022206503035208772,0.14522821576763487,0.2033195020746888,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0206,3223,0.5926826715144984,0.5935531788472965,0.024386645287943296,0.554904275111529,0.6215946502057613,0.3947451140741851,0.38794543597139597,0.01555802233684022,0.375265909500631,0.42050846803749853,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B2705,2324,0.5047796689623727,0.4934737957993772,0.023232989959665412,0.4839851290484021,0.544467228677755,0.1582402254186905,0.15567580067059775,0.010689881183563909,0.14474026456599287,0.17202913631633715,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B3501,1993,0.5954807107338234,0.6042087542087542,0.04339058434055422,0.5446123842028279,0.6484487734487734,0.2546334239411526,0.25064541051877814,0.021025518007833945,0.22750688113936518,0.2882205513784461,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2301,1513,0.4911816800547356,0.49309551208285385,0.028697830277714432,0.45557491289198604,0.5315766550522647,0.19232400061198174,0.18543046357615894,0.013135933221589437,0.18151815181518152,0.21782178217821782,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0203,3937,0.5993341737308165,0.5963626779124935,0.018158592642259684,0.5773862972306552,0.6324191897030862,0.33513301905046233,0.33924154191038164,0.014660323415106542,0.3146605658339374,0.3573927658633222,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0212,1143,0.40716864411759984,0.41738268341044277,0.04240605050655353,0.33189229249011853,0.4578598484848484,0.24056155673025356,0.25,0.02992556260964635,0.19298245614035087,0.2794759825327511,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B5802,31,0.6127777777777779,0.5,0.19660388200580653,0.375,0.888888888888889,0.2952380952380952,0.3333333333333333,0.13006190746426347,0.14285714285714285,0.5,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B5301,620,0.5805001831618799,0.5924288477479965,0.04941278212693718,0.4886904761904762,0.6303716608594658,0.3403225806451613,0.3387096774193548,0.025194353793247273,0.3064516129032258,0.3790322580645161,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4403,502,0.6054721530480192,0.5950468540829986,0.07850773062064206,0.5009041591320071,0.7418666666666667,0.20120792079207922,0.21,0.03430680353726096,0.15,0.25,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4501,483,0.6577870995963101,0.6695156695156695,0.048511381224192074,0.5666666666666667,0.6987012987012987,0.21327319587628865,0.20618556701030927,0.022579143644246433,0.1875,0.25,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A3002,912,0.5441990182883065,0.5691853233830846,0.045108522986082686,0.48414179104477617,0.5885549872122763,0.25768330030625114,0.26229508196721313,0.007179378290443697,0.2459016393442623,0.26373626373626374,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B0802,486,0.5282541056170611,0.5934782608695652,0.1295964655763792,0.3333333333333333,0.6702127659574468,0.03907006101409636,0.030927835051546393,0.009973536549863176,0.030927835051546393,0.05154639175257732,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A8001,782,0.5731812771308868,0.5696148963182395,0.0699450947916836,0.48550100488084985,0.6682330827067668,0.14450432794381837,0.15286624203821655,0.0216861596036534,0.10897435897435898,0.17307692307692307,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B3901,879,0.40570805986410824,0.41746132655223567,0.044160110044391104,0.34237349836184927,0.46428571428571425,0.19344805194805195,0.1875,0.02999267212465846,0.1590909090909091,0.2342857142857143,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0207,30,0.48124999999999996,0.5,0.24960906934644825,0.125,0.8,0.29166666666666663,0.16666666666666666,0.21650635094610965,0.16666666666666666,0.6666666666666666,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A3101,3309,0.4706008576639079,0.4714637146371463,0.026463110039464513,0.431575743137722,0.5074443860571027,0.2324601036972303,0.2366033533830469,0.015525280849335377,0.20823103111508656,0.2526291938359566,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B1501,3142,0.5072863217265713,0.5087464300285598,0.022087642994642078,0.4779472768633748,0.5408558160106147,0.2774885200555778,0.28341855368882396,0.014403818720090006,0.2576473690616252,0.29574222077974766,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B1502,164,0.3916529143371249,0.37969924812030076,0.034344731400447776,0.3518518518518518,0.455,0.7566287878787878,0.7878787878787878,0.09494286257192637,0.5757575757575758,0.84375,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2601,2395,0.46864866628882745,0.48286663502610355,0.033299992094589545,0.4036697247706422,0.4930192758061611,0.1124646423263497,0.10855949895615867,0.01439014951285748,0.0914832135494528,0.13460105212233212,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B5401,621,0.637386043215034,0.6393617021276596,0.03816522606742487,0.5666284841542573,0.673758865248227,0.22388387096774193,0.22580645161290322,0.01849174071053678,0.192,0.24193548387096775,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0101,3169,0.5308969021645761,0.537469345406527,0.02007219127328141,0.49128637963833555,0.5453790813503043,0.15712500569201052,0.15870393774442973,0.006935786940577993,0.1459911035038661,0.1662935593440299,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0301,4601,0.47658347406935464,0.47548010973936894,0.024011884612868878,0.4419498163353501,0.508007386032535,0.3694693485053866,0.3600803402646503,0.02293847494291814,0.33824938902976864,0.4041304347826087,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2603,205,0.44787215787215795,0.47619047619047616,0.15620684800042128,0.21111111111111114,0.6824324324324325,0.12195121951219512,0.12195121951219512,0.021815297341461357,0.0975609756097561,0.14634146341463414,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B1503,416,0.5668852400213785,0.5960061443932412,0.07340876758369155,0.42370892018779344,0.6291079812206574,0.7980493402180149,0.8095238095238095,0.054718903699418786,0.7228915662650602,0.8554216867469879,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A6801,1700,0.46091826410839143,0.4708695652173913,0.02722439572351732,0.4183197831978319,0.4919982698961938,0.3776470588235294,0.38529411764705884,0.02787763673193596,0.3235294117647059,0.4,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0216,894,0.4561539187863346,0.45812714122573284,0.07061447597949588,0.3446236559139785,0.5377680311890838,0.17899064716590296,0.19662921348314608,0.03051841392231157,0.1340782122905028,0.20670391061452514,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B3801,136,0.43209876543209874,0.2962962962962963,0.3888868596792863,0.038461538461538436,0.9615384615384615,0.03659611992945326,0.037037037037037035,0.0006235509534272905,0.03571428571428571,0.037037037037037035,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B7301,115,0.6286659161272166,0.6514705882352942,0.2170249396682407,0.3026315789473685,0.9090909090909091,0.15217391304347827,0.15217391304347827,0.07838154946660846,0.043478260869565216,0.2608695652173913,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B5101,1336,0.6466870231979104,0.6471906754333533,0.0530375950156576,0.5843579234972678,0.7073891625615764,0.12802280731175583,0.13108614232209737,0.027350225737370962,0.08955223880597014,0.15730337078651685,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B0801,2084,0.5595646246772512,0.564582503650604,0.029550271531495347,0.5063746322327558,0.5893126957803574,0.22105299050902313,0.22302158273381295,0.014761727649122166,0.19617927144970415,0.23980815347721823,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2403,1165,0.5331048970591111,0.5423644338118022,0.04790386800106813,0.4632075471698113,0.5956077630234933,0.20686695278969958,0.22746781115879827,0.03774872691381291,0.1459227467811159,0.2446351931330472,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B1517,846,0.550405682973064,0.5695488721804511,0.051663765367253624,0.46955576265921095,0.6215126592485083,0.320341106856944,0.3254437869822485,0.021107926836311696,0.28402366863905326,0.3431952662721893,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4402,1295,0.630661595900837,0.6169859985261607,0.03681854003149678,0.5952459557609773,0.6895964997569275,0.08175787480806786,0.0888030888030888,0.014011779298965807,0.06516003041099566,0.10038610038610038,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B0702,2974,0.604133546558787,0.590381879855564,0.03464991232073981,0.5610647181628392,0.6610153108567893,0.2031603419771873,0.2013049925852694,0.006676549059310762,0.19360836195853032,0.21271096673963702,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B5801,2444,0.5923824420752225,0.5936062851303014,0.026322302773618907,0.5457818274499808,0.6265133171912833,0.1685313399140081,0.16528452122565562,0.012154544537313077,0.15682855123556694,0.19051024376779957,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A3001,1949,0.6266491768092188,0.6246437493402766,0.015326844593368704,0.6036420863309353,0.6511036155985561,0.292225294886161,0.28717948717948716,0.028128649396647828,0.25000657462195924,0.33676092544987146,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A1101,3862,0.5509774961706325,0.5457070468636893,0.01719498617767373,0.5261896693614756,0.5762351915686041,0.32343581689765666,0.32142205650269695,0.010052986329009302,0.3090896006721013,0.33610298263040617,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0250,132,0.5392389931256327,0.5476190476190476,0.12029285219601231,0.33333333333333337,0.7100591715976332,0.6655270655270655,0.7037037037037037,0.10960110627513293,0.5,0.7777777777777778,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A0202,2314,0.5441887900386823,0.539284245166598,0.023793518210466033,0.5215302925117736,0.5874214550685138,0.46370742288668854,0.4427645788336933,0.034523410524400625,0.4298056155507559,0.5226781857451404,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A3301,1616,0.42782240333789245,0.4412852112676056,0.023965884858028356,0.39720223820943246,0.4537375415282392,0.13909291463886383,0.13312693498452013,0.019663530401905917,0.12309137440213172,0.17647058823529413,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A6901,2079,0.5144423240051859,0.518922626098884,0.030975147831123814,0.4667644183773216,0.554656512771696,0.10866050697138736,0.10576923076923077,0.019036407942023808,0.08154449121788358,0.1411566198224852,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2501,519,0.4461564329411273,0.43579454253611566,0.07948608502246761,0.35080304311073546,0.5827715355805245,0.12718446601941746,0.125,0.011403258864118122,0.11538461538461539,0.14423076923076922,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+B4801,861,0.36913410573443145,0.345388788426763,0.06452815755612938,0.289171974522293,0.468984321745058,0.07900255410673478,0.08139534883720931,0.022907103346930243,0.05232558139534884,0.11627906976744186,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A3201,575,0.4513633123853187,0.45644283121597096,0.039913711178527427,0.3764455264759586,0.4863636363636364,0.4782608695652174,0.4782608695652174,0.03849729325422376,0.41739130434782606,0.5304347826086957,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+A2902,1839,0.5665509818533655,0.5571045998739761,0.023313890164006083,0.5444937645318115,0.6113782991202346,0.25587822990012515,0.25271739130434784,0.018522150990840636,0.2270350897920605,0.2826086956521739,0,0,64,relu,mse,glorot_uniform,0,1,0,5000
+allele_name,dataset_size,auc_mean,auc_median,auc_std,auc_min,auc_max,accuracy_mean,accuracy_median,accuracy_std,accuracy_min,accuracy_max,config_idx,embedding_size,hidden_layer_size,activation,loss,init,n_pretrain_epochs,n_epochs,dropout_probability,max_ic50
+A0211,1038,0.535778864902839,0.5203836930455635,0.06926152103536534,0.47256958450988296,0.6626232741617358,0.34917102592687876,0.34782608695652173,0.017272033632816385,0.3285024154589372,0.37740384615384615,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B1509,346,0.41199920671222257,0.43939393939393934,0.190784900523725,0.17647058823529416,0.6197916666666666,0.04612836438923396,0.043478260869565216,0.024531636703871547,0.014492753623188406,0.08571428571428572,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0201,6961,0.7330178517135651,0.7378165683261779,0.016532288615636364,0.7031860363192897,0.7532285575048734,0.5260819006027354,0.5305842499009116,0.014480401804528482,0.5027650816898563,0.5466334720570749,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0205,36,0.1111111111111111,0.0,0.15713484026367722,0.0,0.3333333333333333,0.7678571428571428,0.8571428571428571,0.1390871600660467,0.5714285714285714,0.875,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4601,1406,0.47511910436589355,0.49333688699360345,0.05164527997195994,0.4001404001404002,0.5430925221799746,0.085125739038351,0.08267372500348273,0.019148636294221994,0.052468940362964,0.1053051506439888,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4001,2326,0.6031217679396037,0.6060606060606061,0.05604256627423693,0.5409119581868973,0.7013192121887775,0.19569649800182012,0.1838224072147069,0.024705966939634103,0.16620187304890738,0.23741935483870968,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4002,474,0.5642831005722428,0.555103884372177,0.023908260499132642,0.5384199134199134,0.5982142857142857,0.3458902575587906,0.3368421052631579,0.0461267452661712,0.2978723404255319,0.43157894736842106,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B0803,217,0.46863247170677685,0.47560975609756095,0.24826285927981953,0.16279069767441856,0.8809523809523809,0.04154334038054969,0.045454545454545456,0.017464588945508147,0.022727272727272728,0.06976744186046512,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A6802,3188,0.45065173751915444,0.4542754275427543,0.01807242448229544,0.42743566176470593,0.47704593950043933,0.33575029831540193,0.34702882243688643,0.025765575058827178,0.29476911586953747,0.3653310134583963,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2402,1985,0.43502949858243467,0.45238095238095233,0.0374930359929341,0.3747487986020096,0.4708169506334644,0.19336205419741256,0.1930790754335095,0.019596693068205915,0.16695112588748104,0.22034909173968492,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B5701,1719,0.5324206475255974,0.5239301801801801,0.023678253687709285,0.5009075907590759,0.5716282894736842,0.15566729413161545,0.15462094875434557,0.012067984828405959,0.14081936181719848,0.17725797728501894,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B1801,1661,0.5661277632032762,0.5729457055621929,0.05436526036669351,0.46855213170113974,0.6195833333333334,0.1367802286560354,0.1410763536071999,0.028929388997801012,0.09781898679053563,0.18448177907637367,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2602,202,0.4496168051616328,0.43466666666666665,0.09176973488682882,0.3247126436781609,0.5893333333333333,0.33182926829268294,0.34146341463414637,0.041430043960080536,0.275,0.375,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0219,1203,0.4399296131458562,0.4565510479225395,0.032715664220692615,0.380998613037448,0.47498257839721253,0.17392116361919846,0.17083333333333334,0.022463038123716493,0.1481723799521358,0.2107057385375596,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0206,3223,0.5848626464742674,0.5946920182214299,0.02673926408657368,0.5366161616161615,0.6144506329113925,0.40398531170972707,0.40084129559521664,0.013912386481860679,0.3859744005768884,0.4287836117433741,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B2705,2324,0.5505542783702545,0.5398110661268556,0.025971963387274866,0.5190534067348253,0.5833143464722412,0.23314773745090092,0.22696195005945302,0.010985996939380712,0.22439125910509886,0.2542675453809689,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B3501,1993,0.5288436900855231,0.5264790764790764,0.023041191893061468,0.49592068909475057,0.5674161166116611,0.2716637726464288,0.26920685171575554,0.018477325945791826,0.24402161561576727,0.2967129603457265,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2301,1513,0.4612228025936888,0.4605911330049261,0.03206971488741619,0.41572299651567945,0.5026392961876833,0.201961013387883,0.1979299153545897,0.015960908267460206,0.18751370553923075,0.23085971963532986,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0203,3937,0.6268966157508354,0.6163760193018355,0.015001752645097241,0.6128930817610063,0.6472023495745161,0.37603396676042466,0.372207843789405,0.011850760011011702,0.36624558736375584,0.39863796864227946,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0212,1143,0.4825450318907448,0.47532574125371907,0.014105206276359297,0.46676136363636367,0.5027539105529852,0.24563187013626636,0.25,0.029168831418613553,0.19836872883964296,0.2794759825327511,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B5802,31,0.6666666666666667,0.6666666666666667,0.24152294576982397,0.25,1.0,0.2952380952380952,0.3333333333333333,0.13006190746426347,0.14285714285714285,0.5,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B5301,620,0.5480056491086901,0.5590575275397798,0.024532960400518715,0.5068181818181818,0.575595238095238,0.3403225806451613,0.3387096774193548,0.025194353793247273,0.3064516129032258,0.3790322580645161,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4403,502,0.514630237956788,0.5334538878842676,0.08075361662285097,0.40276179516685845,0.6026666666666667,0.20120792079207922,0.21,0.03430680353726096,0.15,0.25,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4501,483,0.5163789273591906,0.5619212962962963,0.07265756418386628,0.40823211875843457,0.5948051948051949,0.21327319587628865,0.20618556701030927,0.022579143644246433,0.1875,0.25,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A3002,912,0.5299028927070624,0.5471082089552239,0.03767455768857478,0.48019323671497594,0.5775255754475703,0.25872182222389395,0.26229508196721313,0.00827223392491466,0.2459016393442623,0.26892887332447774,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B0802,486,0.5349014751375171,0.4078014184397163,0.21354956062126967,0.37173913043478257,0.9290780141843972,0.041004379220069366,0.04059942608141141,0.009107055900030383,0.030927835051546393,0.05154639175257732,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A8001,782,0.5587634494611649,0.5561868686868687,0.042859680439655205,0.5002115954295387,0.6274509803921569,0.14627316237929247,0.15384615384615385,0.02263595155033026,0.10897435897435898,0.17307692307692307,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B3901,879,0.5263933644155132,0.5632549268912905,0.04757921955535455,0.4555477363359136,0.5664361121223153,0.19555290731995278,0.19105113636363635,0.029766216511244947,0.1590909090909091,0.2342857142857143,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0207,30,0.65,0.6000000000000001,0.2598076211353316,0.4,1.0,0.29166666666666663,0.16666666666666666,0.21650635094610965,0.16666666666666666,0.6666666666666666,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A3101,3309,0.5712716492283991,0.5721927284427284,0.01636295625025263,0.5536095371256163,0.5981088560885608,0.3662244138413136,0.351041885342412,0.02711823838238824,0.33919004025154936,0.4113832935473461,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B1501,3142,0.5125756548149105,0.5095142513046969,0.01115095767972765,0.4988209662950147,0.5323557054400707,0.36231281179536057,0.36533936468010875,0.012744160921505486,0.34475193420297695,0.3811970059637308,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B1502,164,0.5966571842712194,0.5934065934065935,0.014538517086794751,0.575187969924812,0.6172839506172839,0.7566287878787878,0.7878787878787878,0.09494286257192637,0.5757575757575758,0.84375,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2601,2395,0.4451717634521584,0.4165903540903541,0.06044668041130616,0.37417324514614897,0.5416231608922638,0.20013249593577434,0.1968174824900519,0.02628453960308532,0.16342327657219066,0.23907671253176196,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B5401,621,0.4981825075006247,0.5109929078014185,0.027248731148211174,0.4482269503546099,0.5211912943871707,0.22388387096774193,0.22580645161290322,0.01849174071053678,0.192,0.24193548387096775,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0101,3169,0.616283916366293,0.6167940011318619,0.03801113037149527,0.5698510922091269,0.6836146467441432,0.3604200548799734,0.3607086792999059,0.007518857196019624,0.3489735194896954,0.37236413935853674,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0301,4601,0.5422641545824203,0.5410339041360378,0.012727403181484784,0.5272640638494297,0.565437013246481,0.5790190571302555,0.593218336483932,0.024765107806535017,0.5357687261049631,0.6018241965973535,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2603,205,0.40012012012012005,0.41428571428571426,0.07443041055791036,0.304054054054054,0.4857142857142857,0.12195121951219512,0.12195121951219512,0.021815297341461357,0.0975609756097561,0.14634146341463414,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B1503,416,0.44203340754451637,0.4079710144927537,0.12301792526761235,0.2472426470588235,0.5903755868544601,0.7980493402180149,0.8095238095238095,0.054718903699418786,0.7228915662650602,0.8554216867469879,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A6801,1700,0.4911823909434065,0.484284890426759,0.022907961418758362,0.46663501223565507,0.5252173913043477,0.3795916955017301,0.39001730103806226,0.027172027113974196,0.32664359861591696,0.40058823529411763,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0216,894,0.4607577383534484,0.46364674533688616,0.05923833561377449,0.35804195804195804,0.5348306052531405,0.1811454918449491,0.20003787400580736,0.030355752143960683,0.13816672388502232,0.20998096189257515,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B3801,136,0.5968660968660968,0.4444444444444444,0.2878691465695972,0.34615384615384615,1.0,0.03659611992945326,0.037037037037037035,0.0006235509534272905,0.03571428571428571,0.037037037037037035,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B7301,115,0.5021296556900272,0.43122009569377995,0.1655657266469078,0.3627450980392157,0.7833333333333333,0.15217391304347827,0.15217391304347827,0.07838154946660846,0.043478260869565216,0.2608695652173913,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B5101,1336,0.5002448407314062,0.5153688524590164,0.05604026361649571,0.40944411237298267,0.5669841269841269,0.13823343981883976,0.13937634137103902,0.0244238872458405,0.11078848069127074,0.1701384505323402,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B0801,2084,0.5259325147397333,0.5307646356033453,0.03596644425614314,0.4792703150912106,0.583554196398233,0.2446660192917566,0.24478834198827987,0.014229140164107121,0.22406619822485208,0.26851037155886803,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2403,1165,0.494176785755462,0.5033707865168539,0.0350257306424301,0.4298180004986288,0.529350104821803,0.21158614083884397,0.22746781115879827,0.03559934098422238,0.15504061596271804,0.24682716572417984,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B1517,846,0.5614535926183779,0.5538911845730027,0.02181726073104921,0.5285803044423734,0.5881309466215126,0.32136348002314963,0.3254437869822485,0.01937693511090932,0.28913553447008156,0.3431952662721893,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4402,1295,0.5622256516747279,0.5653292181069959,0.0643068520617266,0.48894620486366985,0.6730675741370928,0.09266111119392972,0.10467941742073016,0.01737741233144423,0.06516003041099566,0.10785468314425843,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B0702,2974,0.571965202809532,0.5771542725505723,0.01466255928999966,0.5467761149579331,0.5862280701754387,0.31977692029135446,0.3180001412329638,0.014399737692419822,0.29733493397358945,0.33651894931356213,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B5801,2444,0.5384711840374043,0.5623502710881352,0.038048529713611605,0.47199566713393654,0.569526060296372,0.25275198769205376,0.2525374182945036,0.008964400515103197,0.24152207459821595,0.2639375044433571,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A3001,1949,0.6488440633840044,0.6472572341558672,0.038602774095513395,0.5924844305267232,0.7051657527417746,0.3042702544539587,0.30449704142011835,0.027904896996150383,0.2603155818540434,0.345992955373015,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A1101,3862,0.6523621423686827,0.6598313910240516,0.01748704304803498,0.6192285903125295,0.667562106264115,0.4151258952801193,0.4163903728959167,0.007931870566692758,0.4010198668181795,0.42351472522752287,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0250,132,0.5251364649745217,0.5515151515151515,0.11437923501335094,0.3412698412698413,0.6627218934911243,0.6655270655270655,0.7037037037037037,0.10960110627513293,0.5,0.7777777777777778,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A0202,2314,0.5913396064097032,0.5881452070334656,0.017876108661417797,0.5635026361708915,0.6155083655083655,0.4640929338673724,0.4437535277955301,0.03418423044108848,0.43071526200150206,0.5224822618942104,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A3301,1616,0.46765939339654705,0.470514950166113,0.06518250747712884,0.37197544239797764,0.5444542253521127,0.16568748188292043,0.16493017281867936,0.017690128495328798,0.14657477786617334,0.19650336914951738,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A6901,2079,0.4609815133554065,0.4624353139223144,0.013217267261997154,0.44104349951124144,0.4796282026584473,0.15711083386354285,0.15315273668639054,0.017682145650575992,0.13220496443605748,0.18276164940828402,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2501,519,0.42240214118676034,0.41251056635672023,0.084546696598235,0.33967391304347827,0.5714285714285714,0.13151995751625417,0.13221153846153846,0.011228211627623524,0.11538461538461539,0.14423076923076922,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+B4801,861,0.464537087199676,0.4669393319700068,0.06231486138678245,0.3670886075949367,0.5447852760736196,0.08202605549918655,0.08139534883720931,0.021159128336628395,0.05753109789075176,0.11627906976744186,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A3201,575,0.5275963840248374,0.5422885572139303,0.04026619271472424,0.4532483302975106,0.5727272727272728,0.4782608695652174,0.4782608695652174,0.03849729325422376,0.41739130434782606,0.5304347826086957,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
+A2902,1839,0.5358499013255416,0.5411197936423641,0.038279566116431074,0.4747231934731934,0.5813642091997479,0.26957053185104113,0.27287629962192816,0.022964554717268413,0.23150992438563328,0.30151228733459357,1,0,64,relu,mse,glorot_uniform,0,1,0,50000
diff --git a/mhcflurry/data_helpers.py b/mhcflurry/data_helpers.py
index 7d7d5e1f..a6b19d26 100644
--- a/mhcflurry/data_helpers.py
+++ b/mhcflurry/data_helpers.py
@@ -27,6 +27,7 @@ from .amino_acid import amino_acid_letter_indices
 
 AlleleData = namedtuple("AlleleData", "X Y peptides ic50")
 
+
 def hotshot_encoding(peptides, peptide_length):
     """
     Encode a set of equal length peptides as a binary matrix,
@@ -41,6 +42,7 @@ def hotshot_encoding(peptides, peptide_length):
             X[i, j, k] = 1
     return X
 
+
 def index_encoding(peptides, peptide_length):
     """
     Encode a set of equal length peptides as a vector of their
@@ -52,16 +54,51 @@ def index_encoding(peptides, peptide_length):
             X[i, j] = amino_acid_letter_indices[amino_acid]
     return X
 
+
+def indices_to_hotshot_encoding(X, n_indices=None, first_index_value=0):
+    """
+    Given an (n_samples, peptide_length) integer matrix
+    convert it to a binary encoding of shape:
+        (n_samples, peptide_length * n_indices)
+    """
+    (n_samples, peptide_length) = X.shape
+    if not n_indices:
+        n_indices = X.max() - first_index_value + 1
+
+    X_binary = np.zeros((n_samples, peptide_length * n_indices), dtype=bool)
+    for i, row in enumerate(X):
+        for j, xij in enumerate(row):
+            X_binary[i, n_indices * j + xij - first_index_value] = 1
+    return X_binary
+
+
+def _infer_csv_separator(filename):
+    """
+    Determine if file is separated by comma, tab, or whitespace.
+    Default to whitespace if the others are not detected.
+
+    Returns (sep, delim_whitespace)
+    """
+    for candidate in [",", "\t"]:
+        with open(filename, "r") as f:
+            for line in f:
+                if line.startswith("#"):
+                    continue
+                if candidate in line:
+                    return candidate, False
+    return None, True
+
+
 def load_data(
         filename,
         peptide_length=9,
         max_ic50=5000.0,
         binary_encoding=True,
         flatten_binary_encoding=True,
-        sep="\t",
+        sep=None,
         species_column_name="species",
         allele_column_name="mhc",
-        peptide_column_name="sequence",
+        peptide_column_name=None,
         peptide_length_column_name="peptide_length",
         ic50_column_name="meas"):
     """
@@ -93,8 +130,34 @@ def load_data(
     flatten_features : bool
         If False, returns a (n_samples, peptide_length, 20) matrix, otherwise
         returns the 2D flattened version of the same data.
+
+    sep : str, optional
+        Separator in CSV file, default is to let Pandas infer
+
+    peptide_column_name : str, optional
+        Default behavior is to try {"sequence", "peptide", "peptide_sequence"}
     """
-    df = pd.read_csv(filename, sep=sep)
+    if sep is None:
+        sep, delim_whitespace = _infer_csv_separator(filename)
+    else:
+        delim_whitespace = False
+    df = pd.read_csv(
+        filename,
+        sep=sep,
+        delim_whitespace=delim_whitespace,
+        engine="c")
+    # hack: get around variability of column naming by checking if
+    # the peptide_column_name is actually present and if not try "peptide"
+    if peptide_column_name is None:
+        columns = set(df.keys())
+        for candidate in ["sequence", "peptide", "peptide_sequence"]:
+            if candidate in columns:
+                peptide_column_name = candidate
+                break
+        if peptide_column_name is None:
+            raise ValueError(
+                "Couldn't find peptide column name, candidates: %s" % (
+                    columns))
     human_mask = df[species_column_name] == "human"
     length_mask = df[peptide_length_column_name] == peptide_length
     df = df[human_mask & length_mask]
diff --git a/mhcflurry/feedforward.py b/mhcflurry/feedforward.py
index f9072d8c..555c5fbc 100644
--- a/mhcflurry/feedforward.py
+++ b/mhcflurry/feedforward.py
@@ -48,6 +48,7 @@ def compile_forward_predictor(model, theano_mode=None):
         allow_input_downcast=True,
         mode=theano_mode)
 
+
 def make_network(
         input_size,
         embedding_input_dim=None,
@@ -111,6 +112,7 @@ def make_network(
         compile_forward_predictor(model)
     return model
 
+
 def make_hotshot_network(
         peptide_length=9,
         layer_sizes=[500],
@@ -130,6 +132,7 @@ def make_hotshot_network(
         dropout_probability=dropout_probability,
         optimizer=optimizer)
 
+
 def make_embedding_network(
         peptide_length=9,
         embedding_input_dim=20,
diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py
index 3a2061a8..218850c4 100644
--- a/mhcflurry/mhc1_binding_predictor.py
+++ b/mhcflurry/mhc1_binding_predictor.py
@@ -41,6 +41,7 @@ from .paths import CLASS1_MODEL_DIRECTORY
 
 _allele_model_cache = {}
 
+
 class Mhc1BindingPredictor(object):
     def __init__(
             self,
-- 
GitLab