Understanding Image Classifiers at the Dataset Level with Diffusion Models

by Greg Kondas (MLD3 Group)

Gregory Kondas

Wouldn’t it be great to understand the dataset level behavior of your image classifier?

Intro 

This post is adapted from DEPICT: Diffusion-Enabled Permutation Importance for Image Classification Tasks [ECCV 2024].

Image classifiers are machine learning models trained to analyze and interpret images. For example, an image classifier could be trained to identify the presence of pneumonia in a chest X-ray. With this powerful ability, image classifiers have become ubiquitous, powering everything from your phone’s facial recognition to security cameras and medical imaging systems.

However, our ability to fully understand the predictions made by image classifiers remains limited. This has led to problematic behavior, such as models relying on irrelevant correlations or “shortcuts” present in training data. In high-stakes settings such as healthcare, such failures could harm patients: imagine a model trained to detect pneumonia from chest X-ray images. As pneumonia is more common in children than adults, the model might simply be predicting pneumonia based on the torso size in the image, even though body size doesn’t actually help with pneumonia prediction. Figuring out what a model is using to make predictions can help catch such issues before they can cause harm to patients.

We’re not the first to think about how to explain an image classifiers’ predictions. For example, tools like GradCAM and LIME provide a “heatmap” style visualization, highlighting pixels that most affected the model’s prediction. GradCAM and LIME might suggest that a model was relying on the presence of a green grassy background to predict the presence of a dog on one particular image. But without checking every single image, it’s not clear if this is a one-time mistake by the model, or a systemic problem.

To address this limitation, we propose a new method called DEPICT to detect whether an image model is using a particular feature to make predictions across an entire dataset. Our method relies on the idea of “concepts” in computer vision. In computer vision, a “concept” refers to a recognizable object or pattern, such as a ‘person’, ‘couch’, or ‘bed’. DEPICT provides a ranking of concepts, providing insights into general model behavior as opposed to explaining a single prediction. For example, given the task of predicting whether an image was taken outside, DEPICT might output a ranking as follows: bed, couch, person (most to least important) This ranking tells us that the presence of a bed is highly informative for this task. We can then check this with our common sense: beds are typically only found indoors, so if a bed appears in an image, it’s extremely likely that the image was taken indoors. Therefore, it’s reasonable that our image classifier would find beds to be informative. With feature importance rankings from DEPICT, we can understand systematic model behavior beyond a single prediction. DEPICT relies on an explanation framework known as “permutation feature importance,” which is traditionally used on tabular data — data organized into rows and columns like an Excel sheet — to provide feature importance rankings. Of course, images aren’t exactly the same as an Excel sheet. Let’s first walk through an example of permutation importance in the tabular setting before discussing how we can extend it to image models using DEPICT. 

Tabular Permutation Feature Importance

Suppose we have a trained model that can predict if a dog is a Bichon based on some features such as height, paw diameter, and fur color. When we provide the model with a dog’s features, it outputs a simple “yes” or “no” prediction on whether the dog is a Bichon. Let’s assume we’ve collected these features for several dogs and organized them into a table, which we’ll call X and looks something like this:

For clarity, we’ve marked each row of the dataset (features of a different dog) in red, orange, and green. We now want to understand how important each of these features — height, paw diameter, and fur color — are in helping the model decide if the dog is a Bichon. To measure feature importance using permutation importance, we’ll measure how the accuracy of the model, or the percentage of correct predictions, is affected by permuting (i.e. scrambling) one of the feature columns in X. For example, let’s say that we’ve scrambled the values in the height column of X as follows:

Notice that by permuting (i.e., scrambling) the height column, we destroy the relationship between the height feature and the is_bichon label. In other words, given a row from the Xheighttable, the height feature should no longer be helpful when predicting the is_bichon label due to the scrambling — it’s basically just a randomly generated feature with no correlation to the label we’re trying to predict. A model that heavily relies on height in its predictions would perform worse on this scrambled version of the data. Comparing model performance before and after scrambling a feature is the central idea behind permutation feature importance.

To calculate height’s feature importance score, we first measure the model’s accuracy on X with all features in their original order, then measure accuracy on Xheightwith the height values scrambled. We then calculate height’s feature importance score by measuring the difference in model accuracy with and without scrambling it. By repeating this process for each feature, we obtain importance scores for all features, which we then rank to identify the most and least important ones. To summarize, in order to rank all features by importance, we can: 

  • For each feature:
    • Permute the feature column in dataset X to generate Xscrambled .
    • Calculate feature importance score by calculating difference in model performance between Xscrambled and X.
  • Rank features by importance score in descending order to create the feature importance ranking.

