Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Fix Problem.feature_info.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 219247790
  • Loading branch information
afrozenator authored and Copybara-Service committed Oct 30, 2018
1 parent 66afb76 commit 091373c
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,23 +716,17 @@ def feature_info(self):
assert self._hparams is not None

hp = self.get_hparams()
input_mods = hp.modality["inputs"]
target_mod = hp.modality["targets"]
vocabs = hp.vocabulary
if self.has_inputs:
in_id = hp.input_space_id
out_id = hp.target_space_id

features = collections.defaultdict(FeatureInfo)
for feature_name, modality_cls in six.iteritems(hp.modality):
finfo = features[feature_name]
finfo.modality = modality_cls
finfo.vocab_size = modality_cls.top_dimensionality

for name, mod in six.iteritems(input_mods):
finfo = features[name]
finfo.modality = mod
finfo.vocab_size = mod.top_dimensionality

features["targets"].modality = target_mod
features["targets"].vocab_size = target_mod.top_dimensionality

vocabs = hp.vocabulary
for name, encoder in six.iteritems(vocabs):
features[name].encoder = encoder

Expand Down

0 comments on commit 091373c

Please sign in to comment.