From 33e8c10a9819402b241396e09b6c8775e3a5ab18 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 5 Feb 2018 20:19:14 -0500
Subject: [PATCH] lazy-load model weights

---
 mhcflurry/class1_affinity_predictor.py            | 13 +++++++++----
 mhcflurry/class1_neural_network.py                | 15 +++++++++++++--
 mhcflurry/train_allele_specific_models_command.py |  3 +--
 3 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index fb19b09f..fad3aa11 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -2,13 +2,13 @@ import collections
 import hashlib
 import json
 import logging
-import sys
 import time
 import warnings
-from os.path import join, exists
+from os.path import join, exists, abspath
 from os import mkdir
 from socket import gethostname
 from getpass import getuser
+from functools import partial
 
 import mhcnames
 import numpy
@@ -294,9 +294,14 @@ class Class1AffinityPredictor(object):
         for (_, row) in manifest_df.iterrows():
             weights_filename = Class1AffinityPredictor.weights_path(
                 models_dir, row.model_name)
-            weights = Class1AffinityPredictor.load_weights(weights_filename)
             config = json.loads(row.config_json)
-            model = Class1NeuralNetwork.from_config(config, weights=weights)
+
+            # We will lazy-load weights when the network is used.
+            model = Class1NeuralNetwork.from_config(
+                config,
+                weights_loader=partial(
+                    Class1AffinityPredictor.load_weights,
+                    abspath(weights_filename)))
             if row.allele == "pan-class1":
                 class1_pan_allele_models.append(model)
             else:
diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 9cd528ae..5cf87569 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -148,6 +148,7 @@ class Class1NeuralNetwork(object):
         self._network = None
         self.network_json = None
         self.network_weights = None
+        self.network_weights_loader = None
 
         self.loss_history = None
         self.fit_seconds = None
@@ -217,6 +218,7 @@ class Class1NeuralNetwork(object):
         keras.models.Model
         """
         if self._network is None and self.network_json is not None:
+            self.load_weights()
             if borrow:
                 return self.borrow_cached_network(
                     self.network_json,
@@ -250,7 +252,7 @@ class Class1NeuralNetwork(object):
         return result
 
     @classmethod
-    def from_config(cls, config, weights=None):
+    def from_config(cls, config, weights=None, weights_loader=None):
         """
         deserialize from a dict returned by get_config().
         
@@ -259,6 +261,8 @@ class Class1NeuralNetwork(object):
         config : dict
         weights : list of array, optional
             Network weights to restore
+        weights_loader : callable, optional
+            Function to call (no arguments) to load weights when needed
 
         Returns
         -------
@@ -269,8 +273,14 @@ class Class1NeuralNetwork(object):
         assert all(hasattr(instance, key) for key in config), config.keys()
         instance.__dict__.update(config)
         instance.network_weights = weights
+        instance.network_weights_loader = weights_loader
         return instance
 
+    def load_weights(self):
+        if self.network_weights_loader:
+            self.network_weights = self.network_weights_loader()
+            self.network_weights_loader = None
+
     def get_weights(self):
         """
         Get the network weights
@@ -281,6 +291,7 @@ class Class1NeuralNetwork(object):
         or None if there is no network
         """
         self.update_network_description()
+        self.load_weights()
         return self.network_weights
 
     def __getstate__(self):
@@ -293,7 +304,7 @@ class Class1NeuralNetwork(object):
 
         """
         self.update_network_description()
-        self.update_network_description()
+        self.load_weights()
         result = dict(self.__dict__)
         result['_network'] = None
         return result
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 37462b7b..30a8ae69 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -231,7 +231,6 @@ def run(argv=sys.argv[1:]):
                 tqdm.tqdm(
                     worker_pool.imap_unordered(
                         train_model_entrypoint, work_items, chunksize=1),
-                    ascii=True,
                     total=len(work_items)),
                 key=lambda pair: pair[0])
         ]
@@ -306,7 +305,7 @@ def run(argv=sys.argv[1:]):
                 alleles,
                 chunksize=1)
 
-        for result in tqdm.tqdm(results, ascii=True, total=len(alleles)):
+        for result in tqdm.tqdm(results, total=len(alleles)):
             predictor.allele_to_percent_rank_transform.update(result)
 
         print("Done calibrating %d additional alleles." % len(alleles))
-- 
GitLab