I am not sure how to specify the input params for the dot_product_self_attention layer. The function is supposed take three parameters from previous layer. I wonder how to refer to them when calling this function. Thanks!
return tl.Serial(
tl.Branch( # creates three towers for one input, takes activations and creates queries keys and values
[tl.Dense(d_feature), ComputeAttentionHeads], # queries
[tl.Dense(d_feature), ComputeAttentionHeads], # keys
[tl.Dense(d_feature), ComputeAttentionHeads] # values
),
**tl.Fn('DotProductAttn', dot_product_self_attention(?,?,?), n_out=1), # takes QKV**
...
)
please help