In the ungraded lab 1 of Week 4, I can see very frequent use of the
tie_in function from
jax.lax, however, I don’t understand the use case of this function. As per the documentation here, it returns
y being the function parameters, but why would we use this function anywhere, why not simply using
y in the first place? This function seems to be pretty redundant to me.
This is a good question and I had the same question but failed to find the detailed definitive answer to it. The conclusion I had is that this the way gradient graph needs to be maintained.
According to the explanation in the Notebook:
tie_in: Some non-numeric operations must be invoked during backpropagation. Normally, the gradient compute graph would determine invocation but these functions are not included. To force re-evaluation, they are ‘tied’ to other numeric operations using tie_in.
Which for me suggests that
y gradient is “tied” again to
x in order not to break the gradient compute graph. I think that is the reason (not loosing or including/re-including into gradient graph) is the use of this function (instead of “not simply using
y”). Most of the time these non-numeric operations are related to masks so it makes sense but I failed to find the details (the mechanism) of that.
Maybe someone will elaborate more?
Thanks a lot @arvyzukai for your valuable inputs on this.
Thanks @Elemento and @arvyzukai for this discussion. I too was wondering about this function
tie_in. When you go to the source: jax._src.lax.lax — JAX documentation
It says this is deprecated? Sounds like an important function, why would they remove this?:
def tie_in(x: Any, y: T) -> T:
"""Deprecated. Ignores ``x`` and returns ``y``."""
Don’t quote me on this I’m just loosely speculating… If I had to guess they were cleaning up the library and the
tie_in function seems to be a “hack” of some type to keep the gradient.
The previous doc string stated:
"""Gives ``y`` a fake data dependence on ``x``.
When staging to XLA (e.g. running under jit or pmap), values that don't depend
on computation inputs are computed op-by-op, and folded into the XLA
computation as constants.
``tie_in`` provides a way to explicitly stage values into the computation.
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
is ``y``, but staged to XLA. Downstream use of the result will also be staged
Maybe they found another way to get around this or even it might have been some kind of security flaw… We can just speculate or we even could contact jax developers and try to find out, but I guess the world moves on and some things are left to be forgotten