Skip to content
Snippets Groups Projects
Commit e7b34afb authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

small fix to training script

parent 552af789
No related branches found
No related tags found
No related merge requests found
......@@ -46,13 +46,8 @@ from keras.optimizers import RMSprop
from mhcflurry.common import normalize_allele_name
from mhcflurry.data import load_allele_datasets
from mhcflurry.class1_binding_predictor import Class1BindingPredictor
from mhcflurry.class1_allele_specific_hyperparameters import (
add_hyperparameter_arguments_to_parser
)
from mhcflurry.paths import (
CLASS1_MODEL_DIRECTORY,
CLASS1_DATA_DIRECTORY
)
from mhcflurry.feedforward_hyperparameters import add_hyperparameter_arguments_to_parser
from mhcflurry.paths import (CLASS1_MODEL_DIRECTORY, CLASS1_DATA_DIRECTORY)
from mhcflurry.imputation import create_imputed_datasets, imputer_from_name
CSV_FILENAME = "combined_human_class1_dataset.csv"
......
......@@ -17,7 +17,7 @@ def test_make_embedding_network_properties():
optimizer=RMSprop(lr=0.7, rho=0.9, epsilon=1e-6))
eq_(nn.layers[0].input_dim, 3)
eq_(nn.loss, mse)
eq_(nn.optimizer.lr, 0.7)
eq_(nn.optimizer.lr.eval(), 0.7)
print(nn.layers)
# embedding + flatten + (dense->activation) * hidden layers and last layer
eq_(len(nn.layers), 2 + 2 * (1 + len(layer_sizes)))
......@@ -34,7 +34,7 @@ def test_make_hotshot_network_properties():
optimizer=RMSprop(lr=0.7, rho=0.9, epsilon=1e-6))
eq_(nn.layers[0].input_dim, 6)
eq_(nn.loss, mse)
eq_(nn.optimizer.lr, 0.7)
eq_(nn.optimizer.lr.eval(), 0.7)
print(nn.layers)
eq_(len(nn.layers), 2 + 2 * (1 + len(layer_sizes)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment