Skip to content
Snippets Groups Projects
Commit 4a1cad10 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

write out extended test sets with new predictions

parent a9b105db
No related merge requests found
......@@ -24,8 +24,8 @@ predictions on a separate test set
of the test set with mchflurry predictions (writing results to a new directory)
"""
from os import listdir
from os.path import join
from os import listdir, makedirs
from os.path import join, exists
from argparse import ArgumentParser
from itertools import groupby
......@@ -166,6 +166,7 @@ if __name__ == "__main__":
model = make_model(config)
binary_encoding = (args.embedding_size == 0)
training_datasets, _ = load_data(
filename=args.training_csv,
peptide_length=9,
......@@ -185,16 +186,26 @@ if __name__ == "__main__":
shuffle=True)
old_weights = model.get_weights()
if not exists(args.output_dir):
makedirs(args.output_dir)
for filename in listdir(args.input_dir):
filepath = join(args.input_dir, filename)
parts = filename.split(".")
if len(parts) != 2:
print("Skipping %s" % filepath)
continue
allele, ext = parts
allele_name, ext = parts
if ext != "csv":
print("Skipping %s, only reading CSV files" % filepath)
continue
allele_name = normalize_allele_name(allele_name)
if allele_name not in training_datasets:
print("Skipping %s because allele %s not in training data" % (
filepath,
allele_name))
continue
print("Loading %s" % filepath)
df = pd.read_csv(filepath)
......@@ -207,7 +218,6 @@ if __name__ == "__main__":
true_ic50 = list(df["meas"])
model.set_weights(old_weights)
allele_name = normalize_allele_name(filename.split(".")[0])
allele_dataset = training_datasets[allele_name]
X_train = allele_dataset.X
Y_train = allele_dataset.Y
......@@ -242,8 +252,25 @@ if __name__ == "__main__":
Y_pred_mean = np.mean(Y_pred)
Y_pred_ic50 = args.max_ic50 ** (1.0 - Y_pred_mean)
predictions[peptide] = Y_pred_ic50
df[args.predictor_name] = [
predictions[peptide]
for peptide in peptide_sequences
]
print(df[["sequence", "meas", "mhcflurry"]])
pos = df["meas"] <= 500
pred = df[args.predictor_name] <= 500
tp = (pred & pos).sum()
fp = (pred & ~pos).sum()
tn = (~pred & ~pos).sum()
fn = (~pred & pos).sum()
assert (tp + fp + tn + fn) == len(pos), "Expected %d but got %d" % (
len(pos),
(tp + fp + tn + fn))
precision = tp / float(tp + fp)
recall = tp / float(tp + fn)
f1 = 2 * precision * recall / (precision + recall)
print("-- %s: tp=%d fp=%d tn=%d fn=%d P=%0.4f R=%0.4f F1=%0.4f" % (
filename, tp, fp, tn, fn, precision, recall, f1))
output_path = join(args.output_dir, filename)
df.to_csv(output_path, index=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment