diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index fa530d152d80ed09ee66bd7e1d749d33c5931540..2c16f263842c102f38664c3f173c3bb401e9ed8a 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -416,6 +416,9 @@ def train_model(
         predictor,
         save_to):
 
+    import keras.backend as K
+    K.clear_session()
+
     df = GLOBAL_DATA["train_data"]
     folds_df = GLOBAL_DATA["folds_df"]
     allele_encoding = GLOBAL_DATA["allele_encoding"]