DEPICT: Extending Permutation Feature Importance to Images

Our method, DEPICT, extends permutation importance to image classifiers. While permuting features in a tabular dataset is trivial — we’re just shuffling a column in a table — permuting features across images is not so easy: for example, scrambling pixels of an image destroys a lot of information at once, instead of information associated with one specific feature, and it’s unclear what information was destroyed.

DEPICT bypasses pixel level manipulations and instead leverages higher level concepts, recognizable visual patterns that serve as high level features for image classifiers, such as a ‘person,’ ‘couch,’ or ‘table.’ “Concepts” are advantageous because they are tied to natural language, and also correspond to a visual pattern in an image’s pixels. We can transform a list of concepts for any image into a caption that simply counts the concepts present. DEPICT takes advantage of this vision-language connection to permute concepts across images. Starting with captions that originally describe the concepts in each image, DEPICT first permutes concepts across captions, producing a new set of captions with the concept of interest scrambled. Then, using the concept-scrambled captions, a text-to-image diffusion model can generate a new set of images with the concept also scrambled. With both the original images and concept-permuted images, we can observe how permuting concepts affects model performance and determine the feature importance of each concept, similar to the tabular data case!

Let’s dive deeper into how DEPICT scrambles concepts across images. DEPICT leverages dataset annotations to assign a caption that lists the number of each concept present. For example, the following image from the MS COCO dataset would be given the caption of “1 person, 1 dog, 1 skateboard”: 

This caption feels a lot like a “row” in the tabular dataset: you could imagine, say, an Excel sheet with columns “# of people, # of dogs, # of skateboards” that represents a bunch of images. With these captions, we can easily permute concepts across the captions of a set of images using the same technique as we showed on tabular data earlier. Let’s permute the “person” concept across these captions: 

Note that, by permuting the concept, all we did was change all of the captions. We now have two sets of captions: one corresponding to the original concepts in our dataset of images, and one with the “person” concept permuted. Then, using a text-to-image diffusion model, we generate two corresponding image sets: 1) images with the original concepts present and 2) images with the person concept permuted across images. 

These two sets of images can then be used to quantify the importance of the ‘person’ concept by calculating the difference in model performance between them. We then repeat this process for each concept of interest, calculate their respective importance scores, and rank them to produce a concept importance ranking. This pipeline is markedly similar to permutation importance in tabular data: instead of a row of features, we have a caption that counts concepts. Instead of permuting a column in a table, we permute concepts across captions in text and generate corresponding images. For a more detailed discussion of the method, please see the full manuscript linked above.

Empirical Results

We compare DEPICT to other heat-map based explainability techniques (Grad-CAM and LIME) on a subset of MS COCO, an image segmentation dataset. Since different segmentation masks correspond to different objects, we can treat each object type as a “concept.”

Our experiment proceeds as follows. We train a model that heavily relies on the “person” concept, and evaluate whether the proposed approach assigns larger importance to the “person” concept than other approaches. To convert heatmaps to feature importance scores, we measure how often the heatmap overlaps with the segmentation mask, as measured by intersection-over-union (IoU). Indeed, DEPICT successfully identified ‘person’ as the most important concept and recognized that other concepts, like couches and ovens, were significantly less informative, while GradCAM and LIME struggled to do so. We ran similar experiments on chest X-ray images and a synthetic dataset of shape-filled images in the full manuscript.

Concluding Remarks

Understanding image classifiers is important to make sure that machine learning models are safe in high-stakes settings. DEPICT offers a new way to understand image classifiers at the dataset level by adapting permutation importance for tabular data to image data, using text-to-image models to “permute” concepts in images. er concepts, like couches and ovens, were significantly less informative, while GradCAM and LIME struggled to do so. We ran similar experiments on chest X-ray images and a synthetic dataset of shape-filled images in the full manuscript.

About the authors

Sarah Jabbour is a 5th year CSE PhD student who works on computer vision and human computer interaction. In her past work, Sarah has explored how “shortcut” learning in vision models can lead to bias in chest X-ray prediction models. Sarah’s current research focuses on how clinician’s decision making is impacted by the use of AI and electronic health record representation learning. Sarah is advised by Professors Jenna Wiens and David Fouhey.
Gregory Kondas is a full time research assistant within Professor Jenna Wiens’s MLD3 group and a recent Computer Science B.S. graduate from Michigan. His current research focuses on electronic health record and wearable time series data representation learning. Greg is currently applying to PhD programs with start dates in Fall 2025.

Editors: Trenton Chang, Vaibhav Balloli, Aurelia Bunescu