From d1e6fbb5edc830b3910c26b997b9860b835fdd8d Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Mon, 24 Jun 2019 12:35:47 -0400 Subject: [PATCH] fix --- mhcflurry/class1_neural_network.py | 2 ++ mhcflurry/custom_loss.py | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index c1f8e106..80e58b88 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -719,6 +719,8 @@ class Class1NeuralNetwork(object): ]), } adjusted_inequalities_with_random_negatives = None + assert numpy.isnan(y_dict_with_random_negatives['output']).sum() == 0, ( + y_dict_with_random_negatives) if sample_weights is not None: sample_weights_with_random_negatives = numpy.concatenate([ numpy.ones(int(num_random_negative.sum())), diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py index eabdd130..fcb970bd 100644 --- a/mhcflurry/custom_loss.py +++ b/mhcflurry/custom_loss.py @@ -82,11 +82,11 @@ class MSEWithInequalities(Loss): def encode_y(y, inequalities=None): y = array(y, dtype="float32") if isnan(y).any(): - raise ValueError("y contains NaN: %s" % str(y)) + raise ValueError("y contains NaN", y) if (y > 1.0).any(): - raise ValueError("y contains values > 1.0") + raise ValueError("y contains values > 1.0", y) if (y < 0.0).any(): - raise ValueError("y contains values < 0.0") + raise ValueError("y contains values < 0.0", y) if inequalities is None: encoded = y @@ -141,11 +141,11 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss): def encode_y(y, inequalities=None, output_indices=None): y = array(y, dtype="float32") if isnan(y).any(): - raise ValueError("y contains NaN: %s" % str(y)) + raise ValueError("y contains NaN", y) if (y > 1.0).any(): - raise ValueError("y contains values > 1.0") + raise ValueError("y contains values > 1.0", y) if (y < 0.0).any(): - raise ValueError("y contains values < 0.0") + raise ValueError("y contains values < 0.0", y) encoded = MSEWithInequalities.encode_y( y, inequalities=inequalities) -- GitLab