From 3b89017f7588b5109800847837786cb4dd618fbe Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 26 Aug 2019 13:02:44 -0400
Subject: [PATCH] fix

---
 .../calibrate_percentile_ranks_command.py     | 53 +++++++++++-------
 mhcflurry/downloads.yml                       |  4 +-
 mhcflurry/downloads_command.py                | 56 ++++++++++++++-----
 mhcflurry/select_pan_allele_models_command.py | 40 +++++++++----
 4 files changed, 106 insertions(+), 47 deletions(-)

diff --git a/mhcflurry/calibrate_percentile_ranks_command.py b/mhcflurry/calibrate_percentile_ranks_command.py
index 5cd1334e..bf8621f2 100644
--- a/mhcflurry/calibrate_percentile_ranks_command.py
+++ b/mhcflurry/calibrate_percentile_ranks_command.py
@@ -11,7 +11,6 @@ import collections
 from functools import partial
 
 import pandas
-import numpy
 
 from mhcnames import normalize_allele_name
 import tqdm  # progress bar
@@ -20,11 +19,13 @@ tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 from .class1_affinity_predictor import Class1AffinityPredictor
 from .encodable_sequences import EncodableSequences
 from .common import configure_logging, random_peptides, amino_acid_distribution
-from .amino_acid import BLOSUM62_MATRIX
 from .local_parallelism import (
     add_local_parallelism_args,
     worker_pool_with_gpu_assignments_from_args,
     call_wrapped)
+from .cluster_parallelism import (
+    add_cluster_parallelism_args,
+    cluster_results_from_args)
 
 
 # To avoid pickling large matrices to send to child processes when running in
@@ -85,6 +86,7 @@ parser.add_argument(
     default=0)
 
 add_local_parallelism_args(parser)
+add_cluster_parallelism_args(parser)
 
 
 def run(argv=sys.argv[1:]):
@@ -140,33 +142,36 @@ def run(argv=sys.argv[1:]):
     # Store peptides in global variable so they are in shared memory
     # after fork, instead of needing to be pickled (when doing a parallel run).
     GLOBAL_DATA["calibration_peptides"] = encoded_peptides
-
-    worker_pool = worker_pool_with_gpu_assignments_from_args(args)
-
-    constant_kwargs = {
+    GLOBAL_DATA["predictor"] = predictor
+    GLOBAL_DATA["args"] = {
         'motif_summary': args.motif_summary,
         'summary_top_peptide_fraction': args.summary_top_peptide_fraction,
-        'verbose': args.verbosity > 0,
+        'verbose': args.verbosity > 0
     }
 
-    if worker_pool is None:
+    serial_run = not args.cluster_parallelism and args.num_jobs == 0
+    worker_pool = None
+    start = time.time()
+    if serial_run:
         # Serial run
         print("Running in serial.")
         results = (
-            calibrate_percentile_ranks(
-                allele=allele,
-                predictor=predictor,
-                peptides=encoded_peptides,
-                **constant_kwargs,
-            )
-            for allele in alleles)
+            do_calibrate_percentile_ranks(allele) for allele in alleles)
+    elif args.cluster_parallelism:
+        # Run using separate processes HPC cluster.
+        print("Running on cluster.")
+        results = cluster_results_from_args(
+            args,
+            work_function=do_calibrate_percentile_ranks,
+            work_items=alleles,
+            constant_data=GLOBAL_DATA,
+            result_serialization_method="pickle")
     else:
-        # Parallel run
+        worker_pool = worker_pool_with_gpu_assignments_from_args(args)
+        print("Worker pool", worker_pool)
+        assert worker_pool is not None
         results = worker_pool.imap_unordered(
-            partial(
-                partial(call_wrapped, calibrate_percentile_ranks),
-                predictor=predictor,
-                **constant_kwargs),
+            partial(call_wrapped, do_calibrate_percentile_ranks),
             alleles,
             chunksize=1)
 
@@ -197,6 +202,14 @@ def run(argv=sys.argv[1:]):
     print("Predictor written to: %s" % args.models_dir)
 
 
