Skip to content
Snippets Groups Projects
Commit 7a688ba5 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

basic support for pan allele model selection

parent ab85ec98
No related merge requests found
......@@ -7,13 +7,8 @@ import signal
import sys
import time
import traceback
import random
from functools import partial
import numpy
import pandas
import yaml
from sklearn.metrics.pairwise import cosine_similarity
from mhcnames import normalize_allele_name
import tqdm # progress bar
tqdm.monitor_interval = 0 # see https://github.com/tqdm/tqdm/issues/481
......
......@@ -72,14 +72,12 @@ def worker_pool_with_gpu_assignments(
max_tasks_per_worker=None,
worker_log_dir=None):
num_workers = num_jobs if num_jobs else cpu_count()
if num_workers == 0:
if num_jobs == 0:
if backend:
set_keras_backend(backend)
return None
worker_init_kwargs = [{} for _ in range(num_workers)]
worker_init_kwargs = [{} for _ in range(num_jobs)]
if num_gpus:
print("Attempting to round-robin assign each worker a GPU.")
if backend != "tensorflow-default":
......@@ -115,7 +113,7 @@ def worker_pool_with_gpu_assignments(
kwargs["worker_log_dir"] = worker_log_dir
worker_pool = make_worker_pool(
processes=num_workers,
processes=num_jobs,
initializer=worker_init,
initializer_kwargs_per_process=worker_init_kwargs,
max_tasks_per_worker=max_tasks_per_worker)
......
This diff is collapsed.
......@@ -505,7 +505,7 @@ def train_model(
# Save model-specific training info
train_peptide_hash = hashlib.sha1()
for peptide in train_data.peptide.values:
for peptide in sorted(train_data.peptide.values):
train_peptide_hash.update(peptide.encode())
model.fit_info[-1]["training_info"] = {
"fold_num": fold_num,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment