diff --git a/kaldi_active_grammar/compiler.py b/kaldi_active_grammar/compiler.py index 25f4d32..db1531e 100644 --- a/kaldi_active_grammar/compiler.py +++ b/kaldi_active_grammar/compiler.py @@ -338,7 +338,7 @@ def add_word(self, word, phones=None, lazy_compilation=False, allow_online_pronu def prepare_for_compilation(self): if self._lexicon_files_stale: self.model.generate_lexicon_files() - self.model.load_words() + self.model.load_words() # FIXME: This re-loading from the words.txt file may be unnecessary now that we have/use NativeWFST + SymbolTable, but it's not clear if it's safe to remove it. self.decoder.load_lexicon() if self._agf_compiler: # TODO: Just update the necessary files in the config diff --git a/kaldi_active_grammar/model.py b/kaldi_active_grammar/model.py index 8068a9f..34cae6f 100644 --- a/kaldi_active_grammar/model.py +++ b/kaldi_active_grammar/model.py @@ -252,13 +252,14 @@ def __init__(self, model_dir=None, tmp_dir=None, tmp_dir_needed=False): if self.fst_cache.cache_is_new or files_are_not_current(necessary_files + non_lazy_files): self.generate_lexicon_files() + self.words_table = SymbolTable() self.load_words() def load_words(self, words_file=None): if words_file is None: words_file = self.files_dict['words.txt'] _log.debug("loading words from %r", words_file) invalid_words = " !SIL #0 ".lower().split() - self.words_table = SymbolTable(words_file) + self.words_table.load_text_file(words_file) self.longest_word = max(self.words_table.word_to_id_map.keys(), key=len) return self.words_table @@ -281,7 +282,7 @@ def add_word(self, word, phones=None, lazy_compilation=False, allow_online_pronu word = word.strip().lower() if phones is None: - # Generate pronunciations, then call ourselves recursively + # Not given pronunciation(s), so generate pronunciation(s), then call ourselves recursively for each individual pronunciation pronunciations = Lexicon.generate_pronunciations(word, model_dir=self.model_dir, allow_online_pronunciations=allow_online_pronunciations) pronunciations = sum([ self.add_word(word, phones, lazy_compilation=True) diff --git a/kaldi_active_grammar/wfst.py b/kaldi_active_grammar/wfst.py index 83f277b..a7b7a2c 100644 --- a/kaldi_active_grammar/wfst.py +++ b/kaldi_active_grammar/wfst.py @@ -356,11 +356,20 @@ def compile_text(cls, fst_text, isymbols_filename, osymbols_filename): class SymbolTable(object): - def __init__(self, filename): + def __init__(self, filename=None): + self.word_to_id_map = dict() + self.id_to_word_map = dict() + self.max_term_word_id = -1 + if filename is not None: + self.load_text_file(filename) + + def load_text_file(self, filename): with open(filename, 'r', encoding='utf-8') as file: word_id_pairs = [line.strip().split() for line in file] - self.word_to_id_map = { word: int(id) for (word, id) in word_id_pairs } - self.id_to_word_map = { id: word for (word, id) in self.word_to_id_map.items() } + self.word_to_id_map.clear() + self.id_to_word_map.clear() + self.word_to_id_map.update({ word: int(id) for (word, id) in word_id_pairs }) + self.id_to_word_map.update({ id: word for (word, id) in self.word_to_id_map.items() }) self.max_term_word_id = max(id for (word, id) in self.word_to_id_map.items() if not word.startswith('#nonterm')) def add_word(self, word, id=None):