From aea2f493d6eea5626bf5d8d7d1de79c5047994c5 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sun, 14 Jan 2018 15:34:52 -0500 Subject: [PATCH] Shuffle in Class1NeuralNetwork.fit --- mhcflurry/class1_neural_network.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index eafbbf96..eafad7c2 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -403,6 +403,18 @@ class Class1NeuralNetwork(object): x_dict_without_random_negatives['pseudosequence'] = ( pseudosequences_input) + # Shuffle y_values and the contents of x_dict_without_random_negatives + # This ensures different data is used for the test set for early stopping + # when multiple models are trained. + shuffle_permutation = numpy.random.permutation(len(y_values)) + y_values = y_values[shuffle_permutation] + peptide_encoding = peptide_encoding[shuffle_permutation] + for key in x_dict_without_random_negatives: + x_dict_without_random_negatives[key] = ( + x_dict_without_random_negatives[key][shuffle_permutation]) + if sample_weights is not None: + sample_weights = sample_weights[shuffle_permutation] + if self.network() is None: self._network = self.make_network( pseudosequence_length=pseudosequence_length, -- GitLab