diff --git a/test/test_custom_loss.py b/test/test_custom_loss.py index 3546af037ada7ee001402d3ea2f9aaea00732e0d..f3644c66652136ca7c9f05299f1a20e6995fc8fd 100644 --- a/test/test_custom_loss.py +++ b/test/test_custom_loss.py @@ -13,11 +13,19 @@ from mhcflurry.custom_loss import CUSTOM_LOSSES def evaluate_loss(loss, y_true, y_pred): - session = K.get_session() - y_true_var = K.constant(y_true, name="y_true") - y_pred_var = K.constant(y_pred, name="y_pred") - result = loss(y_true_var, y_pred_var) - return result.eval(session=session) + if K.backend() == "tensorflow": + session = K.get_session() + y_true_var = K.constant(y_true, name="y_true") + y_pred_var = K.constant(y_pred, name="y_pred") + result = loss(y_true_var, y_pred_var) + return result.eval(session=session) + elif K.backend() == "theano": + y_true_var = K.constant(y_true, name="y_true") + y_pred_var = K.constant(y_pred, name="y_pred") + result = loss(y_true_var, y_pred_var) + return result.eval() + else: + raise ValueError("Unsupported backend: %s" % K.backend()) def test_mse_with_inequalities():