Skip to content
Snippets Groups Projects
Commit e7b34afb authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

small fix to training script

parent 552af789
No related merge requests found
......@@ -46,13 +46,8 @@ from keras.optimizers import RMSprop
from mhcflurry.common import normalize_allele_name
from mhcflurry.data import load_allele_datasets
from mhcflurry.class1_binding_predictor import Class1BindingPredictor
from mhcflurry.class1_allele_specific_hyperparameters import (
add_hyperparameter_arguments_to_parser
)
from mhcflurry.paths import (
CLASS1_MODEL_DIRECTORY,
CLASS1_DATA_DIRECTORY
)
from mhcflurry.feedforward_hyperparameters import add_hyperparameter_arguments_to_parser
from mhcflurry.paths import (CLASS1_MODEL_DIRECTORY, CLASS1_DATA_DIRECTORY)
from mhcflurry.imputation import create_imputed_datasets, imputer_from_name
CSV_FILENAME = "combined_human_class1_dataset.csv"
......
......@@ -17,7 +17,7 @@ def test_make_embedding_network_properties():
optimizer=RMSprop(lr=0.7, rho=0.9, epsilon=1e-6))
eq_(nn.layers[0].input_dim, 3)
eq_(nn.loss, mse)
eq_(nn.optimizer.lr, 0.7)
eq_(nn.optimizer.lr.eval(), 0.7)
print(nn.layers)
# embedding + flatten + (dense->activation) * hidden layers and last layer
eq_(len(nn.layers), 2 + 2 * (1 + len(layer_sizes)))
......@@ -34,7 +34,7 @@ def test_make_hotshot_network_properties():
optimizer=RMSprop(lr=0.7, rho=0.9, epsilon=1e-6))
eq_(nn.layers[0].input_dim, 6)
eq_(nn.loss, mse)
eq_(nn.optimizer.lr, 0.7)
eq_(nn.optimizer.lr.eval(), 0.7)
print(nn.layers)
eq_(len(nn.layers), 2 + 2 * (1 + len(layer_sizes)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment