From c9aa6743711846d7605cf712102db673e0c0d724 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sun, 23 Jun 2019 16:16:43 -0400 Subject: [PATCH] fix --- mhcflurry/class1_neural_network.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index e434e1b2..92cf8b27 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -82,7 +82,8 @@ class Class1NeuralNetwork(object): random_negative_affinity_min=20000.0, random_negative_affinity_max=50000.0, random_negative_match_distribution=True, - random_negative_distribution_smoothing=0.0) + random_negative_distribution_smoothing=0.0, + random_negative_output_indices=None) """ Hyperparameters for neural network training. """ @@ -725,8 +726,13 @@ class Class1NeuralNetwork(object): sample_weights_with_random_negatives = None if output_indices is not None: + random_negative_output_indices = ( + self.hyperparameters['random_negative_output_indices'] + if self.hyperparameters['random_negative_output_indices'] + else list(range(0, self.hyperparameters['num_outputs']))) output_indices_with_random_negatives = numpy.concatenate([ - numpy.zeros(int(num_random_negative.sum()), dtype=int), + pandas.Series(random_negative_output_indices, dtype=int).sample( + n=int(num_random_negative.sum()), replace=True).values, output_indices ]) else: -- GitLab