I stumbled upon it too and fixed it with explicit conversion to :f32, i.e.
x_train =
train_data[@columns]
|> Nx.stack(axis: 1)
|> Nx.as_type(:f32)
y_train =
train_data["species"]
|> Nx.stack(axis: -1)
|> Nx.equal(Nx.iota({1, 3}, axis: -1))
x_test =
test_data[@columns]
|> Nx.stack(axis: 1)
|> Nx.as_type(:f32)
y_test =
test_data["species"]
|> Nx.stack(axis: -1)
|> Nx.equal(Nx.iota({1, 3}, axis: -1))