From c69906668b9fa0cfafc50ec28a6730824e66058f Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 28 Nov 2017 13:24:30 -0500 Subject: [PATCH] implement Class1AffinityPredictor.merge() method --- .../class1_affinity_predictor.py | 59 ++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py b/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py index b424aaf5..fa0623fb 100644 --- a/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_prediction/class1_affinity_predictor.py @@ -96,10 +96,67 @@ class Class1AffinityPredictor(object): columns=["model_name", "allele", "config_json", "model"]) self.manifest_df = manifest_df - if allele_to_percent_rank_transform is None: + if not allele_to_percent_rank_transform: allele_to_percent_rank_transform = {} self.allele_to_percent_rank_transform = allele_to_percent_rank_transform + @property + def neural_networks(self): + """ + List of the neural networks in the ensemble. + + Returns + ------- + list of Class1NeuralNetwork + """ + result = [] + for models in self.allele_to_allele_specific_models.values(): + result.extend(models) + result.extend(self.class1_pan_allele_models) + return result + + @classmethod + def merge(cls, predictors): + """ + Merge the ensembles of two or more Class1AffinityPredictor instances. + + Note: the resulting merged predictor will NOT have calibrated percentile + ranks. Call calibrate_percentile_ranks() on it if these are needed. + + Parameters + ---------- + predictors : sequence of Class1AffinityPredictor + + Returns + ------- + Class1AffinityPredictor + + """ + assert len(predictors) > 0 + if len(predictors) == 1: + return predictors[0] + + allele_to_allele_specific_models = collections.defaultdict(list) + class1_pan_allele_models = [] + allele_to_pseudosequence = predictors[0].allele_to_pseudosequence + + for predictor in predictors: + for (allele, networks) in ( + predictor.allele_to_allele_specific_models.items()): + allele_to_allele_specific_models[allele].extend(networks) + class1_pan_allele_models.extend( + predictor.class1_pan_allele_models) + + return Class1AffinityPredictor( + allele_to_allele_specific_models=allele_to_allele_specific_models, + class1_pan_allele_models=class1_pan_allele_models, + allele_to_pseudosequence=allele_to_pseudosequence + ) + + @property + def num_networks(self): + return self.manifest_df.shape[0] + @property def supported_alleles(self): -- GitLab