Skip to content

Commit

Permalink
fix updating of SymbolTable multiple times for new words, so that the…
Browse files Browse the repository at this point in the history
…re is only one instance for a single model
  • Loading branch information
daanzu committed Nov 24, 2021
1 parent 7a10503 commit d324915
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kaldi_active_grammar/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def add_word(self, word, phones=None, lazy_compilation=False, allow_online_pronu
def prepare_for_compilation(self):
if self._lexicon_files_stale:
self.model.generate_lexicon_files()
self.model.load_words()
self.model.load_words() # FIXME: This re-loading from the words.txt file may be unnecessary now that we have/use NativeWFST + SymbolTable, but it's not clear if it's safe to remove it.
self.decoder.load_lexicon()
if self._agf_compiler:
# TODO: Just update the necessary files in the config
Expand Down
5 changes: 3 additions & 2 deletions kaldi_active_grammar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,14 @@ def __init__(self, model_dir=None, tmp_dir=None, tmp_dir_needed=False):
if self.fst_cache.cache_is_new or files_are_not_current(necessary_files + non_lazy_files):
self.generate_lexicon_files()

self.words_table = SymbolTable()
self.load_words()

def load_words(self, words_file=None):
if words_file is None: words_file = self.files_dict['words.txt']
_log.debug("loading words from %r", words_file)
invalid_words = "<eps> !SIL <UNK> #0 <s> </s>".lower().split()
self.words_table = SymbolTable(words_file)
self.words_table.load_text_file(words_file)
self.longest_word = max(self.words_table.word_to_id_map.keys(), key=len)
return self.words_table

Expand All @@ -281,7 +282,7 @@ def add_word(self, word, phones=None, lazy_compilation=False, allow_online_pronu
word = word.strip().lower()

if phones is None:
# Generate pronunciations, then call ourselves recursively
# Not given pronunciation(s), so generate pronunciation(s), then call ourselves recursively for each individual pronunciation
pronunciations = Lexicon.generate_pronunciations(word, model_dir=self.model_dir, allow_online_pronunciations=allow_online_pronunciations)
pronunciations = sum([
self.add_word(word, phones, lazy_compilation=True)
Expand Down
15 changes: 12 additions & 3 deletions kaldi_active_grammar/wfst.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,20 @@ def compile_text(cls, fst_text, isymbols_filename, osymbols_filename):

class SymbolTable(object):

def __init__(self, filename):
def __init__(self, filename=None):
self.word_to_id_map = dict()
self.id_to_word_map = dict()
self.max_term_word_id = -1
if filename is not None:
self.load_text_file(filename)

def load_text_file(self, filename):
with open(filename, 'r', encoding='utf-8') as file:
word_id_pairs = [line.strip().split() for line in file]
self.word_to_id_map = { word: int(id) for (word, id) in word_id_pairs }
self.id_to_word_map = { id: word for (word, id) in self.word_to_id_map.items() }
self.word_to_id_map.clear()
self.id_to_word_map.clear()
self.word_to_id_map.update({ word: int(id) for (word, id) in word_id_pairs })
self.id_to_word_map.update({ id: word for (word, id) in self.word_to_id_map.items() })
self.max_term_word_id = max(id for (word, id) in self.word_to_id_map.items() if not word.startswith('#nonterm'))

def add_word(self, word, id=None):
Expand Down

0 comments on commit d324915

Please sign in to comment.