+def do_calibrate_percentile_ranks(allele):
+    return calibrate_percentile_ranks(
+        allele,
+        GLOBAL_DATA['predictor'],
+        peptides=GLOBAL_DATA['calibration_peptides'],
+        **GLOBAL_DATA["args"])
+
+
 def calibrate_percentile_ranks(
         allele,
         predictor,
diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml
index f20f6dc5..44c8a96e 100644
--- a/mhcflurry/downloads.yml
+++ b/mhcflurry/downloads.yml
@@ -25,7 +25,9 @@ releases:
               default: false
 
             - name: models_class1_pan_unselected
-              url: https://github.com/openvax/mhcflurry/releases/download/pan-dev1/model_class1_pan_unselected.manual_build.20190731.tar.bz2
+              part_urls:
+                - https://github.com/openvax/mhcflurry/releases/download/pan-dev1/models_class1_pan_unselected.20190826.tar.bz2.part.aa
+                - https://github.com/openvax/mhcflurry/releases/download/pan-dev1/models_class1_pan_unselected.20190826.tar.bz2.part.ab
               default: false
 
             - name: data_iedb
diff --git a/mhcflurry/downloads_command.py b/mhcflurry/downloads_command.py
index 34403ab7..8bf1d213 100644
--- a/mhcflurry/downloads_command.py
+++ b/mhcflurry/downloads_command.py
@@ -27,7 +27,8 @@ import os
 from pipes import quote
 import errno
 import tarfile
-from tempfile import mkstemp
+from shutil import copyfileobj
+from tempfile import NamedTemporaryFile
 from tqdm import tqdm
 tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 
@@ -181,27 +182,46 @@ def fetch_subcommand(args):
         "DOWNLOAD NAME", "ALREADY DOWNLOADED?", "WILL DOWNLOAD NOW?", "URL"))
 
     for (item, info) in downloads.items():
+        urls = (
+            [info['metadata']["url"]]
+            if "url" in info['metadata']
+            else info['metadata']["part_urls"])
+        url_description = urls[0]
+        if len(urls) > 1:
+            url_description += " + %d more parts" % (len(urls) - 1)
+
         qprint(format_string % (
             item,
             yes_no(info['downloaded']),
             yes_no(item in items_to_fetch),
-            info['metadata']["url"]))
+            url_description))
 
     # TODO: may want to extract into somewhere temporary and then rename to
     # avoid making an incomplete extract if the process is killed.
     for item in items_to_fetch:
         metadata = downloads[item]['metadata']
-        (temp_fd, target_path) = mkstemp()
+        urls = (
+            [metadata["url"]] if "url" in metadata else metadata["part_urls"])
+        temp = NamedTemporaryFile(delete=False, suffix=".tar.bz2")
         try:
-            qprint("Downloading: %s" % metadata['url'])
-            urlretrieve(
-                metadata['url'],
-                target_path,
-                reporthook=TqdmUpTo(
-                    unit='B', unit_scale=True, miniters=1).update_to)
-            qprint("Downloaded to: %s" % quote(target_path))
-
-            tar = tarfile.open(target_path, 'r:bz2')
+            for (url_num, url) in enumerate(urls):
+                qprint("Downloading [part %d/%d]: %s" % (
+                    url_num + 1, len(urls), url))
+                (downloaded_path, _) = urlretrieve(
+                    url,
+                    temp.name if len(urls) == 1 else None,
+                    reporthook=TqdmUpTo(
+                        unit='B', unit_scale=True, miniters=1).update_to)
+                qprint("Downloaded to: %s" % quote(downloaded_path))
+
+                if downloaded_path != temp.name:
+                    qprint("Appending to: %s" % temp.name)
+                    with open(downloaded_path, "rb") as fd:
+                        copyfileobj(fd, temp, length=64*1024*1024)
+                    os.remove(downloaded_path)
+
+            temp.close()
+            tar = tarfile.open(temp.name, 'r:bz2')
             names = tar.getnames()
             logging.debug("Extracting: %s" % names)
             bad_names = [
@@ -221,7 +241,7 @@ def fetch_subcommand(args):
                 len(names), quote(result_dir)))
         finally:
             if not args.keep:
