Understanding GRU trax code

Hi, I try to understand source code for GRU
Could you please help to understand, how different GRU cells combined together?
As GRU can handle any time length, it means GRU cells count is not connected to input time length.
It’s both changes with time (X_1, then X_2), and with multiple GRU cells. I understand idea of taking each time step and creating y_hat for that step. But how multiple GRU cells interact at that moment?
Also in source code it’s cb.Branch, cb.Scan and cb.Select for integration. Could you please describe logic what each of them does in GRU?

Thank you in advance

Hi @Iaroslav_Iatsenko

You might find this thread helpful. TLDR version: there is often confusion between GRU “cells” and GRU “layer” (dimension / units).

They don’t, each “cell” produces a result, many “cells” produces many results which is the output of the “layer” (loosely speaking, many dimensions = many cells).

Also, this thread might also help understand the underlying calculations.

As for the cb.Branch, cb.Scan and cb.Select, there is a documentation for it:

In short, these functions are needed to implement the “layer” on top of “GRUCell” (the code you might want to spend more time looking into, because here is the essence of it). This post and this post might also help you with that.