Machine Learning in Elixir: Chapter 1 doesn't work with Axon 0.7 (page 26)

Hello @seanmor5 ,

When upgrading Axon to 0.7, evaluation fails with:

** (ArgumentError) argument at position 3 is not compatible with compiled function template.

%{i: #Nx.Tensor<
    s32
  >, model_state: #Inspect.Error<
  got Protocol.UndefinedError with message:

      """
      protocol Enumerable not implemented for #Nx.Tensor<
        f32[3]
      > of type Nx.Defn.TemplateDiff (a struct). This protocol is implemented for the following type(s): Date.Range, Explorer.Series.Iterator, File.Stream, Function, GenEvent.Stream, HashDict, HashSet, IO.Stream, Kino.Control, Kino.Input, Kino.JS.Live, List, Map, MapSet, Range, Stream, Table.Mapper, Table.Zipper
      """

  while inspecting:

      %{
        data: %{
          "dense_0" => %{
            "bias" => #Nx.Tensor<
              f32[3]
            >,
            "kernel" => #Nx.Tensor<
              f32[4][3]
            >
          }
        },
        state: %{},
        __struct__: Axon.ModelState,
        parameters: %{"dense_0" => ["bias", "kernel"]},
        frozen_parameters: %{}
      }

  Stacktrace:

    (elixir 1.17.2) lib/enum.ex:1: Enumerable.impl_for!/1
    (elixir 1.17.2) lib/enum.ex:166: Enumerable.reduce/3
    (elixir 1.17.2) lib/enum.ex:4423: Enum.reduce/3
    (axon 0.7.0) lib/axon/model_state.ex:359: anonymous fn/2 in Inspect.Axon.ModelState.get_param_info/1
    (stdlib 6.0) maps.erl:860: :maps.fold_1/4
    (axon 0.7.0) lib/axon/model_state.ex:359: anonymous fn/2 in Inspect.Axon.ModelState.get_param_info/1
    (stdlib 6.0) maps.erl:860: :maps.fold_1/4
    (axon 0.7.0) lib/axon/model_state.ex:320: Inspect.Axon.ModelState.inspect/2

>, loss: 
  <<<<< Expected <<<<<
  #Nx.Tensor<
    f32
  >
  ==========
  #Nx.Tensor<
    f64
  >
  >>>>> Argument >>>>>
  , optimizer_state: {%{scale: #Nx.Tensor<
       f32
     >}}, loss_scale_state: %{}, y_true: #Nx.Tensor<
    u8[120][3]
  >, y_pred: #Nx.Tensor<
    f64[120][3]
  >}

    (nx 0.9.1) lib/nx/defn.ex:342: anonymous fn/7 in Nx.Defn.compile_flatten/5
    (nx 0.9.1) lib/nx/lazy_container.ex:73: anonymous fn/3 in Nx.LazyContainer.Map.traverse/3
    (elixir 1.17.2) lib/enum.ex:1829: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (elixir 1.17.2) lib/enum.ex:1829: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.9.1) lib/nx/lazy_container.ex:72: Nx.LazyContainer.Map.traverse/3
    (nx 0.9.1) lib/nx/defn.ex:339: Nx.Defn.compile_flatten/5
    (nx 0.9.1) lib/nx/defn.ex:331: anonymous fn/4 in Nx.Defn.compile/3
    #cell:75r7qhxms33esmsj:5: (file)

It seems it’s a problem between the type of the iris dataset (:f64) whereas :f32 is expected.

I stumbled upon it too and fixed it with explicit conversion to :f32, i.e.

    x_train =
      train_data[@columns]
      |> Nx.stack(axis: 1)
      |> Nx.as_type(:f32)

    y_train =
      train_data["species"]
      |> Nx.stack(axis: -1)
      |> Nx.equal(Nx.iota({1, 3}, axis: -1))

    x_test =
      test_data[@columns]
      |> Nx.stack(axis: 1)
      |> Nx.as_type(:f32)

    y_test =
      test_data["species"]
      |> Nx.stack(axis: -1)
      |> Nx.equal(Nx.iota({1, 3}, axis: -1))

Cool! I updated it at the dataframe level since Nx still scares me.

cols = ~w(sepal_width sepal_length petal_length petal_width)

normalized_iris =
  DF.mutate(
    iris,
    for col <- across(^cols) do
      {col.name, Explorer.Series.cast((col - mean(col)) / standard_deviation(col), :f32)}
    end
  )
1 Like