Machine Learning in Elixir: chapter 7 CNN model accuracy no better than MLP (page 160)

When training the cnn_model, I get the following output:

Epoch: 0, Batch: 150, accuracy: 0.4985513 loss: 7.6424022
Epoch: 1, Batch: 163, accuracy: 0.4992854 loss: 7.6783161
Epoch: 2, Batch: 176, accuracy: 0.5000441 loss: 7.6865749
Epoch: 3, Batch: 139, accuracy: 0.4983259 loss: 7.6991839
Epoch: 4, Batch: 152, accuracy: 0.4988766 loss: 7.6995916

%{
  "conv_0" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82179>
      [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN]
    >,
    "kernel" => #Nx.Tensor<
      f32[3][3][3][32]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82180>
      [
        [
          [
            [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN],
            [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, ...],
            ...
          ],
          ...
        ],
        ...
      ]
    >
  },
  "conv_1" => %{
    "bias" => #Nx.Tensor<
      f32[64]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82181>
      [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, -0.0071477023884654045, NaN, NaN, NaN, NaN, 0.0, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3][3][32][64]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82182>
      [
        [
          [
            [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, ...],
            ...
          ],
          ...
        ],
        ...
      ]
    >
  },
  "conv_2" => %{
    "bias" => #Nx.Tensor<
      f32[128]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82183>
      [0.0, NaN, NaN, NaN, NaN, NaN, NaN, NaN, 0.005036031361669302, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3][3][64][128]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82184>
      [
        [
          [
            [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, ...],
            ...
          ],
          ...
        ],
        ...
      ]
    >
  },
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[128]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82185>
      [NaN, -0.005992305930703878, -0.006005365401506424, -0.004664595704525709, NaN, NaN, NaN, -5.619042203761637e-4, 0.0, NaN, -0.005999671295285225, -6.131592726887902e-6, NaN, 0.0, NaN, 0.0, 0.0, NaN, NaN, -0.006002828478813171, -0.00600335793569684, 0.0, NaN, NaN, NaN, -0.006002923008054495, -0.006005282513797283, -0.00600528996437788, -0.0060048955492675304, -0.006004981696605682, NaN, -0.006004655733704567, -0.006005233619362116, NaN, -0.006004724185913801, -0.006005335133522749, -0.006005051080137491, -0.006004408933222294, NaN, -0.006005355156958103, 0.0, -0.006005344912409782, 0.0, NaN, -0.005991040728986263, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[18432][128]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82186>
      [
        [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82187>
      [NaN]
    >,
    "kernel" => #Nx.Tensor<
      f32[128][1]
      EXLA.Backend<host:0, 0.1357844422.1979580433.82188>
      [
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        [NaN],
        ...
      ]
    >
  }
}

The accuracy of the mlp_model was Batch: 6, accuracy: 0.5078125 and the accuracy of this cnn_model is Batch: 6, accuracy: 0.4944196 which was slightly worse instead of the expected “significantly better.”

I reviewed all the code to make sure I hadn’t missed anything, but I couldn’t find anything that didn’t match.

I’m guessing the NaNs in the trained model state are a problem, but I’m not sure how to fix that.

Now that I look again, my mlp_trained_model_state also has a bunch of NaNs.

On page 147, the output in the book has a different shape than what I get.

The book shows:

[
{#Nx.Tensor<
f32[channels: 3][height: 96][width: 96]
EXLA.Backend<cuda:0, 0.2231265192.3745644570.217183>
[
[
[0.5490195751190186, 0.5647058486938477, 0.5568627119064331, ...],
...
],
...
]
>,
#Nx.Tensor<
s64[1]
EXLA.Backend<cuda:0, 0.2231265192.3745644570.217184>
[0]
>},
...
]

Here’s mine:

[
  {#Nx.Tensor<
     f32[height: 96][width: 96][channels: 3]
     EXLA.Backend<host:0, 0.1357844422.1980104721.75624>
     [
       [
         [0.6941176652908325, 0.658823549747467, 0.49803921580314636],
         [0.6941176652908325, 0.658823549747467, 0.49803921580314636],
         [0.6980392336845398, 0.6627451181411743, 0.501960813999176],
         [0.7137255072593689, 0.6784313917160034, 0.5176470875740051],
         [0.7215686440467834, 0.686274528503418, 0.5254902243614197],
         [0.7215686440467834, 0.686274528503418, 0.5254902243614197],
         [0.7215686440467834, 0.686274528503418, 0.5254902243614197],
         [0.7176470756530762, 0.6823529601097107, 0.5215686559677124],
         [0.729411780834198, 0.6941176652908325, 0.5333333611488342],
         [0.7568627595901489, 0.7176470756530762, 0.545098066329956],
         [0.7607843279838562, 0.7176470756530762, 0.5372549295425415],
         [0.7568627595901489, 0.7137255072593689, 0.5333333611488342],
         [0.7568627595901489, 0.7137255072593689, 0.5333333611488342],
         [0.7607843279838562, 0.7215686440467834, 0.5372549295425415],
         [0.7529411911964417, 0.7176470756530762, 0.5176470875740051],
         [0.8196078538894653, 0.7882353067398071, 0.6509804129600525],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     s64[1]
     EXLA.Backend<host:0, 0.1357844422.1977483290.251661>
     [0]
   >},
  {#Nx.Tensor<
     f32[height: 96][width: 96][channels: 3]
     EXLA.Backend<host:0, 0.1357844422.1980104721.75626>
     [
       [
         [0.9333333373069763, 0.9372549057006836, 0.9333333373069763],
         [0.9411764740943909, 0.9490196108818054, 0.9254902005195618],
         [0.9411764740943909, 0.9529411792755127, 0.9098039269447327],
         [0.95686274766922, 0.9411764740943909, 0.8980392217636108],
         [0.9215686321258545, 0.8588235378265381, 0.7647058963775635],
         [0.8470588326454163, 0.6980392336845398, 0.5529412031173706],
         [0.8509804010391235, 0.6549019813537598, 0.49803921580314636],
         [0.8901960849761963, 0.7058823704719543, 0.5411764979362488],
         [0.8156862854957581, 0.6235294342041016, 0.4627451002597809],
         [0.772549033164978, 0.6117647290229797, 0.47058823704719543],
         [0.9450980424880981, 0.8627451062202454, 0.7568627595901489],
         [0.8745098114013672, 0.7843137383460999, 0.6549019813537598],
         [0.8313725590705872, 0.7215686440467834, 0.5490196347236633],
         [0.8235294222831726, 0.6745098233222961, 0.4941176474094391],
         [0.729411780834198, 0.5176470875740051, 0.3450980484485626],
         [0.6941176652908325, 0.46666666865348816, ...],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     s64[1]
     EXLA.Backend<host:0, 0.1357844422.1977483293.251460>
     [1]
   >},
  {#Nx.Tensor<
     f32[height: 96][width: 96][channels: 3]
     EXLA.Backend<host:0, 0.1357844422.1980104721.75628>
     [
       [
         [0.007843137718737125, 0.007843137718737125, 0.0],
         [0.007843137718737125, 0.007843137718737125, 0.0],
         [0.003921568859368563, 0.003921568859368563, 0.0],
         [0.003921568859368563, 0.003921568859368563, 0.0],
         [0.003921568859368563, 0.003921568859368563, 0.0],
         [0.003921568859368563, 0.003921568859368563, 0.0],
         [0.003921568859368563, 0.003921568859368563, 0.0],
         [0.003921568859368563, 0.003921568859368563, 0.0],
         [0.007843137718737125, 0.007843137718737125, 0.0],
         [0.007843137718737125, 0.007843137718737125, 0.0],
         [0.007843137718737125, 0.007843137718737125, 0.0],
         [0.0117647061124444, 0.0117647061124444, 0.003921568859368563],
         [0.0117647061124444, 0.0117647061124444, 0.003921568859368563],
         [0.01568627543747425, 0.01568627543747425, 0.007843137718737125],
         [0.019607843831181526, 0.019607843831181526, 0.0117647061124444],
         [0.027450980618596077, ...],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     s64[1]
     EXLA.Backend<host:0, 0.1357844422.1977483293.251462>
     [0]
   >},
  {#Nx.Tensor<
     f32[height: 96][width: 96][channels: 3]
     EXLA.Backend<host:0, 0.1357844422.1980104721.75630>
     [
       [
         [0.5607843399047852, 0.5686274766921997, 0.6196078658103943],
         [0.6117647290229797, 0.6039215922355652, 0.6666666865348816],
         [0.6549019813537598, 0.6431372761726379, 0.6784313917160034],
         [0.6470588445663452, 0.6431372761726379, 0.6509804129600525],
         [0.6196078658103943, 0.6196078658103943, 0.6078431606292725],
         [0.5882353186607361, 0.5843137502670288, 0.6078431606292725],
         [0.572549045085907, 0.5686274766921997, 0.5960784554481506],
         [0.529411792755127, 0.529411792755127, 0.5490196347236633],
         [0.5098039507865906, 0.5058823823928833, 0.529411792755127],
         [0.5372549295425415, 0.5137255191802979, 0.5333333611488342],
         [0.529411792755127, 0.49803921580314636, 0.5215686559677124],
         [0.5058823823928833, 0.4745098054409027, 0.49803921580314636],
         [0.4901960790157318, 0.4627451002597809, 0.47843137383461],
         [0.4941176474094391, 0.4627451002597809, 0.48235294222831726],
         [0.4901960790157318, 0.4588235318660736, 0.47843137383461],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     s64[1]
     EXLA.Backend<host:0, 0.1357844422.1977483298.253085>
     [1]
   >},
  {#Nx.Tensor<
     f32[height: 96][width: 96][channels: 3]
     EXLA.Backend<host:0, 0.1357844422.1980104721.75632>
     [
       [
         [0.05098039284348488, 0.05882352963089943, 0.03529411926865578],
         [0.054901961237192154, 0.06666667014360428, 0.03529411926865578],
         [0.054901961237192154, 0.06666667014360428, 0.027450980618596077],
         [0.07058823853731155, 0.0784313753247261, 0.0313725508749485],
         [0.0784313753247261, 0.08627451211214066, 0.04313725605607033],
         [0.07450980693101883, 0.08235294371843338, 0.03921568766236305],
         [0.08235294371843338, 0.09019608050584793, 0.0470588244497776],
         [0.09019608050584793, 0.09803921729326248, 0.054901961237192154],
         [0.09019608050584793, 0.09803921729326248, 0.054901961237192154],
         [0.0941176488995552, 0.10196078568696976, 0.05882352963089943],
         [0.09803921729326248, 0.10588235408067703, 0.062745101749897],
         [0.10588235408067703, 0.11372549086809158, 0.07058823853731155],
         [0.11372549086809158, 0.12156862765550613, 0.0784313753247261],
         [0.11372549086809158, 0.12156862765550613, 0.0784313753247261],
         [0.11764705926179886, 0.125490203499794, ...],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     s64[1]
     EXLA.Backend<host:0, 0.1357844422.1977483298.253088>
     [1]
   >}
]

The book says channels last is what we want, but then displays channels first tensors. It doesn’t seem like this is related to my issue.

Switching to Axon 0.7 resolved the issue.