From e112e7bce61f09ec4388303228923a046ac372f2 Mon Sep 17 00:00:00 2001 From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com> Date: Fri, 23 Oct 2015 12:05:35 -0400 Subject: [PATCH] json serialization for models --- mhcflurry/mhc1_binding_predictor.py | 12 ++++--- .../train-class1-allele-specific-models.py | 34 +++++++++++++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py index bd41aa45..3c5ce32c 100644 --- a/mhcflurry/mhc1_binding_predictor.py +++ b/mhcflurry/mhc1_binding_predictor.py @@ -23,10 +23,11 @@ from __future__ import ( from os import listdir from os.path import exists, join from itertools import groupby +import json import numpy as np import pandas as pd -from keras.models import model_from_json +from keras.models import model_from_config from .class1_allele_specific_hyperparameters import MAX_IC50 from .data_helpers import index_encoding, normalize_allele_name @@ -47,7 +48,7 @@ class Mhc1BindingPredictor(object): raise ValueError( "No MHC prediction models found in %s" % (model_directory,)) original_allele_name = allele - self.allele = normalize_allele_name(allele) + allele = self.allele = normalize_allele_name(allele) if self.allele not in _allele_model_cache: json_filename = self.allele + ".json" json_path = join(model_directory, json_filename) @@ -62,10 +63,11 @@ class Mhc1BindingPredictor(object): raise ValueError("Missing model weights for allele %s" % ( original_allele_name,)) - with open(hdf_path, "r") as f: - self.model = model_from_json(f.read()) - + with open(json_path, "r") as f: + json_string = json.load(f) + self.model = model_from_config(json_string) self.model.load_weights(hdf_path) + _allele_model_cache[self.allele] = self.model else: self.model = _allele_model_cache[self.allele] diff --git a/scripts/train-class1-allele-specific-models.py b/scripts/train-class1-allele-specific-models.py index a10e68b1..de9519d7 100755 --- a/scripts/train-class1-allele-specific-models.py +++ b/scripts/train-class1-allele-specific-models.py @@ -38,7 +38,7 @@ from __future__ import ( unicode_literals ) from shutil import rmtree -from os import makedirs +from os import makedirs, remove from os.path import exists, join import argparse @@ -81,7 +81,8 @@ parser.add_argument( default=CSV_PATH, help="CSV file with 'mhc', 'peptide', 'peptide_length', 'meas' columns") -parser.add_argument("--min-samples-per-allele", +parser.add_argument( + "--min-samples-per-allele", default=5, help="Don't train predictors for alleles with fewer samples than this", type=int) @@ -124,19 +125,38 @@ if __name__ == "__main__": continue n_allele = len(allele_data.Y) print("%s: total count = %d" % (allele_name, n_allele)) - filename = allele_name + ".hdf" - path = join(args.output_dir, filename) - if exists(path) and not args.overwrite: + + json_filename = allele_name + ".json" + json_path = join(args.output_dir, json_filename) + + hdf_filename = allele_name + ".hdf" + hdf_path = join(args.output_dir, hdf_filename) + + if exists(json_path) and exists(hdf_path) and not args.overwrite: print("-- already exists, skipping") continue + if n_allele < args.min_samples_per_allele: print("-- too few data points, skipping") continue + + if exists(json_path): + print("-- removing old model description %s" % json_path) + remove(json_path) + if exists(hdf_path): + print("-- removing old weights file %s" % hdf_path) + remove(hdf_path) + model.set_weights(old_weights) model.fit( allele_data.X, allele_data.Y, nb_epoch=N_EPOCHS, show_accuracy=True) - print("Saving model for %s to %s" % (allele_name, path)) - model.save_weights(path) + print("Saving model description for %s to %s" % ( + allele_name, json_path)) + with open(json_path, "w") as f: + f.write(model.to_json()) + print("Saving model weights for %s to %s" % ( + allele_name, hdf_path)) + model.save_weights(hdf_path) -- GitLab