Hello Guys,
I have been recently learning different kinds of generative models , So I think it would be helpful of we share thoughts and questions about different kinds of models and challenges and tips implementing them.
Here is the python model i used for NADE , but it seems that it has an error that i cannot figure out
class NADE(nn.Module):
# Initialize.
def __init__(self, inp_dimensions, latent_dimensions):
super().__init__()
self.inp_dimensions = inp_dimensions
self.latent_dimensions = latent_dimensions
self.hidden = nn.Linear(inp_dimensions, latent_dimensions)
self.alpha_weights = nn.Parameter(torch.FloatTensor(inp_dimensions, latent_dimensions).uniform_(1e-6, 1 - 1e-6))
self.alpha_bias = nn.Parameter(torch.FloatTensor(inp_dimensions).uniform_(1e-6, 10 - 1e-6))
# Helper matrix to compute prefix sums of dot-products for the forward pass.
self.sum_matrix = torch.ones(inp_dimensions, inp_dimensions, requires_grad=False, device = device)
for rownum, row in enumerate(self.sum_matrix):
row[rownum:] = 0
# For a given input x, obtain the mean vectors describing the Bernoulli distributions for each dimension, and each sample.
def mean_vectors(self, x):
# Expand each sample as a diagonal matrix to make dimensions in sample each produce a separate weight e.g. not summed up
x_diag = torch.stack([torch.diag(x_j) for x_j in x]) #[inp_dim*inp_dim]
#n is input vector
#multiply the [n* n]input vector by weights[n* latent_dim] + bias[n*1] n is by broadcasting e.g. all have same bias
#result is [n * latent_dim] where each row corresponds to 1 value that is the activation for the hidden layer as a result
#of considering the current dim's value only
dot_products = self.hidden(x_diag) #[inp_dim * latent_dim]
#since we are in an auto regressive model, so the weights of the inputs that activate neurons corresponding to the current i/p
#are accounted for by summing all the weights preceeding it and using it to calculate the activation for the hidden neuron
#e.g. if we input_dim = 3 then the first row is the activation of the hidden neuron that outputs the prob for X1
# row#2 is the activation of the hidden neuron(s) that activates the o/p neuron that o/ps X2 probability, that has X1 * W1 + B1 weight
#accounted for
hidden_activations = torch.sigmoid(torch.matmul(self.sum_matrix, dot_products)) #[inp_dim*latent_dim]
#alpha_weights are hidden->o/p weights
# Then multiply element-wise with alpha_weights to train the weights associated with each input dim and each hidden layer
#e.x if we have 3 i/ps and 2 hidden layers then we have a matrix of size [3*2] where for example element [1,1] corresponds to the weight
# that binds i/p 1 with hidden layer neuron 1
hidden_to_output_elem_wise_mul = torch.mul(hidden_activations, self.alpha_weights);
#until the line above we can visualize the inputs as if each is connected to a different hidden layer neuron
#conseptually all the i/p share the hidden neurons so we sum over the columns such that all the weights associated with a certain hidden neuron
#are summed up to be then added to the bias and passed through the sigmoid to get the o/p layer results
sum_hidden_to_output_elem_wise_mul = torch.sum(hidden_to_output_elem_wise_mul, dim=2);
return torch.sigmoid(sum_hidden_to_output_elem_wise_mul + self.alpha_bias) #[1, inp_dim]
# Forward pass to compute log-likelihoods for each input separately.
def forward(self, x):
########################################CPT Estimation##################################################################
bernoulli_means = self.mean_vectors(x)
# bernoulli_means_clampped = torch.clamp(bernoulli_means, 1e-19, 1 - 1e-19)
log_bernoulli_means = torch.log(bernoulli_means + 1e-10)
#######################################Parametric Model Calculations####################################################
log_likelihoods = x * (log_bernoulli_means) + (1 - x) * (1 - log_bernoulli_means)
return torch.sum(log_likelihoods, dim=1) #[batch_size, 1]
def zero_grad_for_extra_weights(self):
pass
# Sample.
def sample(self, num_samples):
samples = torch.zeros(num_samples, self.inp_dimensions, device = device)
for sample_num in range(num_samples):
sample = torch.zeros(self.inp_dimensions, device = device)
for dim in range(self.inp_dimensions):
h_dim = torch.sigmoid(self.hidden(sample))
bernoulli_mean_dim = torch.sigmoid(self.alpha_weights[dim].dot(h_dim) + self.alpha_bias[dim])
distribution = dist.bernoulli.Bernoulli(probs=bernoulli_mean_dim)
sample[dim] = distribution.sample()
samples[sample_num] = sample
return samples