Skip to content

Commit

Permalink
Use log1p(x) instead of log(1+x) (#1286)
Browse files Browse the repository at this point in the history
This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html

Found with https://github.com/pytorch-labs/torchfix/
  • Loading branch information
kit1980 committed Sep 19, 2024
1 parent a308b4e commit cdef4d4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mnist_forward_forward/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def train(self, x_pos, x_neg):
for i in range(self.num_epochs):
g_pos = self.forward(x_pos).pow(2).mean(1)
g_neg = self.forward(x_neg).pow(2).mean(1)
loss = torch.log(
1
+ torch.exp(
loss = torch.log1p(
torch.exp(
torch.cat([-g_pos + self.threshold, g_neg - self.threshold])
)
).mean()
Expand Down

0 comments on commit cdef4d4

Please sign in to comment.