Skip to content
Snippets Groups Projects
Commit 5a66b6fa authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

fixed class1_binding_predictor tests

parent ca453c69
No related branches found
No related tags found
No related merge requests found
......@@ -109,7 +109,7 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase):
if any(len(peptide) != self.kmer_size for peptide in peptides):
raise ValueError("Can only predict 9mer peptides")
X, _ = self.encode_peptides(peptides)
return self.predict_scores_for_kmer_array(X)
return self.predict_scores_for_kmer_encoded_array(X)
def predict_kmer_peptides_ic50(self, peptides):
scores = self.predict_scores_for_kmer_peptides(peptides)
......@@ -132,7 +132,8 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase):
# peptides of lengths other than self.kmer_size get multiple predictions,
# which are then combined with the combine_fn argument
multiple_predictions_dict = defaultdict(list)
fixed_length_predictions = self.predict(input_matrix)
fixed_length_predictions = \
self.predict_scores_for_kmer_encoded_array(input_matrix)
for i, yi in enumerate(fixed_length_predictions):
original_peptide_index = original_peptide_indices[i]
original_peptide = peptides[original_peptide_index]
......
......@@ -314,7 +314,7 @@ class Class1BindingPredictor(Class1AlleleSpecificKmerIC50PredictorBase):
alleles = alleles_with_models.intersection(alleles_with_weights)
return list(sorted(alleles))
def predict_scores_for_kmer_array(self, X):
def predict_scores_for_kmer_encoded_array(self, X):
"""
Given an encoded array of amino acid indices, returns a vector
of predicted log IC50 values.
......
......@@ -62,7 +62,7 @@ class Dataset(object):
# make allele and peptide columns the index, and copy it
# so we can add a column without any observable side-effect in
# the calling code
df = df.set_index(["allele", "peptide"])
df = df.set_index(["allele", "peptide"], drop=False)
if "sample_weight" not in columns:
df["sample_weight"] = np.ones(len(df), dtype=float)
......@@ -209,6 +209,7 @@ class Dataset(object):
"Wrong length for column '%s', expected %d but got %d" % (
column_name, column))
df[column_name] = np.asarray(column)
print(df)
return cls(df)
@classmethod
......
# Copyright (c) 2016. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from mhcflurry import Class1BindingPredictor
class Dummy9merIndexEncodingModel(object):
"""
Dummy molde used for testing the pMHC binding predictor.
"""
def __init__(self, constant_output_value=0):
self.constant_output_value = constant_output_value
def predict(self, X, verbose=False):
assert isinstance(X, np.ndarray)
assert len(X.shape) == 2
n_rows, n_cols = X.shape
n_cols == 9, "Expected 9mer index input input, got %d columns" % (
n_cols,)
return np.ones(n_rows, dtype=float) * self.constant_output_value
always_zero_predictor_with_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(0),
allow_unknown_amino_acids=True)
always_zero_predictor_without_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(0),
allow_unknown_amino_acids=False)
always_one_predictor_with_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(1),
allow_unknown_amino_acids=True)
always_one_predictor_without_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(1),
allow_unknown_amino_acids=False)
......@@ -15,7 +15,7 @@
import numpy as np
from dummy_predictors import always_zero_predictor_with_unknown_AAs
from dummy_models import always_zero_predictor_with_unknown_AAs
def test_always_zero_9mer_inputs():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment