Looks like the Nx.Random.uniform_split
has been changed with /3 and /4 arity, so the current /2 arity call does not work.
How about we change to this?
Nx.Random.uniform_split(new_key, 0, 1, shape: {})
|> NeuralNetwork.predict(w1, b1, w2, b2)
giving a similar result.
nx.Tensor<
f32
0.6633665561676025