Skip to content
Snippets Groups Projects
Commit 33e8c10a authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

lazy-load model weights

parent c0469948
No related branches found
No related tags found
No related merge requests found
......@@ -2,13 +2,13 @@ import collections
import hashlib
import json
import logging
import sys
import time
import warnings
from os.path import join, exists
from os.path import join, exists, abspath
from os import mkdir
from socket import gethostname
from getpass import getuser
from functools import partial
import mhcnames
import numpy
......@@ -294,9 +294,14 @@ class Class1AffinityPredictor(object):
for (_, row) in manifest_df.iterrows():
weights_filename = Class1AffinityPredictor.weights_path(
models_dir, row.model_name)
weights = Class1AffinityPredictor.load_weights(weights_filename)
config = json.loads(row.config_json)
model = Class1NeuralNetwork.from_config(config, weights=weights)
# We will lazy-load weights when the network is used.
model = Class1NeuralNetwork.from_config(
config,
weights_loader=partial(
Class1AffinityPredictor.load_weights,
abspath(weights_filename)))
if row.allele == "pan-class1":
class1_pan_allele_models.append(model)
else:
......
......@@ -148,6 +148,7 @@ class Class1NeuralNetwork(object):
self._network = None
self.network_json = None
self.network_weights = None
self.network_weights_loader = None
self.loss_history = None
self.fit_seconds = None
......@@ -217,6 +218,7 @@ class Class1NeuralNetwork(object):
keras.models.Model
"""
if self._network is None and self.network_json is not None:
self.load_weights()
if borrow:
return self.borrow_cached_network(
self.network_json,
......@@ -250,7 +252,7 @@ class Class1NeuralNetwork(object):
return result
@classmethod
def from_config(cls, config, weights=None):
def from_config(cls, config, weights=None, weights_loader=None):
"""
deserialize from a dict returned by get_config().
......@@ -259,6 +261,8 @@ class Class1NeuralNetwork(object):
config : dict
weights : list of array, optional
Network weights to restore
weights_loader : callable, optional
Function to call (no arguments) to load weights when needed
Returns
-------
......@@ -269,8 +273,14 @@ class Class1NeuralNetwork(object):
assert all(hasattr(instance, key) for key in config), config.keys()
instance.__dict__.update(config)
instance.network_weights = weights
instance.network_weights_loader = weights_loader
return instance
def load_weights(self):
if self.network_weights_loader:
self.network_weights = self.network_weights_loader()
self.network_weights_loader = None
def get_weights(self):
"""
Get the network weights
......@@ -281,6 +291,7 @@ class Class1NeuralNetwork(object):
or None if there is no network
"""
self.update_network_description()
self.load_weights()
return self.network_weights
def __getstate__(self):
......@@ -293,7 +304,7 @@ class Class1NeuralNetwork(object):
"""
self.update_network_description()
self.update_network_description()
self.load_weights()
result = dict(self.__dict__)
result['_network'] = None
return result
......
......@@ -231,7 +231,6 @@ def run(argv=sys.argv[1:]):
tqdm.tqdm(
worker_pool.imap_unordered(
train_model_entrypoint, work_items, chunksize=1),
ascii=True,
total=len(work_items)),
key=lambda pair: pair[0])
]
......@@ -306,7 +305,7 @@ def run(argv=sys.argv[1:]):
alleles,
chunksize=1)
for result in tqdm.tqdm(results, ascii=True, total=len(alleles)):
for result in tqdm.tqdm(results, total=len(alleles)):
predictor.allele_to_percent_rank_transform.update(result)
print("Done calibrating %d additional alleles." % len(alleles))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment