Skip to content
Snippets Groups Projects
Commit be308e08 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix test for theano

parent 086b3d07
No related merge requests found
......@@ -13,11 +13,19 @@ from mhcflurry.custom_loss import CUSTOM_LOSSES
def evaluate_loss(loss, y_true, y_pred):
session = K.get_session()
y_true_var = K.constant(y_true, name="y_true")
y_pred_var = K.constant(y_pred, name="y_pred")
result = loss(y_true_var, y_pred_var)
return result.eval(session=session)
if K.backend() == "tensorflow":
session = K.get_session()
y_true_var = K.constant(y_true, name="y_true")
y_pred_var = K.constant(y_pred, name="y_pred")
result = loss(y_true_var, y_pred_var)
return result.eval(session=session)
elif K.backend() == "theano":
y_true_var = K.constant(y_true, name="y_true")
y_pred_var = K.constant(y_pred, name="y_pred")
result = loss(y_true_var, y_pred_var)
return result.eval()
else:
raise ValueError("Unsupported backend: %s" % K.backend())
def test_mse_with_inequalities():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment