-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathTSNEPlotter.m
137 lines (110 loc) · 4.58 KB
/
TSNEPlotter.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
classdef TSNEPlotter < handle
% TSNEPlotter Sets up a t-SNE scatter plot with additional functionality
%
% t = TSNEPlotter(ax, h, reduced, imds) returns a TSNEPlotter
% associated with the axes ax, the collection of t-SNE scatter plot
% handles h, the 2-D reduced representation of the data reduced, and
% the imageDatastore imds.
%
% Use this after you have created the t-SNE plot, e.g. with gscatter.
% It will handle setting up on-click functions for the lines in the
% plot, so clicking them displays the image corresponding to the t-SNE
% datapoint.
%
% Example:
% reduced = tsne(acts);
% colors = lines(numClasses);
%
% % t-SNE scatter plot
% h = gscatter(ax, reduced(:,1), reduced(:,2), g, ...
% colors, [], 8);
%
% plotter = TSNEPlotter(ax, h, reduced, imds, predClass);
% plotter.setUpOnClickBehavior();
% Copyright 2020 The MathWorks, Inc.
properties
% ImageDisplayAxes (empty, or axes handle)
% If this is empty, each new image will be plotted in a new figure.
% If it is an axes handle, the image will be plotted there.
ImageDisplayAxes
end
properties(Access = private)
% Axes (axes)
% Handle to the axes containing the scatter plot
Axes
% ScatterHandles (array of graphics handles)
% Points to which the ButtonDownFcn will be applied to. For a t-SNE
% scatter plot made with gscatter, this is the vector of line
% handles which is the output of gscatter.
ScatterHandles
% ReducedData (Nx2 numeric array)
% Array of x-y points in the reduced representation
ReducedData
% ImageDatastore (imageDatastore with N observations)
% Datastore holding the original data from which ReducedData was
% extracted.
ImageDatastore
% PredictedClass
% Predicted class labels for images in image datastore
PredictedClass
end
methods
function this = TSNEPlotter(ax, scatterHandles, reducedData, imds, predResults)
this.Axes = ax;
this.ScatterHandles = scatterHandles;
this.ReducedData = reducedData;
this.ImageDatastore = imds;
this.PredictedClass = predResults;
end
function setUpOnClickBehavior(this)
% Separately apply a ButtonDownFcn to each graphics handle that
% needs it.
h = this.ScatterHandles;
for i=1:length(h)
h(i).ButtonDownFcn = @this.handleLineClicked;
end
end
end
methods(Access = private)
function handleLineClicked(this, ~, evt)
% Fired when a line is clicked in ax.
% Find which point in the reduced representation this
% corresponds to, i.e. the index of the observation being
% clicked.
idx = iFindNearestPoint(evt.IntersectionPoint(1:2), this.ReducedData);
imgFilename = this.ImageDatastore.Files{idx};
trueClass = this.ImageDatastore.Labels(idx);
predClass = this.PredictedClass(idx);
this.displayImage(imgFilename, trueClass, predClass);
end
function displayImage(this, imgFilename, trueClass, predClass)
% Display in a new figure and axes, unless an axes is provided.
if isempty(this.ImageDisplayAxes)
ax = axes(figure);
else
ax = this.ImageDisplayAxes;
end
img = imread(imgFilename);
imgTitle = iImageTitle(imgFilename, trueClass, predClass);
imshow(img, "Parent", ax);
title(ax, imgTitle, "Interpreter", "none")
ax.ActivePositionProperty = 'outerposition';
end
end
end
function idx = iFindNearestPoint(xy, reducedData)
% Find the index in the data of the nearest actual point to the user's
% click.
% L2 distance between every reduced datapoint and the user click.
d = reducedData - xy;
distances = sqrt(sum(d.^2, 2));
[~, idx] = min(distances);
end
function imgTitle = iImageTitle(imgFilename, trueClass, predClass)
% Create a title for the image.
[~, name, ext] = fileparts(imgFilename);
imageName = strcat(name, ext);
imgTitle = {strcat("True class: ", string(trueClass), " Predicted class: ", ...
string(predClass)), strcat("File name: ", imageName)};
end
% Copyright 2020 The MathWorks, Inc.