Hello everyone,
Would you be able to help me with the CausalAttention function? Everything seems correct, but for some reason 2 unit tests failed. I have a feeling it’s related to the way the inter functions need to be called - my output is very closed to the expected output but not the same. I was trying to debug it but got stuck. Any help would be much appreciated. Thank you beforehand!
# UNQ_C5
# GRADED FUNCTION: CausalAttention
def CausalAttention(d_feature,
n_heads,
compute_attention_heads_closure=compute_attention_heads_closure,
dot_product_self_attention=dot_product_self_attention,
compute_attention_output_closure=compute_attention_output_closure,
mode='train'):
"""Transformer-style multi-headed causal attention.
Args:
d_feature (int): dimensionality of feature embedding.
n_heads (int): number of attention heads.
compute_attention_heads_closure (function): Closure around compute_attention heads.
dot_product_self_attention (function): dot_product_self_attention function.
compute_attention_output_closure (function): Closure around compute_attention_output.
mode (str): 'train' or 'eval'.
Returns:
trax.layers.combinators.Serial: Multi-headed self-attention model.
"""
assert d_feature % n_heads == 0
d_head = d_feature // n_heads
### START CODE HERE ###
# (REPLACE INSTANCES OF 'None' WITH YOUR CODE)
# HINT: The second argument to tl.Fn() is an uncalled function (without the parentheses)
# Since you are dealing with closures you might need to call the outer
# function with the correct parameters to get the actual uncalled function.
# use 'compute_attention_heads_closure'
ComputeAttentionHeads = tl.Fn('AttnHeads', compute_attention_heads_closure, n_out=1)
return tl.Serial(
tl.Branch( # creates three towers for one input, takes activations and creates queries keys and values
[tl.Dense(d_feature), ComputeAttentionHeads], # queries
[tl.Dense(d_feature), ComputeAttentionHeads], # keys
[tl.Dense(d_feature), ComputeAttentionHeads], # values
),
tl.Fn('DotProductAttn', dot_product_self_attention, n_out=1), # takes QKV
# HINT: The second argument to tl.Fn() is an uncalled function
# Since you are dealing with closures you might need to call the outer
# function with the correct parameters to get the actual uncalled function.
# 'compute_attention_output_closure'
tl.Fn('AttnOutput', compute_attention_output_closure, n_out=1), # to allow for parallel
tl.Dense(d_feature),
)
### END CODE HERE ###
Output:
Serial_in3[
Branch_in2_out3[
[Dense_512, AttnHeads_in2]
[Dense_512, AttnHeads_in2]
[Dense_512, AttnHeads_in2]
]
DotProductAttn_in3
AttnOutput_in2
Dense_512
]
Expected Output:
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Unit tests output:
Causal Attention layer is correctly defined Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Causal Attention layer is correctly defined Serial[
Branch_out3[
[Dense_16, AttnHeads]
[Dense_16, AttnHeads]
[Dense_16, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_16
]
6 Tests passed
2 Tests failed