@seanmor5
For now avg_loss is counting like
{loss, params} ->
{batch_loss, new_params} = step(params, x, y)
avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
{avg_loss, new_params}
but it works wrong, i guess it shoud work something like in the snippet below
{total_loss, params} ->
{batch_loss, new_params} = step(params, x, y)
total_loss = Nx.add(Nx.mean(batch_loss), total_loss)
IO.write("\rEpoch: #{i}, Loss: #{total_loss |> Nx.divide(j + 1) |> Nx.to_number()}")
{total_loss, new_params}
it affects only the printable numbers during the training, but anyway