Machine Learning in Elixir: Chapter 4. Bug with avg_loss counting

@seanmor5

For now avg_loss is counting like

{loss, params} ->
  {batch_loss, new_params} = step(params, x, y)
  avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1) 
  IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}") 
  {avg_loss, new_params}

but it works wrong, i guess it shoud work something like in the snippet below

{total_loss, params} ->
  {batch_loss, new_params} = step(params, x, y)
  total_loss = Nx.add(Nx.mean(batch_loss), total_loss) 
  IO.write("\rEpoch: #{i}, Loss: #{total_loss |> Nx.divide(j + 1) |> Nx.to_number()}")
  {total_loss, new_params}

it affects only the printable numbers during the training, but anyway