Machine Learning in Elixir: Chapter 1 - Unable to train model (page 19)

I am really enjoying the book so far but came across an issue in the first chapter. When running:

trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(data_stream, %{}, iterations: 500, epochs: 10)

Still too new to debug this but it appears an issue with expecting either an f32 or f64 and getting the other and/or passing parameter maps rather than using Axon.ModelState{}:


13:42:29.641 [warning] passing parameter map to initialization is deprecated, use %Axon.ModelState{} instead
Epoch: 0, Batch: 0, accuracy: 0.4750000 loss: 0.0000000
** (ArgumentError) argument at position 3 is not compatible with compiled function template.

%{i: #Nx.Tensor<
    s32
  >, model_state: #Inspect.Error<
  got Protocol.UndefinedError with message:

      """
      protocol Enumerable not implemented for type Nx.Defn.TemplateDiff (a struct). This protocol is implemented for the following type(s): Date.Range, Explorer.Series.Iterator, File.Stream, Function, GenEvent.Stream, HashDict, HashSet, IO.Stream, Kino.Control, Kino.Input, Kino.JS.Live, List, Map, MapSet, Range, Stream, Table.Mapper, Table.Zipper

      Got value:

          #Nx.Tensor<
            f32[3]
          >
      """

  while inspecting:

      %{
        data: %{
          "dense_0" => %{
            "bias" => #Nx.Tensor<
              f32[3]
            >,
            "kernel" => #Nx.Tensor<
              f32[4][3]
            >
          }
        },
        state: %{},
        __struct__: Axon.ModelState,
        parameters: %{"dense_0" => ["bias", "kernel"]},
        frozen_parameters: %{}
      }

  Stacktrace:

    (elixir 1.18.3) lib/enum.ex:1: Enumerable.impl_for!/1
    (elixir 1.18.3) lib/enum.ex:166: Enumerable.reduce/3
    (elixir 1.18.3) lib/enum.ex:4515: Enum.reduce/3
    (axon 0.7.0) lib/axon/model_state.ex:359: anonymous fn/2 in Inspect.Axon.ModelState.get_param_info/1
    (stdlib 6.2.2.1) maps.erl:860: :maps.fold_1/4
    (axon 0.7.0) lib/axon/model_state.ex:359: anonymous fn/2 in Inspect.Axon.ModelState.get_param_info/1
    (stdlib 6.2.2.1) maps.erl:860: :maps.fold_1/4
    (axon 0.7.0) lib/axon/model_state.ex:320: Inspect.Axon.ModelState.inspect/2

>, y_true: #Nx.Tensor<
    u8[120][3]
  >, y_pred: #Nx.Tensor<
    f64[120][3]
  >, loss: 
  <<<<< Expected <<<<<
  #Nx.Tensor<
    f32
  >
  ==========
  #Nx.Tensor<
    f64
  >
  >>>>> Argument >>>>>
  , optimizer_state: {%{scale: #Nx.Tensor<
       f32
     >}}, loss_scale_state: %{}}

    (nx 0.10.0) lib/nx/defn.ex:342: anonymous fn/7 in Nx.Defn.compile_flatten/5
    (nx 0.10.0) lib/nx/lazy_container.ex:73: anonymous fn/3 in Nx.LazyContainer.Map.traverse/3
    (elixir 1.18.3) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (elixir 1.18.3) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.10.0) lib/nx/lazy_container.ex:72: Nx.LazyContainer.Map.traverse/3
    (nx 0.10.0) lib/nx/defn.ex:339: Nx.Defn.compile_flatten/5
    (nx 0.10.0) lib/nx/defn.ex:331: anonymous fn/4 in Nx.Defn.compile/3
    #cell:3r6bhsjthve53hp7:5: (file)

In my terminal running the livebook I get another warning:

 [warning] passing parameter map to initialization is deprecated, use %Axon.ModelState{} instead

but I do not yet know how to do this. Please guide me in the right direction. Thank you.

Got the versions by adding :debug true

elixir 1.18.3
nx 0.10.0
axon 0.7.0

Solution is found at:

https://devtalk.com/t/machine-learning-in-elixir-chapter-1-doesnt-work-with-axon-0-7-page-26/173984

Explicitly converting the training and test sets to :f32 corrects the issue and the simulation can run.

feature_columns = [
  "sepal_length",
  "sepal_width",
  "petal_length",
  "petal_width"
]

label_column = "species"

x_train = Nx.stack(train_df[feature_columns], axis: 1)
|> Nx.as_type(:f32)

y_train =
  train_df
  |> DF.pull(label_column)
  |> Explorer.Series.to_list()
  |> Enum.map(fn
    "Iris-setosa" -> 0
    "Iris-versicolor" -> 1
    "Iris-virginica" -> 2
  end)
  |> Nx.tensor(type: :u8)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 3}, axis: -1))
  |> Nx.as_type(:f32)

x_test = Nx.stack(test_df[feature_columns], axis: 1)
|> Nx.as_type(:f32)

y_test =
  test_df
  |> DF.pull(label_column)
  |> Explorer.Series.to_list()
  |> Enum.map(fn
    "Iris-setosa" -> 0
    "Iris-versicolor" -> 1
    "Iris-virginica" -> 2
  end)
  |> Nx.tensor(type: :u8)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 3}, axis: -1))
  |> Nx.as_type(:f32)