Hey, I am struggling with the optional assignment. It seems I have wrong output for the “max” and compiling error for the “average”
def pool_backward(dA, cache, mode = "max"):
# Retrieve information from cache (≈1 line)
(A_prev, hparameters) = ...
# Retrieve hyperparameters from "hparameters" (≈2 lines)
stride = ...
f = ...
# Retrieve dimensions from A_prev's shape and dA's shape (≈2 lines)
m, n_H_prev, n_W_prev, n_C_prev = ...
print("A_prev.shape: ", A_prev.shape)
m, n_H, n_W, n_C = ...
# Initialize dA_prev with zeros (≈1 line)
dA_prev = ...
print("dA_prev.shape: ", dA_prev.shape)
for i in range(...): # loop over the training examples
# select training example from A_prev (≈1 line)
a_prev =
for h in range(n_H): # loop on the vertical axis
for w in range(n_W): # loop on the horizontal axis
for c in range(n_C): # loop over the channels (depth)
# Find the corners of the current "slice" (≈4 lines)
vert_start = ...
vert_end = ...
horiz_start = ...
horiz_end = ...
# Compute the backward propagation in both modes.
if mode == "max":
# Use the corners and "c" to define the current slice from a_prev (≈1 line)
a_prev_slice = ...
# Create the mask from a_prev_slice (≈1 line)
mask = ...
# Set dA_prev to be dA_prev + (the mask multiplied by the correct entry of dA) (≈1 line)
dA_prev[i, vert_start: vert_end, horiz_start: horiz_end, c] += ...
#print("dA_prev.shape: ", dA_prev.shape)
elif mode == "average":
# Get the value da from dA (≈1 line)
da = ...
print("dA[i, h, w, c]: ", dA[i, h, w, c])
print("da: ", da)
# Define the shape of the filter as fxf (≈1 line)
shape = ...
# Distribute it to get the correct slice of dA_prev. i.e. Add the distributed value of da. (≈1 line)
dA_prev[i, h, w, c] += ...
# YOUR CODE STARTS HERE
# YOUR CODE ENDS HERE
# Making sure your output shape is correct
assert(dA_prev.shape == A_prev.shape)
return dA_prev
my output is
mode = max
mean of dA = 0.14571390272918056
dA_prev1[1,1] = [[ 0. 0. ]
[10.11330283 -0.49726956]
[ 0. 0. ]]
A_prev.shape: (5, 5, 3, 2)
dA_prev.shape: (5, 5, 3, 2)
I would appreciate some help! Thanks!