Commit 3634f1f2dfd8df37bc3f0fcd6e094c9eddfa81b2
1 parent
13ac1d32
code cleanup
Showing
1 changed file
with
54 additions
and
835 deletions
TrainingAndEval.ipynb
@@ -2,23 +2,10 @@ | @@ -2,23 +2,10 @@ | ||
2 | "cells": [ | 2 | "cells": [ |
3 | { | 3 | { |
4 | "cell_type": "code", | 4 | "cell_type": "code", |
5 | - "execution_count": 1, | 5 | + "execution_count": null, |
6 | "id": "97d0c9ab", | 6 | "id": "97d0c9ab", |
7 | "metadata": {}, | 7 | "metadata": {}, |
8 | - "outputs": [ | ||
9 | - { | ||
10 | - "name": "stderr", | ||
11 | - "output_type": "stream", | ||
12 | - "text": [ | ||
13 | - "2023-04-11 11:17:29.095631: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", | ||
14 | - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", | ||
15 | - "2023-04-11 11:17:29.331444: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", | ||
16 | - "2023-04-11 11:17:30.167497: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", | ||
17 | - "2023-04-11 11:17:30.167593: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", | ||
18 | - "2023-04-11 11:17:30.167603: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" | ||
19 | - ] | ||
20 | - } | ||
21 | - ], | 8 | + "outputs": [], |
22 | "source": [ | 9 | "source": [ |
23 | "import importlib\n", | 10 | "import importlib\n", |
24 | "\n", | 11 | "\n", |
@@ -40,7 +27,7 @@ | @@ -40,7 +27,7 @@ | ||
40 | }, | 27 | }, |
41 | { | 28 | { |
42 | "cell_type": "code", | 29 | "cell_type": "code", |
43 | - "execution_count": 2, | 30 | + "execution_count": null, |
44 | "id": "c41d6630", | 31 | "id": "c41d6630", |
45 | "metadata": {}, | 32 | "metadata": {}, |
46 | "outputs": [], | 33 | "outputs": [], |
@@ -51,36 +38,10 @@ | @@ -51,36 +38,10 @@ | ||
51 | }, | 38 | }, |
52 | { | 39 | { |
53 | "cell_type": "code", | 40 | "cell_type": "code", |
54 | - "execution_count": 3, | 41 | + "execution_count": null, |
55 | "id": "f30d7b7c", | 42 | "id": "f30d7b7c", |
56 | "metadata": {}, | 43 | "metadata": {}, |
57 | - "outputs": [ | ||
58 | - { | ||
59 | - "name": "stdout", | ||
60 | - "output_type": "stream", | ||
61 | - "text": [ | ||
62 | - "1 Physical GPUs, 1 Logical GPUs\n" | ||
63 | - ] | ||
64 | - }, | ||
65 | - { | ||
66 | - "name": "stderr", | ||
67 | - "output_type": "stream", | ||
68 | - "text": [ | ||
69 | - "2023-04-11 11:17:31.717262: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
70 | - "2023-04-11 11:17:31.762533: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
71 | - "2023-04-11 11:17:31.763529: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
72 | - "2023-04-11 11:17:31.765670: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", | ||
73 | - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", | ||
74 | - "2023-04-11 11:17:31.769196: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
75 | - "2023-04-11 11:17:31.770058: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
76 | - "2023-04-11 11:17:31.770816: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
77 | - "2023-04-11 11:17:32.722287: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
78 | - "2023-04-11 11:17:32.723281: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
79 | - "2023-04-11 11:17:32.724062: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
80 | - "2023-04-11 11:17:32.724846: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20480 MB memory: -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:00:05.0, compute capability: 8.0\n" | ||
81 | - ] | ||
82 | - } | ||
83 | - ], | 44 | + "outputs": [], |
84 | "source": [ | 45 | "source": [ |
85 | "# https://www.tensorflow.org/guide/gpu\n", | 46 | "# https://www.tensorflow.org/guide/gpu\n", |
86 | "gpus = tf.config.list_physical_devices('GPU')\n", | 47 | "gpus = tf.config.list_physical_devices('GPU')\n", |
@@ -98,33 +59,12 @@ | @@ -98,33 +59,12 @@ | ||
98 | }, | 59 | }, |
99 | { | 60 | { |
100 | "cell_type": "code", | 61 | "cell_type": "code", |
101 | - "execution_count": 4, | 62 | + "execution_count": null, |
102 | "id": "89afdb1e", | 63 | "id": "89afdb1e", |
103 | "metadata": { | 64 | "metadata": { |
104 | "scrolled": true | 65 | "scrolled": true |
105 | }, | 66 | }, |
106 | - "outputs": [ | ||
107 | - { | ||
108 | - "name": "stdout", | ||
109 | - "output_type": "stream", | ||
110 | - "text": [ | ||
111 | - "/device:GPU:0\n", | ||
112 | - "2.10.0\n" | ||
113 | - ] | ||
114 | - }, | ||
115 | - { | ||
116 | - "name": "stderr", | ||
117 | - "output_type": "stream", | ||
118 | - "text": [ | ||
119 | - "2023-04-11 11:17:32.739308: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
120 | - "2023-04-11 11:17:32.740224: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
121 | - "2023-04-11 11:17:32.740975: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
122 | - "2023-04-11 11:17:32.741809: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
123 | - "2023-04-11 11:17:32.742586: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", | ||
124 | - "2023-04-11 11:17:32.743322: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /device:GPU:0 with 20480 MB memory: -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:00:05.0, compute capability: 8.0\n" | ||
125 | - ] | ||
126 | - } | ||
127 | - ], | 67 | + "outputs": [], |
128 | "source": [ | 68 | "source": [ |
129 | "print(tf.test.gpu_device_name())\n", | 69 | "print(tf.test.gpu_device_name())\n", |
130 | "print(tf.__version__)" | 70 | "print(tf.__version__)" |
@@ -132,63 +72,22 @@ | @@ -132,63 +72,22 @@ | ||
132 | }, | 72 | }, |
133 | { | 73 | { |
134 | "cell_type": "code", | 74 | "cell_type": "code", |
135 | - "execution_count": 5, | 75 | + "execution_count": null, |
136 | "id": "2b0ab576", | 76 | "id": "2b0ab576", |
137 | "metadata": {}, | 77 | "metadata": {}, |
138 | - "outputs": [ | ||
139 | - { | ||
140 | - "name": "stderr", | ||
141 | - "output_type": "stream", | ||
142 | - "text": [ | ||
143 | - "Found cached dataset pdb_c_beta (/home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1)\n" | ||
144 | - ] | ||
145 | - }, | ||
146 | - { | ||
147 | - "data": { | ||
148 | - "application/vnd.jupyter.widget-view+json": { | ||
149 | - "model_id": "55f181333dc44c7a811c515cc55c4988", | ||
150 | - "version_major": 2, | ||
151 | - "version_minor": 0 | ||
152 | - }, | ||
153 | - "text/plain": [ | ||
154 | - " 0%| | 0/3 [00:00<?, ?it/s]" | ||
155 | - ] | ||
156 | - }, | ||
157 | - "metadata": {}, | ||
158 | - "output_type": "display_data" | ||
159 | - } | ||
160 | - ], | 78 | + "outputs": [], |
161 | "source": [ | 79 | "source": [ |
162 | "pdbc_dataset = load_dataset('pdb_c_beta')" | 80 | "pdbc_dataset = load_dataset('pdb_c_beta')" |
163 | ] | 81 | ] |
164 | }, | 82 | }, |
165 | { | 83 | { |
166 | "cell_type": "code", | 84 | "cell_type": "code", |
167 | - "execution_count": 6, | 85 | + "execution_count": null, |
168 | "id": "2f4c317a", | 86 | "id": "2f4c317a", |
169 | "metadata": { | 87 | "metadata": { |
170 | "scrolled": true | 88 | "scrolled": true |
171 | }, | 89 | }, |
172 | - "outputs": [ | ||
173 | - { | ||
174 | - "name": "stderr", | ||
175 | - "output_type": "stream", | ||
176 | - "text": [ | ||
177 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-ff2490f308f7f25b.arrow\n", | ||
178 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-cbb40b0e978ab6ee.arrow\n", | ||
179 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-3facbd810991cd6c.arrow\n", | ||
180 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-e54a8628e59de21f.arrow\n", | ||
181 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-9692de6b8224e758.arrow\n", | ||
182 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-4042ffa1dc5d9323.arrow\n", | ||
183 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-fb250709424f85ec.arrow\n", | ||
184 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-1f6ce0a488a89d56.arrow\n", | ||
185 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-2ae4daf5101c7aa2.arrow\n", | ||
186 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-a1686820d15bcf04.arrow\n", | ||
187 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-fe2c12481861f4bd.arrow\n", | ||
188 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-da5a875c385c3570.arrow\n" | ||
189 | - ] | ||
190 | - } | ||
191 | - ], | 90 | + "outputs": [], |
192 | "source": [ | 91 | "source": [ |
193 | "import importlib\n", | 92 | "import importlib\n", |
194 | "\n", | 93 | "\n", |
@@ -203,20 +102,10 @@ | @@ -203,20 +102,10 @@ | ||
203 | }, | 102 | }, |
204 | { | 103 | { |
205 | "cell_type": "code", | 104 | "cell_type": "code", |
206 | - "execution_count": 7, | 105 | + "execution_count": null, |
207 | "id": "de1966ed", | 106 | "id": "de1966ed", |
208 | "metadata": {}, | 107 | "metadata": {}, |
209 | - "outputs": [ | ||
210 | - { | ||
211 | - "name": "stderr", | ||
212 | - "output_type": "stream", | ||
213 | - "text": [ | ||
214 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-1dfcf507d62f6da8.arrow\n", | ||
215 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-264c0111246b25c1.arrow\n", | ||
216 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-6a40675124a412f0.arrow\n" | ||
217 | - ] | ||
218 | - } | ||
219 | - ], | 108 | + "outputs": [], |
220 | "source": [ | 109 | "source": [ |
221 | "features = pdbc_dataset_spines['train'].features\n", | 110 | "features = pdbc_dataset_spines['train'].features\n", |
222 | "pdbc_dataset_spines_cont = pdbc_dataset_spines.filter(\n", | 111 | "pdbc_dataset_spines_cont = pdbc_dataset_spines.filter(\n", |
@@ -226,41 +115,17 @@ | @@ -226,41 +115,17 @@ | ||
226 | }, | 115 | }, |
227 | { | 116 | { |
228 | "cell_type": "code", | 117 | "cell_type": "code", |
229 | - "execution_count": 8, | 118 | + "execution_count": null, |
230 | "id": "33ff295b", | 119 | "id": "33ff295b", |
231 | "metadata": {}, | 120 | "metadata": {}, |
232 | - "outputs": [ | ||
233 | - { | ||
234 | - "data": { | ||
235 | - "text/plain": [ | ||
236 | - "DatasetDict({\n", | ||
237 | - " train: Dataset({\n", | ||
238 | - " features: ['corp_id', 'sent_id', 'tokens', 'lemmas', 'cposes', 'poses', 'tags', 'heads', 'deprels', 'nonterminals', 'spines', 'anchors', 'anchor_hs'],\n", | ||
239 | - " num_rows: 15903\n", | ||
240 | - " })\n", | ||
241 | - " validation: Dataset({\n", | ||
242 | - " features: ['corp_id', 'sent_id', 'tokens', 'lemmas', 'cposes', 'poses', 'tags', 'heads', 'deprels', 'nonterminals', 'spines', 'anchors', 'anchor_hs'],\n", | ||
243 | - " num_rows: 1980\n", | ||
244 | - " })\n", | ||
245 | - " test: Dataset({\n", | ||
246 | - " features: ['corp_id', 'sent_id', 'tokens', 'lemmas', 'cposes', 'poses', 'tags', 'heads', 'deprels', 'nonterminals', 'spines', 'anchors', 'anchor_hs'],\n", | ||
247 | - " num_rows: 1990\n", | ||
248 | - " })\n", | ||
249 | - "})" | ||
250 | - ] | ||
251 | - }, | ||
252 | - "execution_count": 8, | ||
253 | - "metadata": {}, | ||
254 | - "output_type": "execute_result" | ||
255 | - } | ||
256 | - ], | 121 | + "outputs": [], |
257 | "source": [ | 122 | "source": [ |
258 | "pdbc_dataset_spines_cont" | 123 | "pdbc_dataset_spines_cont" |
259 | ] | 124 | ] |
260 | }, | 125 | }, |
261 | { | 126 | { |
262 | "cell_type": "code", | 127 | "cell_type": "code", |
263 | - "execution_count": 9, | 128 | + "execution_count": null, |
264 | "id": "a8ddbc1f", | 129 | "id": "a8ddbc1f", |
265 | "metadata": {}, | 130 | "metadata": {}, |
266 | "outputs": [], | 131 | "outputs": [], |
@@ -270,7 +135,7 @@ | @@ -270,7 +135,7 @@ | ||
270 | }, | 135 | }, |
271 | { | 136 | { |
272 | "cell_type": "code", | 137 | "cell_type": "code", |
273 | - "execution_count": 10, | 138 | + "execution_count": null, |
274 | "id": "8029594b", | 139 | "id": "8029594b", |
275 | "metadata": {}, | 140 | "metadata": {}, |
276 | "outputs": [], | 141 | "outputs": [], |
@@ -288,30 +153,24 @@ | @@ -288,30 +153,24 @@ | ||
288 | }, | 153 | }, |
289 | { | 154 | { |
290 | "cell_type": "code", | 155 | "cell_type": "code", |
291 | - "execution_count": 36, | 156 | + "execution_count": null, |
292 | "id": "be8e93fa", | 157 | "id": "be8e93fa", |
293 | "metadata": {}, | 158 | "metadata": {}, |
294 | "outputs": [], | 159 | "outputs": [], |
295 | "source": [ | 160 | "source": [ |
296 | - "def crop(dataset, n):\n", | ||
297 | - " return dataset.filter(lambda example: len(example['tokens']) <= n)\n", | ||
298 | - "\n", | ||
299 | "spines_pdbc = ClassificationTask(\n", | 161 | "spines_pdbc = ClassificationTask(\n", |
300 | " 'spines_pdbc',\n", | 162 | " 'spines_pdbc',\n", |
301 | " pdbc_dataset_spines,\n", | 163 | " pdbc_dataset_spines,\n", |
302 | - " #crop(pdbc_dataset, 6),\n", | ||
303 | ")\n", | 164 | ")\n", |
304 | "\n", | 165 | "\n", |
305 | "spines_pdbc_cont = ClassificationTask(\n", | 166 | "spines_pdbc_cont = ClassificationTask(\n", |
306 | " 'spines_pdbc_cont',\n", | 167 | " 'spines_pdbc_cont',\n", |
307 | " pdbc_dataset_spines_cont,\n", | 168 | " pdbc_dataset_spines_cont,\n", |
308 | - " #crop(pdbc_dataset, 6),\n", | ||
309 | ")\n", | 169 | ")\n", |
310 | "\n", | 170 | "\n", |
311 | "spines_pdbc_compressed = ClassificationTask(\n", | 171 | "spines_pdbc_compressed = ClassificationTask(\n", |
312 | " 'spines_pdbc_compressed',\n", | 172 | " 'spines_pdbc_compressed',\n", |
313 | " pdbc_dataset_spines_compressed,\n", | 173 | " pdbc_dataset_spines_compressed,\n", |
314 | - " #crop(pdbc_dataset, 6),\n", | ||
315 | ")\n", | 174 | ")\n", |
316 | "\n", | 175 | "\n", |
317 | "TASK = spines_pdbc_compressed\n", | 176 | "TASK = spines_pdbc_compressed\n", |
@@ -320,7 +179,7 @@ | @@ -320,7 +179,7 @@ | ||
320 | }, | 179 | }, |
321 | { | 180 | { |
322 | "cell_type": "code", | 181 | "cell_type": "code", |
323 | - "execution_count": 37, | 182 | + "execution_count": null, |
324 | "id": "7824fcee", | 183 | "id": "7824fcee", |
325 | "metadata": {}, | 184 | "metadata": {}, |
326 | "outputs": [], | 185 | "outputs": [], |
@@ -330,56 +189,12 @@ | @@ -330,56 +189,12 @@ | ||
330 | }, | 189 | }, |
331 | { | 190 | { |
332 | "cell_type": "code", | 191 | "cell_type": "code", |
333 | - "execution_count": 38, | 192 | + "execution_count": null, |
334 | "id": "1eb5f41a", | 193 | "id": "1eb5f41a", |
335 | "metadata": { | 194 | "metadata": { |
336 | "scrolled": false | 195 | "scrolled": false |
337 | }, | 196 | }, |
338 | - "outputs": [ | ||
339 | - { | ||
340 | - "name": "stdout", | ||
341 | - "output_type": "stream", | ||
342 | - "text": [ | ||
343 | - "Loading BERT tokenizer...\n" | ||
344 | - ] | ||
345 | - }, | ||
346 | - { | ||
347 | - "name": "stderr", | ||
348 | - "output_type": "stream", | ||
349 | - "text": [ | ||
350 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-49fe5b05228c3588.arrow\n" | ||
351 | - ] | ||
352 | - }, | ||
353 | - { | ||
354 | - "name": "stdout", | ||
355 | - "output_type": "stream", | ||
356 | - "text": [ | ||
357 | - "Preprocessing the dataset for BERT...\n" | ||
358 | - ] | ||
359 | - }, | ||
360 | - { | ||
361 | - "data": { | ||
362 | - "application/vnd.jupyter.widget-view+json": { | ||
363 | - "model_id": "5f108b00fcab4db8a610f24ae03b7308", | ||
364 | - "version_major": 2, | ||
365 | - "version_minor": 0 | ||
366 | - }, | ||
367 | - "text/plain": [ | ||
368 | - " 0%| | 0/2211 [00:00<?, ?ex/s]" | ||
369 | - ] | ||
370 | - }, | ||
371 | - "metadata": {}, | ||
372 | - "output_type": "display_data" | ||
373 | - }, | ||
374 | - { | ||
375 | - "name": "stderr", | ||
376 | - "output_type": "stream", | ||
377 | - "text": [ | ||
378 | - "Loading cached processed dataset at /home/kkrasnowska/.cache/huggingface/datasets/pdb_c_beta/pdb_c_beta/0.2.0/d9c6dc764ae2a3483fa112c6159db4a0342dba8083bdb3b5981c45435b0692e1/cache-b8e2900fbd9615fd.arrow\n", | ||
379 | - "You're using a HerbertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" | ||
380 | - ] | ||
381 | - } | ||
382 | - ], | 197 | + "outputs": [], |
383 | "source": [ | 198 | "source": [ |
384 | "trainer = training.Trainer(\n", | 199 | "trainer = training.Trainer(\n", |
385 | " MODEL,\n", | 200 | " MODEL,\n", |
@@ -398,21 +213,10 @@ | @@ -398,21 +213,10 @@ | ||
398 | }, | 213 | }, |
399 | { | 214 | { |
400 | "cell_type": "code", | 215 | "cell_type": "code", |
401 | - "execution_count": 39, | 216 | + "execution_count": null, |
402 | "id": "276708cc", | 217 | "id": "276708cc", |
403 | "metadata": {}, | 218 | "metadata": {}, |
404 | - "outputs": [ | ||
405 | - { | ||
406 | - "data": { | ||
407 | - "text/plain": [ | ||
408 | - "('keras_fit_logs_spines_pdbc_compressed', 'models_spines_pdbc_compressed')" | ||
409 | - ] | ||
410 | - }, | ||
411 | - "execution_count": 39, | ||
412 | - "metadata": {}, | ||
413 | - "output_type": "execute_result" | ||
414 | - } | ||
415 | - ], | 219 | + "outputs": [], |
416 | "source": [ | 220 | "source": [ |
417 | "log_dir = f'keras_fit_logs_{TASK.name}'\n", | 221 | "log_dir = f'keras_fit_logs_{TASK.name}'\n", |
418 | "model_dir = f'models_{TASK.name}'\n", | 222 | "model_dir = f'models_{TASK.name}'\n", |
@@ -422,51 +226,12 @@ | @@ -422,51 +226,12 @@ | ||
422 | }, | 226 | }, |
423 | { | 227 | { |
424 | "cell_type": "code", | 228 | "cell_type": "code", |
425 | - "execution_count": 40, | 229 | + "execution_count": null, |
426 | "id": "e8ccde06", | 230 | "id": "e8ccde06", |
427 | "metadata": { | 231 | "metadata": { |
428 | "scrolled": false | 232 | "scrolled": false |
429 | }, | 233 | }, |
430 | - "outputs": [ | ||
431 | - { | ||
432 | - "name": "stdout", | ||
433 | - "output_type": "stream", | ||
434 | - "text": [ | ||
435 | - "The tensorboard extension is already loaded. To reload it, use:\n", | ||
436 | - " %reload_ext tensorboard\n", | ||
437 | - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
438 | - "To disable this warning, you can either:\n", | ||
439 | - "\t- Avoid using `tokenizers` before the fork if possible\n", | ||
440 | - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" | ||
441 | - ] | ||
442 | - }, | ||
443 | - { | ||
444 | - "data": { | ||
445 | - "text/html": [ | ||
446 | - "\n", | ||
447 | - " <iframe id=\"tensorboard-frame-83a6a03964d4187a\" width=\"100%\" height=\"800\" frameborder=\"0\">\n", | ||
448 | - " </iframe>\n", | ||
449 | - " <script>\n", | ||
450 | - " (function() {\n", | ||
451 | - " const frame = document.getElementById(\"tensorboard-frame-83a6a03964d4187a\");\n", | ||
452 | - " const url = new URL(\"/\", window.location);\n", | ||
453 | - " const port = 6004;\n", | ||
454 | - " if (port) {\n", | ||
455 | - " url.port = port;\n", | ||
456 | - " }\n", | ||
457 | - " frame.src = url;\n", | ||
458 | - " })();\n", | ||
459 | - " </script>\n", | ||
460 | - " " | ||
461 | - ], | ||
462 | - "text/plain": [ | ||
463 | - "<IPython.core.display.HTML object>" | ||
464 | - ] | ||
465 | - }, | ||
466 | - "metadata": {}, | ||
467 | - "output_type": "display_data" | ||
468 | - } | ||
469 | - ], | 234 | + "outputs": [], |
470 | "source": [ | 235 | "source": [ |
471 | "%load_ext tensorboard\n", | 236 | "%load_ext tensorboard\n", |
472 | "! killall tensorboard\n", | 237 | "! killall tensorboard\n", |
@@ -476,21 +241,12 @@ | @@ -476,21 +241,12 @@ | ||
476 | }, | 241 | }, |
477 | { | 242 | { |
478 | "cell_type": "code", | 243 | "cell_type": "code", |
479 | - "execution_count": 41, | 244 | + "execution_count": null, |
480 | "id": "a5b0da64", | 245 | "id": "a5b0da64", |
481 | "metadata": { | 246 | "metadata": { |
482 | "scrolled": true | 247 | "scrolled": true |
483 | }, | 248 | }, |
484 | - "outputs": [ | ||
485 | - { | ||
486 | - "name": "stdout", | ||
487 | - "output_type": "stream", | ||
488 | - "text": [ | ||
489 | - "CPU times: user 6 ยตs, sys: 1 ยตs, total: 7 ยตs\n", | ||
490 | - "Wall time: 15.7 ยตs\n" | ||
491 | - ] | ||
492 | - } | ||
493 | - ], | 249 | + "outputs": [], |
494 | "source": [ | 250 | "source": [ |
495 | "%%time\n", | 251 | "%%time\n", |
496 | "\n", | 252 | "\n", |
@@ -505,45 +261,10 @@ | @@ -505,45 +261,10 @@ | ||
505 | }, | 261 | }, |
506 | { | 262 | { |
507 | "cell_type": "code", | 263 | "cell_type": "code", |
508 | - "execution_count": 42, | ||
509 | - "id": "e42b2bd4", | ||
510 | - "metadata": {}, | ||
511 | - "outputs": [], | ||
512 | - "source": [ | ||
513 | - "#import importlib\n", | ||
514 | - "#from neural_parser import hybrid_tree_utils\n", | ||
515 | - "#importlib.reload(hybrid_tree_utils)\n", | ||
516 | - "#from neural_parser import data_utils\n", | ||
517 | - "#importlib.reload(data_utils)\n", | ||
518 | - "#from neural_parser import constituency_parser\n", | ||
519 | - "#importlib.reload(constituency_parser)" | ||
520 | - ] | ||
521 | - }, | ||
522 | - { | ||
523 | - "cell_type": "code", | ||
524 | - "execution_count": 43, | 264 | + "execution_count": null, |
525 | "id": "2f65dead", | 265 | "id": "2f65dead", |
526 | "metadata": {}, | 266 | "metadata": {}, |
527 | - "outputs": [ | ||
528 | - { | ||
529 | - "name": "stdout", | ||
530 | - "output_type": "stream", | ||
531 | - "text": [ | ||
532 | - "created 3 classifier(s)\n" | ||
533 | - ] | ||
534 | - }, | ||
535 | - { | ||
536 | - "name": "stderr", | ||
537 | - "output_type": "stream", | ||
538 | - "text": [ | ||
539 | - "Some layers from the model checkpoint at models_spines_pdbc_compressed/model were not used when initializing TFBertForMultiTargetTokenClassification: ['dropout_73']\n", | ||
540 | - "- This IS expected if you are initializing TFBertForMultiTargetTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | ||
541 | - "- This IS NOT expected if you are initializing TFBertForMultiTargetTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||
542 | - "All the layers of TFBertForMultiTargetTokenClassification were initialized from the model checkpoint at models_spines_pdbc_compressed/model.\n", | ||
543 | - "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMultiTargetTokenClassification for predictions without further training.\n" | ||
544 | - ] | ||
545 | - } | ||
546 | - ], | 267 | + "outputs": [], |
547 | "source": [ | 268 | "source": [ |
548 | "if not TRAIN:\n", | 269 | "if not TRAIN:\n", |
549 | " from neural_parser import constituency_parser\n", | 270 | " from neural_parser import constituency_parser\n", |
@@ -552,7 +273,7 @@ | @@ -552,7 +273,7 @@ | ||
552 | }, | 273 | }, |
553 | { | 274 | { |
554 | "cell_type": "code", | 275 | "cell_type": "code", |
555 | - "execution_count": 44, | 276 | + "execution_count": null, |
556 | "id": "24edee79", | 277 | "id": "24edee79", |
557 | "metadata": {}, | 278 | "metadata": {}, |
558 | "outputs": [], | 279 | "outputs": [], |
@@ -565,41 +286,10 @@ | @@ -565,41 +286,10 @@ | ||
565 | }, | 286 | }, |
566 | { | 287 | { |
567 | "cell_type": "code", | 288 | "cell_type": "code", |
568 | - "execution_count": 45, | 289 | + "execution_count": null, |
569 | "id": "4a7cd10b", | 290 | "id": "4a7cd10b", |
570 | "metadata": {}, | 291 | "metadata": {}, |
571 | - "outputs": [ | ||
572 | - { | ||
573 | - "name": "stdout", | ||
574 | - "output_type": "stream", | ||
575 | - "text": [ | ||
576 | - "1/1 [==============================] - 10s 10s/step\n" | ||
577 | - ] | ||
578 | - }, | ||
579 | - { | ||
580 | - "data": { | ||
581 | - "text/plain": [ | ||
582 | - "[(['Miaล', 'em', 'kotka', '.'],\n", | ||
583 | - " {'spines': ['ROOT_S_VP_V', '<EMPTY>', 'NP_N', 'Punct'],\n", | ||
584 | - " 'anchors': ['<ROOT>', 'V', 'S', 'ROOT'],\n", | ||
585 | - " 'anchor_hs': ['<ROOT>', '1', '1', '1']}),\n", | ||
586 | - " (['Wlazล', 'kotek', 'na', 'pลotek', 'i', 'mruga', '.'],\n", | ||
587 | - " {'spines': ['VP_V',\n", | ||
588 | - " 'NP_N',\n", | ||
589 | - " 'PrepNP_Prep',\n", | ||
590 | - " 'NP_N',\n", | ||
591 | - " 'ROOT_S_VP_Conj',\n", | ||
592 | - " 'VP_V',\n", | ||
593 | - " 'Punct'],\n", | ||
594 | - " 'anchors': ['VP', 'S', 'VP', 'PrepNP', '<ROOT>', 'VP', 'ROOT'],\n", | ||
595 | - " 'anchor_hs': ['1', '1', '2', '1', '<ROOT>', '1', '1']})]" | ||
596 | - ] | ||
597 | - }, | ||
598 | - "execution_count": 45, | ||
599 | - "metadata": {}, | ||
600 | - "output_type": "execute_result" | ||
601 | - } | ||
602 | - ], | 292 | + "outputs": [], |
603 | "source": [ | 293 | "source": [ |
604 | "parser.parse(sentences)" | 294 | "parser.parse(sentences)" |
605 | ] | 295 | ] |
@@ -616,21 +306,10 @@ | @@ -616,21 +306,10 @@ | ||
616 | }, | 306 | }, |
617 | { | 307 | { |
618 | "cell_type": "code", | 308 | "cell_type": "code", |
619 | - "execution_count": 46, | 309 | + "execution_count": null, |
620 | "id": "4ac4b9df", | 310 | "id": "4ac4b9df", |
621 | "metadata": {}, | 311 | "metadata": {}, |
622 | - "outputs": [ | ||
623 | - { | ||
624 | - "data": { | ||
625 | - "text/plain": [ | ||
626 | - "<module 'neural_parser.constants' from '/home/kkrasnowska/neural-parsing/ICCS/neural_parser/constants.py'>" | ||
627 | - ] | ||
628 | - }, | ||
629 | - "execution_count": 46, | ||
630 | - "metadata": {}, | ||
631 | - "output_type": "execute_result" | ||
632 | - } | ||
633 | - ], | 312 | + "outputs": [], |
634 | "source": [ | 313 | "source": [ |
635 | "from neural_parser import hybrid_tree_utils\n", | 314 | "from neural_parser import hybrid_tree_utils\n", |
636 | "importlib.reload(hybrid_tree_utils)\n", | 315 | "importlib.reload(hybrid_tree_utils)\n", |
@@ -640,86 +319,12 @@ | @@ -640,86 +319,12 @@ | ||
640 | }, | 319 | }, |
641 | { | 320 | { |
642 | "cell_type": "code", | 321 | "cell_type": "code", |
643 | - "execution_count": 47, | ||
644 | - "id": "d1b28792", | ||
645 | - "metadata": {}, | ||
646 | - "outputs": [], | ||
647 | - "source": [ | ||
648 | - "from spacy import displacy\n", | ||
649 | - "\n", | ||
650 | - "def to_deps(tokens, deprels, heads):\n", | ||
651 | - " deps = {'words' : [], 'arcs' : []}\n", | ||
652 | - " for i, (token, deprel, head) in enumerate(zip(tokens, deprels, heads)):\n", | ||
653 | - " deps['words'].append({'text' : token, 'tag' : 'X'})\n", | ||
654 | - " if head >= 0:\n", | ||
655 | - " d = 'left' if head > i else 'right'\n", | ||
656 | - " start, end = sorted((i, head))\n", | ||
657 | - " deps['arcs'].append({'start' : start, 'end' : end, 'label' : deprel, 'dir' : d})\n", | ||
658 | - " return deps\n", | ||
659 | - "\n", | ||
660 | - "def display_deps(tokens, deprels, heads):\n", | ||
661 | - " displacy.render(to_deps(tokens, deprels, heads), manual=True, options={'distance' : 80})\n", | ||
662 | - " \n", | ||
663 | - "import urllib.parse\n", | ||
664 | - "import json\n", | ||
665 | - "\n", | ||
666 | - "def show_tree(tree):\n", | ||
667 | - " tree_json = json.dumps(hybrid_tree_utils.tree2dict(tree)['tree'])\n", | ||
668 | - " src = f'http://127.0.0.1:8010/?tree={urllib.parse.quote(tree_json)}'\n", | ||
669 | - " display(IFrame(src, 950, 550))" | ||
670 | - ] | ||
671 | - }, | ||
672 | - { | ||
673 | - "cell_type": "code", | ||
674 | - "execution_count": 48, | 322 | + "execution_count": null, |
675 | "id": "9f443569", | 323 | "id": "9f443569", |
676 | "metadata": { | 324 | "metadata": { |
677 | "scrolled": true | 325 | "scrolled": true |
678 | }, | 326 | }, |
679 | - "outputs": [ | ||
680 | - { | ||
681 | - "name": "stdout", | ||
682 | - "output_type": "stream", | ||
683 | - "text": [ | ||
684 | - "2211\n", | ||
685 | - "2205\n", | ||
686 | - "['Caลujฤ', '.']\n" | ||
687 | - ] | ||
688 | - }, | ||
689 | - { | ||
690 | - "data": { | ||
691 | - "text/plain": [ | ||
692 | - "{'heads': [None, 0],\n", | ||
693 | - " 'deprels': ['ROOT', 'punct'],\n", | ||
694 | - " 'spines': ['ROOT_S_VP_V', 'Punct'],\n", | ||
695 | - " 'anchors': ['<ROOT>', 'ROOT'],\n", | ||
696 | - " 'anchor_hs': ['<ROOT>', '1']}" | ||
697 | - ] | ||
698 | - }, | ||
699 | - "metadata": {}, | ||
700 | - "output_type": "display_data" | ||
701 | - }, | ||
702 | - { | ||
703 | - "name": "stdout", | ||
704 | - "output_type": "stream", | ||
705 | - "text": [ | ||
706 | - "['Drzemaล', '.']\n" | ||
707 | - ] | ||
708 | - }, | ||
709 | - { | ||
710 | - "data": { | ||
711 | - "text/plain": [ | ||
712 | - "{'heads': [None, 0],\n", | ||
713 | - " 'deprels': ['ROOT', 'punct'],\n", | ||
714 | - " 'spines': ['ROOT_S_VP_V', 'Punct'],\n", | ||
715 | - " 'anchors': ['<ROOT>', 'ROOT'],\n", | ||
716 | - " 'anchor_hs': ['<ROOT>', '1']}" | ||
717 | - ] | ||
718 | - }, | ||
719 | - "metadata": {}, | ||
720 | - "output_type": "display_data" | ||
721 | - } | ||
722 | - ], | 327 | + "outputs": [], |
723 | "source": [ | 328 | "source": [ |
724 | "HDR = [\n", | 329 | "HDR = [\n", |
725 | " 'heads', 'deprels',\n", | 330 | " 'heads', 'deprels',\n", |
@@ -753,49 +358,10 @@ | @@ -753,49 +358,10 @@ | ||
753 | }, | 358 | }, |
754 | { | 359 | { |
755 | "cell_type": "code", | 360 | "cell_type": "code", |
756 | - "execution_count": 49, | 361 | + "execution_count": null, |
757 | "id": "3f53c039", | 362 | "id": "3f53c039", |
758 | "metadata": {}, | 363 | "metadata": {}, |
759 | - "outputs": [ | ||
760 | - { | ||
761 | - "name": "stdout", | ||
762 | - "output_type": "stream", | ||
763 | - "text": [ | ||
764 | - "70/70 [==============================] - 17s 152ms/step\n", | ||
765 | - "69/69 [==============================] - 12s 168ms/step\n", | ||
766 | - "['Caลujฤ', '.']\n" | ||
767 | - ] | ||
768 | - }, | ||
769 | - { | ||
770 | - "data": { | ||
771 | - "text/plain": [ | ||
772 | - "{'spines': ['ROOT_S_VP_V', 'Punct'],\n", | ||
773 | - " 'anchors': ['<ROOT>', 'ROOT'],\n", | ||
774 | - " 'anchor_hs': ['<ROOT>', '1']}" | ||
775 | - ] | ||
776 | - }, | ||
777 | - "metadata": {}, | ||
778 | - "output_type": "display_data" | ||
779 | - }, | ||
780 | - { | ||
781 | - "name": "stdout", | ||
782 | - "output_type": "stream", | ||
783 | - "text": [ | ||
784 | - "['Drzemaล', '.']\n" | ||
785 | - ] | ||
786 | - }, | ||
787 | - { | ||
788 | - "data": { | ||
789 | - "text/plain": [ | ||
790 | - "{'spines': ['ROOT_S_VP_V', 'Punct'],\n", | ||
791 | - " 'anchors': ['<ROOT>', 'ROOT'],\n", | ||
792 | - " 'anchor_hs': ['<ROOT>', '1']}" | ||
793 | - ] | ||
794 | - }, | ||
795 | - "metadata": {}, | ||
796 | - "output_type": "display_data" | ||
797 | - } | ||
798 | - ], | 364 | + "outputs": [], |
799 | "source": [ | 365 | "source": [ |
800 | "def get_predicted_data(TOKENS_TRUE):\n", | 366 | "def get_predicted_data(TOKENS_TRUE):\n", |
801 | " PARSED = parser.parse([' '.join(toks) for toks in TOKENS_TRUE])\n", | 367 | " PARSED = parser.parse([' '.join(toks) for toks in TOKENS_TRUE])\n", |
@@ -821,45 +387,10 @@ | @@ -821,45 +387,10 @@ | ||
821 | }, | 387 | }, |
822 | { | 388 | { |
823 | "cell_type": "code", | 389 | "cell_type": "code", |
824 | - "execution_count": 50, | 390 | + "execution_count": null, |
825 | "id": "17c1d9cb", | 391 | "id": "17c1d9cb", |
826 | "metadata": {}, | 392 | "metadata": {}, |
827 | - "outputs": [ | ||
828 | - { | ||
829 | - "name": "stdout", | ||
830 | - "output_type": "stream", | ||
831 | - "text": [ | ||
832 | - "2211\n", | ||
833 | - "2205\n", | ||
834 | - "['Caลujฤ', '.']\n" | ||
835 | - ] | ||
836 | - }, | ||
837 | - { | ||
838 | - "data": { | ||
839 | - "text/plain": [ | ||
840 | - "{'heads': [None, 0], 'deprels': ['root', 'punct']}" | ||
841 | - ] | ||
842 | - }, | ||
843 | - "metadata": {}, | ||
844 | - "output_type": "display_data" | ||
845 | - }, | ||
846 | - { | ||
847 | - "name": "stdout", | ||
848 | - "output_type": "stream", | ||
849 | - "text": [ | ||
850 | - "['Drzemaล', '.']\n" | ||
851 | - ] | ||
852 | - }, | ||
853 | - { | ||
854 | - "data": { | ||
855 | - "text/plain": [ | ||
856 | - "{'heads': [None, 0], 'deprels': ['root', 'punct']}" | ||
857 | - ] | ||
858 | - }, | ||
859 | - "metadata": {}, | ||
860 | - "output_type": "display_data" | ||
861 | - } | ||
862 | - ], | 393 | + "outputs": [], |
863 | "source": [ | 394 | "source": [ |
864 | "import conllu\n", | 395 | "import conllu\n", |
865 | "\n", | 396 | "\n", |
@@ -894,7 +425,7 @@ | @@ -894,7 +425,7 @@ | ||
894 | }, | 425 | }, |
895 | { | 426 | { |
896 | "cell_type": "code", | 427 | "cell_type": "code", |
897 | - "execution_count": 51, | 428 | + "execution_count": null, |
898 | "id": "004918c6", | 429 | "id": "004918c6", |
899 | "metadata": {}, | 430 | "metadata": {}, |
900 | "outputs": [], | 431 | "outputs": [], |
@@ -913,42 +444,22 @@ | @@ -913,42 +444,22 @@ | ||
913 | "def tree2spans(tree, labeled=True, headed=False):\n", | 444 | "def tree2spans(tree, labeled=True, headed=False):\n", |
914 | " spans = []\n", | 445 | " spans = []\n", |
915 | " _tree2spans(tree, spans, labeled=labeled, headed=headed)\n", | 446 | " _tree2spans(tree, spans, labeled=labeled, headed=headed)\n", |
916 | - " # TODO\n", | ||
917 | - " #try:\n", | ||
918 | - " # assert(len(spans) == len(set(spans)))\n", | ||
919 | - " #except:\n", | ||
920 | - " # show_tree(tree)\n", | ||
921 | - " # (display(spans))\n", | ||
922 | - " # 1/0\n", | ||
923 | " return set(spans)" | 447 | " return set(spans)" |
924 | ] | 448 | ] |
925 | }, | 449 | }, |
926 | { | 450 | { |
927 | "cell_type": "code", | 451 | "cell_type": "code", |
928 | - "execution_count": 52, | 452 | + "execution_count": null, |
929 | "id": "65d493ca", | 453 | "id": "65d493ca", |
930 | "metadata": {}, | 454 | "metadata": {}, |
931 | - "outputs": [ | ||
932 | - { | ||
933 | - "data": { | ||
934 | - "text/plain": [ | ||
935 | - "<module 'neural_parser.hybrid_tree_utils' from '/home/kkrasnowska/neural-parsing/ICCS/neural_parser/hybrid_tree_utils.py'>" | ||
936 | - ] | ||
937 | - }, | ||
938 | - "execution_count": 52, | ||
939 | - "metadata": {}, | ||
940 | - "output_type": "execute_result" | ||
941 | - } | ||
942 | - ], | 455 | + "outputs": [], |
943 | "source": [ | 456 | "source": [ |
944 | - "from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score\n", | ||
945 | - "\n", | ||
946 | - "importlib.reload(hybrid_tree_utils)" | 457 | + "from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score" |
947 | ] | 458 | ] |
948 | }, | 459 | }, |
949 | { | 460 | { |
950 | "cell_type": "code", | 461 | "cell_type": "code", |
951 | - "execution_count": 53, | 462 | + "execution_count": null, |
952 | "id": "e5f88e76", | 463 | "id": "e5f88e76", |
953 | "metadata": { | 464 | "metadata": { |
954 | "scrolled": false | 465 | "scrolled": false |
@@ -964,18 +475,11 @@ | @@ -964,18 +475,11 @@ | ||
964 | " key : {'true' : [], 'pred' : []} for key in ('heads', ('heads', 'deprels'))\n", | 475 | " key : {'true' : [], 'pred' : []} for key in ('heads', ('heads', 'deprels'))\n", |
965 | " }\n", | 476 | " }\n", |
966 | "\n", | 477 | "\n", |
967 | - " k = 0\n", | ||
968 | " i = 0\n", | 478 | " i = 0\n", |
969 | " PROBLEM_TREES = []\n", | 479 | " PROBLEM_TREES = []\n", |
970 | "\n", | 480 | "\n", |
971 | " for toks, true, pred, combo in zip(tokens, tags_true, tags_pred, tags_combo):\n", | 481 | " for toks, true, pred, combo in zip(tokens, tags_true, tags_pred, tags_combo):\n", |
972 | " \n", | 482 | " \n", |
973 | - " #sent = ' '.join(toks)\n", | ||
974 | - " #cats = HDR\n", | ||
975 | - " #true = dict(zip(cats, zip(*true)))\n", | ||
976 | - " #pred = dict(zip(cats, zip(*pred)))\n", | ||
977 | - " #print('----------------------------')\n", | ||
978 | - " #print(sent)\n", | ||
979 | " dummy = {'lemmas' : ['_' for _ in toks], 'tags' : ['_' for _ in toks]}\n", | 483 | " dummy = {'lemmas' : ['_' for _ in toks], 'tags' : ['_' for _ in toks]}\n", |
980 | " true.update(dummy)\n", | 484 | " true.update(dummy)\n", |
981 | " pred.update(dummy)\n", | 485 | " pred.update(dummy)\n", |
@@ -994,12 +498,6 @@ | @@ -994,12 +498,6 @@ | ||
994 | " print('=============================')\n", | 498 | " print('=============================')\n", |
995 | " raise\n", | 499 | " raise\n", |
996 | " tree_pred, problems = None, None\n", | 500 | " tree_pred, problems = None, None\n", |
997 | - " #if 'reattach' in problems:\n", | ||
998 | - " # show_tree(tree_pred)\n", | ||
999 | - " \n", | ||
1000 | - " #if pred['lemmas_corr'] != pred['lemmas']:\n", | ||
1001 | - " # print(pred['lemmas_corr'])\n", | ||
1002 | - " # print(pred['lemmas'])\n", | ||
1003 | " \n", | 501 | " \n", |
1004 | " for key, v in accuracies.items():\n", | 502 | " for key, v in accuracies.items():\n", |
1005 | " if type(key) == str:\n", | 503 | " if type(key) == str:\n", |
@@ -1011,31 +509,11 @@ | @@ -1011,31 +509,11 @@ | ||
1011 | " \n", | 509 | " \n", |
1012 | " spans_true = tree2spans(tree_true, labeled=labeled, headed=headed)\n", | 510 | " spans_true = tree2spans(tree_true, labeled=labeled, headed=headed)\n", |
1013 | " spans_pred = tree2spans(tree_pred, labeled=labeled, headed=headed) if tree_pred else set()\n", | 511 | " spans_pred = tree2spans(tree_pred, labeled=labeled, headed=headed) if tree_pred else set()\n", |
1014 | - " if 'adwokata' in toks:\n", | ||
1015 | - " print(spans_true)\n", | ||
1016 | - " print(spans_pred)\n", | ||
1017 | " tp = len(spans_true.intersection(spans_pred))\n", | 512 | " tp = len(spans_true.intersection(spans_pred))\n", |
1018 | " P[0] += tp\n", | 513 | " P[0] += tp\n", |
1019 | " R[0] += tp\n", | 514 | " R[0] += tp\n", |
1020 | " P[1] += len(spans_pred)\n", | 515 | " P[1] += len(spans_pred)\n", |
1021 | " R[1] += len(spans_true)\n", | 516 | " R[1] += len(spans_true)\n", |
1022 | - " leafs = tree_true.get_yield()\n", | ||
1023 | - " discont = [leaf.from_index for leaf in leafs] != list(range(len(leafs)))\n", | ||
1024 | - " #if k < 5 and len(toks) > 9 and [leaf.features['index'] for leaf in leafs] != list(range(len(leafs))):\n", | ||
1025 | - " #if k < 5 and spans_combo != spans_true:\n", | ||
1026 | - " #if k < 5 and not OK:\n", | ||
1027 | - " #if discont and len(toks) > 12 and k < 0 and spans_pred == spans_true:\n", | ||
1028 | - " if len(toks) == 8 and k < 0:\n", | ||
1029 | - " print('GOLD TREE:')\n", | ||
1030 | - " show_tree(tree_true)\n", | ||
1031 | - " display(true)\n", | ||
1032 | - " #display(_tree2dict(tree_true))\n", | ||
1033 | - " print('PREDICTED TREE:')\n", | ||
1034 | - " show_tree(tree_pred)\n", | ||
1035 | - " display(pred)\n", | ||
1036 | - " print('FP:', spans_pred - spans_true)\n", | ||
1037 | - " print('FN:', spans_true - spans_pred)\n", | ||
1038 | - " k += 1\n", | ||
1039 | " i += 1\n", | 517 | " i += 1\n", |
1040 | " \n", | 518 | " \n", |
1041 | " p, r = P[0]/P[1], R[0]/R[1]\n", | 519 | " p, r = P[0]/P[1], R[0]/R[1]\n", |
@@ -1060,25 +538,12 @@ | @@ -1060,25 +538,12 @@ | ||
1060 | }, | 538 | }, |
1061 | { | 539 | { |
1062 | "cell_type": "code", | 540 | "cell_type": "code", |
1063 | - "execution_count": 54, | 541 | + "execution_count": null, |
1064 | "id": "8f8a771a", | 542 | "id": "8f8a771a", |
1065 | "metadata": { | 543 | "metadata": { |
1066 | "scrolled": false | 544 | "scrolled": false |
1067 | }, | 545 | }, |
1068 | - "outputs": [ | ||
1069 | - { | ||
1070 | - "name": "stdout", | ||
1071 | - "output_type": "stream", | ||
1072 | - "text": [ | ||
1073 | - "unlabeled{((3,), 'SPAN', False), ((2, 3), 'SPAN', False), ((4,), 'SPAN', False), ((0, 1, 2, 3, 4), 'SPAN', False), ((0, 1, 2, 3), 'SPAN', False), ((2,), 'SPAN', False), ((0, 1), 'SPAN', False)}\n", | ||
1074 | - "{((3,), 'SPAN', False), ((2, 3), 'SPAN', False), ((4,), 'SPAN', False), ((0, 1, 2, 3, 4), 'SPAN', False), ((0, 1, 2, 3), 'SPAN', False), ((2,), 'SPAN', False), ((0, 1), 'SPAN', False)}\n", | ||
1075 | - "non-headed{((2,), 'Prep', False), ((4,), 'Punct', False), ((2, 3), 'PrepNP', False), ((3,), 'N', False), ((0, 1, 2, 3), 'S', False), ((0, 1), 'VP', False), ((0, 1), 'V', False), ((3,), 'NP', False), ((0, 1, 2, 3, 4), 'ROOT', False)}\n", | ||
1076 | - "{((2,), 'Prep', False), ((4,), 'Punct', False), ((2, 3), 'PrepNP', False), ((3,), 'N', False), ((0, 1, 2, 3), 'S', False), ((0, 1), 'VP', False), ((0, 1), 'V', False), ((3,), 'NP', False), ((0, 1, 2, 3, 4), 'ROOT', False)}\n", | ||
1077 | - "headed{((0, 1, 2, 3), 'S', True), ((4,), 'Punct', False), ((0, 1), 'VP', True), ((2, 3), 'PrepNP', False), ((0, 1), 'V', True), ((3,), 'NP', False), ((0, 1, 2, 3, 4), 'ROOT', False), ((2,), 'Prep', True), ((3,), 'N', True)}\n", | ||
1078 | - "{((0, 1, 2, 3), 'S', True), ((4,), 'Punct', False), ((0, 1), 'VP', True), ((2, 3), 'PrepNP', False), ((0, 1), 'V', True), ((3,), 'NP', False), ((0, 1, 2, 3, 4), 'ROOT', False), ((2,), 'Prep', True), ((3,), 'N', True)}\n" | ||
1079 | - ] | ||
1080 | - } | ||
1081 | - ], | 546 | + "outputs": [], |
1082 | "source": [ | 547 | "source": [ |
1083 | "EVAL_DATA = {\n", | 548 | "EVAL_DATA = {\n", |
1084 | " '1val' : (TOKENS_VAL, TAGS_VAL, TAGS_P_VAL, TAGS_C_VAL),\n", | 549 | " '1val' : (TOKENS_VAL, TAGS_VAL, TAGS_P_VAL, TAGS_C_VAL),\n", |
@@ -1113,7 +578,7 @@ | @@ -1113,7 +578,7 @@ | ||
1113 | }, | 578 | }, |
1114 | { | 579 | { |
1115 | "cell_type": "code", | 580 | "cell_type": "code", |
1116 | - "execution_count": 55, | 581 | + "execution_count": null, |
1117 | "id": "63192852", | 582 | "id": "63192852", |
1118 | "metadata": {}, | 583 | "metadata": {}, |
1119 | "outputs": [], | 584 | "outputs": [], |
@@ -1123,7 +588,7 @@ | @@ -1123,7 +588,7 @@ | ||
1123 | }, | 588 | }, |
1124 | { | 589 | { |
1125 | "cell_type": "code", | 590 | "cell_type": "code", |
1126 | - "execution_count": 56, | 591 | + "execution_count": null, |
1127 | "id": "78250b1b", | 592 | "id": "78250b1b", |
1128 | "metadata": {}, | 593 | "metadata": {}, |
1129 | "outputs": [], | 594 | "outputs": [], |
@@ -1133,7 +598,7 @@ | @@ -1133,7 +598,7 @@ | ||
1133 | }, | 598 | }, |
1134 | { | 599 | { |
1135 | "cell_type": "code", | 600 | "cell_type": "code", |
1136 | - "execution_count": 57, | 601 | + "execution_count": null, |
1137 | "id": "bba6ed15", | 602 | "id": "bba6ed15", |
1138 | "metadata": {}, | 603 | "metadata": {}, |
1139 | "outputs": [], | 604 | "outputs": [], |
@@ -1143,260 +608,20 @@ | @@ -1143,260 +608,20 @@ | ||
1143 | }, | 608 | }, |
1144 | { | 609 | { |
1145 | "cell_type": "code", | 610 | "cell_type": "code", |
1146 | - "execution_count": 58, | 611 | + "execution_count": null, |
1147 | "id": "543377f8", | 612 | "id": "543377f8", |
1148 | "metadata": {}, | 613 | "metadata": {}, |
1149 | - "outputs": [ | ||
1150 | - { | ||
1151 | - "data": { | ||
1152 | - "text/html": [ | ||
1153 | - "<div>\n", | ||
1154 | - "<style scoped>\n", | ||
1155 | - " .dataframe tbody tr th:only-of-type {\n", | ||
1156 | - " vertical-align: middle;\n", | ||
1157 | - " }\n", | ||
1158 | - "\n", | ||
1159 | - " .dataframe tbody tr th {\n", | ||
1160 | - " vertical-align: top;\n", | ||
1161 | - " }\n", | ||
1162 | - "\n", | ||
1163 | - " .dataframe thead th {\n", | ||
1164 | - " text-align: right;\n", | ||
1165 | - " }\n", | ||
1166 | - "</style>\n", | ||
1167 | - "<table border=\"1\" class=\"dataframe\">\n", | ||
1168 | - " <thead>\n", | ||
1169 | - " <tr style=\"text-align: right;\">\n", | ||
1170 | - " <th></th>\n", | ||
1171 | - " <th></th>\n", | ||
1172 | - " <th></th>\n", | ||
1173 | - " <th></th>\n", | ||
1174 | - " <th>dataset</th>\n", | ||
1175 | - " <th>measure_type</th>\n", | ||
1176 | - " <th>measure</th>\n", | ||
1177 | - " <th>value</th>\n", | ||
1178 | - " </tr>\n", | ||
1179 | - " <tr>\n", | ||
1180 | - " <th>dataset</th>\n", | ||
1181 | - " <th>measure</th>\n", | ||
1182 | - " <th>measure_type</th>\n", | ||
1183 | - " <th></th>\n", | ||
1184 | - " <th></th>\n", | ||
1185 | - " <th></th>\n", | ||
1186 | - " <th></th>\n", | ||
1187 | - " <th></th>\n", | ||
1188 | - " </tr>\n", | ||
1189 | - " </thead>\n", | ||
1190 | - " <tbody>\n", | ||
1191 | - " <tr>\n", | ||
1192 | - " <th rowspan=\"9\" valign=\"top\">test</th>\n", | ||
1193 | - " <th rowspan=\"3\" valign=\"top\">F1</th>\n", | ||
1194 | - " <th>headed</th>\n", | ||
1195 | - " <th>7</th>\n", | ||
1196 | - " <td>test</td>\n", | ||
1197 | - " <td>headed</td>\n", | ||
1198 | - " <td>F1</td>\n", | ||
1199 | - " <td>0.959192</td>\n", | ||
1200 | - " </tr>\n", | ||
1201 | - " <tr>\n", | ||
1202 | - " <th>non-headed</th>\n", | ||
1203 | - " <th>8</th>\n", | ||
1204 | - " <td>test</td>\n", | ||
1205 | - " <td>non-headed</td>\n", | ||
1206 | - " <td>F1</td>\n", | ||
1207 | - " <td>0.965236</td>\n", | ||
1208 | - " </tr>\n", | ||
1209 | - " <tr>\n", | ||
1210 | - " <th>unlabeled</th>\n", | ||
1211 | - " <th>15</th>\n", | ||
1212 | - " <td>test</td>\n", | ||
1213 | - " <td>unlabeled</td>\n", | ||
1214 | - " <td>F1</td>\n", | ||
1215 | - " <td>0.964436</td>\n", | ||
1216 | - " </tr>\n", | ||
1217 | - " <tr>\n", | ||
1218 | - " <th rowspan=\"3\" valign=\"top\">P</th>\n", | ||
1219 | - " <th>headed</th>\n", | ||
1220 | - " <th>9</th>\n", | ||
1221 | - " <td>test</td>\n", | ||
1222 | - " <td>headed</td>\n", | ||
1223 | - " <td>P</td>\n", | ||
1224 | - " <td>0.959611</td>\n", | ||
1225 | - " </tr>\n", | ||
1226 | - " <tr>\n", | ||
1227 | - " <th>non-headed</th>\n", | ||
1228 | - " <th>6</th>\n", | ||
1229 | - " <td>test</td>\n", | ||
1230 | - " <td>non-headed</td>\n", | ||
1231 | - " <td>P</td>\n", | ||
1232 | - " <td>0.965658</td>\n", | ||
1233 | - " </tr>\n", | ||
1234 | - " <tr>\n", | ||
1235 | - " <th>unlabeled</th>\n", | ||
1236 | - " <th>13</th>\n", | ||
1237 | - " <td>test</td>\n", | ||
1238 | - " <td>unlabeled</td>\n", | ||
1239 | - " <td>P</td>\n", | ||
1240 | - " <td>0.964118</td>\n", | ||
1241 | - " </tr>\n", | ||
1242 | - " <tr>\n", | ||
1243 | - " <th rowspan=\"3\" valign=\"top\">R</th>\n", | ||
1244 | - " <th>headed</th>\n", | ||
1245 | - " <th>2</th>\n", | ||
1246 | - " <td>test</td>\n", | ||
1247 | - " <td>headed</td>\n", | ||
1248 | - " <td>R</td>\n", | ||
1249 | - " <td>0.958773</td>\n", | ||
1250 | - " </tr>\n", | ||
1251 | - " <tr>\n", | ||
1252 | - " <th>non-headed</th>\n", | ||
1253 | - " <th>5</th>\n", | ||
1254 | - " <td>test</td>\n", | ||
1255 | - " <td>non-headed</td>\n", | ||
1256 | - " <td>R</td>\n", | ||
1257 | - " <td>0.964815</td>\n", | ||
1258 | - " </tr>\n", | ||
1259 | - " <tr>\n", | ||
1260 | - " <th>unlabeled</th>\n", | ||
1261 | - " <th>0</th>\n", | ||
1262 | - " <td>test</td>\n", | ||
1263 | - " <td>unlabeled</td>\n", | ||
1264 | - " <td>R</td>\n", | ||
1265 | - " <td>0.964754</td>\n", | ||
1266 | - " </tr>\n", | ||
1267 | - " <tr>\n", | ||
1268 | - " <th rowspan=\"9\" valign=\"top\">val</th>\n", | ||
1269 | - " <th rowspan=\"3\" valign=\"top\">F1</th>\n", | ||
1270 | - " <th>headed</th>\n", | ||
1271 | - " <th>14</th>\n", | ||
1272 | - " <td>val</td>\n", | ||
1273 | - " <td>headed</td>\n", | ||
1274 | - " <td>F1</td>\n", | ||
1275 | - " <td>0.957423</td>\n", | ||
1276 | - " </tr>\n", | ||
1277 | - " <tr>\n", | ||
1278 | - " <th>non-headed</th>\n", | ||
1279 | - " <th>4</th>\n", | ||
1280 | - " <td>val</td>\n", | ||
1281 | - " <td>non-headed</td>\n", | ||
1282 | - " <td>F1</td>\n", | ||
1283 | - " <td>0.963231</td>\n", | ||
1284 | - " </tr>\n", | ||
1285 | - " <tr>\n", | ||
1286 | - " <th>unlabeled</th>\n", | ||
1287 | - " <th>1</th>\n", | ||
1288 | - " <td>val</td>\n", | ||
1289 | - " <td>unlabeled</td>\n", | ||
1290 | - " <td>F1</td>\n", | ||
1291 | - " <td>0.962553</td>\n", | ||
1292 | - " </tr>\n", | ||
1293 | - " <tr>\n", | ||
1294 | - " <th rowspan=\"3\" valign=\"top\">P</th>\n", | ||
1295 | - " <th>headed</th>\n", | ||
1296 | - " <th>10</th>\n", | ||
1297 | - " <td>val</td>\n", | ||
1298 | - " <td>headed</td>\n", | ||
1299 | - " <td>P</td>\n", | ||
1300 | - " <td>0.958145</td>\n", | ||
1301 | - " </tr>\n", | ||
1302 | - " <tr>\n", | ||
1303 | - " <th>non-headed</th>\n", | ||
1304 | - " <th>16</th>\n", | ||
1305 | - " <td>val</td>\n", | ||
1306 | - " <td>non-headed</td>\n", | ||
1307 | - " <td>P</td>\n", | ||
1308 | - " <td>0.963958</td>\n", | ||
1309 | - " </tr>\n", | ||
1310 | - " <tr>\n", | ||
1311 | - " <th>unlabeled</th>\n", | ||
1312 | - " <th>11</th>\n", | ||
1313 | - " <td>val</td>\n", | ||
1314 | - " <td>unlabeled</td>\n", | ||
1315 | - " <td>P</td>\n", | ||
1316 | - " <td>0.962762</td>\n", | ||
1317 | - " </tr>\n", | ||
1318 | - " <tr>\n", | ||
1319 | - " <th rowspan=\"3\" valign=\"top\">R</th>\n", | ||
1320 | - " <th>headed</th>\n", | ||
1321 | - " <th>17</th>\n", | ||
1322 | - " <td>val</td>\n", | ||
1323 | - " <td>headed</td>\n", | ||
1324 | - " <td>R</td>\n", | ||
1325 | - " <td>0.956702</td>\n", | ||
1326 | - " </tr>\n", | ||
1327 | - " <tr>\n", | ||
1328 | - " <th>non-headed</th>\n", | ||
1329 | - " <th>12</th>\n", | ||
1330 | - " <td>val</td>\n", | ||
1331 | - " <td>non-headed</td>\n", | ||
1332 | - " <td>R</td>\n", | ||
1333 | - " <td>0.962505</td>\n", | ||
1334 | - " </tr>\n", | ||
1335 | - " <tr>\n", | ||
1336 | - " <th>unlabeled</th>\n", | ||
1337 | - " <th>3</th>\n", | ||
1338 | - " <td>val</td>\n", | ||
1339 | - " <td>unlabeled</td>\n", | ||
1340 | - " <td>R</td>\n", | ||
1341 | - " <td>0.962343</td>\n", | ||
1342 | - " </tr>\n", | ||
1343 | - " </tbody>\n", | ||
1344 | - "</table>\n", | ||
1345 | - "</div>" | ||
1346 | - ], | ||
1347 | - "text/plain": [ | ||
1348 | - " dataset measure_type measure value\n", | ||
1349 | - "dataset measure measure_type \n", | ||
1350 | - "test F1 headed 7 test headed F1 0.959192\n", | ||
1351 | - " non-headed 8 test non-headed F1 0.965236\n", | ||
1352 | - " unlabeled 15 test unlabeled F1 0.964436\n", | ||
1353 | - " P headed 9 test headed P 0.959611\n", | ||
1354 | - " non-headed 6 test non-headed P 0.965658\n", | ||
1355 | - " unlabeled 13 test unlabeled P 0.964118\n", | ||
1356 | - " R headed 2 test headed R 0.958773\n", | ||
1357 | - " non-headed 5 test non-headed R 0.964815\n", | ||
1358 | - " unlabeled 0 test unlabeled R 0.964754\n", | ||
1359 | - "val F1 headed 14 val headed F1 0.957423\n", | ||
1360 | - " non-headed 4 val non-headed F1 0.963231\n", | ||
1361 | - " unlabeled 1 val unlabeled F1 0.962553\n", | ||
1362 | - " P headed 10 val headed P 0.958145\n", | ||
1363 | - " non-headed 16 val non-headed P 0.963958\n", | ||
1364 | - " unlabeled 11 val unlabeled P 0.962762\n", | ||
1365 | - " R headed 17 val headed R 0.956702\n", | ||
1366 | - " non-headed 12 val non-headed R 0.962505\n", | ||
1367 | - " unlabeled 3 val unlabeled R 0.962343" | ||
1368 | - ] | ||
1369 | - }, | ||
1370 | - "execution_count": 58, | ||
1371 | - "metadata": {}, | ||
1372 | - "output_type": "execute_result" | ||
1373 | - } | ||
1374 | - ], | 614 | + "outputs": [], |
1375 | "source": [ | 615 | "source": [ |
1376 | "results.groupby(['dataset', 'measure', 'measure_type'], group_keys=True).apply(lambda x: x)" | 616 | "results.groupby(['dataset', 'measure', 'measure_type'], group_keys=True).apply(lambda x: x)" |
1377 | ] | 617 | ] |
1378 | }, | 618 | }, |
1379 | { | 619 | { |
1380 | "cell_type": "code", | 620 | "cell_type": "code", |
1381 | - "execution_count": 59, | 621 | + "execution_count": null, |
1382 | "id": "0b5d3fe4", | 622 | "id": "0b5d3fe4", |
1383 | "metadata": {}, | 623 | "metadata": {}, |
1384 | - "outputs": [ | ||
1385 | - { | ||
1386 | - "name": "stdout", | ||
1387 | - "output_type": "stream", | ||
1388 | - "text": [ | ||
1389 | - "\\toprule\n", | ||
1390 | - "& \\multicolumn{3}{c}{validation} & \\multicolumn{3}{c}{test} \\\\\n", | ||
1391 | - "& precision & recall & F1 & precision & recall & F1 \\\\\n", | ||
1392 | - "\\midrule\n", | ||
1393 | - "1unlabeled & 96.28\\% & 96.23\\% & 96.26\\% & 96.41\\% & 96.48\\% & 96.44\\% \\\\\n", | ||
1394 | - "2non-headed & 96.40\\% & 96.25\\% & 96.32\\% & 96.57\\% & 96.48\\% & 96.52\\% \\\\\n", | ||
1395 | - "3headed & 95.81\\% & 95.67\\% & 95.74\\% & 95.96\\% & 95.88\\% & 95.92\\% \\\\\n", | ||
1396 | - "\\bottomrule\n" | ||
1397 | - ] | ||
1398 | - } | ||
1399 | - ], | 624 | + "outputs": [], |
1400 | "source": [ | 625 | "source": [ |
1401 | "for t in tex:\n", | 626 | "for t in tex:\n", |
1402 | " print(t, end='')" | 627 | " print(t, end='')" |
@@ -1444,10 +669,6 @@ | @@ -1444,10 +669,6 @@ | ||
1444 | " precisions = precision_score(TRUE, PRED, average=None)\n", | 669 | " precisions = precision_score(TRUE, PRED, average=None)\n", |
1445 | " recalls = recall_score(TRUE, PRED, average=None)\n", | 670 | " recalls = recall_score(TRUE, PRED, average=None)\n", |
1446 | " f1s = f1_score(TRUE, PRED, average=None)\n", | 671 | " f1s = f1_score(TRUE, PRED, average=None)\n", |
1447 | - " #for v, p, r, f in sorted(zip(values, precisions, recalls, f1s), key=lambda x: -x[3]):\n", | ||
1448 | - " # if v.endswith('formarzecz') or v.endswith('formaczas'):\n", | ||
1449 | - " # spine = ' $\\\\rightarrow$ '.join(f'\\\\nt{{{n}}}' for n in v.split('_'))\n", | ||
1450 | - " # print(f'{spine} & {100 * p:.2f}\\\\% & {100 * r:.2f}\\\\% & {100 * f:.2f}\\\\% \\\\\\\\')\n", | ||
1451 | " \n", | 672 | " \n", |
1452 | " ct_pre, cp_pre = Counter(), Counter()\n", | 673 | " ct_pre, cp_pre = Counter(), Counter()\n", |
1453 | " for val in values:\n", | 674 | " for val in values:\n", |
@@ -1458,7 +679,6 @@ | @@ -1458,7 +679,6 @@ | ||
1458 | " rows = []\n", | 679 | " rows = []\n", |
1459 | " \n", | 680 | " \n", |
1460 | " for pre in ct_pre.keys():\n", | 681 | " for pre in ct_pre.keys():\n", |
1461 | - " # TODO\n", | ||
1462 | " if pre == 'ign':\n", | 682 | " if pre == 'ign':\n", |
1463 | " continue\n", | 683 | " continue\n", |
1464 | " if not cp_pre[pre] * ct_pre[pre]:\n", | 684 | " if not cp_pre[pre] * ct_pre[pre]:\n", |
@@ -1472,7 +692,6 @@ | @@ -1472,7 +692,6 @@ | ||
1472 | " spine = ' $\\\\rightarrow$ '.join(f'\\\\nt{{{n}}}' for n in v.split('_'))\n", | 692 | " spine = ' $\\\\rightarrow$ '.join(f'\\\\nt{{{n}}}' for n in v.split('_'))\n", |
1473 | " rws.append(f'{spine} & {100 * p:.2f}\\\\% & {100 * r:.2f}\\\\% & {100 * f:.2f}\\\\% & {ct[v]} \\\\\\\\')\n", | 693 | " rws.append(f'{spine} & {100 * p:.2f}\\\\% & {100 * r:.2f}\\\\% & {100 * f:.2f}\\\\% & {ct[v]} \\\\\\\\')\n", |
1474 | " wp, wr = cp[v] / cp_pre[pre], ct[v] / ct_pre[pre]\n", | 694 | " wp, wr = cp[v] / cp_pre[pre], ct[v] / ct_pre[pre]\n", |
1475 | - " #print(f' {v:36s} {100 * p:6.2f} {wp:7.3f} {100 * r:6.2f} {wr:7.3f}')\n", | ||
1476 | " P += p * wp\n", | 695 | " P += p * wp\n", |
1477 | " R += r * wr\n", | 696 | " R += r * wr\n", |
1478 | " F = 2 * P * R / (P + R)\n", | 697 | " F = 2 * P * R / (P + R)\n", |