Skip to content
Snippets Groups Projects
Commit e112e7bc authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

json serialization for models

parent 1cbd1946
No related branches found
No related tags found
No related merge requests found
......@@ -23,10 +23,11 @@ from __future__ import (
from os import listdir
from os.path import exists, join
from itertools import groupby
import json
import numpy as np
import pandas as pd
from keras.models import model_from_json
from keras.models import model_from_config
from .class1_allele_specific_hyperparameters import MAX_IC50
from .data_helpers import index_encoding, normalize_allele_name
......@@ -47,7 +48,7 @@ class Mhc1BindingPredictor(object):
raise ValueError(
"No MHC prediction models found in %s" % (model_directory,))
original_allele_name = allele
self.allele = normalize_allele_name(allele)
allele = self.allele = normalize_allele_name(allele)
if self.allele not in _allele_model_cache:
json_filename = self.allele + ".json"
json_path = join(model_directory, json_filename)
......@@ -62,10 +63,11 @@ class Mhc1BindingPredictor(object):
raise ValueError("Missing model weights for allele %s" % (
original_allele_name,))
with open(hdf_path, "r") as f:
self.model = model_from_json(f.read())
with open(json_path, "r") as f:
json_string = json.load(f)
self.model = model_from_config(json_string)
self.model.load_weights(hdf_path)
_allele_model_cache[self.allele] = self.model
else:
self.model = _allele_model_cache[self.allele]
......
......@@ -38,7 +38,7 @@ from __future__ import (
unicode_literals
)
from shutil import rmtree
from os import makedirs
from os import makedirs, remove
from os.path import exists, join
import argparse
......@@ -81,7 +81,8 @@ parser.add_argument(
default=CSV_PATH,
help="CSV file with 'mhc', 'peptide', 'peptide_length', 'meas' columns")
parser.add_argument("--min-samples-per-allele",
parser.add_argument(
"--min-samples-per-allele",
default=5,
help="Don't train predictors for alleles with fewer samples than this",
type=int)
......@@ -124,19 +125,38 @@ if __name__ == "__main__":
continue
n_allele = len(allele_data.Y)
print("%s: total count = %d" % (allele_name, n_allele))
filename = allele_name + ".hdf"
path = join(args.output_dir, filename)
if exists(path) and not args.overwrite:
json_filename = allele_name + ".json"
json_path = join(args.output_dir, json_filename)
hdf_filename = allele_name + ".hdf"
hdf_path = join(args.output_dir, hdf_filename)
if exists(json_path) and exists(hdf_path) and not args.overwrite:
print("-- already exists, skipping")
continue
if n_allele < args.min_samples_per_allele:
print("-- too few data points, skipping")
continue
if exists(json_path):
print("-- removing old model description %s" % json_path)
remove(json_path)
if exists(hdf_path):
print("-- removing old weights file %s" % hdf_path)
remove(hdf_path)
model.set_weights(old_weights)
model.fit(
allele_data.X,
allele_data.Y,
nb_epoch=N_EPOCHS,
show_accuracy=True)
print("Saving model for %s to %s" % (allele_name, path))
model.save_weights(path)
print("Saving model description for %s to %s" % (
allele_name, json_path))
with open(json_path, "w") as f:
f.write(model.to_json())
print("Saving model weights for %s to %s" % (
allele_name, hdf_path))
model.save_weights(hdf_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment