Skip to content

Commit 9614e40

Browse files
authored
Added main files
1 parent 3697080 commit 9614e40

8 files changed

+367
-0
lines changed

Diff for: ExploreTrainedFoodClassificationNetwork.mlx

1.95 MB
Binary file not shown.

Diff for: ImageDataPropertiesModel.m

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
classdef ImageDataPropertiesModel
2+
% ImageDataProperties Access properties and random images of image datastore
3+
%
4+
% ImageDataPropertiesModel Properties:
5+
%
6+
% ImageDatastore - The input image datastore object
7+
% TotalObservations - Total number of observations in the datastore
8+
% Classes - Unique classes in image datastore
9+
% NumClasses - Number of unique classes in image datastore
10+
% NumObservationsByClass - Number observations per unique class
11+
% NumObsToShow - Number of observations to show when accessing
12+
% image datastore
13+
%
14+
% ImageDataPropertiesModel Methods:
15+
%
16+
% getRandomData - Extract random images from datastore
17+
% getRandomDataByLabel - Extract random images from datastore by class
18+
% label
19+
%
20+
% Copyright 2020 The MathWorks, Inc.
21+
22+
properties(SetAccess = private)
23+
% Properties of the image data
24+
ImageDatastore
25+
Classes
26+
NumObservationsByClass
27+
TotalObservations
28+
end
29+
30+
properties(Dependent)
31+
NumClasses
32+
end
33+
34+
properties
35+
% Default number of observations to show when accessing image
36+
% datastore
37+
NumObsToShow = 16;
38+
end
39+
40+
methods
41+
% ImageDataPropertiesModel is constructed with an image datastore
42+
% argument
43+
function imageData = ImageDataPropertiesModel(imds)
44+
% Extract the properties of the image data
45+
imageData.ImageDatastore = imds;
46+
imageData.TotalObservations = length(imds.Labels);
47+
imageData.Classes = unique(imds.Labels);
48+
imageData.NumObservationsByClass = arrayfun(...
49+
@(x)sum(x == imds.Labels), imageData.Classes);
50+
end
51+
52+
% Select numImages random images from the image data
53+
function [data, labels] = getRandomData(imageData, numImages)
54+
idx = randperm(imageData.TotalObservations, imageData.NumObsToShow);
55+
56+
subImds = imageData.ImageDatastore.subset(idx);
57+
cellData = subImds.readall();
58+
59+
if nargin > 1
60+
data = cellfun(@(x)imresize(x, numImages), cellData, ...
61+
"UniformOutput", false);
62+
else
63+
data = cellData;
64+
end
65+
labels = imageData.ImageDatastore.Labels(idx);
66+
end
67+
68+
% Select numImages random images from the image data with class
69+
% chosenClass
70+
function [data, labels] = getRandomDataByLabel(imageData, chosenClass, numImages)
71+
72+
isIncluded = ismember(imageData.ImageDatastore.Labels, categorical(chosenClass));
73+
subImds = imageData.ImageDatastore.subset(isIncluded);
74+
75+
numObsToShow = min(imageData.NumObsToShow, length(subImds.Labels));
76+
77+
idx = randperm(length(subImds.Labels), numObsToShow);
78+
subImds = subImds.subset(idx);
79+
80+
cellData = subImds.readall();
81+
82+
if nargin > 2
83+
data = cellfun(@(x)imresize(x, numImages), cellData, ...
84+
"UniformOutput", false);
85+
else
86+
data = cellData;
87+
end
88+
labels = subImds.Labels;
89+
end
90+
end
91+
end
92+