-                os.remove(target_path)
+                os.remove(temp.name)
 
 
 def info_subcommand(args):
@@ -257,10 +277,18 @@ def info_subcommand(args):
     print(format_string % ("DOWNLOAD NAME", "DOWNLOADED?", "URL"))
 
     for (item, info) in downloads.items():
+        urls = (
+            [info['metadata']["url"]]
+            if "url" in info['metadata']
+            else info['metadata']["part_urls"])
+        url_description = urls[0]
+        if len(urls) > 1:
+            url_description += " + %d more parts" % (len(urls) - 1)
+
         print(format_string % (
             item,
             yes_no(info['downloaded']),
-            info['metadata']["url"]))
+            url_description))
 
 
 def path_subcommand(args):
diff --git a/mhcflurry/select_pan_allele_models_command.py b/mhcflurry/select_pan_allele_models_command.py
index 8102e7fe..fb031a59 100644
--- a/mhcflurry/select_pan_allele_models_command.py
+++ b/mhcflurry/select_pan_allele_models_command.py
@@ -1,5 +1,5 @@
 """
-Model select class1 single allele models.
+Model select class1 pan allele models.
 """
 import argparse
 import os
@@ -7,13 +7,11 @@ import signal
 import sys
 import time
 import traceback
-import random
 import hashlib
 from pprint import pprint
 
 import numpy
 import pandas
-from scipy.stats import kendalltau, percentileofscore, pearsonr
 
 import tqdm  # progress bar
 tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
@@ -22,7 +20,12 @@ from .class1_affinity_predictor import Class1AffinityPredictor
 from .encodable_sequences import EncodableSequences
 from .allele_encoding import AlleleEncoding
 from .common import configure_logging, random_peptides
-from .local_parallelism import worker_pool_with_gpu_assignments_from_args, add_local_parallelism_args
+from .local_parallelism import (
+    worker_pool_with_gpu_assignments_from_args,
+    add_local_parallelism_args)
+from .cluster_parallelism import (
+    add_cluster_parallelism_args,
+    cluster_results_from_args)
 from .regression_target import from_ic50
 
 
@@ -83,6 +86,7 @@ parser.add_argument(
     default=0)
 
 add_local_parallelism_args(parser)
+add_cluster_parallelism_args(parser)
 
 
 def mse(
@@ -216,21 +220,33 @@ def run(argv=sys.argv[1:]):
         allele_to_sequence=input_predictor.allele_to_sequence,
         metadata_dataframes=metadata_dfs)
 
-    worker_pool = worker_pool_with_gpu_assignments_from_args(args)
-
+    serial_run = not args.cluster_parallelism and args.num_jobs == 0
+    worker_pool = None
     start = time.time()
-
-    if worker_pool is None:
+    if serial_run:
         # Serial run
         print("Running in serial.")
         results = (do_model_select_task(item) for item in work_items)
+    elif args.cluster_parallelism:
+        # Run using separate processes HPC cluster.
+        print("Running on cluster.")
+        results = cluster_results_from_args(
+            args,
+            work_function=do_model_select_task,
+            work_items=work_items,
+            constant_data=GLOBAL_DATA,
+            result_serialization_method="pickle")
     else:
+        worker_pool = worker_pool_with_gpu_assignments_from_args(args)
+        print("Worker pool", worker_pool)
+        assert worker_pool is not None
+
+        print("Processing %d work items in parallel." % len(work_items))
+        assert not serial_run
+
         # Parallel run
-        random.shuffle(alleles)
         results = worker_pool.imap_unordered(
-            do_model_select_task,
-            work_items,
-            chunksize=1)
+            do_model_select_task, work_items, chunksize=1)
 
     models_by_fold = {}
     summary_dfs = []
-- 
GitLab