Commit 07ecdd418fac471bb15d7286c7187931b111c0d5
1 parent
c11a89c8
A neural network model of a binary classifier for mention pairs
Showing
1 changed file
with
285 additions
and
0 deletions
mention-pair-classifier.ipynb
0 → 100644
1 | +{ | |
2 | + "cells": [ | |
3 | + { | |
4 | + "cell_type": "code", | |
5 | + "execution_count": null, | |
6 | + "metadata": { | |
7 | + "collapsed": false, | |
8 | + "deletable": true, | |
9 | + "editable": true | |
10 | + }, | |
11 | + "outputs": [], | |
12 | + "source": [ | |
13 | + "from keras.models import Model\n", | |
14 | + "from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization\n", | |
15 | + "from keras.optimizers import SGD, Adam\n", | |
16 | + "import numpy as np" | |
17 | + ] | |
18 | + }, | |
19 | + { | |
20 | + "cell_type": "markdown", | |
21 | + "metadata": { | |
22 | + "deletable": true, | |
23 | + "editable": true | |
24 | + }, | |
25 | + "source": [ | |
26 | + "# Data preparation" | |
27 | + ] | |
28 | + }, | |
29 | + { | |
30 | + "cell_type": "code", | |
31 | + "execution_count": null, | |
32 | + "metadata": { | |
33 | + "collapsed": true, | |
34 | + "deletable": true, | |
35 | + "editable": true | |
36 | + }, | |
37 | + "outputs": [], | |
38 | + "source": [ | |
39 | + "filename = 'input_data.csv'\n", | |
40 | + "raw_data = open(filename, 'rt')\n", | |
41 | + "data = np.loadtxt(raw_data, delimiter= '\\t')" | |
42 | + ] | |
43 | + }, | |
44 | + { | |
45 | + "cell_type": "code", | |
46 | + "execution_count": null, | |
47 | + "metadata": { | |
48 | + "collapsed": false, | |
49 | + "deletable": true, | |
50 | + "editable": true | |
51 | + }, | |
52 | + "outputs": [], | |
53 | + "source": [ | |
54 | + "print data.shape" | |
55 | + ] | |
56 | + }, | |
57 | + { | |
58 | + "cell_type": "markdown", | |
59 | + "metadata": { | |
60 | + "deletable": true, | |
61 | + "editable": true | |
62 | + }, | |
63 | + "source": [ | |
64 | + "Our dataset consists of ~466K examples (pairs of mentions), each example described by 1126 features. Labels say whether a pair belongs to the same cluster (1) or not (0)." | |
65 | + ] | |
66 | + }, | |
67 | + { | |
68 | + "cell_type": "code", | |
69 | + "execution_count": null, | |
70 | + "metadata": { | |
71 | + "collapsed": false, | |
72 | + "deletable": true, | |
73 | + "editable": true | |
74 | + }, | |
75 | + "outputs": [], | |
76 | + "source": [ | |
77 | + "size_of_dataset = len(data)\n", | |
78 | + "number_of_features = 1126\n", | |
79 | + "\n", | |
80 | + "X = data[:,0:1126]\n", | |
81 | + "Y = data[:,1126] #last column consists of labels\n" | |
82 | + ] | |
83 | + }, | |
84 | + { | |
85 | + "cell_type": "markdown", | |
86 | + "metadata": { | |
87 | + "deletable": true, | |
88 | + "editable": true | |
89 | + }, | |
90 | + "source": [ | |
91 | + "Now let's split data into trainig and test set (90/10)" | |
92 | + ] | |
93 | + }, | |
94 | + { | |
95 | + "cell_type": "code", | |
96 | + "execution_count": null, | |
97 | + "metadata": { | |
98 | + "collapsed": true, | |
99 | + "deletable": true, | |
100 | + "editable": true | |
101 | + }, | |
102 | + "outputs": [], | |
103 | + "source": [ | |
104 | + "np.random.seed(999) #seed fixed for reproducibility\n", | |
105 | + "mask = np.random.rand(size_of_dataset) < 0.9 #array of boolean variables\n", | |
106 | + "\n", | |
107 | + "training_set = X[mask]\n", | |
108 | + "training_labels = Y[mask]\n", | |
109 | + "\n", | |
110 | + "test_set = X[~mask]\n", | |
111 | + "test_labels = Y[~mask]" | |
112 | + ] | |
113 | + }, | |
114 | + { | |
115 | + "cell_type": "markdown", | |
116 | + "metadata": { | |
117 | + "deletable": true, | |
118 | + "editable": true | |
119 | + }, | |
120 | + "source": [ | |
121 | + "# Neural network configuration" | |
122 | + ] | |
123 | + }, | |
124 | + { | |
125 | + "cell_type": "code", | |
126 | + "execution_count": null, | |
127 | + "metadata": { | |
128 | + "collapsed": false, | |
129 | + "deletable": true, | |
130 | + "editable": true | |
131 | + }, | |
132 | + "outputs": [], | |
133 | + "source": [ | |
134 | + "inputs = Input(shape=(number_of_features,))\n", | |
135 | + "output_from_1st_layer = Dense(1000, activation='relu')(inputs)\n", | |
136 | + "output_from_1st_layer = Dropout(0.5)(output_from_1st_layer)\n", | |
137 | + "output_from_1st_layer = BatchNormalization()(output_from_1st_layer)\n", | |
138 | + "output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer)\n", | |
139 | + "output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer)\n", | |
140 | + "output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer)\n", | |
141 | + "output = Dense(1, activation='sigmoid')(output_from_2nd_layer)\n", | |
142 | + "\n", | |
143 | + "model = Model(inputs, output)\n", | |
144 | + "model.compile(optimizer='Adam',loss='binary_crossentropy',metrics=['accuracy'])" | |
145 | + ] | |
146 | + }, | |
147 | + { | |
148 | + "cell_type": "markdown", | |
149 | + "metadata": {}, | |
150 | + "source": [ | |
151 | + "# Training" | |
152 | + ] | |
153 | + }, | |
154 | + { | |
155 | + "cell_type": "code", | |
156 | + "execution_count": null, | |
157 | + "metadata": { | |
158 | + "collapsed": false, | |
159 | + "deletable": true, | |
160 | + "editable": true | |
161 | + }, | |
162 | + "outputs": [], | |
163 | + "source": [ | |
164 | + "model.fit(training_set, training_labels, batch_size=256, nb_epoch=25)" | |
165 | + ] | |
166 | + }, | |
167 | + { | |
168 | + "cell_type": "markdown", | |
169 | + "metadata": { | |
170 | + "collapsed": false, | |
171 | + "deletable": true, | |
172 | + "editable": true | |
173 | + }, | |
174 | + "source": [ | |
175 | + "# Evaluation" | |
176 | + ] | |
177 | + }, | |
178 | + { | |
179 | + "cell_type": "code", | |
180 | + "execution_count": null, | |
181 | + "metadata": { | |
182 | + "collapsed": false, | |
183 | + "deletable": true, | |
184 | + "editable": true, | |
185 | + "scrolled": true | |
186 | + }, | |
187 | + "outputs": [], | |
188 | + "source": [ | |
189 | + "scores = model.evaluate(test_set, test_labels)\n", | |
190 | + "print(\"%s: %.2f%%\" % (model.metrics_names[1], scores[1]*100))" | |
191 | + ] | |
192 | + }, | |
193 | + { | |
194 | + "cell_type": "markdown", | |
195 | + "metadata": {}, | |
196 | + "source": [ | |
197 | + "# Playing with the model" | |
198 | + ] | |
199 | + }, | |
200 | + { | |
201 | + "cell_type": "markdown", | |
202 | + "metadata": { | |
203 | + "deletable": true, | |
204 | + "editable": true | |
205 | + }, | |
206 | + "source": [ | |
207 | + "You can save the weights of the model to a file and later recreate the model without training by model.load_weights(\"my_weights.h5\")" | |
208 | + ] | |
209 | + }, | |
210 | + { | |
211 | + "cell_type": "code", | |
212 | + "execution_count": null, | |
213 | + "metadata": { | |
214 | + "collapsed": true, | |
215 | + "deletable": true, | |
216 | + "editable": true | |
217 | + }, | |
218 | + "outputs": [], | |
219 | + "source": [ | |
220 | + "model.save_weights(\"my_weights.h5\")" | |
221 | + ] | |
222 | + }, | |
223 | + { | |
224 | + "cell_type": "markdown", | |
225 | + "metadata": {}, | |
226 | + "source": [ | |
227 | + "To have predictions for a test set we do" | |
228 | + ] | |
229 | + }, | |
230 | + { | |
231 | + "cell_type": "code", | |
232 | + "execution_count": null, | |
233 | + "metadata": { | |
234 | + "collapsed": false, | |
235 | + "deletable": true, | |
236 | + "editable": true | |
237 | + }, | |
238 | + "outputs": [], | |
239 | + "source": [ | |
240 | + "predictions = model.predict(test_set)" | |
241 | + ] | |
242 | + }, | |
243 | + { | |
244 | + "cell_type": "markdown", | |
245 | + "metadata": {}, | |
246 | + "source": [ | |
247 | + "and for a single example" | |
248 | + ] | |
249 | + }, | |
250 | + { | |
251 | + "cell_type": "code", | |
252 | + "execution_count": null, | |
253 | + "metadata": { | |
254 | + "collapsed": false | |
255 | + }, | |
256 | + "outputs": [], | |
257 | + "source": [ | |
258 | + "single_example = test_set[4:5,:] #example number 5 from the test set\n", | |
259 | + "prediction = model.predict(single_example)\n", | |
260 | + "print '%.8f' % prediction[0]" | |
261 | + ] | |
262 | + } | |
263 | + ], | |
264 | + "metadata": { | |
265 | + "kernelspec": { | |
266 | + "display_name": "Python 2", | |
267 | + "language": "python", | |
268 | + "name": "python2" | |
269 | + }, | |
270 | + "language_info": { | |
271 | + "codemirror_mode": { | |
272 | + "name": "ipython", | |
273 | + "version": 2 | |
274 | + }, | |
275 | + "file_extension": ".py", | |
276 | + "mimetype": "text/x-python", | |
277 | + "name": "python", | |
278 | + "nbconvert_exporter": "python", | |
279 | + "pygments_lexer": "ipython2", | |
280 | + "version": "2.7.10" | |
281 | + } | |
282 | + }, | |
283 | + "nbformat": 4, | |
284 | + "nbformat_minor": 2 | |
285 | +} | |
... | ... |