diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
index 9a8038826cb679d1bb8216a9941072097405646d..498972992f0b8d2cb116e3dc45908f89f32edd37 100644
--- a/mhcflurry/class1_ligandome_predictor.py
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -127,6 +127,50 @@ class Class1LigandomePredictor(object):
 
         allele_peptide_merged = concatenate([peptides_repeated, allele_flat])
 
+        layer_names = [
+            layer.name for layer in merged_ensemble.layers
+        ]
+
+        pan_allele_layer_initial_names = [
+            'allele', 'peptide',
+            'allele_representation', 'flattened_0', 'allele_flat',
+            'allele_peptide_merged', 'dense_0', 'dropout_0',
+        ]
+
+        def startswith(lst, prefix):
+            return lst[:len(prefix)] == prefix
+
+        assert startswith(layer_names, pan_allele_layer_initial_names), layer_names
+
+        layers = merged_ensemble.layers[
+            pan_allele_layer_initial_names.index(
+                "allele_peptide_merged") + 1:
+        ]
+        node = allele_peptide_merged
+        layer_name_to_new_node = {
+            "allele_peptide_merged": allele_peptide_merged,
+        }
+        for layer in layers:
+            assert layer.name not in layer_name_to_new_node
+            input_layer_names = []
+            for inbound_node in layer._inbound_nodes:
+                for inbound_layer in inbound_node.inbound_layers:
+                    input_layer_names.append(inbound_layer.name)
+            input_nodes = [
+                layer_name_to_new_node[name]
+                for name in input_layer_names
+            ]
+
+            if len(input_nodes) == 1:
+                lifted = TimeDistributed(layer)
+                result = lifted(input_nodes[0])
+            else:
+                print(layer, layer.name, node, lifted)
+                result = layer(input_nodes)
+
+            layer_name_to_new_node[layer.name] = result
+
+        """
         dense_0 = merged_ensemble.get_layer("dense_0")
         td_dense0 = TimeDistributed(dense_0, name="td_dense_0")(allele_peptide_merged)
         td_dense0 = Dropout(0.5)(td_dense0)
@@ -137,10 +181,11 @@ class Class1LigandomePredictor(object):
 
         output = merged_ensemble.get_layer("output")
         td_output = TimeDistributed(output)(td_dense1)
+        """
 
         network = Model(
             inputs=[input_peptides, input_alleles],
-            outputs=td_output,
+            outputs=node,
             name="ligandome",
         )
         #print('trainable', network.get_layer("td_dense_0").trainable)
@@ -358,7 +403,7 @@ class Class1LigandomePredictor(object):
             'allele': allele_encoding_input,
         }
         predictions = self.network.predict(x_dict, batch_size=batch_size)
-        return numpy.squeeze(predictions, axis=-1)
+        return numpy.squeeze(predictions)
 
     #def predict(self):