Skip to content
Snippets Groups Projects
Commit de6b5380 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

Merge pull request #24 from hammerlab/expose-pretrain-decay

expose pretrain decay factor function
parents 802ab1c0 f992d4f5
No related branches found
No related tags found
No related merge requests found
......@@ -252,6 +252,7 @@ class Class1BindingPredictor(PredictorBase):
Y_pretrain=None,
sample_weights_pretrain=None,
n_random_negative_samples=0,
pretrain_decay=lambda epoch: np.exp(-epoch),
n_training_epochs=200,
verbose=False,
batch_size=128):
......@@ -277,6 +278,10 @@ class Class1BindingPredictor(PredictorBase):
Y_pretrain : array
Labels for extra samples, shape
pretrain_decay : int -> float function
decay function for pretraining, mapping epoch number to decay
factor
sample_weights_pretrain : array
Initial weights for the rows of X_pretrain. If not specified then
initialized to ones.
......@@ -309,16 +314,8 @@ class Class1BindingPredictor(PredictorBase):
100 * total_train_sample_weight / total_combined_sample_weight))
for epoch in range(n_training_epochs):
# weights for synthetic points can be shrunk as:
# ~ 1 / (1+epoch)**2
# or
# 2 ** -epoch
# or
# e ** -epoch
#
# TODO: explore the best scheme for shrinking imputation weight
#
decay_factor = 2.0 ** -epoch
decay_factor = pretrain_decay(epoch)
# if the contribution of synthetic samples is less than a
# thousandth of the actual data, then stop using it
pretrain_contribution = total_pretrain_sample_weight * decay_factor
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment