From aea2f493d6eea5626bf5d8d7d1de79c5047994c5 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sun, 14 Jan 2018 15:34:52 -0500
Subject: [PATCH] Shuffle in Class1NeuralNetwork.fit

---
 mhcflurry/class1_neural_network.py | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index eafbbf96..eafad7c2 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -403,6 +403,18 @@ class Class1NeuralNetwork(object):
             x_dict_without_random_negatives['pseudosequence'] = (
                 pseudosequences_input)
 
+        # Shuffle y_values and the contents of x_dict_without_random_negatives
+        # This ensures different data is used for the test set for early stopping
+        # when multiple models are trained.
+        shuffle_permutation = numpy.random.permutation(len(y_values))
+        y_values = y_values[shuffle_permutation]
+        peptide_encoding = peptide_encoding[shuffle_permutation]
+        for key in x_dict_without_random_negatives:
+            x_dict_without_random_negatives[key] = (
+                x_dict_without_random_negatives[key][shuffle_permutation])
+        if sample_weights is not None:
+            sample_weights = sample_weights[shuffle_permutation]
+
         if self.network() is None:
             self._network = self.make_network(
                 pseudosequence_length=pseudosequence_length,
-- 
GitLab