diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index fb19b09f87144973872305b716b4d195c6a42393..fad3aa11af26ed788d2f498aa0e81592e00902f2 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -2,13 +2,13 @@ import collections import hashlib import json import logging -import sys import time import warnings -from os.path import join, exists +from os.path import join, exists, abspath from os import mkdir from socket import gethostname from getpass import getuser +from functools import partial import mhcnames import numpy @@ -294,9 +294,14 @@ class Class1AffinityPredictor(object): for (_, row) in manifest_df.iterrows(): weights_filename = Class1AffinityPredictor.weights_path( models_dir, row.model_name) - weights = Class1AffinityPredictor.load_weights(weights_filename) config = json.loads(row.config_json) - model = Class1NeuralNetwork.from_config(config, weights=weights) + + # We will lazy-load weights when the network is used. + model = Class1NeuralNetwork.from_config( + config, + weights_loader=partial( + Class1AffinityPredictor.load_weights, + abspath(weights_filename))) if row.allele == "pan-class1": class1_pan_allele_models.append(model) else: diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 9cd528aefed4624aac91b901dc4996043e4cd7ab..5cf87569b625d4243ee98ac7123933c7aa047525 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -148,6 +148,7 @@ class Class1NeuralNetwork(object): self._network = None self.network_json = None self.network_weights = None + self.network_weights_loader = None self.loss_history = None self.fit_seconds = None @@ -217,6 +218,7 @@ class Class1NeuralNetwork(object): keras.models.Model """ if self._network is None and self.network_json is not None: + self.load_weights() if borrow: return self.borrow_cached_network( self.network_json, @@ -250,7 +252,7 @@ class Class1NeuralNetwork(object): return result @classmethod - def from_config(cls, config, weights=None): + def from_config(cls, config, weights=None, weights_loader=None): """ deserialize from a dict returned by get_config(). @@ -259,6 +261,8 @@ class Class1NeuralNetwork(object): config : dict weights : list of array, optional Network weights to restore + weights_loader : callable, optional + Function to call (no arguments) to load weights when needed Returns ------- @@ -269,8 +273,14 @@ class Class1NeuralNetwork(object): assert all(hasattr(instance, key) for key in config), config.keys() instance.__dict__.update(config) instance.network_weights = weights + instance.network_weights_loader = weights_loader return instance + def load_weights(self): + if self.network_weights_loader: + self.network_weights = self.network_weights_loader() + self.network_weights_loader = None + def get_weights(self): """ Get the network weights @@ -281,6 +291,7 @@ class Class1NeuralNetwork(object): or None if there is no network """ self.update_network_description() + self.load_weights() return self.network_weights def __getstate__(self): @@ -293,7 +304,7 @@ class Class1NeuralNetwork(object): """ self.update_network_description() - self.update_network_description() + self.load_weights() result = dict(self.__dict__) result['_network'] = None return result diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 37462b7bc7b1d9cc58e2f438f1770f0444bed47c..30a8ae6934461631bec68fad5d8c6fcb4467cc07 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -231,7 +231,6 @@ def run(argv=sys.argv[1:]): tqdm.tqdm( worker_pool.imap_unordered( train_model_entrypoint, work_items, chunksize=1), - ascii=True, total=len(work_items)), key=lambda pair: pair[0]) ] @@ -306,7 +305,7 @@ def run(argv=sys.argv[1:]): alleles, chunksize=1) - for result in tqdm.tqdm(results, ascii=True, total=len(alleles)): + for result in tqdm.tqdm(results, total=len(alleles)): predictor.allele_to_percent_rank_transform.update(result) print("Done calibrating %d additional alleles." % len(alleles))