Skip to content
Snippets Groups Projects
Commit 5a66b6fa authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

fixed class1_binding_predictor tests

parent ca453c69
No related branches found
No related tags found
No related merge requests found
...@@ -109,7 +109,7 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase): ...@@ -109,7 +109,7 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase):
if any(len(peptide) != self.kmer_size for peptide in peptides): if any(len(peptide) != self.kmer_size for peptide in peptides):
raise ValueError("Can only predict 9mer peptides") raise ValueError("Can only predict 9mer peptides")
X, _ = self.encode_peptides(peptides) X, _ = self.encode_peptides(peptides)
return self.predict_scores_for_kmer_array(X) return self.predict_scores_for_kmer_encoded_array(X)
def predict_kmer_peptides_ic50(self, peptides): def predict_kmer_peptides_ic50(self, peptides):
scores = self.predict_scores_for_kmer_peptides(peptides) scores = self.predict_scores_for_kmer_peptides(peptides)
...@@ -132,7 +132,8 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase): ...@@ -132,7 +132,8 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase):
# peptides of lengths other than self.kmer_size get multiple predictions, # peptides of lengths other than self.kmer_size get multiple predictions,
# which are then combined with the combine_fn argument # which are then combined with the combine_fn argument
multiple_predictions_dict = defaultdict(list) multiple_predictions_dict = defaultdict(list)
fixed_length_predictions = self.predict(input_matrix) fixed_length_predictions = \
self.predict_scores_for_kmer_encoded_array(input_matrix)
for i, yi in enumerate(fixed_length_predictions): for i, yi in enumerate(fixed_length_predictions):
original_peptide_index = original_peptide_indices[i] original_peptide_index = original_peptide_indices[i]
original_peptide = peptides[original_peptide_index] original_peptide = peptides[original_peptide_index]
......
...@@ -314,7 +314,7 @@ class Class1BindingPredictor(Class1AlleleSpecificKmerIC50PredictorBase): ...@@ -314,7 +314,7 @@ class Class1BindingPredictor(Class1AlleleSpecificKmerIC50PredictorBase):
alleles = alleles_with_models.intersection(alleles_with_weights) alleles = alleles_with_models.intersection(alleles_with_weights)
return list(sorted(alleles)) return list(sorted(alleles))
def predict_scores_for_kmer_array(self, X): def predict_scores_for_kmer_encoded_array(self, X):
""" """
Given an encoded array of amino acid indices, returns a vector Given an encoded array of amino acid indices, returns a vector
of predicted log IC50 values. of predicted log IC50 values.
......
...@@ -62,7 +62,7 @@ class Dataset(object): ...@@ -62,7 +62,7 @@ class Dataset(object):
# make allele and peptide columns the index, and copy it # make allele and peptide columns the index, and copy it
# so we can add a column without any observable side-effect in # so we can add a column without any observable side-effect in
# the calling code # the calling code
df = df.set_index(["allele", "peptide"]) df = df.set_index(["allele", "peptide"], drop=False)
if "sample_weight" not in columns: if "sample_weight" not in columns:
df["sample_weight"] = np.ones(len(df), dtype=float) df["sample_weight"] = np.ones(len(df), dtype=float)
...@@ -209,6 +209,7 @@ class Dataset(object): ...@@ -209,6 +209,7 @@ class Dataset(object):
"Wrong length for column '%s', expected %d but got %d" % ( "Wrong length for column '%s', expected %d but got %d" % (
column_name, column)) column_name, column))
df[column_name] = np.asarray(column) df[column_name] = np.asarray(column)
print(df)
return cls(df) return cls(df)
@classmethod @classmethod
......
# Copyright (c) 2016. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from mhcflurry import Class1BindingPredictor
class Dummy9merIndexEncodingModel(object):
"""
Dummy molde used for testing the pMHC binding predictor.
"""
def __init__(self, constant_output_value=0):
self.constant_output_value = constant_output_value
def predict(self, X, verbose=False):
assert isinstance(X, np.ndarray)
assert len(X.shape) == 2
n_rows, n_cols = X.shape
n_cols == 9, "Expected 9mer index input input, got %d columns" % (
n_cols,)
return np.ones(n_rows, dtype=float) * self.constant_output_value
always_zero_predictor_with_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(0),
allow_unknown_amino_acids=True)
always_zero_predictor_without_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(0),
allow_unknown_amino_acids=False)
always_one_predictor_with_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(1),
allow_unknown_amino_acids=True)
always_one_predictor_without_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(1),
allow_unknown_amino_acids=False)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import numpy as np import numpy as np
from dummy_predictors import always_zero_predictor_with_unknown_AAs from dummy_models import always_zero_predictor_with_unknown_AAs
def test_always_zero_9mer_inputs(): def test_always_zero_9mer_inputs():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment