From 33e8c10a9819402b241396e09b6c8775e3a5ab18 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Mon, 5 Feb 2018 20:19:14 -0500 Subject: [PATCH] lazy-load model weights --- mhcflurry/class1_affinity_predictor.py | 13 +++++++++---- mhcflurry/class1_neural_network.py | 15 +++++++++++++-- mhcflurry/train_allele_specific_models_command.py | 3 +-- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index fb19b09f..fad3aa11 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -2,13 +2,13 @@ import collections import hashlib import json import logging -import sys import time import warnings -from os.path import join, exists +from os.path import join, exists, abspath from os import mkdir from socket import gethostname from getpass import getuser +from functools import partial import mhcnames import numpy @@ -294,9 +294,14 @@ class Class1AffinityPredictor(object): for (_, row) in manifest_df.iterrows(): weights_filename = Class1AffinityPredictor.weights_path( models_dir, row.model_name) - weights = Class1AffinityPredictor.load_weights(weights_filename) config = json.loads(row.config_json) - model = Class1NeuralNetwork.from_config(config, weights=weights) + + # We will lazy-load weights when the network is used. + model = Class1NeuralNetwork.from_config( + config, + weights_loader=partial( + Class1AffinityPredictor.load_weights, + abspath(weights_filename))) if row.allele == "pan-class1": class1_pan_allele_models.append(model) else: diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 9cd528ae..5cf87569 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -148,6 +148,7 @@ class Class1NeuralNetwork(object): self._network = None self.network_json = None self.network_weights = None + self.network_weights_loader = None self.loss_history = None self.fit_seconds = None @@ -217,6 +218,7 @@ class Class1NeuralNetwork(object): keras.models.Model """ if self._network is None and self.network_json is not None: + self.load_weights() if borrow: return self.borrow_cached_network( self.network_json, @@ -250,7 +252,7 @@ class Class1NeuralNetwork(object): return result @classmethod - def from_config(cls, config, weights=None): + def from_config(cls, config, weights=None, weights_loader=None): """ deserialize from a dict returned by get_config(). @@ -259,6 +261,8 @@ class Class1NeuralNetwork(object): config : dict weights : list of array, optional Network weights to restore + weights_loader : callable, optional + Function to call (no arguments) to load weights when needed Returns ------- @@ -269,8 +273,14 @@ class Class1NeuralNetwork(object): assert all(hasattr(instance, key) for key in config), config.keys() instance.__dict__.update(config) instance.network_weights = weights + instance.network_weights_loader = weights_loader return instance + def load_weights(self): + if self.network_weights_loader: + self.network_weights = self.network_weights_loader() + self.network_weights_loader = None + def get_weights(self): """ Get the network weights @@ -281,6 +291,7 @@ class Class1NeuralNetwork(object): or None if there is no network """ self.update_network_description() + self.load_weights() return self.network_weights def __getstate__(self): @@ -293,7 +304,7 @@ class Class1NeuralNetwork(object): """ self.update_network_description() - self.update_network_description() + self.load_weights() result = dict(self.__dict__) result['_network'] = None return result diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 37462b7b..30a8ae69 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -231,7 +231,6 @@ def run(argv=sys.argv[1:]): tqdm.tqdm( worker_pool.imap_unordered( train_model_entrypoint, work_items, chunksize=1), - ascii=True, total=len(work_items)), key=lambda pair: pair[0]) ] @@ -306,7 +305,7 @@ def run(argv=sys.argv[1:]): alleles, chunksize=1) - for result in tqdm.tqdm(results, ascii=True, total=len(alleles)): + for result in tqdm.tqdm(results, total=len(alleles)): predictor.allele_to_percent_rank_transform.update(result) print("Done calibrating %d additional alleles." % len(alleles)) -- GitLab