forked from pengyuhe/Nematus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hypgraph.py
131 lines (116 loc) · 4.1 KB
/
hypgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import defaultdict
class HypGraph(object):
def __init__(self):
self.nodes = defaultdict(str) # {id = label}
self.edges = [] # (parent_node_id, child_node_id)
self.costs = defaultdict(float) # {node_id = cost}
self.word_probs = defaultdict(float) # {node_id = word_prob}
def get_id(self, word, history):
if history == []:
return str(word)
history = '-'.join([str(h) for h in reversed(history)])
return '%s-%s' % (word, history)
def get_ids(self, words):
ids = []
for i, w in enumerate(words):
history = words[:i]
ids.append(self.get_id(w, history))
return ids
def add(self, word, history, word_prob=None, cost=None):
history_labels = [0] + history
history_ids = self.get_ids(history_labels)
word_label = word
word_id = self.get_id(word_label, history_labels)
# store
self.nodes[word_id] = word_label
self.edges.append((history_ids[-1], word_id))
if word_prob != None:
self.word_probs[word_id] = word_prob
if cost != None:
self.costs[word_id] = cost
class HypGraphRenderer(object):
def __init__(self, hyp_graph):
self.nodes = hyp_graph.nodes
self.edges = hyp_graph.edges
self.costs = hyp_graph.costs
self.word_probs = hyp_graph.word_probs
# constants
self.BOS_SYMBOLS = ['0']
self.EOS_SYMBOLS = ['<eos>']
def _escape_label(self, label):
replacements = {
'<': '\<',
'>': '\>',
}
for original, replacement in replacements.iteritems():
label = label.replace(original, replacement)
return label
def _render(self, costs=False, word_probs=False, highlight_best=False):
from pygraphviz import AGraph
graph = AGraph(directed=True)
for node_id, node_label in self.nodes.iteritems():
attributes = self._node_attr(node_id, costs=costs, word_probs=word_probs)
graph.add_node(node_id, **attributes)
for (parent_node_id, child_node_id) in self.edges:
graph.add_edge(parent_node_id, child_node_id)
self.graph = graph
if highlight_best:
self._highlight_best()
def _node_attr(self, node_id, costs=False, word_probs=False):
word = self.nodes[node_id].decode('utf-8')
cost = self.costs[node_id]
prob = self.word_probs[node_id]
attr = {}
if costs and word_probs:
attr['shape'] = "record"
attr['label'] = "{{%s|%.3f}|%.3f}" % (word, prob, cost)
elif costs:
attr['shape'] = "record"
attr['label'] = "{{%s}|%.3f}" % (word, cost)
elif word_probs:
attr['shape'] = "record"
attr['label'] = "{{%s|%.3f}}" % (word, prob)
else:
attr['label'] = word
attr['label'] = self._escape_label(attr['label'])
return attr
def _highlight_best(self):
best_hyp_bg_color = '#CDE9EC'
best_hyp_cost = None
best_hyp_leaf_node_id = None
for node_id, label in self.nodes.iteritems():
if label in self.EOS_SYMBOLS:
if best_hyp_cost == None or self.costs[node_id] < best_hyp_cost:
best_hyp_leaf_node_id = node_id
best_hyp_cost = self.costs[node_id]
if best_hyp_leaf_node_id:
best_hyp_leaf_node = self.graph.get_node(best_hyp_leaf_node_id)
current_node = best_hyp_leaf_node
while current_node != []:
current_node.attr['style'] = 'filled'
current_node.attr['fillcolor'] = best_hyp_bg_color
try:
current_node = self.graph.predecessors(current_node)[0]
except IndexError:
break
def wordify(self, word_dict):
"""
Replace node labels (usually integers) with words, subwords, or
characters.
"""
for node_id, label in self.nodes.iteritems():
self.nodes[node_id] = word_dict[label]
def save_png(self, filepath, detailed=False, highlight_best=False):
"""
Renders the graph as PNG image.
@param filepath the taget file
@param detailed whether to include word probabilities and
hypothesis costs.
@param highlight_best whether to highlight the best hypothesis.
"""
costs = True if detailed else False
word_probs = True if detailed else False
self._render(costs=costs, word_probs=word_probs, highlight_best=highlight_best)
self.graph.draw(filepath, prog="dot")