simstringdb.py
9.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
#!/usr/bin/env python
import glob
import os
import sys
from common import ProtocolError
from message import Messager
from os.path import join as path_join, sep as path_sep
try:
from config import BASE_DIR, WORK_DIR
except ImportError:
# for CLI use; assume we're in brat server/src/ and config is in root
from sys import path as sys_path
from os.path import dirname
sys_path.append(path_join(dirname(__file__), '../..'))
from config import BASE_DIR, WORK_DIR
# Filename extension used for DB file.
SS_DB_FILENAME_EXTENSION = 'ss.db'
# Default similarity measure
DEFAULT_SIMILARITY_MEASURE = 'cosine'
# Default similarity threshold
DEFAULT_THRESHOLD = 0.7
# Length of n-grams in simstring DBs
DEFAULT_NGRAM_LENGTH = 3
# Whether to include marks for begins and ends of strings
DEFAULT_INCLUDE_MARKS = False
SIMSTRING_MISSING_ERROR = '''Error: failed to import the simstring library.
This library is required for approximate string matching DB lookup.
Please install simstring and its Python bindings from
http://www.chokkan.org/software/simstring/'''
class NoSimStringError(ProtocolError):
def __str__(self):
return (u'No SimString bindings found, please install them from: '
u'http://www.chokkan.org/software/simstring/')
def json(self, json_dic):
json_dic['exception'] = 'noSimStringError'
class ssdbNotFoundError(Exception):
def __init__(self, fn):
self.fn = fn
def __str__(self):
return u'Simstring database file "%s" not found' % self.fn
# Note: The only reason we use a function call for this is to delay the import
def __set_db_measure(db, measure):
try:
import simstring
except ImportError:
Messager.error(SIMSTRING_MISSING_ERROR, duration=-1)
raise NoSimStringError
ss_measure_by_str = {
'cosine': simstring.cosine,
'overlap': simstring.overlap,
}
db.measure = ss_measure_by_str[measure]
def __ssdb_path(db):
'''
Given a simstring DB name/path, returns the path for the file that
is expected to contain the simstring DB.
'''
# Assume we have a path relative to the brat root if the value
# contains a separator, name only otherwise.
# TODO: better treatment of name / path ambiguity, this doesn't
# allow e.g. DBs to be located in brat root
if path_sep in db:
base = BASE_DIR
else:
base = WORK_DIR
return path_join(base, db+'.'+SS_DB_FILENAME_EXTENSION)
def ssdb_build(strs, dbname, ngram_length=DEFAULT_NGRAM_LENGTH,
include_marks=DEFAULT_INCLUDE_MARKS):
'''
Given a list of strings, a DB name, and simstring options, builds
a simstring DB for the strings.
'''
try:
import simstring
except ImportError:
Messager.error(SIMSTRING_MISSING_ERROR, duration=-1)
raise NoSimStringError
dbfn = __ssdb_path(dbname)
try:
# only library defaults (n=3, no marks) supported just now (TODO)
assert ngram_length == 3, "Error: unsupported n-gram length"
assert include_marks == False, "Error: begin/end marks not supported"
db = simstring.writer(dbfn)
for s in strs:
db.insert(s)
db.close()
except:
print >> sys.stderr, "Error building simstring DB"
raise
return dbfn
def ssdb_delete(dbname):
'''
Given a DB name, deletes all files associated with the simstring
DB.
'''
dbfn = __ssdb_path(dbname)
os.remove(dbfn)
for fn in glob.glob(dbfn+'.*.cdb'):
os.remove(fn)
def ssdb_open(dbname):
'''
Given a DB name, opens it as a simstring DB and returns the handle.
The caller is responsible for invoking close() on the handle.
'''
try:
import simstring
except ImportError:
Messager.error(SIMSTRING_MISSING_ERROR, duration=-1)
raise NoSimStringError
try:
return simstring.reader(__ssdb_path(dbname))
except IOError:
Messager.error('Failed to open simstring DB %s' % dbname)
raise ssdbNotFoundError(dbname)
def ssdb_lookup(s, dbname, measure=DEFAULT_SIMILARITY_MEASURE,
threshold=DEFAULT_THRESHOLD):
'''
Given a string and a DB name, returns the strings matching in the
associated simstring DB.
'''
db = ssdb_open(dbname)
__set_db_measure(db, measure)
db.threshold = threshold
result = db.retrieve(s)
db.close()
# assume simstring DBs always contain UTF-8 - encoded strings
result = [r.decode('UTF-8') for r in result]
return result
def ngrams(s, out=None, n=DEFAULT_NGRAM_LENGTH, be=DEFAULT_INCLUDE_MARKS):
'''
Extracts n-grams from the given string s and adds them into the
given set out (or a new set if None). Returns the set. If be is
True, affixes begin and end markers to strings.
'''
if out is None:
out = set()
# implementation mirroring ngrams() in ngram.h in simstring-1.0
# distribution.
mark = '\x01'
src = ''
if be:
# affix begin/end marks
for i in range(n-1):
src += mark
src += s
for i in range(n-1):
src += mark
elif len(s) < n:
# pad strings shorter than n
src = s
for i in range(n-len(s)):
src += mark
else:
src = s
# count n-grams
stat = {}
for i in range(len(src)-n+1):
ngram = src[i:i+n]
stat[ngram] = stat.get(ngram, 0) + 1
# convert into a set
for ngram, count in stat.items():
out.add(ngram)
# add ngram affixed with number if it appears more than once
for i in range(1, count):
out.add(ngram+str(i+1))
return out
def ssdb_supstring_lookup(s, dbname, threshold=DEFAULT_THRESHOLD,
with_score=False):
'''
Given a string s and a DB name, returns the strings in the
associated simstring DB that likely contain s as an (approximate)
substring. If with_score is True, returns pairs of (str,score)
where score is the fraction of n-grams in s that are also found in
the matched string.
'''
try:
import simstring
except ImportError:
Messager.error(SIMSTRING_MISSING_ERROR, duration=-1)
raise NoSimStringError
db = ssdb_open(dbname.encode('UTF-8'))
__set_db_measure(db, 'overlap')
db.threshold = threshold
result = db.retrieve(s)
db.close()
# assume simstring DBs always contain UTF-8 - encoded strings
result = [r.decode('UTF-8') for r in result]
# The simstring overlap measure is symmetric and thus does not
# differentiate between substring and superstring matches.
# Replicate a small bit of the simstring functionality (mostly the
# ngrams() function) to filter to substrings only.
s_ngrams = ngrams(s)
filtered = []
for r in result:
if s in r:
# avoid calculation: simple containment => score=1
if with_score:
filtered.append((r,1.0))
else:
filtered.append(r)
else:
r_ngrams = ngrams(r)
overlap = s_ngrams & r_ngrams
if len(overlap) >= len(s_ngrams) * threshold:
if with_score:
filtered.append((r, 1.0*len(overlap)/len(s_ngrams)))
else:
filtered.append(r)
return filtered
def ssdb_supstring_exists(s, dbname, threshold=DEFAULT_THRESHOLD):
'''
Given a string s and a DB name, returns whether at least one
string in the associated simstring DB likely contains s as an
(approximate) substring.
'''
try:
import simstring
except ImportError:
Messager.error(SIMSTRING_MISSING_ERROR, duration=-1)
raise NoSimStringError
if threshold == 1.0:
# optimized (not hugely, though) for this common case
db = ssdb_open(dbname.encode('UTF-8'))
__set_db_measure(db, 'overlap')
db.threshold = threshold
result = db.retrieve(s)
db.close()
# assume simstring DBs always contain UTF-8 - encoded strings
result = [r.decode('UTF-8') for r in result]
for r in result:
if s in r:
return True
return False
else:
# naive implementation for everything else
return len(ssdb_supstring_lookup(s, dbname, threshold)) != 0
if __name__ == "__main__":
# test
dbname = "TEMP-TEST-DB"
# strings = [
# "Cellular tumor antigen p53",
# "Nucleoporin NUP53",
# "Tumor protein p53-inducible nuclear protein 2",
# "p53-induced protein with a death domain",
# "TP53-regulating kinase",
# "Tumor suppressor p53-binding protein 1",
# "p53 apoptosis effector related to PMP-22",
# "p53 and DNA damage-regulated protein 1",
# "Tumor protein p53-inducible protein 11",
# "TP53RK-binding protein",
# "TP53-regulated inhibitor of apoptosis 1",
# "Apoptosis-stimulating of p53 protein 2",
# "Tumor protein p53-inducible nuclear protein 1",
# "TP53-target gene 1 protein",
# "Accessory gland protein Acp53Ea",
# "p53-regulated apoptosis-inducing protein 1",
# "Tumor protein p53-inducible protein 13",
# "TP53-target gene 3 protein",
# "Apoptosis-stimulating of p53 protein 1",
# "Ribosome biogenesis protein NOP53",
# ]
strings = [
"0",
"01",
"012",
"0123",
"01234",
"-12345",
"012345",
]
print 'strings:', strings
ssdb_build(strings, dbname)
for t in ['0', '012', '012345', '0123456', '0123456789']:
print 'lookup for', t
for s in ssdb_supstring_lookup(t, dbname):
print s, 'contains', t, '(threshold %f)' % DEFAULT_THRESHOLD
ssdb_delete(dbname)