I am sorry if this is a stupid question, but I just can't wrap my head around this. I am trying to create my first neural network, which takes MNIST data (28x28) on which are hand-drawn numbers 0-9 and output which digit neural networks thinks it is. In the last layer, I need to do a softmax function, which can output all the probabilities of those numbers, which then sum up to 1.
def softmax(z):
exps = np.exp(z - z.max())
return exps/np.sum(exps), z
To this point, everything should be fine. But now we get to the backpropagation part => I have found out on the internet this softmax function for backpropagation.
def softmax_backward(dA, Z):
x, _ =softmax(dA)
s=x.reshape(-1,1)
return (np.diagflat(s) - np.dot(s, s.T))
Question 1: Is this softmax derivative function suitable for my NN?
If it is suitable, then I have error somewhere else. This is my error:
--------------------------------------------------------------------------- ValueError Traceback (most recent call last)in ---> 26 parameters = model(x_testone, y_testone, layer_dims) in model(X, y, layer_dims, learning_rate, epochs, print_cots, activation) 10 zCache = zCaches[l+1] 11 ---> 12 grads = L_model_backward(Al, y, linCaches, zCaches, activation) 13 14 parameters = update_parameters(parameters, grads, learning_rate) in L_model_backward(Al, y, linCaches, zCaches, activation) ---> 11 grads["dA" + str(L-1)], grads["dW" + str(L)], grads["db" + str(L)] = liner_activ_backward(dAl, zCaches[L-1], linCaches[L-1], "softmax") 12 in liner_activ_backward(dA, zCache, linCache, activation) 20 dZ = softmax_backward(dA, Z) ---> 21 dA_prev, dW, db = linear_backward(dZ, linCache) 22 return dA_prev, dW, db 23 in linear_backward(dZ, linCache) ----> 7 dW = (1/m) * np.dot(dZ, A_prev.T) 8 db = (1/m) * np.sum(dZ, axis=1, keepdims=True) 9 dA_prev = np.dot(W.T, dZ) ValueError: shapes (10000,10000) and (20,1000) not aligned: 10000 (dim 1) != 20 (dim 0) ```
Now I think that my error is in liner_backward method, because it's not compatible with softmax. Am I right with this, or totally wrong?
Question 2: What method should I use instead of linear_backward method?
Many thanks for any help!