Skip to content
Snippets Groups Projects
Commit d9cbbcd3 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

starting ligandome predictor

parent c8976b0b
No related branches found
No related tags found
No related merge requests found
from .hyperparameters import HyperparameterDefaults
from .class1_neural_network import Class1NeuralNetwork
class Class1LigandomePredictor(object):
network_hyperparameter_defaults = HyperparameterDefaults(
retrain_mode="all",
)
def __init__(self, class1_affinity_predictor):
if not class1_affinity_predictor.pan_allele_models:
raise NotImplementedError("Pan allele models required")
if class1_affinity_predictor.allele_to_allele_specific_models:
raise NotImplementedError("Only pan allele models are supported")
self.binding_predictors = class1_affinity_predictor.pan_allele_models
self.network = None
self.network = Class1NeuralNetwork.merge(
self.binding_predictors, merge_method="sum")
def make_network(self):
import keras
import keras.backend as K
from keras.layers import Input
from keras.models import Model
models = self.binding_predictors
if len(models) == 1:
return models[0]
assert len(models) > 1
result = Class1NeuralNetwork(**dict(models[0].hyperparameters))
# Remove hyperparameters that are not shared by all models.
for model in models:
for (key, value) in model.hyperparameters.items():
if result.hyperparameters.get(key, value) != value:
del result.hyperparameters[key]
assert result._network is None
networks = [model.network() for model in models]
layer_names = [[layer.name for layer in network.layers] for network in
networks]
pan_allele_layer_names = ['allele', 'peptide', 'allele_representation',
'flattened_0', 'allele_flat', 'allele_peptide_merged', 'dense_0',
'dropout_0', 'dense_1', 'dropout_1', 'output', ]
if all(names == pan_allele_layer_names for names in layer_names):
# Merging an ensemble of pan-allele architectures
network = networks[0]
peptide_input = Input(
shape=tuple(int(x) for x in K.int_shape(network.inputs[0])[1:]),
dtype='float32', name='peptide')
allele_input = Input(shape=(1,), dtype='float32', name='allele')
allele_embedding = network.get_layer("allele_representation")(
allele_input)
peptide_flat = network.get_layer("flattened_0")(peptide_input)
allele_flat = network.get_layer("allele_flat")(allele_embedding)
allele_peptide_merged = network.get_layer("allele_peptide_merged")(
[peptide_flat, allele_flat])
sub_networks = []
for (i, network) in enumerate(networks):
layers = network.layers[
pan_allele_layer_names.index("allele_peptide_merged") + 1:]
node = allele_peptide_merged
for layer in layers:
layer.name += "_%d" % i
node = layer(node)
sub_networks.append(node)
if merge_method == 'average':
output = keras.layers.average(sub_networks)
elif merge_method == 'sum':
output = keras.layers.add(sub_networks)
elif merge_method == 'concatenate':
output = keras.layers.concatenate(sub_networks)
else:
raise NotImplementedError("Unsupported merge method",
merge_method)
result._network = Model(inputs=[peptide_input, allele_input],
outputs=[output], name="merged_predictor")
result.update_network_description()
else:
raise NotImplementedError(
"Don't know merge_method to merge networks with layer names: ",
layer_names)
return result
def fit(self, peptides, labels, experiment_names,
experiment_name_to_alleles):
pass
def predict(self, allele_lists, peptides):
pass
"""
Idea:
- take an allele where MS vs. no-MS trained predictors are very different. One
possiblility is DLA-88*501:01 but human would be better
- generate synethetic multi-allele MS by combining single-allele MS for differnet
alleles, including the selected allele
- train ligandome predictor based on the no-ms pan-allele models on theis
synthetic dataset
- see if the pan-allele predictor learns the "correct" motif for the selected
allele, i.e. updates to become more similar to the with-ms pan allele predictor.
"""
from sklearn.metrics import roc_auc_score
import pandas
import argparse
import sys
from numpy.testing import assert_, assert_equal
import numpy
from random import shuffle
from mhcflurry import Class1AffinityPredictor,Class1NeuralNetwork
from mhcflurry.allele_encoding import AlleleEncoding
from mhcflurry.class1_ligandome_predictor import Class1LigandomePredictor
from mhcflurry.downloads import get_path
from mhcflurry.testing_utils import cleanup, startup
from mhcflurry.amino_acid import COMMON_AMINO_ACIDS
COMMON_AMINO_ACIDS = sorted(COMMON_AMINO_ACIDS)
PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
def setup():
global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
global PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF
global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
startup()
PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = Class1AffinityPredictor.load(
get_path("models_class1_pan", "models.no_mass_spec"))
PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = pandas.read_csv(
get_path(
"models_class1_pan",
"models.with_mass_spec/frequency_matrices.csv.bz2"))
PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = pandas.read_csv(
get_path(
"models_class1_pan",
"models.no_mass_spec/frequency_matrices.csv.bz2"))
def teardown():
global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
global PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF
global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
cleanup()
def sample_peptides_from_pssm(pssm, count):
result = pandas.DataFrame(
index=numpy.arange(count),
columns=pssm.index,
dtype=object,
)
for (position, vector) in pssm.iterrows():
result.loc[:, position] = numpy.random.choice(
pssm.columns,
size=count,
replace=True,
p=vector.values)
return result.apply("".join, axis=1)
def scramble_peptide(peptide):
lst = list(peptide)
shuffle(lst)
return "".join(lst)
def test_synthetic_allele_refinement():
refine_allele = "HLA-C*01:02"
alleles = [
"HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
"HLA-A*03:01", "HLA-B*15:01", refine_allele
]
peptides_per_allele = [
2000, 1000, 500,
1500, 1200, 800,
]
allele_to_peptides = dict(zip(alleles, peptides_per_allele))
length = 9
train_with_ms = pandas.read_csv(
get_path("data_curated", "curated_training_data.with_mass_spec.csv.bz2"))
train_no_ms = pandas.read_csv(get_path("data_curated",
"curated_training_data.no_mass_spec.csv.bz2"))
def filter_df(df):
df = df.loc[
(df.allele.isin(alleles)) &
(df.peptide.str.len() == length)
]
return df
train_with_ms = filter_df(train_with_ms)
train_no_ms = filter_df(train_no_ms)
ms_specific = train_with_ms.loc[
~train_with_ms.peptide.isin(train_no_ms.peptide)
]
train_peptides = []
train_true_alleles = []
for allele in alleles:
peptides = ms_specific.loc[ms_specific.allele == allele].peptide.sample(
n=allele_to_peptides[allele])
train_peptides.extend(peptides)
train_true_alleles.extend([allele] * len(peptides))
hits_df = pandas.DataFrame({"peptide": train_peptides})
hits_df["true_allele"] = train_true_alleles
hits_df["hit"] = 1.0
decoys_df = hits_df.copy()
decoys_df["peptide"] = decoys_df.peptide.map(scramble_peptide)
decoys_df["true_allele"] = ""
decoys_df["hit"] = 0.0
train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
predictor = Class1LigandomePredictor(PAN_ALLELE_PREDICTOR_NO_MASS_SPEC)
predictor.fit(
peptides=train_df.peptide.values,
labels=train_df.hit.values,
experiment_names=["experiment1"] * len(train_df),
experiment_name_to_alleles={
"experiment1": alleles,
}
)
predictions = predictor.predict(
peptides=train_df.peptide.values,
alleles=alleles,
output_format="concatenate"
)
print(predictions)
import ipdb ; ipdb.set_trace()
"""
def test_simple_synethetic(
num_peptide_per_allele_and_length=100, lengths=[8,9,10,11]):
alleles = [
"HLA-A*02:01", "HLA-B*52:01", "HLA-C*07:01",
"HLA-A*03:01", "HLA-B*57:02", "HLA-C*03:01",
]
cutoff = PAN_ALLELE_MOTIFS_DF.cutoff_fraction.min()
peptides_and_alleles = []
for allele in alleles:
sub_df = PAN_ALLELE_MOTIFS_DF.loc[
(PAN_ALLELE_MOTIFS_DF.allele == allele) &
(PAN_ALLELE_MOTIFS_DF.cutoff_fraction == cutoff)
]
assert len(sub_df) > 0, allele
for length in lengths:
pssm = sub_df.loc[
sub_df.length == length
].set_index("position")[COMMON_AMINO_ACIDS]
peptides = sample_peptides_from_pssm(pssm, num_peptide_per_allele_and_length)
for peptide in peptides:
peptides_and_alleles.append((peptide, allele))
hits_df = pandas.DataFrame(
peptides_and_alleles,
columns=["peptide", "allele"]
)
hits_df["hit"] = 1
decoys = hits_df.copy()
decoys["peptide"] = decoys.peptide.map(scramble_peptide)
decoys["hit"] = 0.0
train_df = pandas.concat([hits_df, decoys], ignore_index=True)
return train_df
return
pass
"""
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
"--alleles",
nargs="+",
default=None,
help="Which alleles to test")
if __name__ == '__main__':
# If run directly from python, leave the user in a shell to explore results.
setup()
args = parser.parse_args(sys.argv[1:])
result = test_synthetic_allele_refinement()
# Leave in ipython
import ipdb # pylint: disable=import-error
ipdb.set_trace()
......@@ -30,7 +30,6 @@ def teardown():
def test_merge():
assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) > 1
peptides = random_peptides(100, length=9)
peptides.extend(random_peptides(100, length=10))
peptides = pandas.Series(peptides).sample(frac=1.0)
......
......@@ -22,11 +22,11 @@ from mhcflurry.testing_utils import cleanup, startup
ALLELE_SPECIFIC_PREDICTOR = None
PAN_ALLELE_PREDICTOR = None
PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
def setup():
global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR
global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
startup()
ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
get_path("models_class1", "models"))
......@@ -36,7 +36,7 @@ def setup():
def teardown():
global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR
global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
ALLELE_SPECIFIC_PREDICTOR = None
PAN_ALLELE_PREDICTOR = None
cleanup()
......@@ -97,7 +97,7 @@ def test_speed_allele_specific(profile=False, num=DEFAULT_NUM_PREDICTIONS):
def test_speed_pan_allele(profile=False, num=DEFAULT_NUM_PREDICTIONS):
global PAN_ALLELE_PREDICTOR
global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
starts = collections.OrderedDict()
timings = collections.OrderedDict()
profilers = collections.OrderedDict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment