From eeab5995e95889f974a0154f6440ee0a41baa086 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 4 Feb 2020 16:18:07 -0500 Subject: [PATCH] fix --- mhcflurry/class1_presentation_predictor.py | 25 ++++++++++++++++------ test/test_class1_presentation_predictor.py | 8 +++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mhcflurry/class1_presentation_predictor.py b/mhcflurry/class1_presentation_predictor.py index 69a531c6..f2d8094e 100644 --- a/mhcflurry/class1_presentation_predictor.py +++ b/mhcflurry/class1_presentation_predictor.py @@ -129,7 +129,7 @@ class Class1PresentationPredictor(object): raise ValueError("No processing predictor with flanks") predictor = self.processing_predictor_with_flanks - num_chunks = int(numpy.ceil(len(peptides) / PREDICT_CHUNK_SIZE)) + num_chunks = int(numpy.ceil(float(len(peptides)) / PREDICT_CHUNK_SIZE)) peptide_chunks = numpy.array_split(peptides, num_chunks) n_flank_chunks = numpy.array_split(n_flanks, num_chunks) c_flank_chunks = numpy.array_split(c_flanks, num_chunks) @@ -222,9 +222,6 @@ class Class1PresentationPredictor(object): model = self._models_cache[name] return model - def predict_sequences(self, alleles, sequences): - raise NotImplementedError - def predict( self, peptides, @@ -241,7 +238,7 @@ class Class1PresentationPredictor(object): c_flanks=c_flanks, verbose=verbose).presentation_score.values - def predict_scan( + def predict_sequences( self, sequences, alleles, @@ -275,8 +272,24 @@ class Class1PresentationPredictor(object): ("sequence_%04d" % (i + 1), sequence) for (i, sequence) in enumerate(sequences)) + if isinstance(alleles, string_types): + alleles = [alleles] + if not isinstance(alleles, dict): - alleles = dict((name, alleles) for name in sequences.keys()) + if all([isinstance(item, string_types) for item in alleles]): + alleles = dict((name, alleles) for name in sequences.keys()) + elif len(alleles) != len(sequences): + raise ValueError( + "alleles must be (1) a string (a single allele), (2) a list of " + "strings (a single genotype), (3) a list of list of strings (" + "(multiple genotypes, where the total number of genotypes " + "must equal the number of sequences), or (4) a dict (in which " + "case the keys must match the sequences dict keys). Here " + "it seemed like option (3) was being used, but the length " + "of alleles (%d) did not match the length of sequences (%d)." + % (len(alleles), len(sequences))) + else: + alleles = dict(zip(sequences.keys(), alleles)) missing = [key for key in sequences if key not in alleles] if missing: diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py index 363ee84b..f811ab6b 100644 --- a/test/test_class1_presentation_predictor.py +++ b/test/test_class1_presentation_predictor.py @@ -136,7 +136,7 @@ def test_downloaded_predictor(): global PRESENTATION_PREDICTOR # Test sequence scanning - scan_results1 = PRESENTATION_PREDICTOR.predict_scan( + scan_results1 = PRESENTATION_PREDICTOR.predict_sequences( sequences=[ "MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHLKDGTCGLVEVEKGVLPQLE", "QPYVFIKRSDARTAPHGHVMVELVAELEGIQYGRSGETLGVLVPHVGEIPVAYRKVLLRKNGNKG", @@ -156,7 +156,7 @@ def test_downloaded_predictor(): assert (scan_results1.affinity < 200).all() assert (scan_results1.presentation_score > 0.7).all() - scan_results2 = PRESENTATION_PREDICTOR.predict_scan( + scan_results2 = PRESENTATION_PREDICTOR.predict_sequences( result="filtered", comparison_value=500, comparison_quantity="affinity", @@ -178,7 +178,7 @@ def test_downloaded_predictor(): assert len(scan_results2) > 10 assert (scan_results2.affinity <= 500).all() - scan_results3 = PRESENTATION_PREDICTOR.predict_scan( + scan_results3 = PRESENTATION_PREDICTOR.predict_sequences( result="filtered", comparison_value=0.9, comparison_quantity="presentation_score", @@ -200,7 +200,7 @@ def test_downloaded_predictor(): assert len(scan_results3) > 5, len(scan_results3) assert (scan_results3.presentation_score >= 0.9).all() - scan_results4 = PRESENTATION_PREDICTOR.predict_scan( + scan_results4 = PRESENTATION_PREDICTOR.predict_sequences( result="all", comparison_quantity="affinity", sequences={ -- GitLab