Neural Autoregressive density estimation (NADE)

Hello Guys,
I have been recently learning different kinds of generative models , So I think it would be helpful of we share thoughts and questions about different kinds of models and challenges and tips implementing them.

Here is the python model i used for NADE , but it seems that it has an error that i cannot figure out

class NADE(nn.Module):

    # Initialize.
    def __init__(self, inp_dimensions, latent_dimensions):
        super().__init__()
        self.inp_dimensions = inp_dimensions
        self.latent_dimensions = latent_dimensions
        self.hidden = nn.Linear(inp_dimensions, latent_dimensions)
        self.alpha_weights = nn.Parameter(torch.FloatTensor(inp_dimensions, latent_dimensions).uniform_(1e-6, 1 - 1e-6))
        self.alpha_bias = nn.Parameter(torch.FloatTensor(inp_dimensions).uniform_(1e-6, 10 - 1e-6))

        # Helper matrix to compute prefix sums of dot-products for the forward pass.
        self.sum_matrix = torch.ones(inp_dimensions, inp_dimensions, requires_grad=False, device = device)
        for rownum, row in enumerate(self.sum_matrix):
            row[rownum:] = 0
    
    # For a given input x, obtain the mean vectors describing the Bernoulli distributions for each dimension, and each sample.
    def mean_vectors(self, x):
        # Expand each sample as a diagonal matrix to make dimensions in sample each produce a separate weight e.g. not summed up
        x_diag = torch.stack([torch.diag(x_j) for x_j in x]) #[inp_dim*inp_dim]
        
        #n is input vector
        #multiply the [n* n]input vector by weights[n* latent_dim] + bias[n*1] n is by broadcasting e.g. all have same bias
        #result is [n * latent_dim] where each row corresponds to 1 value that is the activation for the hidden layer as a result
        #of considering the current dim's value only
        dot_products = self.hidden(x_diag) #[inp_dim * latent_dim]
        
        #since we are in an auto regressive model, so the weights of the inputs that activate neurons corresponding to the current i/p
        #are accounted for by summing all the weights preceeding it and using it to calculate the activation for the hidden neuron
        #e.g. if we input_dim = 3 then the first row is the activation of the hidden neuron that outputs the prob for X1
        # row#2 is the activation of the hidden neuron(s) that activates the o/p neuron that o/ps X2 probability, that has X1 * W1 + B1 weight
        #accounted for
        hidden_activations = torch.sigmoid(torch.matmul(self.sum_matrix, dot_products)) #[inp_dim*latent_dim]

        #alpha_weights are hidden->o/p weights 
        # Then multiply element-wise with alpha_weights to train the weights associated with each input dim and each hidden layer
        #e.x if we have 3 i/ps and 2 hidden layers then we have a matrix of size [3*2] where for example element [1,1] corresponds to the weight 
        # that binds i/p 1 with hidden layer neuron 1 
        hidden_to_output_elem_wise_mul =  torch.mul(hidden_activations, self.alpha_weights);
        #until the line above we can visualize the inputs as if each is connected to a different hidden layer neuron
        
        #conseptually all the i/p share the hidden neurons so we sum over the columns such that all the weights associated with a certain hidden neuron 
        #are summed up to be then added to the bias and passed through the sigmoid to get the o/p layer results
        sum_hidden_to_output_elem_wise_mul = torch.sum(hidden_to_output_elem_wise_mul, dim=2);

        return torch.sigmoid(sum_hidden_to_output_elem_wise_mul + self.alpha_bias) #[1, inp_dim]

    # Forward pass to compute log-likelihoods for each input separately.
    def forward(self, x):
        ########################################CPT Estimation##################################################################
        bernoulli_means = self.mean_vectors(x)
        # bernoulli_means_clampped = torch.clamp(bernoulli_means, 1e-19, 1 - 1e-19)
        log_bernoulli_means = torch.log(bernoulli_means + 1e-10)

        #######################################Parametric Model Calculations####################################################
        log_likelihoods = x * (log_bernoulli_means) + (1 - x) * (1 - log_bernoulli_means)
        return torch.sum(log_likelihoods, dim=1) #[batch_size, 1]

    def zero_grad_for_extra_weights(self):
        pass
        
    # Sample.
    def sample(self, num_samples):
        samples = torch.zeros(num_samples, self.inp_dimensions, device = device)
        for sample_num in range(num_samples):
            sample = torch.zeros(self.inp_dimensions, device = device)
            for dim in range(self.inp_dimensions):        
                h_dim = torch.sigmoid(self.hidden(sample))
                bernoulli_mean_dim = torch.sigmoid(self.alpha_weights[dim].dot(h_dim) + self.alpha_bias[dim])
                distribution = dist.bernoulli.Bernoulli(probs=bernoulli_mean_dim)
                sample[dim] = distribution.sample()
            samples[sample_num] = sample
        return samples