diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index eafbbf96398e58eb0028a8bfaadab62835377a67..eafad7c28f422f9c753613b680a1aceaf7a6ae2d 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -403,6 +403,18 @@ class Class1NeuralNetwork(object): x_dict_without_random_negatives['pseudosequence'] = ( pseudosequences_input) + # Shuffle y_values and the contents of x_dict_without_random_negatives + # This ensures different data is used for the test set for early stopping + # when multiple models are trained. + shuffle_permutation = numpy.random.permutation(len(y_values)) + y_values = y_values[shuffle_permutation] + peptide_encoding = peptide_encoding[shuffle_permutation] + for key in x_dict_without_random_negatives: + x_dict_without_random_negatives[key] = ( + x_dict_without_random_negatives[key][shuffle_permutation]) + if sample_weights is not None: + sample_weights = sample_weights[shuffle_permutation] + if self.network() is None: self._network = self.make_network( pseudosequence_length=pseudosequence_length,