Skip to content
Snippets Groups Projects
Commit e6f26c20 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 214a1474
No related branches found
No related tags found
No related merge requests found
......@@ -116,28 +116,24 @@ class MultiallelicMassSpecBatchGenerator(object):
def plan_from_dataframe(df, hyperparameters):
affinity_fraction = hyperparameters["batch_generator_affinity_fraction"]
batch_size = hyperparameters["batch_generator_batch_size"]
equivalence_columns = ["is_affinity", "is_binder", "experiment_name"]
df["equivalence_key"] = df[equivalence_columns].astype(str).sum(1)
equivalence_map = dict(
(v, i)
for (i, v) in zip(*df.equivalence_key.factorize()))
df["equivalence_class"] = df.equivalence_key.map(equivalence_map)
df["first_allele"] = df.alleles.str.get(0)
equivalence_columns = [
"is_affinity",
"is_binder",
"experiment_name",
"first_allele",
]
df["equivalence_key"] = numpy.where(
df.is_affinity,
df.first_allele,
df.experiment_name,
) + " " + df.is_binder.map({True: "binder", False: "nonbinder"})
(df["equivalence_class"], equivalence_class_labels) = (
df.equivalence_key.factorize())
df["unused"] = True
df["idx"] = df.index
equivalence_class_to_label = dict(
(idx, (
"{first_allele} {binder}" if row.is_affinity else
"{experiment_name} {binder}"
).format(
binder="binder" if row.is_binder else "nonbinder",
**row.to_dict()))
for (idx, row) in df.drop_duplicates(
"equivalence_class").set_index("equivalence_class").iterrows())
df = df.sample(frac=1.0)
#df["key"] = df.is_binder ^ (numpy.arange(len(df)) % 2).astype(bool)
#df = df.sort_values("key")
#del df["key"]
affinities_per_batch = int(affinity_fraction * batch_size)
......@@ -206,10 +202,7 @@ class MultiallelicMassSpecBatchGenerator(object):
return BatchPlan(
equivalence_classes=equivalence_classes,
batch_compositions=batch_compositions,
equivalence_class_labels=[
equivalence_class_to_label[i] for i in
range(len(class_to_indices))
])
equivalence_class_labels=equivalence_class_labels)
def plan(
self,
......
......@@ -137,7 +137,7 @@ def test_large(sample_rate=0.01):
planner = MultiallelicMassSpecBatchGenerator(
hyperparameters=dict(
batch_generator_validation_split=0.2,
batch_generator_batch_size=1024,
batch_generator_batch_size=128,
batch_generator_affinity_fraction=0.5))
s = time.time()
......@@ -168,6 +168,7 @@ def test_large(sample_rate=0.01):
combined_train_df.loc[idx, "kind"] = kind
combined_train_df.loc[idx, "idx"] = idx
combined_train_df.loc[idx, "batch"] = i
import ipdb ; ipdb.set_trace()
combined_train_df["idx"] = combined_train_df.idx.astype(int)
combined_train_df["batch"] = combined_train_df.batch.astype(int)
......
......@@ -22,9 +22,10 @@ import pandas
import argparse
import sys
import copy
from functools import partial
import os
from numpy.testing import assert_, assert_equal, assert_allclose
from nose.tools import assert_greater, assert_less
import numpy
from random import shuffle
......@@ -47,6 +48,14 @@ PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
def data_path(name):
'''
Return the absolute path to a file in the test/data directory.
The name specified should be relative to test/data.
'''
return os.path.join(os.path.dirname(__file__), "data", name)
def setup():
global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
......@@ -341,7 +350,7 @@ def Xtest_real_data_multiallelic_refinement(max_epochs=10):
import ipdb ; ipdb.set_trace()
def test_synthetic_allele_refinement_with_affinity_data(max_epochs=10):
def Xtest_synthetic_allele_refinement_with_affinity_data(max_epochs=10):
refine_allele = "HLA-C*01:02"
alleles = [
"HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
......@@ -744,6 +753,77 @@ def Xtest_synthetic_allele_refinement(max_epochs=10):
return (predictor, predictions, metrics, motifs)
def test_refinemeent_large(sample_rate=0.1):
multi_train_df = pandas.read_csv(
data_path("multiallelic_ms.benchmark1.csv.bz2"))
multi_train_df["label"] = multi_train_df.hit
multi_train_df["is_affinity"] = False
sample_table = multi_train_df.loc[
multi_train_df.label == True
].drop_duplicates("sample_id").set_index("sample_id").loc[
multi_train_df.sample_id.unique()
]
grouped = multi_train_df.groupby("sample_id").nunique()
for col in sample_table.columns:
if (grouped[col] > 1).any():
del sample_table[col]
sample_table["alleles"] = sample_table.hla.str.split()
pan_train_df = pandas.read_csv(
get_path(
"models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
pan_sub_train_df = pan_train_df
pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
del pan_sub_train_df["measurement_value"]
pan_sub_train_df["is_affinity"] = True
pan_sub_train_df = pan_sub_train_df.sample(frac=sample_rate)
multi_train_df = multi_train_df.sample(frac=sample_rate)
pan_predictor = Class1AffinityPredictor.load(
get_path("models_class1_pan", "models.with_mass_spec"),
optimization_level=0,
max_models=1)
allele_encoding = MultipleAlleleEncoding(
experiment_names=multi_train_df.sample_id.values,
experiment_to_allele_list=sample_table.alleles.to_dict(),
max_alleles_per_experiment=sample_table.alleles.str.len().max(),
allele_to_sequence=pan_predictor.allele_to_sequence,
)
allele_encoding.append_alleles(pan_sub_train_df.allele.values)
allele_encoding = allele_encoding.compact()
combined_train_df = pandas.concat(
[multi_train_df, pan_sub_train_df], ignore_index=True, sort=True)
ligandome_predictor = Class1LigandomePredictor(
pan_predictor,
auxiliary_input_features=[],
max_ensemble_size=1,
max_epochs=0,
batch_generator_batch_size=128,
learning_rate=0.0001,
patience=5,
min_delta=0.0,
random_negative_rate=1.0)
fit_results = ligandome_predictor.fit(
peptides=combined_train_df.peptide.values,
labels=combined_train_df.label.values,
allele_encoding=allele_encoding,
affinities_mask=combined_train_df.is_affinity.values,
inequalities=combined_train_df.measurement_inequality.values,
)
batch_generator = fit_results['batch_generator']
train_batch_plan = batch_generator.train_batch_plan
assert_greater(len(train_batch_plan.equivalence_class_labels), 100)
assert_less(len(train_batch_plan.equivalence_class_labels), 1000)
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
"--out-metrics-csv",
......@@ -760,6 +840,8 @@ parser.add_argument(
help="Max epochs")
if __name__ == '__main__':
# If run directly from python, leave the user in a shell to explore results.
setup()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment