From 44bdd65b6c8245e779f45dfbb36a3e369a8daf50 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Fri, 7 Jun 2019 16:48:39 -0400 Subject: [PATCH] fix --- .../models_class1_pan/generate_hyperparameters.py | 2 +- mhcflurry/train_pan_allele_models_command.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/downloads-generation/models_class1_pan/generate_hyperparameters.py b/downloads-generation/models_class1_pan/generate_hyperparameters.py index fe05e313..ea728dc3 100644 --- a/downloads-generation/models_class1_pan/generate_hyperparameters.py +++ b/downloads-generation/models_class1_pan/generate_hyperparameters.py @@ -48,7 +48,7 @@ base_hyperparameters = { } grid = [] -for layer_sizes in [[1024, 512], [512, 512], [1024, 1024]]: +for layer_sizes in [[512, 256], [1024, 512], [1024, 1024]]: for l1 in [0.0, 0.0001, 0.001, 0.01]: new = deepcopy(base_hyperparameters) new["layer_sizes"] = layer_sizes diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index 9880f804..fa530d15 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -434,7 +434,7 @@ def train_model( train_peptides = EncodableSequences(train_data.peptide.values) train_alleles = AlleleEncoding( train_data.allele.values, borrow_from=allele_encoding) - train_target = from_ic50(train_data.measurement_value) + train_target = from_ic50(train_data.measurement_value.values) model = Class1NeuralNetwork(**hyperparameters) @@ -468,6 +468,7 @@ def train_model( peptides=peptides, affinities=affinities, allele_encoding=alleles) + fit_time = time.time() - start start = time.time() predictions = model.predict( @@ -484,7 +485,7 @@ def train_model( mask = train_data.measurement_inequality == inequality predictions[mask.values] = func( predictions[mask.values], - train_data.loc[mask.values].measurement_value.values) + train_data.loc[mask].measurement_value.values) score_mse = numpy.mean((from_ic50(predictions) - train_target)**2) score_time = time.time() - start print( -- GitLab