Skip to content
Snippets Groups Projects
Commit 52a88ace authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

updates to dataset size sensitivity script

parent f0fdbc98
No related branches found
No related tags found
No related merge requests found
......@@ -382,7 +382,8 @@ class Dataset(object):
Get Dataset for a single allele
"""
if allele_name not in self.unique_alleles():
raise KeyError("Allele '%s' not found" % (allele_name,))
raise KeyError("Allele '%s' not found, available alleles: %s" % (
allele_name, list(sorted(self.unique_alleles()))))
df = self.to_dataframe()
df_allele = df[df.allele == allele_name]
return self.__class__(df_allele)
......
......@@ -25,6 +25,7 @@ from fancyimpute.iterative_svd import IterativeSVD
from fancyimpute.simple_fill import SimpleFill
from fancyimpute.soft_impute import SoftImpute
from fancyimpute.mice import MICE
from fancyimpute.biscaler import BiScaler
def check_dense_pMHC_array(X, peptide_list, allele_list):
......@@ -134,6 +135,8 @@ def imputer_from_name(imputation_method_name, **kwargs):
kwargs["rank"] = kwargs.get("rank", 10)
return IterativeSVD(**kwargs)
elif imputation_method_name == "svt" or imputation_method_name == "softimpute":
kwargs["init_fill_method"] = kwargs.get("init_fill_method", "min")
kwargs["normalizer"] = kwargs.get("normalizer", BiScaler())
return SoftImpute(**kwargs)
elif imputation_method_name == "mean":
return SimpleFill("mean", **kwargs)
......
......@@ -70,6 +70,11 @@ parser.add_argument(
type=int,
default=500)
parser.add_argument(
"--min-observations-per-peptide",
type=int,
default=2)
parser.add_argument(
"--sample-censored-affinities",
default=False,
......@@ -86,6 +91,11 @@ parser.add_argument(
action="store_true",
default=False)
parser.add_argument(
"--pretraining-weight-decay",
choices=("exponential", "quadratic", "linear"),
default="quadratic",
help="Rate at which weight of imputed samples decays")
"""
parser.add_argument(
"--remove-similar-peptides-from-test-data",
......@@ -111,11 +121,13 @@ def subsample_performance(
imputer=None,
min_training_samples=20,
max_training_samples=3000,
min_observations_per_peptide=2,
n_subsample_sizes=10,
n_repeats_per_size=1,
n_training_epochs=200,
n_random_negative_samples=100,
batch_size=32,
pretrain_weight_decay_fn=lambda t: np.exp(-t),
sample_censored_affinities=False):
dataset_allele = dataset.get_allele(allele)
......@@ -151,7 +163,7 @@ def subsample_performance(
allele=allele,
n_training_samples=n_train,
imputation_method=imputer,
min_observations_per_peptide=3,
min_observations_per_peptide=min_observations_per_peptide,
min_observations_per_allele=1,
stratify_fn=stratify_by_binder_label)
print("=== #%d/%d: Training model for %s with sample_size = %d/%d" % (
......@@ -226,13 +238,15 @@ if __name__ == "__main__":
args = parser.parse_args()
base_filename = \
"%s-vs-nsamples-hidden-%s-activation-%s-impute-%s-epochs-%d-embedding-%d" % (
("%s-vs-nsamples-hidden-%s-activation-%s"
"-impute-%s-epochs-%d-embedding-%d-pretrain-%s") % (
args.allele,
args.hidden_layer_size,
args.activation,
args.imputation_method,
args.training_epochs,
args.embedding_size)
args.embedding_size,
args.pretraining_weight_decay)
csv_filename = base_filename + ".csv"
if args.load_existing_data:
......@@ -244,6 +258,22 @@ if __name__ == "__main__":
def make_model():
return predictor_from_args(allele_name=args.allele, args=args)
if args.pretraining_weight_decay == "exponential":
def pretrain_weight_decay_fn(t):
return np.exp(-t)
elif args.pretraining_weight_decay == "quadratic":
def pretrain_weight_decay_fn(t):
return 1.0 / (t + 1) ** 2.0
elif args.pretraining_weight_decay == "linear":
def pretrain_weight_decay_fn(t):
return 1.0 / (t + 1)
else:
raise ValueError("Invalid weight decay schedule: '%s'" % (
args.pretraining_weight_decay))
results_df = subsample_performance(
dataset=dataset,
allele=args.allele,
......@@ -252,8 +282,10 @@ if __name__ == "__main__":
n_repeats_per_size=args.repeat,
n_training_epochs=args.training_epochs,
batch_size=args.batch_size,
pretrain_weight_decay_fn=pretrain_weight_decay_fn,
min_training_samples=args.min_training_samples,
max_training_samples=args.max_training_samples,
min_observations_per_peptide=args.min_observations_per_peptide,
n_subsample_sizes=args.number_dataset_sizes,
n_random_negative_samples=args.random_negative_samples,
sample_censored_affinities=args.sample_censored_affinities)
......@@ -264,77 +296,57 @@ if __name__ == "__main__":
metrics = ["auc", "f1", "tau"]
if args.seaborn_lmplot:
for score_name in metrics:
seaborn.lmplot(
data=results_df,
x="num_samples",
y=score_name,
hue="impute",
legend=True,
fit_reg=True,
logx=True,
truncate=True,
x_jitter=0.5,
y_jitter=0.01)
seaborn.plt.xlim(
max(-1, results_df["num_samples"].min() - 2),
results_df["num_samples"].max() + 50,
)
seaborn.plt.ylim(0, 1)
seaborn.plt.xlabel("# samples (subset of %s)" % args.allele)
seaborn.plt.ylabel(score_name)
image_filename = "%s-%s.png" % (base_filename, score_name)
print("Writing image to %s" % image_filename)
seaborn.plt.savefig(image_filename)
else:
titles = {
"tau": "Kendall's $\\tau$",
"auc": "AUC",
"f1": "$F_1$ score"
}
pyplot.figure(figsize=(6, 4))
seaborn.set_style("whitegrid")
for (j, score_name) in enumerate(metrics):
ax = pyplot.subplot2grid((1, 3), (0, j))
groups = results_df.groupby(["num_samples", "impute"])
groups_score = groups[score_name].mean().to_frame().reset_index()
groups_score["std_error"] = \
groups[score_name].std().to_frame().reset_index()[score_name]
for impute in [True, False]:
sub = groups_score[groups_score.impute == impute]
color = seaborn.get_color_cycle()[0] if impute else seaborn.get_color_cycle()[1]
pyplot.errorbar(
x=sub.num_samples.values,
y=sub[score_name].values,
yerr=sub.std_error.values,
label=("with" if impute else "without") + " imputation",
color=color)
titles = {
"tau": "Kendall's $\\tau$",
"auc": "AUC",
"f1": "$F_1$ score"
}
pyplot.figure(figsize=(6.5, 3.5))
seaborn.set_style("whitegrid")
for (j, score_name) in enumerate(metrics):
ax = pyplot.subplot2grid((1, 4), (0, j))
groups = results_df.groupby(["num_samples", "impute"])
groups_score = groups[score_name].mean().to_frame().reset_index()
groups_score["std_error"] = \
groups[score_name].std().to_frame().reset_index()[score_name]
for impute in [True, False]:
sub = groups_score[groups_score.impute == impute]
color = seaborn.get_color_cycle()[0] if impute else seaborn.get_color_cycle()[1]
pyplot.errorbar(
x=sub.num_samples.values,
y=sub[score_name].values,
yerr=sub.std_error.values,
label=("with" if impute else "without") + " imputation",
color=color)
if j == 1:
pyplot.xlabel("Training set size")
pyplot.xscale("log")
pyplot.title(titles[score_name])
if score_name == "auc":
pyplot.ylim(ymin=0.5, ymax=1.0)
if score_name == "f1":
pyplot.ylim(ymin=0, ymax=1)
if score_name == "tau":
pyplot.ylim(ymin=0, ymax=0.6)
pyplot.yticks(np.arange(0, 0.61, 0.15))
if j == 0:
pyplot.legend(
loc=(-0.1, 0.05),
fancybox=True,
frameon=True,
fontsize="small")
pyplot.tight_layout()
# Put the legend out of the figure
# pyplot.legend(
# bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fancybox=True, frameon=True)
image_filename = base_filename + ".png"
print("Writing PNG to %s" % image_filename)
pyplot.savefig(image_filename)
pyplot.xscale("log")
pyplot.title(titles[score_name])
if score_name == "auc":
pyplot.ylim(ymin=0.5, ymax=1.0)
if score_name == "f1":
pyplot.ylim(ymin=0, ymax=1)
if score_name == "tau":
pyplot.ylim(ymin=0, ymax=0.6)
pyplot.yticks(np.arange(0, 0.61, 0.15))
pyplot.legend(
bbox_to_anchor=(1.1, 1),
loc=2,
borderaxespad=0.,
fancybox=True,
frameon=True,
fontsize="small")
pyplot.tight_layout()
# Put the legend out of the figure
image_filename = base_filename + ".png"
print("Writing PNG to %s" % image_filename)
pyplot.savefig(image_filename)
pdf_filename = base_filename + ".pdf"
print("Writing PDF to %s" % pdf_filename)
pyplot.savefig(pdf_filename)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment