forked from alumae/kiirkirjutaja
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compounder.py
94 lines (74 loc) · 2.41 KB
/
compounder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#! /usr/bin/env python
from __future__ import print_function
import sys
import pywrapfst as fst
import pdb
'''
Script that adds symbols for compound word reconstruction (+C+, +D+, the
latter is word dash-seperated words) between tokens, using a "hidden event LM",
i.e. a LM that includes also +C+ and +D+.
'''
def make_sentence_fsa(syms, word_ids):
t = fst.VectorFst()
start_state = t.add_state()
assert(start_state == 0)
t.set_start(start_state)
i = 0
for word_id in word_ids:
if i > 0:
new_state = t.add_state()
assert(new_state == i+1)
t.add_arc(i, fst.Arc(syms["<eps>"], syms["<eps>"], 1, i+1))
t.add_arc(i, fst.Arc(syms["+C+"], syms["+C+"], 1, i+1))
t.add_arc(i, fst.Arc(syms["+D+"], syms["+D+"], 1, i+1))
i += 1
t.add_state()
t.add_arc(i, fst.Arc(word_id, word_id, 1, i+1))
i+=1
t.set_final(i, 1)
return t
class Compounder:
def __init__(self, fst_filename, words_filename):
self.g = fst.Fst.read(fst_filename)
self.syms = {}
self.syms_list = []
for l in open(words_filename):
ss = l.split()
self.syms[ss[0]] = int(ss[1])
self.syms_list.append(ss[0])
self.unk_id = self.syms["<unk>"]
def apply_compounder(self, words):
unks = []
word_ids = []
for word in words:
word_id = self.syms.get(word, self.unk_id)
word_ids.append(word_id)
if word_id == self.unk_id:
unks.append(word)
sentence = make_sentence_fsa(self.syms, word_ids)
sentence.arcsort(sort_type="olabel")
composed = fst.compose(sentence, self.g)
alignment = fst.shortestpath(composed)
alignment.rmepsilon()
alignment.topsort()
labels = []
for state in alignment.states():
for arc in alignment.arcs(state):
if arc.olabel > 0:
if arc.olabel == self.unk_id:
labels.append(unks.pop(0))
else:
labels.append(self.syms_list[arc.olabel])
return labels
if __name__ == '__main__':
if len(sys.argv) != 3:
print("Usage: %s G.fst words.txt" % sys.argv[0], file=sys.stderr)
compounder = Compounder(sys.argv[1], sys.argv[2])
# Following is needed to avoid line buffering
while 1:
l = sys.stdin.readline()
if not l: break
words = l.split()
labels = compounder.apply_compounder(words)
print(" ".join(labels))
sys.stdout.flush()