Fixing the Training Function for an LSTM model built from scratch

So, I’m currently trying to build a custom-built LSTM model. To do so, I’m trying to build the custom LSTM from scratch first and slowly modify it. However, right now I’m having difficulty making the model learn at all.

Here is my code for my LSTM function.

class LSTM_Model(nn.Module):
    def __init__(self, input_sz, hidden_sz, out_sz, ):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.fc1 = nn.Linear(hidden_sz, out_sz)
        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, init_states=None):
        b_sz, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(b_sz, self.hidden_size).to(x.device),
                        torch.zeros(b_sz, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states

        for t in range(seq_sz):
            x_t = x[:, t, :]
            if x_t.size()[0] != b_sz:
                fill_in = torch.zeros(b_sz-x_t[0], seq_sz)
                x_t - torch.cat((x_t, fill_in), 0)

            gates = torch.matmul(x_t, self.W) + torch.matmul(h_t, self.U) + self.bias
            i_t, f_t, g_t, o_t = gates.chunk(4, 1)
            i_t, f_t, g_t, o_t = torch.sigmoid(i_t), torch.sigmoid(f_t), torch.tanh(g_t), torch.sigmoid(i_t)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t)
        out = self.fc1(h_t)
        return out, (h_t, c_t)

Here is my training function.

def train_model(data_loader, model, loss_function, optimizer):
    num_batches = len(data_loader)
    total_loss = 0
    model.train()
    hidden = None
    #torch.autograd.set_detect_anomaly(True)
    for X, y in data_loader:
        output, hidden = model(X, hidden)
        loss = loss_function(output, y)

        optimizer.zero_grad()
        loss.backward()
        #loss.backward(retain_graph=True)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / num_batches
    print(f"Train loss: {avg_loss}")
    return avg_loss

Running it gets an initial error message, which requires me to put the retain_graph(True), but after that, I get more errors.

I also noticed that by removing the hidden variable for the training function, the program gives an output, but the model does not learn at all, probably because it’s throwing away the hidden state I think.

Yea, but I have no idea what’s going wrong. Please help.

1 Like

It looks like you are zeroing the gradients after the training iteration and before you apply them. That will not end well. :scream_cat: The gradients are computed when you invoke the model and the loss function, right? So you zero them before that and then apply the gradients after that. Rinse and repeat …

Note that most everything here uses TF and it’s only the GANs specialization that uses PyTorch, so this may not be the best place to get torch advice. But if this is your first attempt at using PyTorch there are lots of tutorials and documentation on the torch website. :nerd_face: Or you could take the GANs specialization and they give you a nice intro to torch and you get to see lots of examples of how to write the code to run and train a model as you go through GANs.

1 Like

Fixing this does not work. Changing the position of the zero grad does not make the model learn. It its still training like this:

1 Like

Well there are quite a few lines of code there, so more debugging is required. You don’t actually show how you invoke train_model. What loss function are you using?

Do you have any “worked examples” of models that you have successfully defined and trained using torch? You might compare the current non-working example to a working example and try to reason about what is different.

1 Like