Diff for: README.md

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Explore Deep Network Explainability Using an App
2+
This repository creates an app for understanding network predictions for image classification (UNPIC).
3+
4+
UNPIC is an app which can be used to explore the predictions of an image classification network using several deep learning visualization techniques. Using the app, you can:
5+
- Calculate network accuracy and the prediction scores of an image.
6+
- Investigate network predictions and misclassifications with occlusion sensitivity, Grad-CAM, and gradient attribution.
7+
- Visualize activations, maximally activating images, and deep dream.
8+
- Compute and examine the t-SNE plot to better understand network misclassifications.
9+
10+
![](App_Images/tsne_misclassified.png)![](App_Images/occlusionSensitivity_app.png)
11+
12+
* [📦 Requirements](#requirements)
13+
14+
* [🐆 Quick Start](#quick-start)
15+
16+
* [📂 Open the App](#open-app)
17+
18+
* [🚂 Open App with Own Trained Network](#open-app-with-own-trained-network)
19+
20+
* [🤓 Explainers](#explainers)
21+
22+
* [📚 References](#references)
23+
24+
## Requirements
25+
- [X] [MATLAB 2020a ](https://www.mathworks.com/products/matlab.html) or later
26+
- [X] [Deep Learning Toolbox](https://www.mathworks.com/products/deep-learning.html)
27+
- [ ] [Statistics and Machine Learning Toolbox](https://www.mathworks.com/products/statistics.html) (only required for t-SNE)
28+
- [ ] [Parallel Computing Toolbox](https://www.mathworks.com/products/parallel-computing.html) (only required for training using a GPU)
29+
30+
31+
## Quick Start
32+
Download or clone the repository and then run the script `startUNPIC.mlx` to open the app with a trained image classification network.
33+
34+
## Open App
35+
UNPIC is an app for interactively comparing different deep learning visualization techniques. The easiest way to get started with UNPIC is to download the repository and then open and run a live script example.
36+
* Click **Run** on the `ExploreTrainedFoodClassificationNetwork.mlx` live script to open the app with a pretrained network and image datastore.
37+
<details>
38+
<summary>Example Steps</summary>
39+
40+
1. Import image data and trained network
41+
2. Explore the trained network using several visualization techniques in an app
42+
</details>
43+
44+
* Click **Run** on the `VisualizeTrainedNetworkExample.mlx` live script to build and train a network and then open the app.
45+
<details>
46+
<summary>Example Steps</summary>
47+
48+
1. Import image data and create a datastore object
49+
2. Prepare a pretrained network for transfer learning
50+
3. Train a food image classification network
51+
4. Explore the trained network using several visualization techniques in an app
52+
</details>
53+
54+
* Open the app with your own trained network `net` and image datastore `imds` using `UNPIC(net,imds)`. For more information, see [Open App with Own Trained Network](#open-app-with-own-trained-network).
55+
56+
The UNPIC app is attached as a supporting file (`UNPIC.mlapp`) and includes several visualisation techniques to help you explain what a network has learned. Once you have the app open, click **Help** to learn more about the techniques and methods implemented.
57+
58+
UNPIC is created using App Designer. You can use App Designer to edit the underlying settings of the methods or add additional methods to the app.
59+
60+
## Open App with Own Trained Network
61+
To use the app, you must have a trained network and an image datastore. The network must be trained on images with the same class labels as the image datastore object. See the example live scripts to see how to prepare an `imageDatastore` object and train a network. Using the image data you can explore what the trained network has learned. Suppose you have a trained network called `net` and a validation image datastore called `imdsVal`, you can easily open the app to explore the trained network.
62+
```
63+
UNPIC(net,imdsVal);
64+
```
65+
66+
This app is created for use with a DAG or series image classification network, trained on RGB or grayscale images stored in an image datastore object. To use many of the visualization methods, your network must also have a `softmaxLayer`.
67+
68+
You are advised to run the live scripts to get a feel for how the app works before exploring your own trained network. For more advanced use cases, such as large data, large number of classes, or nonimage data, you will need to adapt the code using App Designer. For more information on using App Designer to create apps interactively, see [Develop Apps Using App Designer](https://www.mathworks.com/help/matlab/app-designer.html).
69+
70+
71+
## Explainers
72+
The app illustrates several explanation techniques.
73+
##### Occlusion Sensitivity
74+
Occlusion sensitivity is a simple technique for understanding which parts of an image are most important for a deep network's classification [[1]](#references). You can measure a network's sensitivity to occlusion in different regions of the data using small perturbations of the data. For more information, see [Understand Network Predictions Using Occlusion](https://www.mathworks.com/help/deeplearning/ug/understand-network-predictions-using-occlusion.html).
75+
##### Grad-CAM
76+
Grad-CAM, invented by Selvaraju and coauthors [[2]](#references), uses the gradient of the classification score with respect to the convolutional features determined by the network in order to understand which parts of the image are most important for classification. For more information, see [Grad-CAM Reveals the Why Behind Deep Learning Decisions](https://www.mathworks.com/help/deeplearning/ug/gradcam-explains-why.html).
77+
##### Gradient Attribution
78+
Gradient attribution provides pixel-resolution maps that show which pixels are most important to the network's classification [[3]](#references). Intuitively, the map shows which pixels most affect the class score when changed. For more information, see [Investigate Classification Decisions Using Gradient Attribution Techniques](https://www.mathworks.com/help/deeplearning/ug/investigate-classification-decisions-using-gradient-attribution-techniques.html).
79+
80+
![](App_Images/app_techniques.png)
81+
82+
##### Confusion Matrix
83+
The confusion matrix plot displays the predicted class vs the true class. Use this to see which classes the network struggles the most with.
84+
85+
##### Deep Dream
86+
Deep Dream is a feature visualization technique in deep learning that synthesizes images that strongly activate network layers [[4]](#references). By visualizing these images, you can highlight the image features learned by a network. These images are useful for understanding and diagnosing network behaviour. For more information, see [Deep Dream Images Using GoogLeNet](https://www.mathworks.com/help/deeplearning/ug/deep-dream-images-using-googlenet.html).
87+
##### t-SNE
88+
Use t-SNE to visualize the network activations and gain an understanding of how the network responds [[5]](#references). You can use t-SNE to visualize how deep learning networks change the representation of input data as it passes through the network layers. t-SNE is good for reducing high-dimensional activations into an easy-to-use 2-D “map” of the data. For more information, see [View Network Behaviour Using tsne](https://www.mathworks.com/help/deeplearning/ug/view-network-behavior-using-tsne.html).
89+
90+
![](App_Images/app_techniques2.png)
91+
92+
93+
## References
94+
[1] Zeiler M.D., Fergus R. (2014) Visualizing and Understanding Convolutional Networks. In: Fleet D., Pajdla T., Schiele B., Tuytelaars T. (eds) Computer Vision – ECCV 2014. ECCV 2014. Lecture Notes in Computer Science, vol 8689. Springer, Cham
95+
96+
[2] Selvaraju, R. R., M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization." In IEEE International Conference on Computer Vision (ICCV), 2017, pp. 618–626. Available at Grad-CAM on the Computer Vision Foundation Open Access website.
97+
98+
[3] Simonyan, Karen, Andrea Vedaldi, and Andrew Zisserman. “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps.” ArXiv:1312.6034 [Cs], April 19, 2014. http://arxiv.org/abs/1312.6034.
99+
100+
[4] DeepDreaming with TensorFlow. https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/deepdream.ipynb
101+
102+
[5] van der Maaten, Laurens, and Geoffrey Hinton. "Visualizing Data using t-SNE." Journal of Machine Learning Research 9, 2008, pp. 2579–2605.
103+
104+
*Copyright 2020 The MathWorks, Inc.*

Diff for: TSNEPlotter.m

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
classdef TSNEPlotter < handle
2+
% TSNEPlotter Sets up a t-SNE scatter plot with additional functionality
3+
%
4+
% t = TSNEPlotter(ax, h, reduced, imds) returns a TSNEPlotter
5+
% associated with the axes ax, the collection of t-SNE scatter plot
6+
% handles h, the 2-D reduced representation of the data reduced, and
7+
% the imageDatastore imds.
8+
%
9+
% Use this after you have created the t-SNE plot, e.g. with gscatter.
10+
% It will handle setting up on-click functions for the lines in the
11+
% plot, so clicking them displays the image corresponding to the t-SNE
12+
% datapoint.
13+
%
14+
% Example:
15+
% reduced = tsne(acts);
16+
% colors = lines(numClasses);
17+
%
18+
% % t-SNE scatter plot
19+
% h = gscatter(ax, reduced(:,1), reduced(:,2), g, ...
20+
% colors, [], 8);
21+
%
22+
% plotter = TSNEPlotter(ax, h, reduced, imds, predClass);
23+
% plotter.setUpOnClickBehavior();
24+
25+
% Copyright 2020 The MathWorks, Inc.
26+
27+
properties
28+
% ImageDisplayAxes (empty, or axes handle)
29+
% If this is empty, each new image will be plotted in a new figure.
30+
% If it is an axes handle, the image will be plotted there.
31+
ImageDisplayAxes
32+
end
33+
34+
properties(Access = private)
35+
% Axes (axes)
36+
% Handle to the axes containing the scatter plot
37+
Axes
38+
39+
% ScatterHandles (array of graphics handles)
40+
% Points to which the ButtonDownFcn will be applied to. For a t-SNE
41+
% scatter plot made with gscatter, this is the vector of line
42+
% handles which is the output of gscatter.
43+
ScatterHandles
44+
45+
% ReducedData (Nx2 numeric array)
46+
% Array of x-y points in the reduced representation
47+
ReducedData
48+
49+
% ImageDatastore (imageDatastore with N observations)
50+
% Datastore holding the original data from which ReducedData was
51+
% extracted.
52+
ImageDatastore
53+
54+
% PredictedClass
55+
% Predicted class labels for images in image datastore
56+
PredictedClass
57+
end
58+
59+
methods
60+
function this = TSNEPlotter(ax, scatterHandles, reducedData, imds, predResults)
61+
this.Axes = ax;
62+
63+
this.ScatterHandles = scatterHandles;
64+
this.ReducedData = reducedData;
65+
this.ImageDatastore = imds;
66+
this.PredictedClass = predResults;
67+
68+
end
69+
70+
function setUpOnClickBehavior(this)
71+
72+
% Separately apply a ButtonDownFcn to each graphics handle that
73+
% needs it.
74+
h = this.ScatterHandles;
75+
for i=1:length(h)
76+
h(i).ButtonDownFcn = @this.handleLineClicked;
77+
end
78+
end
79+
end
80+
81+
methods(Access = private)
82+
function handleLineClicked(this, ~, evt)
83+
% Fired when a line is clicked in ax.
84+
85+
% Find which point in the reduced representation this
86+
% corresponds to, i.e. the index of the observation being
87+
% clicked.
88+
idx = iFindNearestPoint(evt.IntersectionPoint(1:2), this.ReducedData);
89+
90+
imgFilename = this.ImageDatastore.Files{idx};
91+
trueClass = this.ImageDatastore.Labels(idx);
92+
predClass = this.PredictedClass(idx);
93+
94+
this.displayImage(imgFilename, trueClass, predClass);
95+
end
96+
97+
function displayImage(this, imgFilename, trueClass, predClass)
98+
99+
% Display in a new figure and axes, unless an axes is provided.
100+
if isempty(this.ImageDisplayAxes)
101+
ax = axes(figure);
102+
else
103+
ax = this.ImageDisplayAxes;
104+
end
105+
106+
img = imread(imgFilename);
107+
imgTitle = iImageTitle(imgFilename, trueClass, predClass);
108+
109+
imshow(img, "Parent", ax);
110+
title(ax, imgTitle, "Interpreter", "none")
111+
ax.ActivePositionProperty = 'outerposition';
112+
113+
end
114+
end
115+
end
116+
117+
function idx = iFindNearestPoint(xy, reducedData)
118+
% Find the index in the data of the nearest actual point to the user's
119+
% click.
120+
121+
% L2 distance between every reduced datapoint and the user click.
122+
d = reducedData - xy;
123+
distances = sqrt(sum(d.^2, 2));
124+
125+
[~, idx] = min(distances);
126+
end
127+
128+
function imgTitle = iImageTitle(imgFilename, trueClass, predClass)
129+
% Create a title for the image.
130+
[~, name, ext] = fileparts(imgFilename);
131+
132+
imageName = strcat(name, ext);
133+
imgTitle = {strcat("True class: ", string(trueClass), " Predicted class: ", ...
134+
string(predClass)), strcat("File name: ", imageName)};
135+
end
136+
137+
% Copyright 2020 The MathWorks, Inc.

Diff for: UNPIC.mlapp

237 KB
Binary file not shown.

Diff for: VisualizeTrainedNetworkExample.mlx

1.88 MB
Binary file not shown.

Diff for: downloadExampleFoodImagesData.m

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
function downloadExampleFoodImagesData(url,dataDir)
2+
% Download the Example Food Image data set, containing 978 images of
3+
% different types of food split into 9 classes.
4+
5+
% Copyright 2019-2020 The MathWorks, Inc.
6+
7+
fileName = "ExampleFoodImageDataset.zip";
8+
fileFullPath = fullfile(dataDir,fileName);
9+
10+
% Download the .zip file into a temporary directory.
11+
if ~exist(fileFullPath,"file")
12+
fprintf("Downloading MathWorks Example Food Image dataset...\n");
13+
fprintf("This can take several minutes to download...\n");
14+
websave(fileFullPath,url);
15+
fprintf("Download finished...\n");
16+
else
17+
fprintf("Skipping download, file already exists...\n");
18+
end
19+
20+
% Unzip the file.
21+
%
22+
% Check if the file has already been unzipped by checking for the presence
23+
% of one of the class directories.
24+
exampleFolderFullPath = fullfile(dataDir,"pizza");
25+
if ~exist(exampleFolderFullPath,"dir")
26+
fprintf("Unzipping file...\n");
27+
unzip(fileFullPath,dataDir);
28+
fprintf("Unzipping finished...\n");
29+
else
30+
fprintf("Skipping unzipping, file already unzipped...\n");
31+
end
32+
fprintf("Done.\n");
33+
34+
end

Diff for: startUNPIC.mlx

517 KB
Binary file not shown.

0 commit comments

Comments
 (0)