Skip to content
Snippets Groups Projects
Commit 32d6b970 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

added script to summarize performance on test data of multiple classifiers

parent 6fc1f9cd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
#
# Copyright (c) 2015. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Compute accuracy, AUC, and F1 score for allele-specific test datasets
"""
from os import listdir
from os.path import join
from argparse import ArgumentParser
from collections import defaultdict, OrderedDict
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score
from model_selection_helpers import f1_score
parser = ArgumentParser()
parser.add_argument(
"--test-data-dir",
help="Directory which contains one CSV file per allele",
required=True)
parser.add_argument(
"--true-ic50-column-name",
default="meas")
parser.add_argument(
"--peptide-sequence-column-name",
default="sequence")
parser.add_argument(
"--peptide-length-column-name",
default="length")
if __name__ == "__main__":
args = parser.parse_args()
# mapping from predictor names to dictionaries
results = defaultdict(lambda: OrderedDict([
("allele", []),
("length", []),
("auc", []),
("accuracy", []),
("f1", [])]
))
for filename in listdir(args.test_data_dir):
filepath = join(args.test_data_dir, filename)
parts = filename.split(".")
if len(parts) != 2:
print("Skipping %s" % filepath)
continue
allele, ext = parts
if ext != "csv":
print("Skipping %s, only reading CSV files" % filepath)
continue
df = pd.read_csv(filepath)
columns = set(df.columns)
drop_columns = {
args.true_ic50_column_name,
args.peptide_length_column_name,
args.peptide_sequence_column_name,
}
predictor_names = columns.difference(drop_columns)
true_ic50 = df[args.true_ic50_column_name]
true_label = true_ic50 <= 500
n = len(df)
print("%s (total = %d, n_pos = %d, n_neg = %d)" % (
allele,
n,
true_label.sum(),
n - true_label.sum()))
for predictor in sorted(predictor_names):
pred_ic50 = df[predictor]
pred_label = pred_ic50 <= 500
if true_label.std() == 0:
# can't compute AUC from single class
auc = np.nan
else:
# using negative IC50 since it's inversely related to binding
auc = roc_auc_score(true_label, -pred_ic50)
f1 = f1_score(true_label, pred_label)
accuracy = np.mean(true_label == pred_label)
print("-- %s AUC=%0.4f, acc=%0.4f, F1=%0.4f" % (
predictor,
auc,
accuracy,
f1))
results[predictor]["allele"].append(allele)
results[predictor]["length"].append(n)
results[predictor]["f1"].append(f1)
results[predictor]["accuracy"].append(accuracy)
results[predictor]["auc"].append(auc)
print("\n === Aggregate Results ===\n")
for (predictor, data) in sorted(results.items()):
df = pd.DataFrame(data)
print(predictor)
aucs = df["auc"]
auc_lower = aucs.quantile(0.25)
auc_upper = aucs.quantile(0.75)
auc_iqr = auc_upper - auc_lower
print("-- AUC: median=%0.4f, mean=%0.4f, iqr=%0.4f" % (
aucs.median(),
aucs.mean(),
auc_iqr))
f1s = df["f1"]
f1_lower = f1s.quantile(0.25)
f1_upper = f1s.quantile(0.75)
f1_iqr = f1_upper - f1_lower
print("-- F1: median=%0.4f, mean=%0.4f, iqr=%0.4f" % (
f1s.median(),
f1s.mean(),
f1_iqr))
#!/usr/bin/env python
#
# Copyright (c) 2015. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from os.path import join, exists
from os import makedirs
from argparse import ArgumentParser
from test_data import load_test_data
parser = ArgumentParser()
parser.add_argument(
"--test-data-input-dirs",
nargs='*',
type=str,
help="Multiple directories from other predictors",
required=True)
parser.add_argument(
"--test-data-input-sep",
default="\s+",
help="Separator to use for loading test data CSV/TSV files",
type=str)
parser.add_argument(
"--test-data-output-dir",
help="Save combined test datasets to this directory",
required=True)
if __name__ == "__main__":
args = parser.parse_args()
dataframes, predictor_names = load_test_data(args.test_data_input_dirs)
if not exists(args.test_data_output_dir):
makedirs(args.test_data_output_dir)
print("Loaded test data:")
for (allele, df) in dataframes.items():
df.index.name = "sequence"
print("%s: %d results" % (allele, len(df)))
filename = "blind-%s.csv" % allele
filepath = join(args.test_data_output_dir, filename)
df.to_csv(filepath)
assert False
"""
combined_df = evaluate_model_configs(
configs=configs,
results_filename=args.output,
train_fn=lambda config: evaluate_model_config_train_vs_test(
config,
training_allele_datasets=training_datasets,
testing_allele_datasets=testing_datasets,
min_samples_per_allele=5))
"""
...@@ -19,7 +19,6 @@ from __future__ import ( ...@@ -19,7 +19,6 @@ from __future__ import (
unicode_literals unicode_literals
) )
from collections import OrderedDict from collections import OrderedDict
import logging
import numpy as np import numpy as np
import sklearn import sklearn
...@@ -34,30 +33,27 @@ from mhcflurry.data_helpers import indices_to_hotshot_encoding ...@@ -34,30 +33,27 @@ from mhcflurry.data_helpers import indices_to_hotshot_encoding
from score_collection import ScoreCollection from score_collection import ScoreCollection
def score_predictions(predicted_log_ic50, true_label, max_ic50): def f1_score(true_label, label_pred, cutoff=500):
"""Computes accuracy, AUC, and F1 score of predictions"""
auc = sklearn.metrics.roc_auc_score(true_label, predicted_log_ic50)
ic50_pred = max_ic50 ** (1.0 - predicted_log_ic50)
label_pred = (ic50_pred <= 500)
same_mask = true_label == label_pred
accuracy = np.mean(same_mask)
tp = (true_label & label_pred).sum() tp = (true_label & label_pred).sum()
fp = ((~true_label) & label_pred).sum() fp = ((~true_label) & label_pred).sum()
tn = ((~true_label) & (~label_pred)).sum()
fn = (true_label & (~label_pred)).sum() fn = (true_label & (~label_pred)).sum()
sensitivity = (tp / float(tp + fn)) if (tp + fn) > 0 else 0.0 sensitivity = (tp / float(tp + fn)) if (tp + fn) > 0 else 0.0
precision = (tp / float(tp + fp)) if (tp + fp) > 0 else 0.0 precision = (tp / float(tp + fp)) if (tp + fp) > 0 else 0.0
if (precision + sensitivity) > 0: if (precision + sensitivity) > 0:
f1_score = (2 * precision * sensitivity) / (precision + sensitivity) return (2 * precision * sensitivity) / (precision + sensitivity)
else: else:
f1_score = 0.0 return 0.0
# sanity check that we're computing accuracy correctly
accuracy_estimate2 = (tp + tn) / float(tp + fp + tn + fn)
if abs(accuracy - accuracy_estimate2) > 0.00001: def score_predictions(predicted_log_ic50, true_label, max_ic50):
logging.warn( """Computes accuracy, AUC, and F1 score of predictions"""
"!!! Conflicting accuracy estimates! (%0.5f vs. %0.5f)" % ( auc = sklearn.metrics.roc_auc_score(true_label, predicted_log_ic50)
accuracy, accuracy_estimate2)) ic50_pred = max_ic50 ** (1.0 - predicted_log_ic50)
return accuracy, auc, f1_score label_pred = (ic50_pred <= 500)
same_mask = true_label == label_pred
accuracy = np.mean(same_mask)
f1 = f1_score(true_label, label_pred)
return accuracy, auc, f1
def train_model_and_return_scores( def train_model_and_return_scores(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment