Hello,
In lab3 of module 4, we prune a resnet18 model and then finetune it. There is this information in the notebook: During this fine-tuning process, the optimizer updates weight_orig, while the weight_mask ensures that the pruned (zeroed-out) connections remain zero. This allows the model to recover performance by adjusting the remaining non-zero weights.
What does weight_mask ensures mean?
If I understand it correctly then it means that during training:
in every epoch, weight_mask zeroes selected weight_orig weights to make them not affect the prediction
in every epoch, weight_mask zeroes selected weight_orig gradients to not update corresponding weights
Your understanding is partially correct. While you are right about the prediction phase, your hypothesis about the gradients is technically incorrect.
Here is the precise behavior:
Prediction (Correct): You are right. In every forward pass, the weight_mask is multiplied by weight_orig. This forces the selected weights to become zero before they are used for any calculation, ensuring they do not affect the prediction.
Training/Updates (Incorrect Mechanism): The weight_mask does not zero out the gradients. In fact, the optimizer does update weight_orig. The underlying values in weight_orig corresponding to the pruned connections can change during training (e.g., due to weight decay).
The “assurance” that connections remain zero comes from the mask acting as a filter, not a lock. No matter how the optimizer updates the hidden values in weight_orig, the mask forces them to zero every time the model attempts to use them.
Sorry, I haven’t gotten to PyTorch C3 yet and have never looked at pruning before, but just from a math p.o.v. the pruned weight values are irrelevant. You could do back prop either way: apply the masks to the gradients or not. It makes no difference in the net effect, so it’s just a computational efficiency question. The gradients are derived from zero values for the pruned weights, so if you update the “real” version of those pruned weights the values become sort of nonsensical. But we’re not using them anyway, so it doesn’t matter. You just save some Hadamard multiplies during the application of the gradients.
At least that’s my interpretation of the points that Mubsi makes above. Let me know if I’m missing the point here.
When I tried to research on weight mask, pruning on pytorch, I found the following
How Weight Masks Work in PyTorch
When you apply pruning to a model using the built-in torch.nn.utils.prune functionality, PyTorch implements it as
The original weight parameter is renamed to weight_orig.
A new binary tensor, weight_mask, is created.
The actual weight attribute used in the forward pass is a dynamic computation: weight = weight_orig * weight_mask.
During backpropagation, the optimizer updates weight_orig, while the weight_mask ensures that the gradients for the pruned (zeroed-out) connections remain zero, preventing them from ever being updated.
Pruning and Epoch Training Workflow
The standard practice for training a pruned model across epochs is an iterative three-steps
Pre-training: Train the original dense network for a number of epochs to allow the weights to stabilize.
Pruning: At a certain epoch or after full convergence, identify and prune a percentage of the least important weights (for example those with the smallest absolute values) by updating the weight_mask.
Fine-tuning: Continue training the pruned model for more epochs. The existing non-zero weights are adjusted to recover performance, while the masked weights are kept at zero. This cycle can be repeated to achieve higher sparsity levels.
How it is implemented in Pytorch
torch.nn.utils.prune: You apply pruning using functions like prune.random_unstructured or prune.l1_unstructured at specific epochs or iterations by calling the relevant methods within your training loop. After the final fine-tuning, you can use prune.remove to make the pruning permanent and remove the weight_orig and weight_mask buffers.
PyTorch Lightning: The ModelPruning callback can automate this process, allowing you to define the pruning method and when it should run (example, at the end of a training epoch)
I also found an interesting forum topic in pytorch from 2023 where a learner is creating neural network pruning from scratch.
Sharing the link, in case you want to explore, notice the topic creator in the link was not able to prune the model parameters unlike the staff response, and with that thread what I could understand pruned model when again trained with the test data can result into more parameters, probably I noticed the topic creator was using the model instead of pruned model
This another stackflow link describing model pruning steps