Skip to content
Snippets Groups Projects
Commit 4b337b22 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

added cross_validation iterator to Dataset

parent 37e0372b
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ from six import string_types
import pandas as pd
import numpy as np
from typechecks import require_iterable_of
from sklearn.cross_validation import KFold
from .common import geometric_mean
from .dataset_helpers import (
......@@ -603,6 +604,69 @@ class Dataset(object):
right = self.slice(all_indices[n:])
return left, right
def cross_validation_iterator(
self, test_allele=None, n_folds=3, shuffle=True):
"""
Yields a sequence of training/test splits of this dataset.
If test_allele is None then split across all pMHC entries, otherwise
only split the measurements of the specified allele (other alleles
will then always be included in the training datasets).
"""
if test_allele is None:
candidate_test_indices = np.arange(len(self))
elif test_allele not in self.unique_alleles():
raise ValueError("Allele '%s' not in Dataset" % test_allele)
else:
candidate_test_indices = np.where(self.alleles == test_allele)[0]
n_candidate_test_samples = len(candidate_test_indices)
n_total = len(self)
for _, subindices in KFold(
n=n_candidate_test_samples,
n_folds=n_folds,
shuffle=shuffle):
test_indices = candidate_test_indices[subindices]
train_mask = np.ones(n_total, dtype=bool)
train_mask[test_indices] = False
train_data = self.slice(train_mask)
test_data = self.slice(test_indices)
yield train_data, test_data
def split_allele_randomly_and_impute_training_set(
self, allele, n_training_samples=None, **kwargs):
"""
Split an allele into training and test sets, and then impute values
for peptides missing from the training set using data from other alleles
in this Dataset.
(apologies for the wordy name, this turns out to be a common operation)
Parameters
----------
allele : str
Name of allele
n_training_samples : int, optional
Size of the training set to return for this allele.
**kwargs : dict
Extra keyword arguments passed to Dataset.impute_missing_values
Returns three Dataset objects:
- training set with original pMHC affinities for given allele
- larger imputed training set for given allele
- test set
"""
dataset_allele = self.get_allele(allele)
dataset_allele_train, dataset_allele_test = dataset_allele.random_split(
n=n_training_samples)
full_dataset_without_test_samples = self.difference(dataset_allele_test)
imputed_dataset = full_dataset_without_test_samples.impute_missing_values(**kwargs)
imputed_dataset_allele = imputed_dataset.get_allele(allele)
return dataset_allele_train, imputed_dataset_allele, dataset_allele_test
def drop_allele_peptide_lists(self, alleles, peptides):
"""
Drop all allele-peptide pairs in the given lists.
......@@ -657,39 +721,6 @@ class Dataset(object):
alleles=other_dataset.alleles,
peptides=other_dataset.peptides)
def split_allele_randomly_and_impute_training_set(
self, allele, n_training_samples=None, **kwargs):
"""
Split an allele into training and test sets, and then impute values
for peptides missing from the training set using data from other alleles
in this Dataset.
(apologies for the wordy name, this turns out to be a common operation)
Parameters
----------
allele : str
Name of allele
n_training_samples : int, optional
Size of the training set to return for this allele.
**kwargs : dict
Extra keyword arguments passed to Dataset.impute_missing_values
Returns three Dataset objects:
- training set with original pMHC affinities for given allele
- larger imputed training set for given allele
- test set
"""
dataset_allele = self.get_allele(allele)
dataset_allele_train, dataset_allele_test = dataset_allele.random_split(
n=n_training_samples)
full_dataset_without_test_samples = self.difference(dataset_allele_test)
imputed_dataset = full_dataset_without_test_samples.impute_missing_values(**kwargs)
imputed_dataset_allele = imputed_dataset.get_allele(allele)
return dataset_allele_train, imputed_dataset_allele, dataset_allele_test
def impute_missing_values(
self,
imputation_method,
......
......@@ -48,6 +48,25 @@ def test_dataset_difference():
}})
eq_(dataset_diff, expected_result)
def test_dataset_cross_validation():
dataset = Dataset.from_nested_dictionary({
"H-2-Kb": {
"SIINFEKL": 10.0,
"FEKLSIIN": 20000.0,
"SIFEKLIN": 50000.0,
},
"HLA-A*02:01": {"ASASAS": 1.0, "CCC": 0.0}})
fold_count = 0
for train_dataset, test_dataset in dataset.cross_validation_iterator(
test_allele="HLA-A*02:01",
n_folds=2):
assert train_dataset.unique_alleles() == {"H-2-Kb", "HLA-A*02:01"}
assert test_dataset.unique_alleles() == {"HLA-A*02:01"}
assert len(test_dataset) == 1
fold_count += 1
assert fold_count == 2
if __name__ == "__main__":
test_create_allele_data_from_single_allele_dict()
test_dataset_random_split()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment