diff --git a/test/test_neural_nets.py b/test/test_neural_nets.py index 0adc1cc54233a7635201e76b283db0c93143433f..e791d9271f2597ce4346debc1d6ceaf0e12a6f42 100644 --- a/test/test_neural_nets.py +++ b/test/test_neural_nets.py @@ -36,7 +36,9 @@ def test_make_hotshot_network_properties(): eq_(nn.loss, mse) assert np.allclose(nn.optimizer.lr.eval(), 0.7) print(nn.layers) - eq_(len(nn.layers), 2 + 2 * (1 + len(layer_sizes))) + # since the hotshot network doesn't have an embedding layer + flatten + # we expect two fewer total layers than the embedding network. + eq_(len(nn.layers), 2 * (1 + len(layer_sizes))) def test_make_embedding_network_small_dataset(): nn = make_embedding_network(