Skip to content

Commit

Permalink
add gamma for regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 2, 2025
1 parent 2905136 commit d3132d0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def __init__(
init_w: List[float],
n_epoch: int = 1,
lr: float = 1e-2,
gamma: float = 0.01,
batch_size: int = 256,
max_seq_len: int = 64,
float_delta_t: bool = False,
Expand All @@ -312,6 +313,7 @@ def __init__(
self.model = FSRS(init_w, float_delta_t)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.clipper = ParameterClipper()
self.gamma = gamma
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.build_dataset(train_set, test_set)
Expand Down Expand Up @@ -364,7 +366,7 @@ def train(self, verbose: bool = True):
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).sum()
penalty = torch.mean(torch.abs(self.model.w - default_params_tensor) / (torch.abs(self.model.w) + torch.abs(default_params_tensor)) * 2)
loss += penalty * 0.01
loss += penalty * self.gamma / epoch_len
loss.backward()
if self.float_delta_t or not self.enable_short_term:
for param in self.model.parameters():
Expand Down

0 comments on commit d3132d0

Please sign in to comment.