forked from iseesaw/Pinyin2ChineseChars
-
Notifications
You must be signed in to change notification settings - Fork 0
/
count_for_hmm.py
153 lines (128 loc) · 3.78 KB
/
count_for_hmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# -*- coding: utf-8 -*-
'''
HMM参数计算
init - 汉字初始概率
trans - 汉字间转移概率
emiss - 拼音对多音汉字的发射概率
'''
import pypinyin
import json
import math
from count_for_bigram import load
# 模型参数保存文件夹
dirname = 'hmm_params'
'''
计算汉字初始概率
初步计算每个汉字作为句首的概率
'''
def count_init(seqs):
init_prob = {}
num = 0
len_ = len(seqs)
for seq in seqs:
init_prob[seq[0]] = init_prob.get(seq[0], 0) + 1
num +=1
if not num % 10000:
print('{}/{}'.format(num, len_))
# normalize
# log
total = len(seqs)
for key in init_prob.keys():
init_prob[key] = math.log(init_prob.get(key) / total)
save('init_prob', init_prob)
'''
计算拼音-汉字发射概率
调用pypinyin对每句话进行拼音标注
记录每个拼音对应的汉字以及次数(多音汉字即为拼音的状态)
********状态(汉字)的发射概率
观察序列 - 拼音串
emiss_prob = {
word1 : {pinyin11: num11, pinyin12: num12, ...},
word2 : {pinyin21: num21, pinyin22: num22, ...},
...
}
'''
def count_emiss(seqs):
emiss_prob = {}
num = 0
len_ = len(seqs)
for seq in seqs:
# 句子转拼音:含声调,且使用 ü
pinyin = pypinyin.lazy_pinyin(seq, style=pypinyin.Style.TONE, v_to_u=True)
# 汉字-拼音 发射概率
for py, word in zip(pinyin, seq):
if not emiss_prob.get(word, None):
emiss_prob[word] = {}
emiss_prob[word][py] = emiss_prob[word].get(py, 0) + 1
num +=1
if not num % 10000:
print('{}/{}'.format(num, len_))
# normalize
# log
for word in emiss_prob.keys():
total = sum(emiss_prob.get(word).values())
for key in emiss_prob.get(word):
emiss_prob[word][key] = math.log(emiss_prob[word][key] / total)
save('emiss_prob', emiss_prob)
'''
计算汉字(状态)间转移概率
计算每个句子中汉字转移概率
'''
def count_trans(seqs):
trans_prob = {}
num = 0
len_ = len(seqs)
for seq in seqs:
seq = [w for w in seq]
seq.insert(0, 'BOS')
seq.append('EOS')
for index, post in enumerate(seq):
if index:
pre = seq[index - 1]
if not trans_prob.get(post, None):
trans_prob[post] = {}
trans_prob[post][pre] = trans_prob[post].get(pre, 0) + 1
num +=1
if not num % 10000:
print('{}/{}'.format(num, len_))
# normalize
for word in trans_prob.keys():
total = sum(trans_prob.get(word).values())
for pre in trans_prob.get(word).keys():
trans_prob[word][pre] = math.log(trans_prob[word].get(pre) / total)
save('trans_prob', trans_prob)
'''
统计同音字
作为拼音的所有状态
'''
def count_pinyin_states():
with open(dirname+'/emiss_prob.json') as f:
emiss_prob = json.load(f)
data = {}
for key in emiss_prob.keys():
for pinyin in emiss_prob.get(key):
if not data.get(pinyin, None):
data[pinyin] = []
data[pinyin].append(key)
with open(dirname+'/pinyin_states.json', 'w') as f:
json.dump(data, f)
'''
概率句子写入json文件
'''
def save(filename, data):
with open(dirname+'/' + filename + '.json', 'w') as f:
# Note:不转换 ASCII
json.dump(data, f, indent=2, ensure_ascii=False)
def count():
seqs = load()
print('Count init prob...')
count_init(seqs)
print('Count emiss prob...')
count_emiss(seqs)
print('Count trans prob...')
count_trans(seqs)
# print('Count pinyin states...')
# count_pinyin_states()
print('That is all...')
if __name__=='__main__':
count()