Skip to content
Snippets Groups Projects
Commit 33e8c10a authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

lazy-load model weights

parent c0469948
No related branches found
No related tags found
No related merge requests found
...@@ -2,13 +2,13 @@ import collections ...@@ -2,13 +2,13 @@ import collections
import hashlib import hashlib
import json import json
import logging import logging
import sys
import time import time
import warnings import warnings
from os.path import join, exists from os.path import join, exists, abspath
from os import mkdir from os import mkdir
from socket import gethostname from socket import gethostname
from getpass import getuser from getpass import getuser
from functools import partial
import mhcnames import mhcnames
import numpy import numpy
...@@ -294,9 +294,14 @@ class Class1AffinityPredictor(object): ...@@ -294,9 +294,14 @@ class Class1AffinityPredictor(object):
for (_, row) in manifest_df.iterrows(): for (_, row) in manifest_df.iterrows():
weights_filename = Class1AffinityPredictor.weights_path( weights_filename = Class1AffinityPredictor.weights_path(
models_dir, row.model_name) models_dir, row.model_name)
weights = Class1AffinityPredictor.load_weights(weights_filename)
config = json.loads(row.config_json) config = json.loads(row.config_json)
model = Class1NeuralNetwork.from_config(config, weights=weights)
# We will lazy-load weights when the network is used.
model = Class1NeuralNetwork.from_config(
config,
weights_loader=partial(
Class1AffinityPredictor.load_weights,
abspath(weights_filename)))
if row.allele == "pan-class1": if row.allele == "pan-class1":
class1_pan_allele_models.append(model) class1_pan_allele_models.append(model)
else: else:
......
...@@ -148,6 +148,7 @@ class Class1NeuralNetwork(object): ...@@ -148,6 +148,7 @@ class Class1NeuralNetwork(object):
self._network = None self._network = None
self.network_json = None self.network_json = None
self.network_weights = None self.network_weights = None
self.network_weights_loader = None
self.loss_history = None self.loss_history = None
self.fit_seconds = None self.fit_seconds = None
...@@ -217,6 +218,7 @@ class Class1NeuralNetwork(object): ...@@ -217,6 +218,7 @@ class Class1NeuralNetwork(object):
keras.models.Model keras.models.Model
""" """
if self._network is None and self.network_json is not None: if self._network is None and self.network_json is not None:
self.load_weights()
if borrow: if borrow:
return self.borrow_cached_network( return self.borrow_cached_network(
self.network_json, self.network_json,
...@@ -250,7 +252,7 @@ class Class1NeuralNetwork(object): ...@@ -250,7 +252,7 @@ class Class1NeuralNetwork(object):
return result return result
@classmethod @classmethod
def from_config(cls, config, weights=None): def from_config(cls, config, weights=None, weights_loader=None):
""" """
deserialize from a dict returned by get_config(). deserialize from a dict returned by get_config().
...@@ -259,6 +261,8 @@ class Class1NeuralNetwork(object): ...@@ -259,6 +261,8 @@ class Class1NeuralNetwork(object):
config : dict config : dict
weights : list of array, optional weights : list of array, optional
Network weights to restore Network weights to restore
weights_loader : callable, optional
Function to call (no arguments) to load weights when needed
Returns Returns
------- -------
...@@ -269,8 +273,14 @@ class Class1NeuralNetwork(object): ...@@ -269,8 +273,14 @@ class Class1NeuralNetwork(object):
assert all(hasattr(instance, key) for key in config), config.keys() assert all(hasattr(instance, key) for key in config), config.keys()
instance.__dict__.update(config) instance.__dict__.update(config)
instance.network_weights = weights instance.network_weights = weights
instance.network_weights_loader = weights_loader
return instance return instance
def load_weights(self):
if self.network_weights_loader:
self.network_weights = self.network_weights_loader()
self.network_weights_loader = None
def get_weights(self): def get_weights(self):
""" """
Get the network weights Get the network weights
...@@ -281,6 +291,7 @@ class Class1NeuralNetwork(object): ...@@ -281,6 +291,7 @@ class Class1NeuralNetwork(object):
or None if there is no network or None if there is no network
""" """
self.update_network_description() self.update_network_description()
self.load_weights()
return self.network_weights return self.network_weights
def __getstate__(self): def __getstate__(self):
...@@ -293,7 +304,7 @@ class Class1NeuralNetwork(object): ...@@ -293,7 +304,7 @@ class Class1NeuralNetwork(object):
""" """
self.update_network_description() self.update_network_description()
self.update_network_description() self.load_weights()
result = dict(self.__dict__) result = dict(self.__dict__)
result['_network'] = None result['_network'] = None
return result return result
......
...@@ -231,7 +231,6 @@ def run(argv=sys.argv[1:]): ...@@ -231,7 +231,6 @@ def run(argv=sys.argv[1:]):
tqdm.tqdm( tqdm.tqdm(
worker_pool.imap_unordered( worker_pool.imap_unordered(
train_model_entrypoint, work_items, chunksize=1), train_model_entrypoint, work_items, chunksize=1),
ascii=True,
total=len(work_items)), total=len(work_items)),
key=lambda pair: pair[0]) key=lambda pair: pair[0])
] ]
...@@ -306,7 +305,7 @@ def run(argv=sys.argv[1:]): ...@@ -306,7 +305,7 @@ def run(argv=sys.argv[1:]):
alleles, alleles,
chunksize=1) chunksize=1)
for result in tqdm.tqdm(results, ascii=True, total=len(alleles)): for result in tqdm.tqdm(results, total=len(alleles)):
predictor.allele_to_percent_rank_transform.update(result) predictor.allele_to_percent_rank_transform.update(result)
print("Done calibrating %d additional alleles." % len(alleles)) print("Done calibrating %d additional alleles." % len(alleles))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment