Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Tests for training and predicting using Class1 pan-allele models.
"""
import json
import os
import shutil
import tempfile
import subprocess
from copy import deepcopy
from sklearn.metrics import roc_auc_score
import pandas
from numpy.testing import assert_, assert_equal, assert_array_less
from mhcflurry import Class1AffinityPredictor,Class1NeuralNetwork
from mhcflurry.allele_encoding import AlleleEncoding
from mhcflurry.downloads import get_path
HYPERPARAMETERS_LIST = [
{
'activation': 'tanh',
'allele_dense_layer_sizes': [],
'batch_normalization': False,
'dense_layer_l1_regularization': 0.0,
'dense_layer_l2_regularization': 0.0,
'dropout_probability': 0.5,
'early_stopping': True,
'init': 'glorot_uniform',
'layer_sizes': [64],
'learning_rate': None,
'locally_connected_layers': [],
'loss': 'custom:mse_with_inequalities',
'max_epochs': 5,
'minibatch_size': 128,
'optimizer': 'rmsprop',
'output_activation': 'sigmoid',
'patience': 10,
'peptide_allele_merge_activation': '',
'peptide_allele_merge_method': 'concatenate',
'peptide_amino_acid_encoding': 'BLOSUM62',
'peptide_dense_layer_sizes': [],
'peptide_encoding': {
'alignment_method': 'left_pad_centered_right_pad',
'max_length': 15,
'vector_encoding_name': 'BLOSUM62',
},
'random_negative_affinity_max': 50000.0,
'random_negative_affinity_min': 20000.0,
'random_negative_constant': 25,
'random_negative_distribution_smoothing': 0.0,
'random_negative_match_distribution': True,
'random_negative_rate': 0.2,
'train_data': {},
'validation_split': 0.1,
},
{
'activation': 'tanh',
'allele_dense_layer_sizes': [],
'batch_normalization': False,
'dense_layer_l1_regularization': 0.0,
'dense_layer_l2_regularization': 0.0,
'dropout_probability': 0.5,
'early_stopping': True,
'init': 'glorot_uniform',
'layer_sizes': [32],
'learning_rate': None,
'locally_connected_layers': [],
'loss': 'custom:mse_with_inequalities',
'max_epochs': 5,
'minibatch_size': 128,
'optimizer': 'rmsprop',
'output_activation': 'sigmoid',
'patience': 10,
'peptide_allele_merge_activation': '',
'peptide_allele_merge_method': 'concatenate',
'peptide_amino_acid_encoding': 'BLOSUM62',
'peptide_dense_layer_sizes': [],
'peptide_encoding': {
'alignment_method': 'left_pad_centered_right_pad',
'max_length': 15,
'vector_encoding_name': 'BLOSUM62',
},
'random_negative_affinity_max': 50000.0,
'random_negative_affinity_min': 20000.0,
'random_negative_constant': 25,
'random_negative_distribution_smoothing': 0.0,
'random_negative_match_distribution': True,
'random_negative_rate': 0.2,
'train_data': {},
'validation_split': 0.1,
},
]
def run_and_check(n_jobs=0):
models_dir = tempfile.mkdtemp(prefix="mhcflurry-test-models")
hyperparameters_filename = os.path.join(
models_dir, "hyperparameters.yaml")
with open(hyperparameters_filename, "w") as fd:
json.dump(HYPERPARAMETERS_LIST, fd)
args = [
"mhcflurry-class1-train-pan-allele-models",
"--data", get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"),
"--allele-sequences", get_path("allele_sequences", "allele_sequences.csv"),
"--hyperparameters", hyperparameters_filename,
"--out-models-dir", models_dir,
"--num-jobs", str(n_jobs),
"--ensemble-size", "2",
]
print("Running with args: %s" % args)
subprocess.check_call(args)
result = Class1AffinityPredictor.load(models_dir)
predictions = result.predict(
peptides=["SLYNTVATL"],
alleles=["HLA-A*02:01"])
assert_equal(predictions.shape, (1,))
assert_array_less(predictions, 1000)
df = result.predict_to_dataframe(
peptides=["SLYNTVATL"],
alleles=["HLA-A*02:01"])
print(df)
print("Deleting: %s" % models_dir)
shutil.rmtree(models_dir)
if os.environ.get("KERAS_BACKEND") != "theano":
def test_run_parallel():
run_and_check(n_jobs=2)
def test_run_serial():
run_and_check(n_jobs=1)