From fa353a34ca4d7b9ca615fcc549b21c4faa829470 Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Tue, 10 May 2016 13:08:59 -0400
Subject: [PATCH] fixed obscure shape bug on empty arrays

---
 .../class1_allele_specific_kmer_ic50_predictor_base.py    | 2 +-
 mhcflurry/dataset.py                                      | 8 ++++----
 mhcflurry/training_helpers.py                             | 6 ++----
 3 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/mhcflurry/class1_allele_specific_kmer_ic50_predictor_base.py b/mhcflurry/class1_allele_specific_kmer_ic50_predictor_base.py
index 0d2ffd70..3a242d8b 100644
--- a/mhcflurry/class1_allele_specific_kmer_ic50_predictor_base.py
+++ b/mhcflurry/class1_allele_specific_kmer_ic50_predictor_base.py
@@ -175,5 +175,5 @@ class Class1AlleleSpecificKmerIC50PredictorBase(IC50PredictorBase):
             sample_weights=sample_weights,
             X_pretrain=X_pretrain,
             ic50_pretrain=ic50_pretrain,
-            sample_weights_pretrain=sample_weights,
+            sample_weights_pretrain=sample_weights_pretrain,
             **kwargs)
diff --git a/mhcflurry/dataset.py b/mhcflurry/dataset.py
index 67fed68d..d58180d6 100644
--- a/mhcflurry/dataset.py
+++ b/mhcflurry/dataset.py
@@ -437,10 +437,10 @@ class Dataset(object):
         """
         if len(self.peptides) == 0:
             return (
-                np.array([[]], dtype=int),
-                np.array([], dtype=float),
-                np.array([], dtype=float),
-                np.array([], dtype=int)
+                np.empty((0, kmer_size), dtype=int),
+                np.empty((0,), dtype=float),
+                np.empty((0,), dtype=float),
+                np.empty((0,), dtype=int)
             )
 
         X_index, _, original_peptide_indices, counts = \
diff --git a/mhcflurry/training_helpers.py b/mhcflurry/training_helpers.py
index c61adac8..5a3be5c4 100644
--- a/mhcflurry/training_helpers.py
+++ b/mhcflurry/training_helpers.py
@@ -49,7 +49,6 @@ def check_encoded_array_shapes(X, Y, sample_weights):
             sample_weights.shape,))
 
     n_samples, n_dims = X.shape
-
     if len(Y) != n_samples:
         raise ValueError("Mismatch between len(X) = %d and len(Y) = %d" % (
             n_samples, len(Y)))
@@ -78,7 +77,6 @@ def combine_training_arrays(
     """
     X = np.asarray(X)
     Y = np.asarray(Y)
-
     if sample_weights is None:
         sample_weights = np.ones_like(Y)
     else:
@@ -87,8 +85,8 @@ def combine_training_arrays(
     n_samples, n_dims = check_encoded_array_shapes(X, Y, sample_weights)
 
     if X_pretrain is None or Y_pretrain is None:
-        X_pretrain = np.empty((0, n_dims), dtype=X.dtype)
-        Y_pretrain = np.empty((0,), dtype=Y.dtype)
+        X_pretrain = np.zeros((0, n_dims), dtype=X.dtype)
+        Y_pretrain = np.zeros((0,), dtype=Y.dtype)
     else:
         X_pretrain = np.asarray(X_pretrain)
         Y_pretrain = np.asarray(Y_pretrain)
-- 
GitLab