This repository has been archived by the owner on Jan 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathp3.html
94 lines (75 loc) · 10.3 KB
/
p3.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
<!DOCTYPE html>
<html>
<link rel="shortcut icon" href="favicon.ico">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<link rel="stylesheet" href="highlight.css">
<meta name="author" content="sunprinceS" >
<meta property="og:image" content="joy.png"/>
<title>Machine Learning (2017, Spring)</title>
<xmp theme="cerulean" style="display:none;">
# Problem 3: Analyze the Model by Confusion Matrix
Problem Description:
* put the prediction and true label in cofusion matrix of your splited validation data
* describe what you observed
Hint:
* you can pick up some images and record their probability distributions over 7 classes.
## 範例
**[Note] 請不要直接使用助教的圖來當成作業交上來**
![r3](http://i.imgur.com/nTWMqGn.png)
## TA hour
假設已經訓練了一個不錯的模型,將其預測在validation data上。
<i class="fa fa-diamond"></i> Keywords: `sklearn.metrics.confusion_matrix`, `keras.load_model`, `predict_classes`
<div class="highlight"><pre><span></span><span class="ch">#!/usr/bin/env python</span>
<span class="c1"># -*- coding: utf-8 -*-</span>
<span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">load_model</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">confusion_matrix</span>
<span class="kn">from</span> <span class="nn">marcos</span> <span class="kn">import</span> <span class="n">exp_dir</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span>
<span class="k">def</span> <span class="nf">plot_confusion_matrix</span><span class="p">(</span><span class="n">cm</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span>
<span class="n">title</span><span class="o">=</span><span class="s1">'Confusion matrix'</span><span class="p">,</span>
<span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">jet</span><span class="p">):</span>
<span class="sd">"""</span>
<span class="sd"> This function prints and plots the confusion matrix.</span>
<span class="sd"> """</span>
<span class="n">cm</span> <span class="o">=</span> <span class="n">cm</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s1">'float'</span><span class="p">)</span> <span class="o">/</span> <span class="n">cm</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">cm</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cmap</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">title</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="n">tick_marks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">classes</span><span class="p">))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">tick_marks</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">(</span><span class="n">tick_marks</span><span class="p">,</span> <span class="n">classes</span><span class="p">)</span>
<span class="n">thresh</span> <span class="o">=</span> <span class="n">cm</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">/</span> <span class="mf">2.</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">cm</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">range</span><span class="p">(</span><span class="n">cm</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])):</span>
<span class="n">plt</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="s1">'{:.2f}'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">cm</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]),</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s2">"center"</span><span class="p">,</span>
<span class="n">color</span><span class="o">=</span><span class="s2">"white"</span> <span class="k">if</span> <span class="n">cm</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">></span> <span class="n">thresh</span> <span class="k">else</span> <span class="s2">"black"</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'True label'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'Predicted label'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="n">model_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">exp_dir</span><span class="p">,</span><span class="n">store_path</span><span class="p">,</span><span class="s1">'model.h5'</span><span class="p">)</span>
<span class="n">emotion_classifier</span> <span class="o">=</span> <span class="n">load_model</span><span class="p">(</span><span class="n">model_path</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">precision</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">dev_feats</span> <span class="o">=</span> <span class="n">read_dataset</span><span class="p">(</span><span class="s1">'valid'</span><span class="p">)</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">emotion_classifier</span><span class="o">.</span><span class="n">predict_classes</span><span class="p">(</span><span class="n">dev_feats</span><span class="p">)</span>
<span class="n">te_labels</span> <span class="o">=</span> <span class="n">get_labels</span><span class="p">(</span><span class="s1">'valid'</span><span class="p">)</span>
<span class="n">conf_mat</span> <span class="o">=</span> <span class="n">confusion_matrix</span><span class="p">(</span><span class="n">te_labels</span><span class="p">,</span><span class="n">predictions</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">()</span>
<span class="n">plot_confusion_matrix</span><span class="p">(</span><span class="n">conf_mat</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="p">[</span><span class="s2">"Angry"</span><span class="p">,</span><span class="s2">"Disgust"</span><span class="p">,</span><span class="s2">"Fear"</span><span class="p">,</span><span class="s2">"Happy"</span><span class="p">,</span><span class="s2">"Sad"</span><span class="p">,</span><span class="s2">"Surprise"</span><span class="p">,</span><span class="s2">"Neutral"</span><span class="p">])</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</pre></div>
</xmp> <script src="strapdown.js"></script> </html>
<footer>
<center><a href="./index.html"><i class="fa fa-home"></i></a></center>
<center><i class="fa fa-github"></i></a> Posted by: <a href="https://github.com/sunprinceS/" target="_blank">sunprinceS</a> </center>
<center><i class="fa fa-envelope"></i> Contact information: <a href="mailto:"> [email protected] </a>.</center>
<center><i class="fa fa-mortar-board"></i> Course information: <a href="http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17.html", target="_blank">Machine Learning (2017, Spring) @ National Taiwan University</a>.</center>
</footer>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-59748795-2', 'auto');
ga('send', 'pageview');
</script>
</html>