Commit a8fe22801cb7abdca543ea6b592a7b058475d5de
1 parent
87c8a521
Short tutorial how to evaluate test examples
Showing
1 changed file
with
188 additions
and
0 deletions
for_investigation.ipynb
0 → 100644
1 | +{ | |
2 | + "cells": [ | |
3 | + { | |
4 | + "cell_type": "code", | |
5 | + "execution_count": 1, | |
6 | + "metadata": {}, | |
7 | + "outputs": [ | |
8 | + { | |
9 | + "name": "stderr", | |
10 | + "output_type": "stream", | |
11 | + "text": [ | |
12 | + "Using TensorFlow backend.\n" | |
13 | + ] | |
14 | + } | |
15 | + ], | |
16 | + "source": [ | |
17 | + "from keras.models import Model\n", | |
18 | + "from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization\n", | |
19 | + "from keras.optimizers import SGD, Adam\n", | |
20 | + "import numpy as np" | |
21 | + ] | |
22 | + }, | |
23 | + { | |
24 | + "cell_type": "code", | |
25 | + "execution_count": 2, | |
26 | + "metadata": { | |
27 | + "collapsed": true | |
28 | + }, | |
29 | + "outputs": [], | |
30 | + "source": [ | |
31 | + "filename = 'test_set.csv'\n", | |
32 | + "raw_data = open(filename, 'rt')\n", | |
33 | + "test_data = np.loadtxt(raw_data, delimiter= '\\t')" | |
34 | + ] | |
35 | + }, | |
36 | + { | |
37 | + "cell_type": "code", | |
38 | + "execution_count": 3, | |
39 | + "metadata": { | |
40 | + "collapsed": true | |
41 | + }, | |
42 | + "outputs": [], | |
43 | + "source": [ | |
44 | + "number_of_features = 1126\n", | |
45 | + "test_set = test_data[:,0:1126]\n", | |
46 | + "test_labels = test_data[:,1126] #last column consists of labels" | |
47 | + ] | |
48 | + }, | |
49 | + { | |
50 | + "cell_type": "markdown", | |
51 | + "metadata": {}, | |
52 | + "source": [ | |
53 | + "# Neural network configuration" | |
54 | + ] | |
55 | + }, | |
56 | + { | |
57 | + "cell_type": "code", | |
58 | + "execution_count": 4, | |
59 | + "metadata": { | |
60 | + "collapsed": true | |
61 | + }, | |
62 | + "outputs": [], | |
63 | + "source": [ | |
64 | + "inputs = Input(shape=(number_of_features,))\n", | |
65 | + "output_from_1st_layer = Dense(1000, activation='relu')(inputs)\n", | |
66 | + "output_from_1st_layer = Dropout(0.5)(output_from_1st_layer)\n", | |
67 | + "output_from_1st_layer = BatchNormalization()(output_from_1st_layer)\n", | |
68 | + "output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer)\n", | |
69 | + "output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer)\n", | |
70 | + "output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer)\n", | |
71 | + "output = Dense(1, activation='sigmoid')(output_from_2nd_layer)\n", | |
72 | + "\n", | |
73 | + "model = Model(inputs, output)\n", | |
74 | + "model.compile(optimizer='Adam',loss='binary_crossentropy',metrics=['accuracy'])" | |
75 | + ] | |
76 | + }, | |
77 | + { | |
78 | + "cell_type": "markdown", | |
79 | + "metadata": {}, | |
80 | + "source": [ | |
81 | + "Let's load weights learnt earlier" | |
82 | + ] | |
83 | + }, | |
84 | + { | |
85 | + "cell_type": "code", | |
86 | + "execution_count": 5, | |
87 | + "metadata": { | |
88 | + "collapsed": true | |
89 | + }, | |
90 | + "outputs": [], | |
91 | + "source": [ | |
92 | + "model.load_weights(\"weights_2017_05_10.h5\")" | |
93 | + ] | |
94 | + }, | |
95 | + { | |
96 | + "cell_type": "markdown", | |
97 | + "metadata": {}, | |
98 | + "source": [ | |
99 | + "# Evaluation" | |
100 | + ] | |
101 | + }, | |
102 | + { | |
103 | + "cell_type": "markdown", | |
104 | + "metadata": {}, | |
105 | + "source": [ | |
106 | + "First, calculate predictions for test set" | |
107 | + ] | |
108 | + }, | |
109 | + { | |
110 | + "cell_type": "code", | |
111 | + "execution_count": 6, | |
112 | + "metadata": { | |
113 | + "collapsed": true | |
114 | + }, | |
115 | + "outputs": [], | |
116 | + "source": [ | |
117 | + " predictions = model.predict(test_set)" | |
118 | + ] | |
119 | + }, | |
120 | + { | |
121 | + "cell_type": "markdown", | |
122 | + "metadata": {}, | |
123 | + "source": [ | |
124 | + "Now we can calculate basic metrics" | |
125 | + ] | |
126 | + }, | |
127 | + { | |
128 | + "cell_type": "code", | |
129 | + "execution_count": 7, | |
130 | + "metadata": {}, | |
131 | + "outputs": [ | |
132 | + { | |
133 | + "name": "stdout", | |
134 | + "output_type": "stream", | |
135 | + "text": [ | |
136 | + "Accuracy:0.7316259444607988\n", | |
137 | + "Precision: 0.7378337531486147\n", | |
138 | + "Recall: 0.7185752134236091\n", | |
139 | + "F1: 0.7280771525154107\n" | |
140 | + ] | |
141 | + } | |
142 | + ], | |
143 | + "source": [ | |
144 | + " true_positives = 0.0\n", | |
145 | + " false_positives = 0.0\n", | |
146 | + " true_negatives = 0.0\n", | |
147 | + " false_negatives = 0.0\n", | |
148 | + "\n", | |
149 | + " for i in range(len(test_set)):\n", | |
150 | + " if (predictions[i]<0.5 and test_labels[i]==0): true_negatives += 1 \n", | |
151 | + " if (predictions[i]<0.5 and test_labels[i]==1): false_negatives += 1\n", | |
152 | + " if (predictions[i]>=0.5 and test_labels[i]==1): true_positives += 1\n", | |
153 | + " if (predictions[i]>=0.5 and test_labels[i]==0): false_positives += 1 \n", | |
154 | + " \n", | |
155 | + " accuracy = (true_positives+true_negatives)/len(test_set)\n", | |
156 | + " precision = true_positives/(true_positives+false_positives)\n", | |
157 | + " recall = true_positives/(true_positives+false_negatives)\n", | |
158 | + " f1 = 2*(precision*recall)/(precision+recall)\n", | |
159 | + "\n", | |
160 | + " print ('Accuracy:' + repr(accuracy))\n", | |
161 | + " print ('Precision: ' + repr(precision))\n", | |
162 | + " print ('Recall: ' + repr(recall))\n", | |
163 | + " print ('F1: ' + repr(f1))" | |
164 | + ] | |
165 | + } | |
166 | + ], | |
167 | + "metadata": { | |
168 | + "kernelspec": { | |
169 | + "display_name": "Python 2", | |
170 | + "language": "python", | |
171 | + "name": "python2" | |
172 | + }, | |
173 | + "language_info": { | |
174 | + "codemirror_mode": { | |
175 | + "name": "ipython", | |
176 | + "version": 2 | |
177 | + }, | |
178 | + "file_extension": ".py", | |
179 | + "mimetype": "text/x-python", | |
180 | + "name": "python", | |
181 | + "nbconvert_exporter": "python", | |
182 | + "pygments_lexer": "ipython2", | |
183 | + "version": "2.7.6" | |
184 | + } | |
185 | + }, | |
186 | + "nbformat": 4, | |
187 | + "nbformat_minor": 2 | |
188 | +} | |
... | ... |