From cd8e0e8274d8852ecfa3adf39d0f69367fe208d3 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 27 Aug 2019 13:58:04 -0400
Subject: [PATCH] fix

---
 mhcflurry/amino_acid.py                         |  2 +-
 mhcflurry/calibrate_percentile_ranks_command.py | 16 ++++++++++------
 mhcflurry/cluster_parallelism.py                | 12 +++++++++---
 test/test_calibrate_percentile_ranks_command.py |  4 ++--
 4 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/mhcflurry/amino_acid.py b/mhcflurry/amino_acid.py
index f8c7047b..1192f340 100644
--- a/mhcflurry/amino_acid.py
+++ b/mhcflurry/amino_acid.py
@@ -68,7 +68,7 @@ W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1  1 -4 -3 -2 11  2 -3  0
 Y -2 -2 -2 -3 -2 -1 -2 -3  2 -1 -1 -2 -1  3 -3 -2 -2  2  7 -1  0
 V  0 -3 -3 -3 -1 -2 -2 -3 -3  3  1 -2  1 -1 -2 -2  0 -3 -1  4  0
 X  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1
-"""), sep='\s+').loc[AMINO_ACIDS, AMINO_ACIDS]
+"""), sep='\s+').loc[AMINO_ACIDS, AMINO_ACIDS].astype("int8")
 assert (BLOSUM62_MATRIX == BLOSUM62_MATRIX.T).all().all()
 
 ENCODING_DATA_FRAMES = {
diff --git a/mhcflurry/calibrate_percentile_ranks_command.py b/mhcflurry/calibrate_percentile_ranks_command.py
index 3ba3f7ab..e06d0bfd 100644
--- a/mhcflurry/calibrate_percentile_ranks_command.py
+++ b/mhcflurry/calibrate_percentile_ranks_command.py
@@ -120,25 +120,27 @@ def run(argv=sys.argv[1:]):
 
     start = time.time()
 
-    print("Percent rank calibration for %d alleles. Encoding peptides." % (
+    print("Percent rank calibration for %d alleles. Generating peptides." % (
         len(alleles)))
-
     peptides = []
     lengths = range(args.length_range[0], args.length_range[1] + 1)
     for length in lengths:
         peptides.extend(
             random_peptides(
                 args.num_peptides_per_length, length, distribution=distribution))
+    print("Done generating peptides in %0.2f sec." % (time.time() - start))
+    print("Encoding %d peptides." % len(peptides))
+    start = time.time()
+
     encoded_peptides = EncodableSequences.create(peptides)
+    del peptides
 
     # Now we encode the peptides for each neural network, so the encoding
     # becomes cached.
     for network in predictor.neural_networks:
         network.peptides_to_network_input(encoded_peptides)
     assert encoded_peptides.encoding_cache  # must have cached the encoding
-    print("Finished encoding peptides for percent ranks in %0.2f sec." % (
-        time.time() - start))
-    print("Calibrating percent rank calibration for %d alleles." % len(alleles))
+    print("Finished encoding peptides in %0.2f sec." % (time.time() - start))
 
     # Store peptides in global variable so they are in shared memory
     # after fork, instead of needing to be pickled (when doing a parallel run).
@@ -149,6 +151,7 @@ def run(argv=sys.argv[1:]):
         'summary_top_peptide_fractions': args.summary_top_peptide_fraction,
         'verbose': args.verbosity > 0
     }
+    del encoded_peptides
 
     serial_run = not args.cluster_parallelism and args.num_jobs == 0
     worker_pool = None
@@ -167,7 +170,8 @@ def run(argv=sys.argv[1:]):
             work_function=do_calibrate_percentile_ranks,
             work_items=work_items,
             constant_data=GLOBAL_DATA,
-            result_serialization_method="pickle")
+            result_serialization_method="pickle",
+            clear_constant_data=True)
     else:
         worker_pool = worker_pool_with_gpu_assignments_from_args(args)
         print("Worker pool", worker_pool)
diff --git a/mhcflurry/cluster_parallelism.py b/mhcflurry/cluster_parallelism.py
index c756239c..31af13c1 100644
--- a/mhcflurry/cluster_parallelism.py
+++ b/mhcflurry/cluster_parallelism.py
@@ -43,7 +43,8 @@ def cluster_results_from_args(
         work_function,
         work_items,
         constant_data=None,
-        result_serialization_method="pickle"):
+        result_serialization_method="pickle",
+        clear_constant_data=False):
     return cluster_results(
         work_function=work_function,
         work_items=work_items,
@@ -51,7 +52,8 @@ def cluster_results_from_args(
         submit_command=args.cluster_submit_command,
         results_workdir=args.cluster_results_workdir,
         script_prefix_path=args.cluster_script_prefix_path,
-        result_serialization_method=result_serialization_method
+        result_serialization_method=result_serialization_method,
+        clear_constant_data=clear_constant_data
     )
 
 
@@ -63,7 +65,8 @@ def cluster_results(
         results_workdir="./cluster-workdir",
         script_prefix_path=None,
         result_serialization_method="pickle",
-        max_retries=3):
+        max_retries=3,
+        clear_constant_data=False):
 
     constant_payload = {
         'constant_data': constant_data,
@@ -78,6 +81,9 @@ def cluster_results(
     with open(constant_payload_path, "wb") as fd:
         pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL)
     print("Wrote:", constant_payload_path)
+    if clear_constant_data:
+        constant_data.clear()
+        print("Cleared constant data to free up memory.")
 
     if script_prefix_path:
         with open(script_prefix_path) as fd:
diff --git a/test/test_calibrate_percentile_ranks_command.py b/test/test_calibrate_percentile_ranks_command.py
index 2a87c192..2a891cb4 100644
--- a/test/test_calibrate_percentile_ranks_command.py
+++ b/test/test_calibrate_percentile_ranks_command.py
@@ -73,6 +73,6 @@ def test_run_cluster_parallelism(delete=True):
 
 
 if __name__ == "__main__":
-    run_and_check(n_jobs=0, delete=False)
+    # run_and_check(n_jobs=0, delete=False)
     # run_and_check(n_jobs=2, delete=False)
-    # test_run_cluster_parallelism(delete=False)
+    test_run_cluster_parallelism(delete=False)
-- 
GitLab