Skip to content
Snippets Groups Projects
Commit fbbef344 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

update

parent 053361ea
No related merge requests found
......@@ -73,13 +73,12 @@ class Class1NeuralNetwork(object):
self.hyperparameters = self.hyperparameter_defaults.with_defaults(
hyperparameters)
self.network = None
self.fit_history = None
self.loss_history = None
self.fit_seconds = None
def __getstate__(self):
result = dict(self.__dict__)
del result['network']
result['fit_history'] = None
result['network_json'] = self.network.to_json()
result['network_weights'] = self.get_weights()
return result
......@@ -204,7 +203,7 @@ class Class1NeuralNetwork(object):
min_val_loss_iteration = None
min_val_loss = None
self.fit_history = collections.defaultdict(list)
self.loss_history = collections.defaultdict(list)
start = time.time()
for i in range(self.hyperparameters['max_epochs']):
random_negative_peptides_list = []
......@@ -243,17 +242,17 @@ class Class1NeuralNetwork(object):
sample_weight=sample_weights)
for (key, value) in fit_history.history.items():
self.fit_history[key].extend(value)
self.loss_history[key].extend(value)
logging.info(
"Epoch %3d / %3d: loss=%g. Min val loss at epoch %s" % (
i,
self.hyperparameters['max_epochs'],
self.fit_history['loss'][-1],
self.loss_history['loss'][-1],
min_val_loss_iteration))
if self.hyperparameters['validation_split']:
val_loss = fit_history.history['val_loss'][-1]
val_loss = self.loss_history['val_loss'][-1]
val_losses.append(val_loss)
if min_val_loss is None or val_loss <= min_val_loss:
......
......@@ -16,6 +16,7 @@ import mhcnames
from .class1_neural_network import Class1NeuralNetwork
from ..common import configure_logging
def normalize_allele_name(s):
try:
return mhcnames.normalize_allele_name(s)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment