Unconstrained Channel Pruning

Alvin Wan · Hanxiang Hao · Kaushik Patnaik · Yueyang Xu · Omer Hadad · David Güera · Zhile Ren · Qi Shan

Our export library, called UPSCALE, improves accuracy of any channel pruning algorithm by removing constraints.


Alvin Wan · Hanxiang Hao · Kaushik Patnaik · Yueyang Xu · Omer Hadad · David Güera · Zhile Ren · Qi Shan



Publish Date

Jul 26, 2023 Int'l Conference on Machine Learning (ICML)


As neural networks grow in size and complexity, inference speeds decline. To combat this, one of the most effective compression techniques -- channel pruning -- removes channels from weights. However, for multi-branch segments of a model, channel removal can introduce inference-time memory copies. In turn, these copies increase inference latency -- so much so that the pruned model can be slower than the unpruned model.

As a workaround, pruners conventionally constrain certain channels to be pruned together. This fully eliminates memory copies but, as we show, significantly impairs accuracy. We now have a dilemma: Remove constraints but increase latency, or add constraints and impair accuracy.

In response, our insight is to reorder channels at export time, (1) reducing latency by reducing memory copies and (2) improving accuracy by removing constraints. Using this insight, we design a generic algorithm UPSCALE to prune models with any pruning pattern. By removing constraints from existing pruners, we improve ImageNet accuracy for post-training pruned models by 2.1 points on average -- benefiting DenseNet (+16.9), EfficientNetV2 (+7.9), and ResNet (+6.2). Furthermore, by reordering channels, UPSCALE improves inference speeds by up to 2x over a baseline export.

Code is on Github.

Getting Started

Installation is just one line.

pip install apple-upscale

Run on any model of your choosing.

  import torch, torchvision
  from upscale import MaskingManager, PruningManager
  x = torch.rand((1, 3, 224, 224), device='cuda')
  model = torchvision.models.get_model('resnet18', pretrained=True).cuda()  # get any pytorch model