From be308e086c6356ad870f01f3e26dafa21af23a63 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 19 Jan 2018 11:58:07 -0500
Subject: [PATCH] fix test for theano

---
 test/test_custom_loss.py | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/test/test_custom_loss.py b/test/test_custom_loss.py
index 3546af03..f3644c66 100644
--- a/test/test_custom_loss.py
+++ b/test/test_custom_loss.py
@@ -13,11 +13,19 @@ from mhcflurry.custom_loss import CUSTOM_LOSSES
 
 
 def evaluate_loss(loss, y_true, y_pred):
-    session = K.get_session()
-    y_true_var = K.constant(y_true, name="y_true")
-    y_pred_var = K.constant(y_pred, name="y_pred")
-    result = loss(y_true_var, y_pred_var)
-    return result.eval(session=session)
+    if K.backend() == "tensorflow":
+        session = K.get_session()
+        y_true_var = K.constant(y_true, name="y_true")
+        y_pred_var = K.constant(y_pred, name="y_pred")
+        result = loss(y_true_var, y_pred_var)
+        return result.eval(session=session)
+    elif K.backend() == "theano":
+        y_true_var = K.constant(y_true, name="y_true")
+        y_pred_var = K.constant(y_pred, name="y_pred")
+        result = loss(y_true_var, y_pred_var)
+        return result.eval()
+    else:
+        raise ValueError("Unsupported backend: %s" % K.backend())
 
 
 def test_mse_with_inequalities():
-- 
GitLab