The calculation of average loss doesn’t seem right to me:
avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
I believe loss
is the average loss returned by the previous iteration, which includes j
samples so far, so I would think we’d need to weight that by multiplying by j
before dividing by j + 1
… something like this?
avg_loss = Nx.add(Nx.mean(batch_loss), Nx.multiply(loss, j)) |> Nx.divide(j + 1)
or am I confusing myself?