This repository has been archived by the owner on Jan 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathp4.html
136 lines (108 loc) · 17.2 KB
/
p4.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
<!DOCTYPE html>
<html>
<link rel="shortcut icon" href="favicon.ico">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<link rel="stylesheet" href="highlight.css">
<meta name="author" content="sunprinceS" >
<meta property="og:image" content="joy.png"/>
<title>Machine Learning (2017, Spring)</title>
<xmp theme="cerulean" style="display:none;">
# Problem 4: Analyze the Model by Plotting the Saliency Map
Problem Description:
* Given an image and its corresponding class, we would like to rank the pixels of original image based on their influence on the distribution of final output
* Use your trained CNN, get the gradient of input image and plot it, or you can use the other method mentioned in class to plot the saliency map
## 範例
* **[Note] 請不要直接使用助教的圖來當成作業交上來**
* **[Note] 請不要使用這張範例圖**
<img src="17.png" alt="Drawing" style="width: 200px;"/>
#### Heatmap
<center>原圖</center>| <center> Saliency Map</center> | <center>Mask掉heat小的部份</center>
:-------------------------:|:-------------------------:|:------------------------:
<img src="17.png" alt="Drawing" style="width: 400px;"/>|<img src="17-cmap.png" alt="Drawing" style="width: 400px;"/>|<img src="17-psee.png" alt="Drawing" style="width: 400px;"/>
## TA hour
<i class="fa fa-diamond"></i> Keywords: `keras.backend`, `gradients`
<div class="highlight"><pre><span></span><span class="ch">#!/usr/bin/env python</span>
<span class="c1"># -*- coding: utf-8 -*-</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">argparse</span>
<span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">load\_model</span>
<span class="kn">from</span> <span class="nn">termcolor</span> <span class="kn">import</span> <span class="n">colored</span><span class="p">,</span><span class="n">cprint</span>
<span class="kn">import</span> <span class="nn">keras.backend</span> <span class="kn">as</span> <span class="nn">K</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span>
<span class="n">base\_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">realpath</span><span class="p">(</span><span class="vm">\_\_file\_\_</span><span class="p">)))</span>
<span class="n">img\_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">base\_dir</span><span class="p">,</span> <span class="s1">'image'</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">img\_dir</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">img\_dir</span><span class="p">)</span>
<span class="n">cmap\_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">img\_dir</span><span class="p">,</span> <span class="s1">'cmap'</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">cmap\_dir</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">cmap\_dir</span><span class="p">)</span>
<span class="n">partial\_see\_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">img\_dir</span><span class="p">,</span><span class="s1">'partial\_see'</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">partial\_see\_dir</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">partial\_see\_dir</span><span class="p">)</span>
<span class="n">model\_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">base\_dir</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">prog</span><span class="o">=</span><span class="s1">'plot\_saliency.py'</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s1">'ML-Assignment3 visualize attention heat map.'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add\_argument</span><span class="p">(</span><span class="s1">'--epoch'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'<#epoch>'</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">80</span><span class="p">)</span>
<span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse\_args</span><span class="p">()</span>
<span class="n">model\_name</span> <span class="o">=</span> <span class="s2">"model-</span><span class="si">%s</span><span class="s2">.h5"</span> <span class="o">%</span><span class="nb">str</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">epoch</span><span class="p">)</span>
<span class="n">model\_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">model\_dir</span><span class="p">,</span> <span class="n">model\_name</span><span class="p">)</span>
<span class="n">emotion\_classifier</span> <span class="o">=</span> <span class="n">load\_model</span><span class="p">(</span><span class="n">model\_path</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">colored</span><span class="p">(</span><span class="s2">"Loaded model from {}"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">model\_name</span><span class="p">),</span> <span class="s1">'yellow'</span><span class="p">,</span> <span class="n">attrs</span><span class="o">=</span><span class="p">[</span><span class="s1">'bold'</span><span class="p">]))</span>
<span class="n">private\_pixels</span> <span class="o">=</span> <span class="n">load\_pickle</span><span class="p">(</span><span class="s1">'../fer2013/privateTest\_pixels.pkl'</span><span class="p">)</span>
<span class="n">private\_pixels</span> <span class="o">=</span> <span class="p">[</span> <span class="n">np</span><span class="o">.</span><span class="n">fromstring</span><span class="p">(</span><span class="n">private\_pixels</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s1">' '</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">private\_pixels</span><span class="p">))</span> <span class="p">]</span>
<span class="n">input\_img</span> <span class="o">=</span> <span class="n">emotion\_classifier</span><span class="o">.</span><span class="n">input</span>
<span class="n">img\_ids</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"image ids from which you want to make heatmaps"</span><span class="p">]</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="n">img\_ids</span><span class="p">:</span>
<span class="n">val\_proba</span> <span class="o">=</span> <span class="n">emotion\_classifier</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">private\_pixels</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">val\_proba</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">target</span> <span class="o">=</span> <span class="n">K</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">emotion\_classifier</span><span class="o">.</span><span class="n">output</span><span class="p">[:,</span> <span class="n">pred</span><span class="p">])</span>
<span class="n">grads</span> <span class="o">=</span> <span class="n">K</span><span class="o">.</span><span class="n">gradients</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">input\_img</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">fn</span> <span class="o">=</span> <span class="n">K</span><span class="o">.</span><span class="n">function</span><span class="p">([</span><span class="n">input\_img</span><span class="p">,</span> <span class="n">K</span><span class="o">.</span><span class="n">learning\_phase</span><span class="p">()],</span> <span class="p">[</span><span class="n">grads</span><span class="p">])</span>
<span class="n">heatmap</span> <span class="o">=</span> <span class="bp">None</span>
<span class="sd">'''</span>
<span class="sd"> Implement your heatmap processing here!</span>
<span class="sd"> hint: Do some normalization or smoothening on grads</span>
<span class="sd"> '''</span>
<span class="n">thres</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="n">see</span> <span class="o">=</span> <span class="n">private\_pixels</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">48</span><span class="p">,</span> <span class="mi">48</span><span class="p">)</span>
<span class="n">see</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">heatmap</span> <span class="o"><=</span> <span class="n">thres</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">see</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">heatmap</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">jet</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">tight\_layout</span><span class="p">()</span>
<span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">gcf</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">draw</span><span class="p">()</span>
<span class="n">fig</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">cmap\_dir</span><span class="p">,</span> <span class="s1">'privateTest'</span><span class="p">,</span> <span class="s1">'{}.png'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">idx</span><span class="p">)),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">see</span><span class="p">,</span><span class="n">cmap</span><span class="o">=</span><span class="s1">'gray'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">tight\_layout</span><span class="p">()</span>
<span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">gcf</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">draw</span><span class="p">()</span>
<span class="n">fig</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">partial\_see\_dir</span><span class="p">,</span> <span class="s1">'privateTest'</span><span class="p">,</span> <span class="s1">'{}.png'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">idx</span><span class="p">)),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="k">if</span> <span class="vm">\_\_name\_\_</span> <span class="o">==</span> <span class="s2">"\_\_main\_\_"</span><span class="p">:</span>
<span class="n">main</span><span class="p">()</span>
</pre></div>
## Reference
→ [Deep Inside CNN: Visualizing Image Classification Models and Saliency Maps](https://arxiv.org/pdf/1312.6034.pdf)
</xmp>
</xmp> <script src="strapdown.js"></script> </html>
<footer>
<center><a href="./index.html"><i class="fa fa-home"></i></a></center>
<center><i class="fa fa-github"></i></a> Posted by: <a href="https://github.com/sunprinceS/" target="_blank">sunprinceS</a> </center>
<center><i class="fa fa-envelope"></i> Contact information: <a href="mailto:"> [email protected] </a>.</center>
<center><i class="fa fa-mortar-board"></i> Course information: <a href="http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17.html", target="_blank">Machine Learning (2017, Spring) @ National Taiwan University</a>.</center>
</footer>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-59748795-2', 'auto');
ga('send', 'pageview');
</script>
</html>