From be308e086c6356ad870f01f3e26dafa21af23a63 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Fri, 19 Jan 2018 11:58:07 -0500 Subject: [PATCH] fix test for theano --- test/test_custom_loss.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/test_custom_loss.py b/test/test_custom_loss.py index 3546af03..f3644c66 100644 --- a/test/test_custom_loss.py +++ b/test/test_custom_loss.py @@ -13,11 +13,19 @@ from mhcflurry.custom_loss import CUSTOM_LOSSES def evaluate_loss(loss, y_true, y_pred): - session = K.get_session() - y_true_var = K.constant(y_true, name="y_true") - y_pred_var = K.constant(y_pred, name="y_pred") - result = loss(y_true_var, y_pred_var) - return result.eval(session=session) + if K.backend() == "tensorflow": + session = K.get_session() + y_true_var = K.constant(y_true, name="y_true") + y_pred_var = K.constant(y_pred, name="y_pred") + result = loss(y_true_var, y_pred_var) + return result.eval(session=session) + elif K.backend() == "theano": + y_true_var = K.constant(y_true, name="y_true") + y_pred_var = K.constant(y_pred, name="y_pred") + result = loss(y_true_var, y_pred_var) + return result.eval() + else: + raise ValueError("Unsupported backend: %s" % K.backend()) def test_mse_with_inequalities(): -- GitLab