Someone please explain the meaning of ax[0]? Why is sub 0 and sub 1 written with axes? Also what does pos = y_train == 1 neg = y_train == 0 do in the given code?

I am try to answer about plotting (your ax[0], ax[1]).

ax here comes from Matplotlib’s subplots.
example: a single figure with two axes:
fig, ax = plt.subplots(2)

Please read [1] for the refence:
[1] Creating multiple subplots using plt.subplots — Matplotlib 3.6.0 documentation



plt.subplots returns a tuple that consist of a figure and axes where the number of axes depends on the first and second arguments of the plt.subplots takes which are nrows and ncols for example
fig, ax = plt.subplots(1, 2) returns a figure that consist of 1 row and 2 columns so to draw on the first axis you just need to index the ax like this ax[0].
and you can use it as
ax[0].plot(...) for the first axis.
ax[1].plot(...) for the second axis.