Skip to content

Commit

Permalink
L2 regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 6, 2025
1 parent d3132d0 commit af9b79f
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,33 @@
S_MIN = 0.01


DEFAULT_PARAMS_TENSOR = torch.tensor(DEFAULT_PARAMETER, dtype=torch.float)
DEFAULT_PARAMS_STDDEV_TENSOR = torch.tensor(
[
6.61,
9.52,
17.69,
27.74,
0.55,
0.28,
0.67,
0.12,
0.4,
0.18,
0.34,
0.27,
0.08,
0.14,
0.57,
0.25,
1.03,
0.27,
0.39,
],
dtype=torch.float,
)


class FSRS(nn.Module):
def __init__(self, w: List[float], float_delta_t: bool = False):
super(FSRS, self).__init__()
Expand Down Expand Up @@ -344,7 +371,6 @@ def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame

def train(self, verbose: bool = True):
self.verbose = verbose
default_params_tensor = torch.tensor(DEFAULT_PARAMETER, dtype=torch.float)
best_loss = np.inf
epoch_len = len(self.train_set.y_train)
if verbose:
Expand All @@ -365,8 +391,11 @@ def train(self, verbose: bool = True):
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).sum()
penalty = torch.mean(torch.abs(self.model.w - default_params_tensor) / (torch.abs(self.model.w) + torch.abs(default_params_tensor)) * 2)
loss += penalty * self.gamma / epoch_len
penalty = torch.sum(
torch.square(self.model.w - DEFAULT_PARAMS_TENSOR)
/ torch.square(DEFAULT_PARAMS_STDDEV_TENSOR)
)
loss += penalty * self.gamma * real_batch_size / epoch_len
loss.backward()
if self.float_delta_t or not self.enable_short_term:
for param in self.model.parameters():
Expand Down Expand Up @@ -417,6 +446,11 @@ def eval(self):
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).mean()
penalty = torch.sum(
torch.square(self.model.w - DEFAULT_PARAMS_TENSOR)
/ torch.square(DEFAULT_PARAMS_STDDEV_TENSOR)
)
loss += penalty * self.gamma / len(self.train_set.y_train)
losses.append(loss)
self.avg_train_losses.append(losses[0])
self.avg_eval_losses.append(losses[1])
Expand Down

0 comments on commit af9b79f

Please sign in to comment.