From e112e7bce61f09ec4388303228923a046ac372f2 Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Fri, 23 Oct 2015 12:05:35 -0400
Subject: [PATCH] json serialization for models

---
 mhcflurry/mhc1_binding_predictor.py           | 12 ++++---
 .../train-class1-allele-specific-models.py    | 34 +++++++++++++++----
 2 files changed, 34 insertions(+), 12 deletions(-)

diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py
index bd41aa45..3c5ce32c 100644
--- a/mhcflurry/mhc1_binding_predictor.py
+++ b/mhcflurry/mhc1_binding_predictor.py
@@ -23,10 +23,11 @@ from __future__ import (
 from os import listdir
 from os.path import exists, join
 from itertools import groupby
+import json
 
 import numpy as np
 import pandas as pd
-from keras.models import model_from_json
+from keras.models import model_from_config
 
 from .class1_allele_specific_hyperparameters import MAX_IC50
 from .data_helpers import index_encoding, normalize_allele_name
@@ -47,7 +48,7 @@ class Mhc1BindingPredictor(object):
             raise ValueError(
                 "No MHC prediction models found in %s" % (model_directory,))
         original_allele_name = allele
-        self.allele = normalize_allele_name(allele)
+        allele = self.allele = normalize_allele_name(allele)
         if self.allele not in _allele_model_cache:
             json_filename = self.allele + ".json"
             json_path = join(model_directory, json_filename)
@@ -62,10 +63,11 @@ class Mhc1BindingPredictor(object):
                 raise ValueError("Missing model weights for allele %s" % (
                     original_allele_name,))
 
-            with open(hdf_path, "r") as f:
-                self.model = model_from_json(f.read())
-
+            with open(json_path, "r") as f:
+                json_string = json.load(f)
+            self.model = model_from_config(json_string)
             self.model.load_weights(hdf_path)
+
             _allele_model_cache[self.allele] = self.model
         else:
             self.model = _allele_model_cache[self.allele]
diff --git a/scripts/train-class1-allele-specific-models.py b/scripts/train-class1-allele-specific-models.py
index a10e68b1..de9519d7 100755
--- a/scripts/train-class1-allele-specific-models.py
+++ b/scripts/train-class1-allele-specific-models.py
@@ -38,7 +38,7 @@ from __future__ import (
     unicode_literals
 )
 from shutil import rmtree
-from os import makedirs
+from os import makedirs, remove
 from os.path import exists, join
 import argparse
 
@@ -81,7 +81,8 @@ parser.add_argument(
     default=CSV_PATH,
     help="CSV file with 'mhc', 'peptide', 'peptide_length', 'meas' columns")
 
-parser.add_argument("--min-samples-per-allele",
+parser.add_argument(
+    "--min-samples-per-allele",
     default=5,
     help="Don't train predictors for alleles with fewer samples than this",
     type=int)
@@ -124,19 +125,38 @@ if __name__ == "__main__":
             continue
         n_allele = len(allele_data.Y)
         print("%s: total count = %d" % (allele_name, n_allele))
-        filename = allele_name + ".hdf"
-        path = join(args.output_dir, filename)
-        if exists(path) and not args.overwrite:
+
+        json_filename = allele_name + ".json"
+        json_path = join(args.output_dir, json_filename)
+
+        hdf_filename = allele_name + ".hdf"
+        hdf_path = join(args.output_dir, hdf_filename)
+
+        if exists(json_path) and exists(hdf_path) and not args.overwrite:
             print("-- already exists, skipping")
             continue
+
         if n_allele < args.min_samples_per_allele:
             print("-- too few data points, skipping")
             continue
+
+        if exists(json_path):
+            print("-- removing old model description %s" % json_path)
+            remove(json_path)
+        if exists(hdf_path):
+            print("-- removing old weights file %s" % hdf_path)
+            remove(hdf_path)
+
         model.set_weights(old_weights)
         model.fit(
             allele_data.X,
             allele_data.Y,
             nb_epoch=N_EPOCHS,
             show_accuracy=True)
-        print("Saving model for %s to %s" % (allele_name, path))
-        model.save_weights(path)
+        print("Saving model description for %s to %s" % (
+            allele_name, json_path))
+        with open(json_path, "w") as f:
+            f.write(model.to_json())
+        print("Saving model weights for %s to %s" % (
+            allele_name, hdf_path))
+        model.save_weights(hdf_path)
-- 
GitLab