Skip to content
Snippets Groups Projects
Commit 2c380b44 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

added Dataset tests

parent c894b11f
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ import logging
from six import string_types
import pandas as pd
import numpy as np
from typechecks import require_iterable_of
from .common import geometric_mean
from .dataset_helpers import (
......@@ -127,6 +128,9 @@ class Dataset(object):
return "Dataset(n=%d, alleles=%s)" % (
len(self), self.unique_alleles())
def __repr__(self):
return str(self)
def __eq__(self, other):
"""
Two datasets are equal if they contain the same number of samples
......@@ -136,15 +140,25 @@ class Dataset(object):
return False
elif len(self) != len(other):
return False
elif list(self.columns) != list(other.columns):
columns = self.columns
if len(columns) != len(other.columns):
return False
elif set(columns) != set(other.columns):
return False
self_df = self.to_dataframe()
other_df = other.to_dataframe()
# test for equality of the rows of the two DataFrames regardless
# of order
my_dict = self.allele_and_peptide_pair_to_row_dictionary()
other_dict = other.allele_and_peptide_pair_to_row_dictionary()
for column_name in self.columns:
if (self_df[column_name] != other_df[column_name]).any():
return False
if set(my_dict.keys()) != set(other_dict.keys()):
return False
for key, my_row in my_dict.items():
for column in columns:
if my_row[column] != other_dict[key][column]:
return False
return True
def iterrows(self):
......@@ -154,14 +168,26 @@ class Dataset(object):
"""
return self.to_dataframe().iterrows()
def allele_and_peptide_pair_to_row_dictionary(self):
"""
Returns a dictionary mapping (allele, peptide) pairs to rows.
"""
return {key: row for (key, row) in self.iterrows()}
@property
def columns(self):
return self.to_dataframe().columns
def unique_alleles(self):
"""
Returns the set of allele names contained in this Dataset.
"""
return set(self.alleles)
def unique_peptides(self):
"""
Returns the set of peptide sequences contained in this Dataset.
"""
return set(self.peptides)
def groupby_allele(self):
......@@ -553,15 +579,15 @@ class Dataset(object):
right = self.slice(all_indices[n:])
return left, right
def drop_pMHCs(self, alleles, peptides):
def drop_allele_peptide_lists(self, alleles, peptides):
"""
Drop all allele-peptide combinations in the given sequences.
Drop all allele-peptide pairs in the given lists.
Parameters
----------
alleles : sequence of str
alleles : list of str
peptides : sequence of str
peptides : list of str
The two arguments are assumed to be the same length.
......@@ -571,9 +597,25 @@ class Dataset(object):
raise ValueError(
"Expected alleles to be same length (%d) as peptides (%d)" % (
len(alleles), len(peptides)))
my_keys = list(zip(self.alleles, self.peptides))
keys_to_remove = set(zip(alleles, peptides))
remove_mask = np.array([k in keys_to_remove for k in my_keys])
return self.drop_allele_peptide_pairs(list(zip(alleles, peptides)))
def drop_allele_peptide_pairs(self, allele_peptide_pairs):
"""
Drop all allele-peptide tuple pairs in the given list.
Parameters
----------
allele_peptide_pairs : list of (str, str) tuples
The two arguments are assumed to be the same length.
Returns Dataset of equal or smaller size.
"""
require_iterable_of(allele_peptide_pairs, tuple)
keys_to_remove_set = set(allele_peptide_pairs)
remove_mask = np.array([
(k in keys_to_remove_set)
for k in zip(self.alleles, self.peptides)
])
keep_mask = ~remove_mask
return self.slice(keep_mask)
......@@ -587,7 +629,9 @@ class Dataset(object):
Returns a new Dataset object of equal or lesser size.
"""
return self.drop_pMHCs(other_dataset.alleles, other_dataset.peptides)
return self.drop_allele_peptide_lists(
alleles=other_dataset.alleles,
peptides=other_dataset.peptides)
def split_allele_randomly_and_impute_training_set(
self, allele, n_training_samples=None, **kwargs):
......
......@@ -21,5 +21,34 @@ def test_create_allele_data_from_single_allele_dict():
for pi, pj in zip(sorted(expected_peptides), sorted(dataset.unique_peptides())):
eq_(pi, pj)
def test_dataset_random_split():
dataset = Dataset.from_nested_dictionary({
"H-2-Kb": {
"SIINFEKL": 10.0,
"FEKLSIIN": 20000.0,
"SIFEKLIN": 50000.0,
}})
left, right = dataset.random_split(n=2)
assert len(left) == 2
assert len(right) == 1
def test_dataset_difference():
dataset1 = Dataset.from_nested_dictionary({
"H-2-Kb": {
"SIINFEKL": 10.0,
"FEKLSIIN": 20000.0,
"SIFEKLIN": 50000.0,
}})
dataset2 = Dataset.from_nested_dictionary({"H-2-Kb": {"SIINFEKL": 10.0}})
dataset_diff = dataset1.difference(dataset2)
expected_result = Dataset.from_nested_dictionary({
"H-2-Kb": {
"FEKLSIIN": 20000.0,
"SIFEKLIN": 50000.0,
}})
eq_(dataset_diff, expected_result)
if __name__ == "__main__":
test_create_allele_data_from_single_allele_dict()
test_dataset_random_split()
test_dataset_difference()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment