diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 55729a44c0a973242e4d143637ece369858e4c1c..9cb41d3f8b654d3b62603ae80209333d16a1d297 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -933,6 +933,7 @@ class Class1NeuralNetwork(object): verbose=verbose) needs_initialization = False + epoch_start = time.time() fit_history = self.network().fit( x_dict_with_random_negatives, y_dict_with_random_negatives, @@ -943,6 +944,7 @@ class Class1NeuralNetwork(object): initial_epoch=i, validation_split=self.hyperparameters['validation_split'], sample_weight=sample_weights_with_random_negatives) + epoch_time = time.time() - epoch_start for (key, value) in fit_history.history.items(): fit_info[key].extend(value) @@ -953,10 +955,11 @@ class Class1NeuralNetwork(object): time.time() - last_progress_print > progress_print_interval)): print((progress_preamble + " " + - "Epoch %3d / %3d: loss=%g. " + "Epoch %3d / %3d [%0.2f sec]: loss=%g. " "Min val loss (%s) at epoch %s" % ( i, self.hyperparameters['max_epochs'], + epoch_time, fit_info['loss'][-1], str(min_val_loss), min_val_loss_iteration)).strip()) diff --git a/mhcflurry/cluster_parallelism.py b/mhcflurry/cluster_parallelism.py index ac548f31ba02648c8cc5da3fbb8d1fb284d1639c..26f65613156ac157c4fd763feef47c02ca78b37a 100644 --- a/mhcflurry/cluster_parallelism.py +++ b/mhcflurry/cluster_parallelism.py @@ -6,6 +6,7 @@ import signal import argparse import pickle import subprocess +import shutil from .local_parallelism import call_wrapped_kwargs from .class1_affinity_predictor import Class1AffinityPredictor @@ -34,6 +35,7 @@ def add_cluster_parallelism_args(parser): '--cluster-script-prefix-path', help="", ) + group.add_argument('--cluster-max-retries', help="", default=3) def cluster_results_from_args( @@ -60,7 +62,8 @@ def cluster_results( submit_command="sh", results_workdir="./cluster-workdir", script_prefix_path=None, - result_serialization_method="pickle"): + result_serialization_method="pickle", + max_retries=3): constant_payload = { 'constant_data': constant_data, @@ -82,7 +85,7 @@ def cluster_results( else: script_prefix = "#!/bin/bash" - result_paths = [] + result_items = [] for (i, item) in enumerate(work_items): item_workdir = os.path.join( @@ -124,23 +127,58 @@ def cluster_results( subprocess.check_call(launch_command, shell=True) print("Invoked", launch_command) - result_paths.append( - (item_finished_path, item_result_path, item_error_path)) + result_items.append({ + 'work_dir': item_workdir, + 'finished_path': item_finished_path, + 'result_path': item_result_path, + 'error_path': item_error_path, + 'retry_num': 0, + 'launch_command': launch_command, + }) def result_generator(): start = time.time() - for (complete_dir, result_path, error_path) in result_paths: - while not os.path.exists(complete_dir): - print("[%0.1f sec elapsed] waiting on" % (time.time() - start), - complete_dir) - time.sleep(60) - print("Complete", complete_dir) + while result_items: + while True: + result_item = None + for d in result_items: + if os.path.exists(item['finished_path']): + result_item = d + break + if result_item is None: + os.sleep(60) + else: + del result_items[result_item] + break + + complete_dir = result_item['finished_path'] + result_path = result_item['result_path'] + error_path = result_item['error_path'] + retry_num = result_item['retry_num'] + launch_command = result_item['launch_command'] + + print("[%0.1f sec elapsed] processing item %s" % ( + time.time() - start, result_item)) if os.path.exists(error_path): print("Error path exists", error_path) with open(error_path, "rb") as fd: exception = pickle.load(fd) - raise exception + print(exception) + if retry_num < max_retries: + print("Relaunching", launch_command) + attempt_dir = os.path.join( + result_item['work_dir'], "attempt.%d" % retry_num) + shutil.move(complete_dir, attempt_dir) + shutil.move(error_path, attempt_dir) + subprocess.check_call(launch_command, shell=True) + print("Invoked", launch_command) + result_item['retry_num'] += 1 + result_items.append(result_item) + continue + else: + print("Max retries exceeded", max_retries) + raise exception if os.path.exists(result_path): print("Result path exists", error_path) diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index 78c29da608131097071bff805d5ba42bbce9b2d5..a2bf3944912767fbb07dee585487afc7ba28b4ce 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -366,6 +366,9 @@ def main(args): results_generator = worker_pool.imap_unordered( partial(call_wrapped_kwargs, train_model), + + + work_items, chunksize=1) else: