Back

Neural-Backed Decision Trees

Alvin Wan, Lisa Dunlap*, Daniel Ho*, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez

Our models, termed Neural-Backed Decision Trees, improve both accuracy and interpretability of modern neural networks on image classification.

Try the demo

Colab

Provide our classification model with an image of your choice, or pick one of our suggested images. Unlike a run-of-the-mill neural network, our NBDT returns sequential decisions leading up to a prediction.

Decision 1

Animal

97.2% probability

Decision 2

Chordate

97.2% probability

Decision 3

Carnivore

97.2% probability

Prediction

Dog

97.2% probability

Authors

Alvin Wan, Lisa Dunlap*, Daniel Ho*, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez

*denotes equal contribution

Affiliations

University of California, Berkeley

Boston University

Publish Date

May 4, 2021 Int'l Conference on Learning Representations (ICLR)

Abstract

Machine learning applications such as finance and medicine demand accurate and justifiable predictions, barring most deep learning methods from use. In response, previous work combines decision trees with deep learning, yielding models that:

  1. sacrifice interpretability to maintain accuracy OR
  2. sacrifice accuracy to maintain interpretability.

We forgo this dilemma by proposing Neural-Backed Decision Trees (NBDTs). NBDTs replace a neural network’s final linear layer with a differentiable sequence of decisions and a surrogate loss.This forces the model to learn high-level concepts and lessens reliance on highly-uncertain decisions, yielding both:

  1. improved accuracy: NBDTs match or outperform modern neural networks on CIFAR, ImageNet and better generalize to unseen classes byup to 16%. Furthermore, our surrogate loss improves the original model’s accuracy by up to 2%
  2. improved interpretability: improving human trust by clearly identifying model mistakes and assisting in dataset debugging

Code and pretrained NBDTs are on Github.

Takeaways

Our work culminates in three key contributions that you can takeaway for future research:

Getting Started

Installation is just one line.

pip install nbdt

Run on any image of your choosing.

nbdt https://bit.ly/3eiuCId