Skip to content
Snippets Groups Projects
Commit eeab5995 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 36d65c39
No related merge requests found
......@@ -129,7 +129,7 @@ class Class1PresentationPredictor(object):
raise ValueError("No processing predictor with flanks")
predictor = self.processing_predictor_with_flanks
num_chunks = int(numpy.ceil(len(peptides) / PREDICT_CHUNK_SIZE))
num_chunks = int(numpy.ceil(float(len(peptides)) / PREDICT_CHUNK_SIZE))
peptide_chunks = numpy.array_split(peptides, num_chunks)
n_flank_chunks = numpy.array_split(n_flanks, num_chunks)
c_flank_chunks = numpy.array_split(c_flanks, num_chunks)
......@@ -222,9 +222,6 @@ class Class1PresentationPredictor(object):
model = self._models_cache[name]
return model
def predict_sequences(self, alleles, sequences):
raise NotImplementedError
def predict(
self,
peptides,
......@@ -241,7 +238,7 @@ class Class1PresentationPredictor(object):
c_flanks=c_flanks,
verbose=verbose).presentation_score.values
def predict_scan(
def predict_sequences(
self,
sequences,
alleles,
......@@ -275,8 +272,24 @@ class Class1PresentationPredictor(object):
("sequence_%04d" % (i + 1), sequence)
for (i, sequence) in enumerate(sequences))
if isinstance(alleles, string_types):
alleles = [alleles]
if not isinstance(alleles, dict):
alleles = dict((name, alleles) for name in sequences.keys())
if all([isinstance(item, string_types) for item in alleles]):
alleles = dict((name, alleles) for name in sequences.keys())
elif len(alleles) != len(sequences):
raise ValueError(
"alleles must be (1) a string (a single allele), (2) a list of "
"strings (a single genotype), (3) a list of list of strings ("
"(multiple genotypes, where the total number of genotypes "
"must equal the number of sequences), or (4) a dict (in which "
"case the keys must match the sequences dict keys). Here "
"it seemed like option (3) was being used, but the length "
"of alleles (%d) did not match the length of sequences (%d)."
% (len(alleles), len(sequences)))
else:
alleles = dict(zip(sequences.keys(), alleles))
missing = [key for key in sequences if key not in alleles]
if missing:
......
......@@ -136,7 +136,7 @@ def test_downloaded_predictor():
global PRESENTATION_PREDICTOR
# Test sequence scanning
scan_results1 = PRESENTATION_PREDICTOR.predict_scan(
scan_results1 = PRESENTATION_PREDICTOR.predict_sequences(
sequences=[
"MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHLKDGTCGLVEVEKGVLPQLE",
"QPYVFIKRSDARTAPHGHVMVELVAELEGIQYGRSGETLGVLVPHVGEIPVAYRKVLLRKNGNKG",
......@@ -156,7 +156,7 @@ def test_downloaded_predictor():
assert (scan_results1.affinity < 200).all()
assert (scan_results1.presentation_score > 0.7).all()
scan_results2 = PRESENTATION_PREDICTOR.predict_scan(
scan_results2 = PRESENTATION_PREDICTOR.predict_sequences(
result="filtered",
comparison_value=500,
comparison_quantity="affinity",
......@@ -178,7 +178,7 @@ def test_downloaded_predictor():
assert len(scan_results2) > 10
assert (scan_results2.affinity <= 500).all()
scan_results3 = PRESENTATION_PREDICTOR.predict_scan(
scan_results3 = PRESENTATION_PREDICTOR.predict_sequences(
result="filtered",
comparison_value=0.9,
comparison_quantity="presentation_score",
......@@ -200,7 +200,7 @@ def test_downloaded_predictor():
assert len(scan_results3) > 5, len(scan_results3)
assert (scan_results3.presentation_score >= 0.9).all()
scan_results4 = PRESENTATION_PREDICTOR.predict_scan(
scan_results4 = PRESENTATION_PREDICTOR.predict_sequences(
result="all",
comparison_quantity="affinity",
sequences={
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment