From 45664353a0b0ff0fb5f69dc6512a3326dc6f2aa5 Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Thu, 8 Jun 2017 00:44:30 -0700 Subject: [PATCH] Initial commit --- .gitignore | 17 + LICENSE | 201 +++++++++++ NOTICE | 2 + README.md | 76 ++++ docs/Makefile | 20 ++ docs/README.md | 1 + docs/conf.py | 203 +++++++++++ docs/development.md | 72 ++++ docs/faq.md | 4 + docs/index.rst | 31 ++ docs/make.bat | 36 ++ docs/modules.rst | 150 ++++++++ docs/user_documentation.md | 137 +++++++ pre-commit.sh | 42 +++ pytest.ini | 2 + requirements.txt | 2 + setup.cfg | 11 + setup.py | 78 ++++ sockeye.pylintrc | 408 +++++++++++++++++++++ sockeye/__init__.py | 14 + sockeye/arguments.py | 371 +++++++++++++++++++ sockeye/attention.py | 593 +++++++++++++++++++++++++++++++ sockeye/average.py | 196 +++++++++++ sockeye/bleu.py | 107 ++++++ sockeye/callback.py | 272 ++++++++++++++ sockeye/checkpoint_decoder.py | 104 ++++++ sockeye/constants.py | 94 +++++ sockeye/coverage.py | 294 ++++++++++++++++ sockeye/data_io.py | 467 ++++++++++++++++++++++++ sockeye/decoder.py | 452 ++++++++++++++++++++++++ sockeye/embeddings.py | 116 ++++++ sockeye/encoder.py | 413 ++++++++++++++++++++++ sockeye/inference.py | 646 ++++++++++++++++++++++++++++++++++ sockeye/initializer.py | 99 ++++++ sockeye/lexicon.py | 159 +++++++++ sockeye/log.py | 119 +++++++ sockeye/loss.py | 148 ++++++++ sockeye/lr_scheduler.py | 166 +++++++++ sockeye/model.py | 183 ++++++++++ sockeye/output_handler.py | 142 ++++++++ sockeye/rnn.py | 57 +++ sockeye/train.py | 218 ++++++++++++ sockeye/training.py | 311 ++++++++++++++++ sockeye/translate.py | 97 +++++ sockeye/utils.py | 335 ++++++++++++++++++ sockeye/vocab.py | 147 ++++++++ test/__init__.py | 13 + test/test_attention.py | 191 ++++++++++ test/test_bleu.py | 26 ++ test/test_callback.py | 69 ++++ test/test_coverage.py | 138 ++++++++ test/test_data_io.py | 72 ++++ test/test_decoder.py | 97 +++++ test/test_loss.py | 135 +++++++ test/test_lr_scheduler.py | 28 ++ test/test_output_handler.py | 53 +++ test/test_utils.py | 68 ++++ test/test_vocab.py | 54 +++ typechecked-files | 5 + 59 files changed, 8762 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 NOTICE create mode 100644 README.md create mode 100644 docs/Makefile create mode 120000 docs/README.md create mode 100644 docs/conf.py create mode 100644 docs/development.md create mode 100644 docs/faq.md create mode 100644 docs/index.rst create mode 100644 docs/make.bat create mode 100644 docs/modules.rst create mode 100644 docs/user_documentation.md create mode 100755 pre-commit.sh create mode 100644 pytest.ini create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 sockeye.pylintrc create mode 100644 sockeye/__init__.py create mode 100644 sockeye/arguments.py create mode 100644 sockeye/attention.py create mode 100644 sockeye/average.py create mode 100644 sockeye/bleu.py create mode 100644 sockeye/callback.py create mode 100644 sockeye/checkpoint_decoder.py create mode 100644 sockeye/constants.py create mode 100644 sockeye/coverage.py create mode 100644 sockeye/data_io.py create mode 100644 sockeye/decoder.py create mode 100644 sockeye/embeddings.py create mode 100644 sockeye/encoder.py create mode 100644 sockeye/inference.py create mode 100644 sockeye/initializer.py create mode 100644 sockeye/lexicon.py create mode 100644 sockeye/log.py create mode 100644 sockeye/loss.py create mode 100644 sockeye/lr_scheduler.py create mode 100644 sockeye/model.py create mode 100644 sockeye/output_handler.py create mode 100644 sockeye/rnn.py create mode 100644 sockeye/train.py create mode 100644 sockeye/training.py create mode 100644 sockeye/translate.py create mode 100644 sockeye/utils.py create mode 100644 sockeye/vocab.py create mode 100644 test/__init__.py create mode 100644 test/test_attention.py create mode 100644 test/test_bleu.py create mode 100644 test/test_callback.py create mode 100644 test/test_coverage.py create mode 100644 test/test_data_io.py create mode 100644 test/test_decoder.py create mode 100644 test/test_loss.py create mode 100644 test/test_lr_scheduler.py create mode 100644 test/test_output_handler.py create mode 100644 test/test_utils.py create mode 100644 test/test_vocab.py create mode 100644 typechecked-files diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..e329cbb4d --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +*.pyc +*.pyo +*.class +*~ +*# +/docs/generated/* +/docs/_build +/runpy +/build +.coverage* +.idea** +.history** +.cache** +.eggs** +*.egg** +.*.swp +.mypy_cache diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..8dada3eda --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..9aef69705 --- /dev/null +++ b/NOTICE @@ -0,0 +1,2 @@ +Sockeye +Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..422e68ce4 --- /dev/null +++ b/README.md @@ -0,0 +1,76 @@ +# Sockeye + +This package contains the Sockeye project, +a sequence-to-sequence framework for Neural Machine Translation based on MXNet. +It implements the well-known encoder-decoder architecture with attention. + +If you are interested in collaborating or have any questions, please submit a pull request or issue. +You can also send questions to *sockeye-dev-at-amazon-dot-com*. + +## Dependencies + +Sockeye requires: +- **Python3** +- [MXNet-0.10.0](https://github.com/dmlc/mxnet/tree/v0.10.0) +- numpy + +Install them with: +```bash +> pip install -r requirements.txt +``` + +Optionally, dmlc's tensorboard fork is supported to track learning curves (````pip install tensorboard````). + +Full dependencies are listed in requirements.txt. + +## Installation + +If you want to just use sockeye without extending it, simply install it via +```bash +> python setup.py install +``` +after cloning the repository from git. After installation, command line tools such as +*sockeye-train, sockeye-translate, sockeye-average* +and *sockeye-embeddings* are available. Alternatively, if the sockeye directory is on your +PYTHONPATH you can run the modules +directly. For example *sockeye-train* can also be invoked as +```bash +> python -m sockeye.train +``` + +## First Steps + +### Train + +In order to train your first Neural Machine Translation model you will need two sets of parallel files: one for training +and one for validation. The latter will be used for computing various metrics during training. +Each set should consist of two files: one with source sentences and one with target sentences (translations). Both files should have the same number of lines, each line containing a single +sentence. Each sentence should be a whitespace delimited list of tokens. + +Say you wanted to train a German to English translation model, then you would call sockeye like this: +```bash +> python -m sockeye.train --source sentences.de \ + --target sentences.en \ + --validation-source sentences.dev.de \ + --validation-target sentences.dev.en \ + --use-cpu \ + --output +``` + +After training the directory ** will contain all model artifacts such as parameters and model +configuration. + + +### Translate + +Input data for translation should be in the same format as the training data (tokenization, preprocessing scheme). +You can translate as follows: + +```bash +> python -m sockeye.translate --models --use-cpu +``` + +This will take the best set of parameters found during training and then translate strings from STDIN and +write translations to STDOUT. + +For more detailed examples check out our user documentation. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..559b1aa81 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = sockeye +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 120000 index 000000000..32d46ee88 --- /dev/null +++ b/docs/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 000000000..481d4c06f --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# sockeye documentation build configuration file, created by +# sphinx-quickstart on Wed May 17 15:38:17 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import re +import sys +sys.path.insert(0, os.path.abspath('..')) + +ROOT = os.path.dirname(__file__) + +def get_version(): + VERSION_RE = re.compile(r'''__version__ = ['"]([0-9.]+)['"]''') + init = open(os.path.join(ROOT, '../sockeye', '__init__.py')).read() + return VERSION_RE.search(init).group(1) + + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ['sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.viewcode', + 'sphinx.ext.githubpages', + 'sphinx_autodoc_typehints', + 'sphinx.ext.imgmath'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +from recommonmark.parser import CommonMarkParser + +source_parsers = { + '.md': CommonMarkParser, +} + +source_suffix = ['.rst', '.md'] + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'sockeye' +copyright = 'Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.' +author = 'Amazon' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = get_version() +# The full version, including alpha/beta/rc tags. +release = get_version() + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + 'collapse_navigation': False, + 'display_version': True, + 'navigation_depth': 2, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +#html_static_path = ['_static'] + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'sockeye_doc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'sockeye.tex', 'Sockeye Documentation', + 'amazon', 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'sockeye', 'Sockeye Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'sockeye', 'Sockeye Documentation', + author, 'sockeye', 'Sequence-to-Sequence modeling with MXNet', + 'Miscellaneous'), +] + + + +# -- Options for Epub output ---------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = project +epub_author = author +epub_publisher = author +epub_copyright = copyright + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +# +# epub_identifier = '' + +# A unique identification for the text. +# +# epub_uid = '' + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = {'https://docs.python.org/': None} diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 000000000..367775afe --- /dev/null +++ b/docs/development.md @@ -0,0 +1,72 @@ +# Developer Documentation + +## Requirements + +The following packages are required for developers (also see the requirements.txt): + - pytest + - pytest-cov + - sphinx>=1.4 + - sphinx_rtd_theme + - sphinx-autodoc-typehints + - recommonmark + +Install them via +```bash +> pip install -e '.[dev]' +``` + + +## Developer Guidelines + +We welcome contributions to sockeye in form of pull requests on Github. +If you want to develop sockeye, please adhere to the following development guidelines. + + + * Write Python 3.5, PEP8 compatible code. + + * Functions should be documented with Sphinx-style docstrings and + should include type hints for static code analyzers. + + ```python + def foo(bar: ) -> : + """ + . + + :param bar: . + :return: . + """ + ``` + + * When using MXNet operators, preceding symbolic statements + in the code with the resulting, expected shape of the tensor greatly improves readability of the code: + ```python + # (batch_size, num_hidden) + data = mx.sym.Variable('data') + # (batch_size * num_hidden,) + data = mx.sym.reshape(data=data, shape=(-1)) + ``` + + * The desired line length of Python modules should not exceed 120 characters. + + * When writing symbol-generating classes (such as encoders/decoders), initialize variables in the constructor of the + class and re-use them in the class methods. + + * Make sure to pass unit tests before submitting a pull request. + + * Whenever reasonable, write py.test unit tests covering your contribution. + + +## Building the Documentation +Full documentation, including a code reference, can be generated using Sphinx with the following command: +```bash +> python setup.py docs +``` +The results are written to ```docs/_build/html/index.html```. + + +## Unit tests +Unit tests are written using py.test. +They can be run like this: +```bash +> python setup.py test +``` \ No newline at end of file diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 000000000..af59d33ef --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,4 @@ +# Frequently Asked Questions + +### What does Sockeye mean? +Sockeye is a salmon found in the Northern Pacific Ocean. diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 000000000..0d3a54fe6 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,31 @@ +Sockeye Documentation +===================== + +This is the documentation for sockeye, +a framework for sequence-to-sequence modeling with `MXNet `_. +To get started, please read through the :doc:`README `. + +The individual modules and functions of the project are documented under :doc:`Python Modules `. + +For Contributors +---------------- +If you want to contribute or develop for sockeye, please see the :doc:`Developer Guide `. + + +Table of Contents +----------------- + +.. toctree:: + :maxdepth: 4 + + README + user_documentation + development + faq + modules + +Resources +--------- + +* :ref:`genindex` +* :ref:`modindex` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..18082c357 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=sockeye + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/modules.rst b/docs/modules.rst new file mode 100644 index 000000000..460f028ee --- /dev/null +++ b/docs/modules.rst @@ -0,0 +1,150 @@ +Python Modules +============== + +sockeye.attention module +------------------------ + +.. automodule:: sockeye.attention + :members: + :show-inheritance: + +sockeye.average module +---------------------- + +.. automodule:: sockeye.average + :members: + :show-inheritance: + +sockeye.bleu module +------------------- + +.. automodule:: sockeye.bleu + :members: corpus_bleu + :show-inheritance: + +sockeye.callback module +----------------------- + +.. automodule:: sockeye.callback + :members: + :show-inheritance: + +sockeye.checkpoint_decoder module +--------------------------------- + +.. automodule:: sockeye.checkpoint_decoder + :members: + :show-inheritance: + +sockeye.coverage module +----------------------- + +.. automodule:: sockeye.coverage + :members: + :show-inheritance: + +sockeye.data_io module +---------------------- + +.. automodule:: sockeye.data_io + :members: + :show-inheritance: + +sockeye.decoder module +---------------------- + +.. automodule:: sockeye.decoder + :members: + :show-inheritance: + +sockeye.embeddings module +------------------------- + +.. automodule:: sockeye.embeddings + :members: + :show-inheritance: + +sockeye.encoder module +---------------------- + +.. automodule:: sockeye.encoder + :members: + :show-inheritance: + +sockeye.inference module +------------------------ + +.. automodule:: sockeye.inference + :members: + :show-inheritance: + +sockeye.initializer module +-------------------------- + +.. automodule:: sockeye.initializer + :members: + :show-inheritance: + +sockeye.lexicon module +---------------------- + +.. automodule:: sockeye.lexicon + :members: + :show-inheritance: + +sockeye.loss module +------------------- + +.. automodule:: sockeye.loss + :members: + :show-inheritance: + +sockeye.lr_scheduler module +--------------------------- + +.. automodule:: sockeye.lr_scheduler + :members: + :show-inheritance: + +sockeye.model module +-------------------- + +.. automodule:: sockeye.model + :members: + :show-inheritance: + +sockeye.output_handler module +----------------------------- + +.. automodule:: sockeye.output_handler + :members: + :show-inheritance: + +sockeye.rnn module +------------------ + +.. automodule:: sockeye.rnn + :members: + :show-inheritance: + + +sockeye.training module +----------------------- + +.. automodule:: sockeye.training + :members: + :show-inheritance: + +sockeye.utils module +-------------------- + +.. automodule:: sockeye.utils + :members: + :show-inheritance: + +sockeye.vocab module +-------------------- + +.. automodule:: sockeye.vocab + :members: + :show-inheritance: diff --git a/docs/user_documentation.md b/docs/user_documentation.md new file mode 100644 index 000000000..f0af16756 --- /dev/null +++ b/docs/user_documentation.md @@ -0,0 +1,137 @@ +# User documentation + +## Training + + +Training is carried out by the `sockeye.train` module. Basic usage is given by + +```bash +> python -m sockeye.train +usage: train.py [-h] --source SOURCE --target TARGET --validation-source + VALIDATION_SOURCE --validation-target VALIDATION_TARGET + --output OUTPUT [...] +``` + +Training requires 5 arguments: +* `--source`, `--target`: give the training data files. Gzipped files are supported, provided that their filenames end with .gz. +* `--validation-source`, `--validation-target`: give the validation data files, gzip supported as above. +* `--output`: gives the output directory where the intermediate and final results will be written to. +Intermediate directories will be created if needed. +Logging will be written to `/log` as well as being echoed on the console. + +For a complete list of supported options use the `--help` option. + +### Data format + +All input files files should be UTF-8 encoded, tokenized with standard whitespaces. +Each line should contain a single sentence and the source and target files should have the same number of lines. +Vocabularies will automatically be created from the training data and vocabulary +coverage on the validation set during initialization will be reported. + +### Checkpointing and early-stopping + +Training is governed by the concept of "checkpoints", rather than epochs. You +can specify the checkpoint frequency in terms of updates/batches with +`--checkpoint-frequency`. Training performs early-stopping to prevent +overfitting, i.e. training is stopped once a defined evaluation metric computed +on the held-out validation data does not improve for a number of checkpoints +given by the parameter `--max-num-checkpoint-not-improved`. You can specify a +maximum number of updates/batches using `--max-updates`. + +Perplexity is the default metric to be considered for early-stopping, but you +can also choose to optimize accuracy or BLEU using the `--optimized-metric` +argument. In case of optimizing with respect to BLEU, you will need to specify +`--monitor-bleu`. For efficiency reasons, sockeye spawns a sub-processes after each +checkpoint to decode the validation data and compute BLEU. This may introduce +some delay in the reporting of results, i.e. there may be checkpoints with no +BLEU results reported or with results corresponding to older checkpoints. This +is expected behaviour and sockeye internally keeps track of the results in the +correct order. + +Note that evaluation metrics for training data and held-out validation data are written in a +tab-separated file called `metrics`. + +### Monitoring training progress with tensorboard + +Sockeye can write all evaluation metrics in a tensorboard compatible format. +This way you can monitor the training progress in the browser. +If you have not yet installed dmlc's tensorboard fork do so as follows: +```bash +> pip install tensorboard +``` + +Now when training specify the additional command line parameter `--use-tensorboard` to `sockeye.train`. +Then start tensorboard and point it to the model directory (or any parent directory): +```bash +> tensorboard --logdir model_dir +``` + +### CPU/GPU training + +By default, training is carried out on the first GPU device of your machine. +You can specify alternative devices with the `--device-ids` option, with +which you can also activate multi-GPU training (see below). If +`--device-ids -1`, sockeye will try to find a free GPU on your machine and block +until one is available. The locking mechanism is based on files and therefore assumes all processes are running +on the same machine with the same file system. +If this is not the case there is a chance that two processes will be using the same GPU and you run out of GPU memory. +If you do not have or do not want to use a GPU, specify `--use-cpu`. +In this case a drop in performance is expected. + +Training can be carried out on multiple GPUs using the `--device-ids` flag and specifying multiple GPU device ids: +`--device-ids 0 1 2 3`. +This will train using [Data Parallelism](https://github.com/dmlc/mxnet/blob/master/docs/how_to/multi_devices.md). +MXNet will divide the data in each batch and send it to the different devices. +Note that you should increase the batch size, for k GPUs use ``--batch-size k*``. +Also note that this will likely linearly increase your throughput in terms of sentences/second, but not necessarily +increase the model's convergence speed. + + +### Checkpoint averaging + +A common technique for improving model performance is to average the weights for the last checkpoints. +This can be done as follows: +```bash +> python -m sockeye.average -o /model.best.avg.params +``` + +## Translation + +Translating is handled by the `sockeye.translate` module: +```bash +> python -m sockeye.translate +``` + +The only required argument is `--models`, which should point to an `` +folder of trained models. By default, sockeye chooses the parameters from the +best checkpoint and uses these for translation. You can specify parameters +from a specific checkpoint by using `--checkpoints X`. + +You can control the size of the beam using `--beam-size` and the maximum input +length by `--max-input-length`. Sentences that are longer than +`max-input-length` are stripped. + +Input is read from the standard input and the output is written to the standard +output. The CLI will log translation speed once the input is consumed. Like in +the training module, the first GPU device is used by default. Note however that +multi-GPU translation is not currently supported. For CPU decoding use +`--use-cpu`. + +Use the `--help` option to see a full list of options for translation. + +### Ensemble Decoding +Sockeye supports ensemble decoding by specifying multiple model directories and +multiple checkpoints. The given lists must have the same length, such that the +first given checkpoint will be taken from the first model directory, the second +specified checkpoint from the second directory, etc. +```bash +> python -m sockeye.translate --models [ ] --checkpoints [ ] +``` + +### Visualization +The default mode of the translate CLI is to output translations to STDOUT. You +can also print out an ASCII matrix of the alignments using `--output-type +align_text`, or save the alignment matrix as a PNG plot using `--output-type +align_plot`. The PNG files will be written to files beginning with the prefix +given by the `--align-plot-prefix` option, one for each input sentence, indexed +by the sentence id. \ No newline at end of file diff --git a/pre-commit.sh b/pre-commit.sh new file mode 100755 index 000000000..c0c668881 --- /dev/null +++ b/pre-commit.sh @@ -0,0 +1,42 @@ +#!/bin/sh +# +# A pre-commit script that will run before every commit to Sockeye. +# This script contains the same tests that will evenutally run in CI. +# Install by running ln -s ../../pre-commit.sh .git/hooks/pre-commit +# You can remove these checks at any time by running rm .git/hooks/pre-commit +# You can commit bypassing these changes by running git commit --no-verify + +# Stash all non-commited files +STASH_NAME="pre-commit-$(date +%s)" +git stash save -q --keep-index $STASH_NAME + +# Run unit tests +python3 setup.py test +TEST_RESULT=$? + +# Run pylint on the sockeye package, failing on any reported errors. +pylint --rcfile=sockeye.pylintrc sockeye -E +SOCKEYE_LINT_RESULT=$? + +# Run pylint on test package, failing on any reported errors. +pylint --rcfile=sockeye.pylintrc test -E +TESTS_LINT_RESULT=$? + +# Run mypy, we are currently limiting to modules that pass +# Please feel free to fix mypy issues in other modules and add them to typechecked-files +mypy --ignore-missing-imports @typechecked-files +MYPY_RESULT=$? + +# Pop our stashed files +STASHES=$(git stash list) +if [[ $STASHES == "$STASH_NAME" ]]; then + git stash pop -q +fi + +[ $TEST_RESULT -ne 0 ] && echo 'Unit tests failed' && exit 1 +[ $SOCKEYE_LINT_RESULT -ne 0 ] && echo 'pylint found errors in the sockeye package' && exit 1 +[ $TESTS_LINT_RESULT -ne 0 ] && echo 'pylint found errors in the test package' && exit 1 +[ $MYPY_RESULT -ne 0 ] && echo 'mypy found incorrect type usage' && exit 1 + +echo 'all pre-commit checks passed' +exit 0 \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..d05332148 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --cov sockeye test -v diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..32943f591 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +mxnet==0.10.0 +numpy>=1.12 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..7134f3762 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,11 @@ +[aliases] +test=pytest +tests=pytest +doc=build_sphinx +docs=build_sphinx + + +[build_sphinx] +source-dir = docs +build-dir = docs/_build +all_files = 1 diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..b2f753217 --- /dev/null +++ b/setup.py @@ -0,0 +1,78 @@ +import os +import re +import logging +from setuptools import setup, find_packages + +ROOT = os.path.dirname(__file__) + + +def get_long_description(): + with open(os.path.join(ROOT, 'README.md'), encoding='utf-8') as f: + return f.read() + + +def get_version(): + VERSION_RE = re.compile(r'''__version__ = ['"]([0-9.]+)['"]''') + init = open(os.path.join(ROOT, 'sockeye', '__init__.py')).read() + return VERSION_RE.search(init).group(1) + + +def get_requirements(): + with open(os.path.join(ROOT, 'requirements.txt')) as f: + return [line.rstrip() for line in f] + + +try: + from sphinx.setup_command import BuildDoc + cmdclass = {'build_sphinx': BuildDoc} +except: + logging.warning("Package 'sphinx' not found. You will not be able to build docs.") + cmdclass = {} + +args = dict( + name='sockeye', + + version=get_version(), + + description='Sequence-to-Sequence framework for Neural Machine Translation', + long_description=get_long_description(), + + url='https://github.com/awslabs/sockeye', + + author='Amazon', + author_email='sockeye-dev@amazon.com', + maintainer_email='sockeye-dev@amazon.com', + + license='Apache License 2.0', + + packages=find_packages(exclude=("test",)), + + setup_requires=['pytest-runner'], + tests_require=['pytest', 'pytest-cov'], + + extras_require={ + 'optional': ['tensorboard'], + 'dev': [ + 'sphinx>=1.4', + 'sphinx_rtd_theme', + 'sphinx-autodoc-typehints', + 'recommonmark' + ] + }, + + install_requires=get_requirements(), + + entry_points={ + 'console_scripts': [ + 'sockeye-train = sockeye.train:main', + 'sockeye-translate = sockeye.translate:main', + 'sockeye-average = sockeye.average:main', + 'sockeye-embeddings = sockeye.embeddings:main' + ], + }, + + cmdclass=cmdclass, + +) + +setup(**args) diff --git a/sockeye.pylintrc b/sockeye.pylintrc new file mode 100644 index 000000000..d4c419405 --- /dev/null +++ b/sockeye.pylintrc @@ -0,0 +1,408 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + +# Allow optimization of some AST trees. This will activate a peephole AST +# optimizer, which will apply various small optimizations. For instance, it can +# be used to obtain the result of joining multiple strings with the addition +# operator. Joining a lot of strings can lead to a maximum recursion error in +# Pylint and this flag can prevent that. It has one side effect, the resulting +# AST will be different than the one from reality. This option is deprecated +# and it will be removed in Pylint 2.0. +optimize-ast=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,import-star-module-level,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,long-suffix,old-ne-operator,old-octal-literal,suppressed-message,useless-suppression,bad-whitespace,too-many-instance-attributes,too-many-locals,line-too-long,bad-continuation,missing-docstring,too-few-public-methods + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". This option is deprecated +# and it will be removed in Pylint 2.0. +files-output=no + +# Tells whether to display a full report or only the messages +reports=yes + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=foo,bar,baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +#no-docstring-rgx=^_ +no-docstring-rgx=. + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + + +[ELIF] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=mxnet,mxnet.*,numpy,numpy.* + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,future.builtins + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=500 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of boolean expressions in a if statement +max-bool-expr=5 + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=optparse + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/sockeye/__init__.py b/sockeye/__init__.py new file mode 100644 index 000000000..4ecf31617 --- /dev/null +++ b/sockeye/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +__version__ = '1.0.0' diff --git a/sockeye/arguments.py b/sockeye/arguments.py new file mode 100644 index 000000000..da426e348 --- /dev/null +++ b/sockeye/arguments.py @@ -0,0 +1,371 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Defines commandline arguments for the main CLIs with reasonable defaults. +""" +import argparse +from typing import Callable + +import sockeye.constants as C + + +def int_greater_or_equal(threshold: int) -> Callable: + """ + Returns a method that can be used in argument parsing to check that the argument is greater or equal to `threshold`. + + :param threshold: The threshold that we assume the cli argument value is greater or equal to. + :return: A method that can be used as a type in argparse. + """ + + def check_greater_equal(value_to_check): + value_to_check = int(value_to_check) + if value_to_check < threshold: + raise argparse.ArgumentTypeError("must be greater or equal to %d." % threshold) + return value_to_check + + return check_greater_equal + + +def add_io_args(params): + data_params = params.add_argument_group("Data & I/O") + + data_params.add_argument('--source', '-s', + required=True, + help='Source side of parallel training data.') + data_params.add_argument('--target', '-t', + required=True, + help='Target side of parallel training data.') + + data_params.add_argument('--validation-source', '-vs', + required=True, + help='Source side of validation data.') + data_params.add_argument('--validation-target', '-vt', + required=True, + help='Target side of validation data.') + + data_params.add_argument('--output', '-o', + required=True, + help='Folder where model & training results are written to.') + + data_params.add_argument('--source-vocab', + required=False, + default=None, + help='Existing source vocabulary (JSON)') + data_params.add_argument('--target-vocab', + required=False, + default=None, + help='Existing target vocabulary (JSON)') + + data_params.add_argument('--use-tensorboard', + action='store_true', + help='Track metrics through tensorboard. Requires installed tensorboard.') + + data_params.add_argument('--quiet', '-q', + default=False, + action="store_true", + help='Suppress console logging.') + return params + + +def add_device_args(params): + device_params = params.add_argument_group("Device parameters") + + device_params.add_argument('--device-ids', default=[-1], + help='List of GPU ids to use. Default: %(default)s. ' + 'Use -1 to automatically acquire a GPU through a file locking mechanism. ' + '(Note that this assumes GPU processes are using automatic sockeye GPU ids).', + nargs='+', type=int) + device_params.add_argument('--use-cpu', + action='store_true', + help='Use CPU device instead of GPU.') + return params + + +def add_model_parameters(params): + model_params = params.add_argument_group("ModelConfig") + + model_params.add_argument('--params', '-p', + type=str, + default=None, + help='Initialize model parameters from file. Overrides random initializations.') + + model_params.add_argument('--num-words', + type=int_greater_or_equal(0), + default=50000, + help='Maximum vocabulary size. Default: %(default)s.') + model_params.add_argument('--word-min-count', + type=int_greater_or_equal(1), + default=1, + help='Minimum frequency of words to be included in vocabularies. Default: %(default)s.') + + model_params.add_argument('--rnn-num-layers', + type=int_greater_or_equal(1), + default=1, + help='Number of layers for encoder and decoder. Default: %(default)s.') + model_params.add_argument('--rnn-cell-type', + choices=[C.LSTM_TYPE, C.GRU_TYPE], + default=C.LSTM_TYPE, + help='RNN cell type for encoder and decoder. Default: %(default)s.') + model_params.add_argument('--rnn-num-hidden', + type=int_greater_or_equal(1), + default=1024, + help='Number of RNN hidden units for encoder and decoder. Default: %(default)s.') + model_params.add_argument('--rnn-residual-connections', + action="store_true", + default=False, + help="Add residual connections to stacked RNNs if --rnn-num-layers > 3. " + "(see Wu ETAL'16). Default: %(default)s.") + + model_params.add_argument('--num-embed', + type=int_greater_or_equal(1), + default=512, + help='Embedding size for source and target tokens. Default: %(default)s.') + model_params.add_argument('--num-embed-source', + type=int_greater_or_equal(1), + default=None, + help='Embedding size for source tokens. Overrides --num-embed. Default: %(default)s') + model_params.add_argument('--num-embed-target', + type=int_greater_or_equal(1), + default=None, + help='Embedding size for target tokens. Overrides --num-embed. Default: %(default)s') + + model_params.add_argument('--attention-type', + choices=["bilinear", "dot", "fixed", "location", "mlp", "coverage"], + default="mlp", + help='Attention model. Choices: {%(choices)s}. ' + 'Default: %(default)s.') + model_params.add_argument('--attention-num-hidden', + default=None, + type=int, + help='Number of hidden units for attention layers. Default: equal to --rnn-num-hidden.') + + model_params.add_argument('--attention-coverage-type', + choices=["tanh", "sigmoid", "relu", "softrelu", "gru", "count"], + default="count", + help="Type of model for updating coverage vectors. 'count' refers to an update method" + "that accumulates attention scores. 'tanh', 'sigmoid', 'relu', 'softrelu' " + "use non-linear layers with the respective activation type, and 'gru' uses a" + "GRU to update the coverage vectors. Default: %(default)s.") + model_params.add_argument('--attention-coverage-num-hidden', + type=int, + default=1, + help="Number of hidden units for coverage vectors. Default: %(default)s") + + model_params.add_argument('--lexical-bias', + default=None, + type=str, + help="Specify probabilistic lexicon for lexical biasing (Arthur ETAL'16). " + "Set smoothing value epsilon by appending :") + model_params.add_argument('--learn-lexical-bias', + action='store_true', + help='Adjust lexicon probabilities during training. Default: %(default)s') + + model_params.add_argument('--weight-tying', + action='store_true', + help='Share target embedding and output layer parameter matrix. Default: %(default)s.') + + model_params.add_argument('--max-seq-len', + type=int_greater_or_equal(1), + default=100, + help='Maximum sequence length in tokens. Default: %(default)s') + + model_params.add_argument('--attention-use-prev-word', action="store_true", + help="Feed the previous target embedding into the attention mechanism.") + + model_params.add_argument('--context-gating', action="store_true", + help="Enables a context gate which adaptively weighs the decoder input against the" + "source context vector before each update of the decoder hidden state.") + + return params + + +def add_training_args(params): + train_params = params.add_argument_group("Training parameters") + + train_params.add_argument('--batch-size', '-b', + type=int_greater_or_equal(1), + default=64, + help='Mini-batch size. Default: %(default)s.') + train_params.add_argument('--fill-up', + type=str, + default='replicate', + help=argparse.SUPPRESS) + train_params.add_argument('--no-bucketing', + action='store_true', + help='Disable bucketing: always unroll to the max_len.') + train_params.add_argument('--bucket-width', + type=int_greater_or_equal(1), + default=10, + help='Width of buckets in tokens. Default: %(default)s.') + + train_params.add_argument('--loss', + default=C.CROSS_ENTROPY, + choices=[C.CROSS_ENTROPY, C.SMOOTHED_CROSS_ENTROPY], + help='Loss to optimize. Default: %(default)s.') + train_params.add_argument('--smoothed-cross-entropy-alpha', + default=0.3, + type=float, + help='Smoothing value for smoothed-cross-entropy loss. Default: %(default)s.') + train_params.add_argument('--normalize-loss', + default=False, + action="store_true", + help='Normalize the loss by dividing by the number of non-PAD tokens.') + + train_params.add_argument('--metrics', + nargs='+', + default=[C.PERPLEXITY], + choices=[C.PERPLEXITY, C.ACCURACY], + help='Names of metrics to track on training and validation data. Default: %(default)s.') + train_params.add_argument('--optimized-metric', + default='perplexity', + choices=[C.PERPLEXITY, C.ACCURACY, C.BLEU], + help='Metric to optimize with early stopping {%(choices)s}. ' + 'Default: %(default)s.') + + train_params.add_argument('--max-updates', + type=int, + default=-1, + help='Maximum number of updates/batches to process. -1 for infinite. ' + 'Default: %(default)s.') + train_params.add_argument('--checkpoint-frequency', + type=int_greater_or_equal(1), + default=1000, + help='Checkpoint and evaluate every x updates/batches. Default: %(default)s.') + train_params.add_argument('--max-num-checkpoint-not-improved', + type=int, + default=8, + help='Maximum number of checkpoints the model is allowed to not improve in ' + ' on validation data before training is stopped. ' + 'Default: %(default)s') + + train_params.add_argument('--dropout', + type=float, + default=0., + help='Dropout probability for source embedding and source and target RNNs. ' + 'Default: %(default)s.') + + train_params.add_argument('--optimizer', + default='adam', + choices=['adam', 'sgd', 'rmsprop'], + help='SGD update rule. Default: %(default)s.') + train_params.add_argument('--initial-learning-rate', + type=float, + default=0.0003, + help='Initial learning rate. Default: %(default)s.') + train_params.add_argument('--weight-decay', + type=float, + default=0.0, + help='Weight decay constant. Default: %(default)s.') + train_params.add_argument('--momentum', + type=float, + default=None, + help='Momentum constant. Default: %(default)s.') + train_params.add_argument('--clip-gradient', + type=float, + default=1.0, + help='Clip absolute gradients values greater than this value. ' + 'Set to negative to disable. Default: %(default)s.') + + train_params.add_argument('--learning-rate-scheduler-type', + default='plateau-reduce', + choices=["fixed-rate-inv-sqrt-t", "fixed-rate-inv-t", "plateau-reduce"], + help='Learning rate scheduler type. Default: %(default)s.') + train_params.add_argument('--learning-rate-reduce-factor', + type=float, + default=0.5, + help="Factor to multiply learning rate with " + "(for 'plateau-reduce' learning rate scheduler). Default: %(default)s.") + train_params.add_argument('--learning-rate-reduce-num-not-improved', + type=int, + default=3, + help="For 'plateau-reduce' learning rate scheduler. Adjust learning rate " + "if did not improve for x checkpoints. Default: %(default)s.") + train_params.add_argument('--learning-rate-half-life', + type=float, + default=10, + help="Half-life of learning rate in checkpoints. For 'fixed-rate-*' " + "learning rate schedulers. Default: 10.") + + train_params.add_argument('--use-fused-rnn', + default=False, + action="store_true", + help='Use FusedRNNCell in encoder (requires GPU device). Speeds up training.') + + train_params.add_argument('--rnn-forget-bias', + default=0.0, + type=float, + help='Initial value of RNN forget biases.') + train_params.add_argument('--rnn-h2h-init', type=str, default=C.RNN_INIT_ORTHOGONAL, + choices=[C.RNN_INIT_ORTHOGONAL, C.RNN_INIT_ORTHOGONAL_STACKED], + help="Initialization method for RNN parameters. Default: %(default)s.") + + train_params.add_argument('--monitor-bleu', + default=0, + type=int, + help='x>0: sample and decode x sentences from validation data and monitor BLEU score. ' + 'x==-1: use full validation data. Default: %(default)s.') + + train_params.add_argument('--seed', + type=int, + default=13, + help='Random seed. Default: %(default)s.') + return params + + +def add_inference_args(params): + decode_params = params.add_argument_group("Inference parameters") + decode_params.add_argument('--models', '-m', + required=True, + nargs='+', + help='Model folder(s). Use multiple for ensemble decoding. ' + 'Model determines config, best parameters and vocab files.') + decode_params.add_argument('--checkpoints', '-c', + default=None, + type=int, + nargs='+', + help='If not given, chooses best checkpoints for model(s). ' + 'If specified, must have the same length as --models and be integer') + + decode_params.add_argument('--beam-size', '-b', + type=int, + default=5, + help='Size of the beam. Default: %(default)s.') + decode_params.add_argument('--ensemble-mode', + type=str, + default='linear', + choices=['linear', 'log_linear'], + help='Ensemble mode: [linear, log-linear]. Default: %(default)s.') + decode_params.add_argument('--max-input-len', '-n', + type=int, + default=None, + help='Maximum sequence length. Default: value from model(s).') + decode_params.add_argument('--softmax-temperature', + type=float, + default=None, + help='Controls peakiness of model predictions. Values < 1.0 produce ' + 'peaked predictions, values > 1.0 produce smoothed distributions.') + + decode_params.add_argument('--output-type', + default='translation', + choices=["translation", "translation_with_alignments", "align_plot", "align_text"], + help='Output type. Choices: [translation, translation_with_alignments, ' + 'align_plot, align_text]. Default: %(default)s.') + decode_params.add_argument('--align-plot-prefix', + default="align", + help='Filename prefix for generated alignment visualization. Default: %(default)s') + decode_params.add_argument('--sure-align-threshold', + default=0.9, + type=float, + help='Threshold to consider a soft alignment a sure alignment. Default: %(default)s') + return params diff --git a/sockeye/attention.py b/sockeye/attention.py new file mode 100644 index 000000000..79523113c --- /dev/null +++ b/sockeye/attention.py @@ -0,0 +1,593 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Implementations of different attention mechanisms in sequence-to-sequence models. +""" +import logging +from typing import Callable, NamedTuple, Optional, Tuple + +import mxnet as mx + +import sockeye.coverage + +logger = logging.getLogger(__name__) + + +def get_attention(input_previous_word: bool, + attention_type: str, + attention_num_hidden: int, + rnn_num_hidden: int, + max_seq_len: int, + attention_coverage_type: str, + attention_coverage_num_hidden: int) -> 'Attention': + """ + Returns an Attention instance based on attention_type. + + :param input_previous_word: Feeds the previous target embedding into the attention mechanism. + :param attention_type: Attention name. + :param attention_num_hidden: Number of hidden units for attention networks. + :param rnn_num_hidden: Number of hidden units of encoder/decoder RNNs. + :param max_seq_len: Maximum length of source sequences. + :param attention_coverage_type: The type of update for the dynamic source encoding. + :param attention_coverage_num_hidden: Number of hidden units for coverage attention. + :return: Instance of Attention. + """ + if attention_type == "bilinear": + if input_previous_word: + logger.warning("bilinear attention does not support input_previous_word") + return BilinearAttention(rnn_num_hidden) + elif attention_type == "dot": + return DotAttention(input_previous_word, rnn_num_hidden, attention_num_hidden) + elif attention_type == "fixed": + return EncoderLastStateAttention(input_previous_word) + elif attention_type == "location": + return LocationAttention(input_previous_word, max_seq_len) + elif attention_type == "mlp": + return MlpAttention(input_previous_word=input_previous_word, + attention_num_hidden=attention_num_hidden) + elif attention_type == "coverage": + if attention_coverage_type == 'count' and attention_coverage_num_hidden != 1: + logging.warning("Ignoring coverage_num_hidden=%d and setting to 1" % attention_coverage_num_hidden) + attention_coverage_num_hidden = 1 + return MlpAttention(input_previous_word=input_previous_word, + attention_num_hidden=attention_num_hidden, + attention_coverage_type=attention_coverage_type, + attention_coverage_num_hidden=attention_coverage_num_hidden) + else: + raise ValueError("Unknown attention type %s" % attention_type) + + +AttentionInput = NamedTuple('AttentionInput', [('seq_idx', int), ('query', mx.sym.Symbol)]) +""" +Input to attention callables. + +:param seq_idx: Decoder time step / sequence index. +:param query: Query input to attention mechanism, e.g. decoder hidden state (plus previous word). +""" + +AttentionState = NamedTuple('AttentionState', [ + ('context', mx.sym.Symbol), + ('probs', mx.sym.Symbol), + ('dynamic_source', mx.sym.Symbol), +]) +""" +Results returned from attention callables. + +:param context: Context vector (Bahdanau et al, 15). Shape: (batch_size, encoder_num_hidden) +:param probs: Attention distribution over source encoder states. Shape: (batch_size, source_seq_len). +:param dynamic_source: Dynamically updated source encoding. + Shape: (batch_size, source_seq_len, dynamic_source_num_hidden) +""" + + +class Attention(object): + """ + Generic attention interface that returns a callable for attending to source states. + + :param input_previous_word: Feed the previous target embedding into the attention mechanism. + :param dynamic_source_num_hidden: Number of hidden units of dynamic source encoding update mechanism. + """ + + def __init__(self, input_previous_word: bool, dynamic_source_num_hidden: int = 1) -> None: + self.dynamic_source_num_hidden = dynamic_source_num_hidden + self._input_previous_word = input_previous_word + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for recurrent attention in a sequence decoder. + The callable is a recurrent function of the form: + AttentionState = attend(AttentionInput, AttentionState). + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Attention callable. + """ + + def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: + """ + Returns updated attention state given attention input and current attention state. + + :param att_input: Attention input as returned by make_input(). + :param att_state: Current attention state + :return: Updated attention state. + """ + raise NotImplementedError() + + return attend + + def get_initial_state(self, source_length: mx.sym.Symbol, source_seq_len: int) -> AttentionState: + """ + Returns initial attention state. Dynamic source encoding is initialized with zeros. + + :param source_length: Source length. Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + """ + dynamic_source = mx.sym.expand_dims(mx.sym.expand_dims(mx.sym.zeros_like(source_length), axis=1), axis=2) + # dynamic_source: (batch_size, source_seq_len, num_hidden_dynamic_source) + dynamic_source = mx.sym.broadcast_to(dynamic_source, shape=(0, source_seq_len, self.dynamic_source_num_hidden)) + return AttentionState(context=None, probs=None, dynamic_source=dynamic_source) + + def make_input(self, + seq_idx: int, + word_vec_prev: mx.sym.Symbol, + decoder_state: mx.sym.Symbol) -> AttentionInput: + """ + Returns AttentionInput to be fed into the attend callable returned by the on() method. + + :param seq_idx: Decoder time step. + :param word_vec_prev: Embedding of previously predicted ord + :param decoder_state: Current decoder state + :return: Attention input. + """ + query = decoder_state + if self._input_previous_word: + # (batch_size, num_target_embed + rnn_num_hidden) + query = mx.sym.concat(word_vec_prev, decoder_state, dim=1, name='att_concat_prev_word_%d' % seq_idx) + return AttentionInput(seq_idx=seq_idx, query=query) + + +class BilinearAttention(Attention): + """ + Bilinear attention based on Luong et al. 2015. + + :math:`score(h_t, h_s) = h_t^T \\mathbf{W} h_s` + + For implementation reasons we modify to: + + :math:`score(h_t, h_s) = h_s^T \\mathbf{W} h_t` + + :param num_hidden: Number of hidden units. + """ + + def __init__(self, + num_hidden: int, + prefix: str = '') -> None: + super().__init__(False) + self.prefix = prefix + self.num_hidden = num_hidden + self.s2t_weight = mx.sym.Variable("%satt_s2t_weight", self.prefix) + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for recurrent attention in a sequence decoder. + The callable is a recurrent function of the form: + AttentionState = attend(AttentionInput, AttentionState). + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Attention callable. + """ + + # (batch_size * seq_len, self.num_hidden) + source_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=source, shape=(-3, -1), name="att_flat_source"), + weight=self.s2t_weight, num_hidden=self.num_hidden, + no_bias=True, name="att_source_hidden_fc") + # (batch_size, seq_len, self.num_hidden) + source_hidden = mx.sym.reshape(source_hidden, shape=(-1, source_seq_len, self.num_hidden), + name="att_source_hidden") + + def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: + """ + Returns updated attention state given attention input and current attention state. + + :param att_input: Attention input as returned by make_input(). + :param att_state: Current attention state + :return: Updated attention state. + """ + # (batch_size, decoder_num_hidden, 1) + query = mx.sym.expand_dims(att_input.query, axis=2) + + # in: (batch_size, source_seq_len, self.num_hidden) X (batch_size, self.num_hidden, 1) + # out: (batch_size, source_seq_len, 1). + attention_scores = mx.sym.batch_dot(lhs=source_hidden, rhs=query, name="att_batch_dot") + + context, attention_probs = get_context_and_attention_probs(source, source_length, attention_scores) + + return AttentionState(context=context, + probs=attention_probs, + dynamic_source=att_state.dynamic_source) + + return attend + + +class DotAttention(Attention): + """ + Attention mechanism with dot product between encoder and decoder hidden states [Luong et al. 2015]. + + :math:`score(h_t, h_s) = \\langle h_t, h_s \\rangle` + + :math:`a = softmax(score(*, h_s))` + + If rnn_num_hidden != num_hidden, states are projected with additional parameters to num_hidden. + + :math:`score(h_t, h_s) = \\langle \\mathbf{W}_t h_t, \\mathbf{W}_s h_s \\rangle` + + :param input_previous_word: Feed the previous target embedding into the attention mechanism. + :param rnn_num_hidden: Number of hidden units in encoder/decoder RNNs. + :param num_hidden: Number of hidden units. + """ + + def __init__(self, + input_previous_word: bool, + rnn_num_hidden: int, + num_hidden: int) -> None: + super().__init__(input_previous_word) + self.project = rnn_num_hidden != num_hidden + self.num_hidden = num_hidden + self.t2h_weight = mx.sym.Variable("att_t2h_weight") if self.project else None + self.s2h_weight = mx.sym.Variable("att_s2h_weight") if self.project else None + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for recurrent attention in a sequence decoder. + The callable is a recurrent function of the form: + AttentionState = attend(AttentionInput, AttentionState). + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Attention callable. + """ + + if self.project: + # (batch_size * seq_len, self.num_hidden) + source_hidden = mx.sym.FullyConnected( + data=mx.sym.reshape(data=source, shape=(-3, -1), name="att_flat_source"), + weight=self.s2h_weight, num_hidden=self.num_hidden, + no_bias=True, name="att_source_hidden_fc") + # (batch_size, seq_len, self.num_hidden) + source_hidden = mx.sym.reshape(source_hidden, shape=(-1, source_seq_len, self.num_hidden), + name="att_source_hidden") + + def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: + """ + Returns updated attention state given attention input and current attention state. + + :param att_input: Attention input as returned by make_input(). + :param att_state: Current attention state + :return: Updated attention state. + """ + query = att_input.query + local_source = source + if self.project: + local_source = source_hidden + # query: (batch_size, self.num_hidden) + query = mx.sym.FullyConnected(data=query, + weight=self.t2h_weight, + num_hidden=self.num_hidden, + no_bias=True, name="att_query_hidden_fc") + + # (batch_size, decoder_num_hidden, 1) + expanded_decoder_state = mx.sym.expand_dims(query, axis=2) + + # batch_dot: (batch, M, K) X (batch, K, N) –> (batch, M, N). + # (batch_size, seq_len, 1) + attention_scores = mx.sym.batch_dot(lhs=local_source, rhs=expanded_decoder_state, name="att_batch_dot") + + context, attention_probs = get_context_and_attention_probs(source, source_length, attention_scores) + return AttentionState(context=context, + probs=attention_probs, + dynamic_source=att_state.dynamic_source) + + return attend + + +class EncoderLastStateAttention(Attention): + """ + Always returns the last encoder state independent of the query vector. + Equivalent to no attention. + """ + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for recurrent attention in a sequence decoder. + The callable is a recurrent function of the form: + AttentionState = attend(AttentionInput, AttentionState). + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Attention callable. + """ + source = mx.sym.swapaxes(source, dim1=0, dim2=1) + encoder_last_state = mx.sym.SequenceLast(data=source, sequence_length=source_length, + use_sequence_length=True) + fixed_probs = mx.sym.one_hot(source_length - 1, depth=source_seq_len) + + def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: + return AttentionState(context=encoder_last_state, + probs=fixed_probs, + dynamic_source=att_state.dynamic_source) + + return attend + + +class LocationAttention(Attention): + """ + Attends to locations in the source [Luong et al, 2015] + + :math:`a_t = softmax(\\mathbf{W}_a h_t)` for decoder hidden state at time t. + + :note: :math:`\\mathbf{W}_a` is of shape (max_source_seq_len, decoder_num_hidden). + + :param input_previous_word: Feed the previous target embedding into the attention mechanism. + :param max_source_seq_len: Maximum length of source sequences. + """ + + def __init__(self, input_previous_word: bool, max_source_seq_len: int) -> None: + super().__init__(input_previous_word) + self.max_source_seq_len = max_source_seq_len + self.location_weight = mx.sym.Variable("att_loc_weight") + self.location_bias = mx.sym.Variable("att_loc_bias") + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for recurrent attention in a sequence decoder. + The callable is a recurrent function of the form: + AttentionState = attend(AttentionInput, AttentionState). + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Attention callable. + """ + + def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: + """ + Returns updated attention state given attention input and current attention state. + + :param att_input: Attention input as returned by make_input(). + :param att_state: Current attention state + :return: Updated attention state. + """ + # attention_scores: (batch_size, seq_len) + attention_scores = mx.sym.FullyConnected(data=att_input.query, + num_hidden=self.max_source_seq_len, + weight=self.location_weight, + bias=self.location_bias) + + # attention_scores: (batch_size, seq_len) + attention_scores = mx.sym.slice_axis(data=attention_scores, + axis=1, + begin=0, + end=source_seq_len) + + # attention_scores: (batch_size, seq_len, 1) + attention_scores = mx.sym.expand_dims(data=attention_scores, axis=2) + + context, attention_probs = get_context_and_attention_probs(source, source_length, attention_scores) + return AttentionState(context=context, + probs=attention_probs, + dynamic_source=att_state.dynamic_source) + + return attend + + +class MlpAttention(Attention): + """ + Attention computed through a one-layer MLP with num_hidden units [Luong et al, 2015]. + + :math:`score(h_t, h_s) = \\mathbf{W}_a tanh(\\mathbf{W}_c [h_t, h_s] + b)` + + :math:`a = softmax(score(*, h_s))` + + Optionally, if attention_coverage_type is not None, attention uses dynamic source encoding ('coverage' mechanism) + as in Tu et al. (2016): Modeling Coverage for Neural Machine Translation. + + :math:`score(h_t, h_s) = \\mathbf{W}_a tanh(\\mathbf{W}_c [h_t, h_s, c_s] + b)` + + :math:`c_s` is the decoder time-step dependent source encoding which is updated using the current + decoder state. + + :param input_previous_word: Feed the previous target embedding into the attention mechanism. + :param attention_num_hidden: Number of hidden units. + :param attention_coverage_type: The type of update for the dynamic source encoding. + If None, no dynamic source encoding is done. + :param attention_coverage_num_hidden: Number of hidden units for coverage attention. + """ + + def __init__(self, + input_previous_word: bool, + attention_num_hidden: int, + attention_coverage_type: Optional[str] = None, + attention_coverage_num_hidden: int = 1, + prefix='') -> None: + dynamic_source_num_hidden = 1 if attention_coverage_type is None else attention_coverage_num_hidden + super().__init__(input_previous_word=input_previous_word, + dynamic_source_num_hidden=dynamic_source_num_hidden) + self.prefix = prefix + self.attention_num_hidden = attention_num_hidden + # input (encoder) to hidden + self.att_e2h_weight = mx.sym.Variable("%satt_e2h_weight" % prefix) + # input (query) to hidden + self.att_q2h_weight = mx.sym.Variable("%satt_q2h_weight" % prefix) + # hidden to score + self.att_h2s_weight = mx.sym.Variable("%satt_h2s_weight" % prefix) + # dynamic source (coverage) weights and settings + # input (coverage) to hidden + self.att_c2h_weight = mx.sym.Variable("%satt_c2h_weight" % prefix) if attention_coverage_type else None + self.coverage = sockeye.coverage.get_coverage(attention_coverage_type, + dynamic_source_num_hidden) if attention_coverage_type else None + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for recurrent attention in a sequence decoder. + The callable is a recurrent function of the form: + AttentionState = attend(AttentionInput, AttentionState). + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Attention callable. + """ + + coverage_func = self.coverage.on(source, source_length, source_seq_len) if self.coverage else None + + # (batch_size * seq_len, attention_num_hidden) + source_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=source, + shape=(-3, -1), + name="%satt_flat_source" % self.prefix), + weight=self.att_e2h_weight, + num_hidden=self.attention_num_hidden, + no_bias=True, + name="%satt_source_hidden_fc" % self.prefix) + + # (batch_size, seq_len, attention_num_hidden) + source_hidden = mx.sym.reshape(source_hidden, + shape=(-1, source_seq_len, self.attention_num_hidden), + name="%satt_source_hidden" % self.prefix) + + def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: + """ + Returns updated attention state given attention input and current attention state. + + :param att_input: Attention input as returned by make_input(). + :param att_state: Current attention state + :return: Updated attention state. + """ + + # (batch_size, attention_num_hidden) + query_hidden = mx.sym.FullyConnected(data=att_input.query, + weight=self.att_q2h_weight, + num_hidden=self.attention_num_hidden, + no_bias=True, + name="%satt_query_hidden" % self.prefix) + + # (batch_size, 1, attention_num_hidden) + query_hidden = mx.sym.expand_dims(data=query_hidden, + axis=1, + name="%satt_query_hidden_expanded" % self.prefix) + + attention_hidden_lhs = source_hidden + if self.coverage: + # (batch_size * seq_len, attention_num_hidden) + dynamic_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=att_state.dynamic_source, + shape=(-3, -1), + name="%satt_flat_dynamic_source" % self.prefix), + weight=self.att_c2h_weight, + num_hidden=self.attention_num_hidden, + no_bias=True, + name="%satt_dynamic_source_hidden_fc" % self.prefix) + + # (batch_size, seq_len, attention_num_hidden) + dynamic_hidden = mx.sym.reshape(dynamic_hidden, + shape=(-1, source_seq_len, self.attention_num_hidden), + name="%satt_dynamic_source_hidden" % self.prefix) + + # (batch_size, seq_len, attention_num_hidden + attention_hidden_lhs = dynamic_hidden + source_hidden + + # (batch_size, seq_len, attention_num_hidden) + attention_hidden = mx.sym.broadcast_add(lhs=attention_hidden_lhs, rhs=query_hidden, + name="%satt_query_plus_input" % self.prefix) + + # (batch_size * seq_len, attention_num_hidden) + attention_hidden = mx.sym.reshape(data=attention_hidden, + shape=(-3, -1), + name="%satt_query_plus_input_before_fc" % self.prefix) + + # (batch_size * seq_len, attention_num_hidden) + attention_hidden = mx.sym.Activation(attention_hidden, act_type="tanh", + name="%satt_hidden" % self.prefix) + + # (batch_size * seq_len, 1) + attention_scores = mx.sym.FullyConnected(data=attention_hidden, + weight=self.att_h2s_weight, + num_hidden=1, + no_bias=True, + name="%sraw_att_score_fc" % self.prefix) + + # (batch_size, seq_len, 1) + attention_scores = mx.sym.reshape(attention_scores, + shape=(-1, source_seq_len, 1), + name="%sraw_att_score_fc" % self.prefix) + + context, attention_probs = get_context_and_attention_probs(source, source_length, attention_scores) + + dynamic_source = att_state.dynamic_source + if self.coverage: + # update dynamic source encoding + # Note: this is a slight change to the Tu et al, 2016 paper: input to the coverage update + # is the attention input query, not the previous decoder state. + dynamic_source = coverage_func(prev_hidden=att_input.query, + attention_prob_scores=attention_probs, + prev_coverage=att_state.dynamic_source) + + return AttentionState(context=context, + probs=attention_probs, + dynamic_source=dynamic_source) + + return attend + + +def get_context_and_attention_probs(source: mx.sym.Symbol, + source_length: mx.sym.Symbol, + attention_scores: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: + """ + Returns context vector and attention probs via a weighted sum over the masked, softmaxed attention scores. + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param attention_scores: Shape: (batch_size, seq_len, 1). + :return: context: (batch_size, encoder_num_hidden), attention_probs: (batch_size, seq_len). + """ + + # TODO: It would be nice if SequenceMask could take a 2d input... + # Note: we need to add an axis as SequenceMask expects 3D input + # TODO: we should probably replace this with a multiplication of a 0-1 mask, to avoid the multiplication + attention_scores = mx.sym.swapaxes(data=attention_scores, dim1=0, dim2=1) + attention_scores = mx.sym.SequenceMask(data=attention_scores, + use_sequence_length=True, + sequence_length=source_length, + value=-99999999.) + attention_scores = mx.sym.swapaxes(data=attention_scores, dim1=0, dim2=1) + # attention_scores is batch_major from here: (batch_size, seq_len, 1) + + # (batch_size, seq_len) + attention_scores = mx.sym.reshape(data=attention_scores, shape=(0, 0)) + + # (batch_size, seq_len) + attention_probs = mx.sym.softmax(attention_scores, name='attention_softmax') + + # (batch_size, seq_len, 1) + attention_probs_expanded = mx.sym.expand_dims(data=attention_probs, axis=2) + + # batch_dot: (batch, M, K) X (batch, K, N) –> (batch, M, N). + # (batch_size, seq_len, encoder_num_hidden) X (batch_size, seq_len, 1) -> (batch_size, encoder_num_hidden) + context = mx.sym.batch_dot(lhs=source, rhs=attention_probs_expanded, transpose_a=True) + context = mx.sym.reshape(data=context, shape=(0, 0)) + + return context, attention_probs diff --git a/sockeye/average.py b/sockeye/average.py new file mode 100644 index 000000000..c67ff32fc --- /dev/null +++ b/sockeye/average.py @@ -0,0 +1,196 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Average parameters from multiple model checkpoints. Checkpoints can be either +specificed manually or automatically chosen according to one of several +strategies. The default strategy of simply selecting the top-scoring N points +works well in practice. +""" + +import argparse +import itertools +import os +from typing import Dict, Iterable, Tuple + +import mxnet as mx + +import sockeye.constants as C +import sockeye.utils +from sockeye.log import setup_main_logger + +logger = setup_main_logger(__name__, console=True, file_logging=False) + + +def average(param_paths: Iterable[str]) -> Dict[str, mx.nd.NDArray]: + """ + Averages parameters from a list of .params file paths. + + :param param_paths: List of paths to parameter files. + :return: Averaged parameter dictionary. + """ + all_arg_params = [] + all_aux_params = [] + for path in param_paths: + logger.info("Loading parameters from '%s'", path) + arg_params, aux_params = sockeye.utils.load_params(path) + all_arg_params.append(arg_params) + all_aux_params.append(aux_params) + + logger.info("%d models loaded", len(all_arg_params)) + assert all( + all_arg_params[0].keys() == p.keys() + for p in all_arg_params), "arg_param names do not match across models" + assert all( + all_aux_params[0].keys() == p.keys() + for p in all_aux_params), "aux_param names do not match across models" + + avg_params = {} + # average arg_params + for k in all_arg_params[0]: + arrays = [p[k] for p in all_arg_params] + avg_params["arg:" + k] = sockeye.utils.average_arrays(arrays) + # average aux_params + for k in all_aux_params[0]: + arrays = [p[k] for p in all_aux_params] + avg_params["aux:" + k] = sockeye.utils.average_arrays(arrays) + + return avg_params + + +def find_checkpoints(model_path: str, size=4, strategy="best", maximize=False) -> Iterable[str]: + """ + Finds N best points from .metrics file according to strategy + + :param model_path: Path to model. + :param size: Number of checkpoints to combine. + :param strategy: Combination strategy. + :param maximize: Whether the value of the metric should be maximized. + :return: List of paths corresponding to chosen checkpoints. + """ + metrics_path = os.path.join(model_path, C.METRICS_NAME) + points = read_metrics_points(metrics_path) + + if strategy == "best": + # N best scoring points + top_n = sorted(points, reverse=maximize)[:size] + + elif strategy == "last": + # N sequential points ending with overall best + best = max if maximize else min + after_top = points.index(best(points)) + 1 + top_n = points[after_top - size:after_top] + + elif strategy == "lifespan": + # Track lifespan of every "new best" point + # Points dominated by a previous better point have lifespan 0 + top_n = [] + cur_best = points[0] + cur_lifespan = 0 + for point in points[1:]: + better = point > cur_best if maximize else point < cur_best + if better: + top_n.append(list(itertools.chain([cur_lifespan], cur_best))) + cur_best = point + cur_lifespan = 0 + else: + top_n.append(list(itertools.chain([0], point))) + cur_lifespan += 1 + top_n.append(list(itertools.chain([cur_lifespan], cur_best))) + # Sort by lifespan, then by val + top_n = sorted( + top_n, + key=lambda point: [point[0], point[1] if maximize else -point[1]], + reverse=True)[:size] + + else: + raise RuntimeError("Unknown strategy, options: best last lifespan") + + # Assemble paths for params files corresponding to chosen checkpoints + # Last element in point is always the checkpoint id + params_paths = [ + os.path.join(model_path, C.PARAMS_NAME % point[-1]) for point in top_n + ] + + # Report + logger.info("Found: " + ", ".join(str(point) for point in top_n)) + + return params_paths + + +def read_metrics_points(path: str) -> Iterable[Tuple[float, int]]: + """ + Reads lines from .metrics file and return list of elements [val, checkpoint] + + :param path: File to read metric values from. + :return: List of pairs (metric value, checkpoint). + """ + points = [] + # First field is checkpoint id + # Metric on validation (dev) set looks like this: METRIC-val=N + with open(path, "r") as metrics_in: + for line in metrics_in: + fields = line.split() + checkpoint = int(fields[0]) + for field in fields[1:]: + key_value = field.split("=") + if len(key_value) == 2: + metric_set = key_value[0].split("-") + if len(metric_set) == 2 and metric_set[0] != C.ACCURACY and metric_set[1] == "val": + metric_value = float(key_value[1]) + points.append([metric_value, checkpoint]) + return points + + +def main(): + """ + Commandline interface to average parameters. + """ + params = argparse.ArgumentParser( + description="Averages parameters from multiple models.") + params.add_argument( + "inputs", + metavar="INPUT", + type=str, + nargs="+", + help="either a single model directory (automatic checkpoint selection) " + "or multiple .params files (manual checkpoint selection)") + params.add_argument( + "--max", action="store_true", help="maximize metric (default: min)") + params.add_argument( + "-n", + type=int, + default=4, + help="number of checkpoints to find (default: 4)") + params.add_argument( + "--output", "-o", required=True, type=str, help="output param file") + params.add_argument( + "--strategy", + choices=["best", "last", "lifespan"], + default="best", + help="selection method (default: best)") + args = params.parse_args() + + if len(args.inputs) > 1: + avg_params = average(args.inputs) + else: + param_paths = find_checkpoints(args.inputs[0], args.n, args.strategy, + args.max) + avg_params = average(param_paths) + + mx.nd.save(args.output, avg_params) + logger.info("Averaged parameters written to '%s'", args.output) + + +if __name__ == "__main__": + main() diff --git a/sockeye/bleu.py b/sockeye/bleu.py new file mode 100644 index 000000000..8b1cc6564 --- /dev/null +++ b/sockeye/bleu.py @@ -0,0 +1,107 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Implementation of the BLEU-4 score [Papineni, 2002] +""" +import logging +from collections import Counter, namedtuple +from itertools import tee, islice +from math import log, exp +from typing import List, AnyStr + +logger = logging.getLogger(__name__) + +ORDER = 4 +Statistics = namedtuple('Statistics', ['common', 'total']) + + +def zipngram(words, n): + return zip(*(islice(it, pos, None) for pos, it in enumerate(tee(words, n + 1)))) + + +def bleu_from_counts(count_triple, offset=0.01): + counts, hyp_count, ref_count = count_triple + bleu = 0.0 + brevity = 0.0 + effective_order = 0 + for n in range(ORDER): + count = counts.common[n] + total = counts.total[n] + if total <= 0: + if n == 0: + return 0 + else: + break + effective_order += 1 + if count == 0: + count = offset + bleu += log(float(count) / total) + if hyp_count > 0: + brevity = min(0., 1. - float(ref_count) / hyp_count) + return exp(bleu / effective_order + brevity) + + +def bleu_counts(hyp, ref): + counts = Statistics([0, 0, 0, 0], [0, 0, 0, 0]) + + hyp_words = hyp.split() + ref_words = ref.split() + + hyp_wcount = len(hyp_words) + ref_wcount = len(ref_words) + for n in range(ORDER): + h_grams = Counter(zipngram(hyp_words, n)) + r_grams = Counter(zipngram(ref_words, n)) + + # do clipping + inter = (min(h_grams[g], r_grams[g]) for g in h_grams if g in r_grams) + counts.common[n] += sum(inter) + counts.total[n] += sum(h_grams.values()) + + return counts, hyp_wcount, ref_wcount + + +def add_counts_in_place(c1, c2): + for n in range(ORDER): + c1.common[n] += c2.common[n] + c1.total[n] += c2.total[n] + + +def corpus_bleu_counts(hyps, refs): + counts = Statistics([0, 0, 0, 0], [0, 0, 0, 0]) + hyp_total_wcount, ref_total_wcount = 0, 0 + + if len(hyps) != len(refs): + logger.error("Hyps and refs lengths are not the same") + + for hyp, ref in zip(hyps, refs): + sent_counts, hyp_wcount, ref_wcount = bleu_counts(hyp, ref) + + add_counts_in_place(counts, sent_counts) + hyp_total_wcount += hyp_wcount + ref_total_wcount += ref_wcount + + return counts, hyp_total_wcount, ref_total_wcount + + +def corpus_bleu(hyps: List[AnyStr], refs: List[AnyStr], offset: float = 0.01) -> float: + """ + Computes corpus BLEU. + + :param hyps: List of hypotheses. + :param refs: List of references. + :param offset: Smoothing value. + :return: BLEU score. + """ + return bleu_from_counts(corpus_bleu_counts(hyps, refs), offset) diff --git a/sockeye/callback.py b/sockeye/callback.py new file mode 100644 index 000000000..9357b203a --- /dev/null +++ b/sockeye/callback.py @@ -0,0 +1,272 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Provides functionality to track metrics on training and validation data during training and controls +early-stopping. +""" +import logging +import multiprocessing as mp +import os +import shutil +import time +from typing import Optional, Tuple, Dict + +import mxnet as mx +import numpy as np + +import sockeye.checkpoint_decoder +import sockeye.constants as C +import sockeye.inference + +logger = logging.getLogger(__name__) + + +class TrainingMonitor(object): + """ + TrainingMonitor logs metrics on training and validation data, submits decoding processes to compute BLEU scores, + and writes metrics to the model output folder. + It further controls early stopping as it decides based on the specified metric to optimize, whether the model + has improved w.r.t to the last checkpoint. + Technically, TrainingMonitor exposes a couple of callback function that are called in the fit() method of + TrainingModel. + + :param batch_size: Batch size during training. + :param output_folder: Folder where model files are written to. + :param optimized_metric: Name of the metric that controls early stopping. + :param use_tensorboard: Whether to use Tensorboard logging of metrics. + :param checkpoint_decoder: Optional CheckpointDecoder instance for BLEU monitoring. + :param num_concurrent_decodes: Number of concurrent subprocesses to decode validation data. + """ + + def __init__(self, + batch_size: int, + output_folder: str, + optimized_metric: str = C.PERPLEXITY, + use_tensorboard: bool = False, + checkpoint_decoder: Optional[sockeye.checkpoint_decoder.CheckpointDecoder] = None, + num_concurrent_decodes: int = 1) -> None: + self.metrics = [] # stores dicts of metric names & values for each checkpoint + self.metrics_filename = os.path.join(output_folder, C.METRICS_NAME) + open(self.metrics_filename, 'w').close() # clear metrics file + self.best_checkpoint = 0 + self.start_tic = time.time() + self.summary_writer = None + if use_tensorboard: + import tensorboard + log_dir = os.path.join(output_folder, C.TENSORBOARD_NAME) + if os.path.exists(log_dir): + logger.info("Deleting existing tensorboard log dir %s", log_dir) + shutil.rmtree(log_dir) + logger.info("Logging training events for Tensorboard at '%s'", log_dir) + self.summary_writer = tensorboard.FileWriter(log_dir) + self.checkpoint_decoder = checkpoint_decoder + self.ctx = mp.get_context('spawn') + self.num_concurrent_decodes = num_concurrent_decodes + self.decoder_metric_queue = self.ctx.Queue() + self.decoder_processes = [] + # TODO(fhieber): MXNet Speedometer uses root logger. How to fix this? + self.speedometer = mx.callback.Speedometer(batch_size=batch_size, + frequent=C.MEASURE_SPEED_EVERY, + auto_reset=False) + self.optimized_metric = optimized_metric + if self.optimized_metric == C.PERPLEXITY: + self.minimize = True + self.validation_best = np.inf + elif self.optimized_metric == C.ACCURACY: + self.minimize = False + self.validation_best = -np.inf + elif self.optimized_metric == C.BLEU: + assert self.checkpoint_decoder is not None, "BLEU requires CheckpointDecoder" + self.minimize = False + self.validation_best = -np.inf + else: + raise ValueError("No other metrics supported") + logger.info("Early stopping by optimizing '%s' (minimize=%s)", + self.optimized_metric, self.minimize) + self.tic = 0 + + def get_best_checkpoint(self) -> int: + """ + Returns current best checkpoint. + """ + return self.best_checkpoint + + def get_best_validation_score(self) -> float: + """ + Returns current best validation result for optimized metric. + """ + return self.validation_best + + def _is_better(self, value): + return value < self.validation_best if self.minimize else value > self.validation_best + + def batch_end_callback(self, epoch: int, nbatch: int, metric: mx.metric.EvalMetric): + """ + Callback function when processing of a data bach is completed. + + :param epoch: Current epoch. + :param nbatch: Current batch. + :param metric: Evaluation metric for training data. + """ + self.speedometer( + mx.model.BatchEndParam( + epoch=epoch, nbatch=nbatch, eval_metric=metric, locals=None)) + + def checkpoint_callback(self, checkpoint: int, train_metric: mx.metric.EvalMetric): + """ + Callback function when a model checkpoint is performed. + If TrainingMonitor uses Tensorboard, training metrics are written to the Tensorboard event file. + + :param checkpoint: Current checkpoint. + :param train_metric: Evaluation metric for training data. + """ + metrics = {} + for name, value in train_metric.get_name_value(): + metrics[name + "-train"] = value + self.metrics.append(metrics) + if self.summary_writer: + write_tensorboard(self.summary_writer, metrics, checkpoint) + + def eval_end_callback(self, checkpoint: int, val_metric: mx.metric.EvalMetric) -> Tuple[bool, int]: + """ + Callback function when processing of held-out validation data is complete. + Counts time elapsed since the start of training. + If TrainingMonitor uses Tensorboard, validation metrics are written to the Tensorboard event file. + If BLEU is monitored with subprocesses, this function collects result from finished decoder processes + and starts a new one for the current checkpoint. + + :param checkpoint: Current checkpoint. + :param val_metric: Evaluation metric for validation data. + :return: Tuple of boolean indicating if model improved on validation data according to the. + optimized metric, and the (updated) best checkpoint. + """ + metrics = {} + for name, value in val_metric.get_name_value(): + metrics[name + "-val"] = value + metrics['time-elapsed'] = time.time() - self.start_tic + + if self.summary_writer: + write_tensorboard(self.summary_writer, metrics, checkpoint) + + if self.checkpoint_decoder: + self._empty_decoder_metric_queue() + self._start_decode_process(checkpoint) + + self.metrics[-1].update(metrics) + self._write_scores() + + has_improved, best_checkpoint = self._find_best_checkpoint() + return has_improved, best_checkpoint + + def _find_best_checkpoint(self): + """ + Returns True if optimized_metric has improved since the last call of + this function, together with the best checkpoint + """ + has_improved = False + for checkpoint, metric_dict in enumerate(self.metrics, 1): + value = metric_dict.get(self.optimized_metric + "-val", + self.validation_best) + if self._is_better(value): + self.validation_best = value + self.best_checkpoint = checkpoint + has_improved = True + + if has_improved: + logger.info("Validation-%s improved to %f.", self.optimized_metric, + self.validation_best) + else: + logger.info("Validation-%s has not improved, best so far: %f", + self.optimized_metric, self.validation_best) + return has_improved, self.best_checkpoint + + def _write_scores(self): + """ + Overwrite metrics_filename with latest metrics results. + """ + with open(self.metrics_filename, 'w') as metrics_out: + for checkpoint, metric_dict in enumerate(self.metrics, 1): + metrics_out.write("%d\t" % checkpoint) + metrics_out.write("\t".join(["%s=%.6f" % (name, value) + for name, value in sorted( + metric_dict.items())]) + "\n") + + def _start_decode_process(self, checkpoint): + self._wait_for_decode_slot() + process = self.ctx.Process( + target=_decode_and_evaluate, + args=(self.checkpoint_decoder, checkpoint, + self.decoder_metric_queue)) + process.name = 'Decoder-%d' % checkpoint + logger.info("Starting process: %s", process.name) + process.start() + self.decoder_processes.append(process) + + def _empty_decoder_metric_queue(self): + """ + Get metric results from decoder_process queue and optionally write to tensorboard logs + """ + while not self.decoder_metric_queue.empty(): + decoded_checkpoint, decoder_metrics = self.decoder_metric_queue.get() + logger.info("Checkpoint [%d]: Decoder finished (%s)", + decoded_checkpoint, decoder_metrics) + self.metrics[decoded_checkpoint - 1].update(decoder_metrics) + if self.summary_writer: + write_tensorboard(self.summary_writer, decoder_metrics, + decoded_checkpoint) + + def _wait_for_decode_slot(self, timeout: int = 5): + while len(self.decoder_processes) == self.num_concurrent_decodes: + self.decoder_processes = [p for p in self.decoder_processes + if p.is_alive()] + time.sleep(timeout) + + def stop_fit_callback(self): + """ + Callback function when fitting is stopped. Collects results from decoder processes and writes their results. + """ + for process in self.decoder_processes: + if process.is_alive(): + logger.info("Waiting for %s process to finish." % process.name) + process.join() + self._empty_decoder_metric_queue() + self._write_scores() + + +def _decode_and_evaluate(checkpoint_decoder: sockeye.checkpoint_decoder.CheckpointDecoder, + checkpoint: int, + queue: mp.Queue): + """ + Decodes and evaluates using given checkpoint_decoder and puts result in the queue, + indexed by the checkpoint. + """ + metrics = checkpoint_decoder.decode_and_evaluate(checkpoint) + queue.put((checkpoint, metrics)) + + +def write_tensorboard(summary_writer, + metrics: Dict[str, float], + checkpoint: int): + """ + Writes a Tensorboard scalar event to the given SummaryWriter. + + :param summary_writer: A Tensorboard SummaryWriter instance. + :param metrics: Mapping of metric names to their values. + :param checkpoint: Current checkpoint. + """ + from tensorboard.summary import scalar + for name, value in metrics.items(): + summary_writer.add_summary( + scalar( + name=name, scalar=value), global_step=checkpoint) diff --git a/sockeye/checkpoint_decoder.py b/sockeye/checkpoint_decoder.py new file mode 100644 index 000000000..80887fbca --- /dev/null +++ b/sockeye/checkpoint_decoder.py @@ -0,0 +1,104 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Implements a thin wrapper around Translator to compute BLEU scores on (a sample of) validation data during training. +""" +import logging +import os +import random +from typing import Dict + +import mxnet as mx + +import sockeye.bleu +import sockeye.inference +import sockeye.output_handler +from sockeye import constants as C +from sockeye.data_io import smart_open + +logger = logging.getLogger(__name__) + + +class CheckpointDecoder: + """ + Decodes a (random sample of a) dataset using parameters at given checkpoint and computes BLEU against references. + + :param context: MXNet context to bind the model to. + :param inputs: Path to file containing input sentences. + :param references: Path to file containing references. + :param model: Model to load. + :param max_input_len: Maximum input length. + :param beam_size: Size of the beam. + :param limit: Maximum number of sentences to sample and decode. If <=0, all sentences are used. + """ + + def __init__(self, + context: mx.context.Context, + inputs: str, + references: str, + model: str, + max_input_len: int, + beam_size=C.DEFAULT_BEAM_SIZE, + limit: int = -1): + self.context = context + self.max_input_len = max_input_len + self.beam_size = beam_size + self.model = model + with smart_open(inputs) as inputs_fin, smart_open(references) as references_fin: + input_sentences = inputs_fin.readlines() + target_sentences = references_fin.readlines() + assert len(input_sentences) == len(target_sentences), "Number of sentence pairs do not match" + if limit <= 0: + limit = len(input_sentences) + if limit < len(input_sentences): + self.input_sentences, self.target_sentences = zip( + *random.sample(list(zip(input_sentences, target_sentences)), + limit)) + else: + self.input_sentences, self.target_sentences = input_sentences, target_sentences + + logger.info("Created CheckpointDecoder(max_input_len=%d, beam_size=%d, model=%s, num_sentences=%d)", + max_input_len, beam_size, model, len(self.input_sentences)) + + with smart_open(os.path.join(self.model, C.DECODE_REF_NAME), 'w') as trg_out, \ + smart_open(os.path.join(self.model, C.DECODE_IN_NAME), 'w') as src_out: + [trg_out.write(s) for s in self.target_sentences] + [src_out.write(s) for s in self.input_sentences] + + def decode_and_evaluate(self, checkpoint: int) -> Dict[str, float]: + """ + Decodes data set and evaluates given a checkpoint. + + :param checkpoint: Checkpoint to load parameters from. + :return: Mapping of metric names to scores. + """ + translator = sockeye.inference.Translator(self.context, 'linear', + *sockeye.inference.load_models(self.context, + self.max_input_len, + self.beam_size, + [self.model], + [checkpoint])) + + output_name = os.path.join(self.model, C.DECODE_OUT_NAME % checkpoint) + with smart_open(output_name, 'w') as output: + handler = sockeye.output_handler.StringOutputHandler(output) + translations = [] + for sent_id, input_sentence in enumerate(self.input_sentences): + trans_input = translator.make_input(sent_id, input_sentence) + trans_output = translator.translate(trans_input) + handler.handle(trans_input, trans_output) + translations.append(trans_output.translation) + logger.info("Checkpoint [%d] %d translations saved to '%s'", checkpoint, len(translations), output_name) + # TODO(fhieber): eventually add more metrics (METEOR etc.) + return {"bleu-val": sockeye.bleu.corpus_bleu(translations, self.target_sentences)} diff --git a/sockeye/constants.py b/sockeye/constants.py new file mode 100644 index 000000000..12e23b5de --- /dev/null +++ b/sockeye/constants.py @@ -0,0 +1,94 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Defines various constants used througout the project +""" + +BOS_SYMBOL = "" +EOS_SYMBOL = "" +UNK_SYMBOL = "" +PAD_SYMBOL = "" +PAD_ID = 0 +TOKEN_SEPARATOR = " " +VOCAB_SYMBOLS = [PAD_SYMBOL, UNK_SYMBOL, BOS_SYMBOL, EOS_SYMBOL] + +# default encoder prefixes +ENCODER_PREFIX = "encoder_" +EMBEDDING_PREFIX = "embed_" +BIDIRECTIONALRNN_PREFIX = ENCODER_PREFIX + "birnn_" +STACKEDRNN_PREFIX = ENCODER_PREFIX + "rnn_" +FORWARD_PREFIX = "forward_" +REVERSE_PREFIX = "reverse_" + +# embedding prefixes +SOURCE_EMBEDDING_PREFIX = "source_embed_" +TARGET_EMBEDDING_PREFIX = "target_embed_" + +# rnn types +LSTM_TYPE = 'lstm' +GRU_TYPE = 'gru' + +# init types +RNN_INIT_ORTHOGONAL = 'orthogonal' +RNN_INIT_ORTHOGONAL_STACKED = 'orthogonal_stacked' + +# default decoder prefixes +DECODER_PREFIX = "decoder_" + +# default I/O variable names +SOURCE_NAME = "source" +SOURCE_LENGTH_NAME = "source_length" +TARGET_NAME = "target" +TARGET_LABEL_NAME = "target_label" +LEXICON_NAME = "lexicon" + +SOURCE_ENCODED_NAME = "encoded_source" +TARGET_PREVIOUS_NAME = "prev_target_word_id" +HIDDEN_PREVIOUS_NAME = "prev_hidden" +SOURCE_DYNAMIC_PREVIOUS_NAME = "prev_dynamic_source" + +LOGITS_NAME = "logits" +SOFTMAX_NAME = "softmax" +SOFTMAX_OUTPUT_NAME = SOFTMAX_NAME + "_output" + +MEASURE_SPEED_EVERY = 50 # measure speed and metrics every X batches + +DEFAULT_BEAM_SIZE = 5 + +CONFIG_NAME = "config" +LOG_NAME = "log" +JSON_SUFFIX = ".json" +VOCAB_SRC_NAME = "vocab.src" +VOCAB_TRG_NAME = "vocab.trg" +PARAMS_NAME = "params.%04d" +PARAMS_BEST_NAME = "params.best" +DECODE_OUT_NAME = "decode.output.%04d" +DECODE_IN_NAME = "decode.source" +DECODE_REF_NAME = "decode.target" +SYMBOL_NAME = "symbol" + JSON_SUFFIX +METRICS_NAME = "metrics" +TENSORBOARD_NAME = "tensorboard" + +# data layout strings +BATCH_MAJOR = "NTC" +TIME_MAJOR = "TNC" + +# metric names +ACCURACY = 'accuracy' +PERPLEXITY = 'perplexity' +BLEU = 'bleu' + +# loss names +CROSS_ENTROPY = 'cross-entropy' +SMOOTHED_CROSS_ENTROPY = 'smoothed-cross-entropy' diff --git a/sockeye/coverage.py b/sockeye/coverage.py new file mode 100644 index 000000000..86e9f1c6d --- /dev/null +++ b/sockeye/coverage.py @@ -0,0 +1,294 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Defines the dynamic source encodings ('coverage' mechanisms) for encoder/decoder networks as used in Tu et al. (2016). +""" +import logging +from typing import Callable + +import mxnet as mx + +logger = logging.getLogger(__name__) + + +def get_coverage(coverage_type: str, + coverage_num_hidden: int) -> 'Coverage': + """ + Returns a Coverage instance. + + :param coverage_type: Name of coverage type. + :param coverage_num_hidden: Number of hidden units for coverage vectors. + :return: Instance of Coverage. + """ + + if coverage_type == "gru": + return GRUCoverage(coverage_num_hidden) + elif coverage_type in {"tanh", "sigmoid", "relu", "softrelu"}: + return ActivationCoverage(coverage_num_hidden, coverage_type) + elif coverage_type == "count": + return CountCoverage() + else: + raise ValueError("Unknown coverage type %s" % coverage_type) + + +class Coverage: + """ + Generic coverage class. Similar to Attention classes, a coverage instance returns a callable, update_coverage(), + function when self.on() is called. + """ + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for updating coverage vectors in a sequence decoder. + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Coverage callable. + """ + + def update_coverage(prev_hidden: mx.sym.Symbol, + attention_prob_scores: mx.sym.Symbol, + prev_coverage: mx.sym.Symbol): + """ + :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden). + :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1). + :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). + :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). + """ + raise NotImplementedError() + + return update_coverage + + +class CountCoverage(Coverage): + """ + Coverage class that accumulates the attention weights for each source word. + """ + + def __init__(self, prefix='') -> None: + self.prefix = prefix + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for updating coverage vectors in a sequence decoder. + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Coverage callable. + """ + + def update_coverage(prev_hidden: mx.sym.Symbol, + attention_prob_scores: mx.sym.Symbol, + prev_coverage: mx.sym.Symbol): + """ + :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden). + :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1). + :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). + :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). + """ + return prev_coverage + mx.sym.expand_dims(attention_prob_scores, axis=2) + + return update_coverage + + +class GRUCoverage(Coverage): + """ + Implements a GRU whose state is the coverage vector. + + TODO: This implementation is slightly inefficient since the source is fed in at every step. + It would be better to pre-compute the mapping of the source but this will likely mean opening up the GRU. + + :param coverage_num_hidden: Number of hidden units for coverage vectors. + """ + + def __init__(self, coverage_num_hidden: int, prefix='') -> None: + self.prefix = prefix + self.num_hidden = coverage_num_hidden + self.gru = mx.rnn.GRUCell(self.num_hidden, prefix="%scoverage_gru" % self.prefix) + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for updating coverage vectors in a sequence decoder. + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Coverage callable. + """ + + def update_coverage(prev_hidden: mx.sym.Symbol, + attention_prob_scores: mx.sym.Symbol, + prev_coverage: mx.sym.Symbol): + """ + :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden). + :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1). + :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). + :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). + """ + + # (batch_size, source_seq_len, decoder_num_hidden) + expanded_decoder = mx.sym.broadcast_axis( + data=mx.sym.expand_dims(data=prev_hidden, axis=1, name="%scov_expand_decoder" % self.prefix), + axis=1, size=source_seq_len, name="%scov_broadcast_decoder" % self.prefix) + + expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores, + axis=2, + name="%scov_expand_attention_scores" % self.prefix) + + # (batch_size, source_seq_len, encoder_num_hidden + decoder_num_hidden + 1) + # +1 for the attention_prob_score for the source word + concat_input = mx.sym.concat(source, expanded_decoder, expanded_att_scores, dim=2, + name="%scov_concat_inputs" % self.prefix) + + # (batch_size * source_seq_len, encoder_num_hidden + decoder_num_hidden + 1) + flat_input = mx.sym.reshape(concat_input, shape=(-3, -1), name="%scov_flatten_inputs") + + # coverage: (batch_size * seq_len, coverage_num_hidden) + coverage = mx.sym.reshape(data=prev_coverage, shape=(-3, -1)) + updated_coverage, _ = self.gru(flat_input, states=[coverage]) + + # coverage: (batch_size, seq_len, coverage_num_hidden) + coverage = mx.sym.reshape(updated_coverage, shape=(-1, source_seq_len, self.num_hidden)) + + return mask_coverage(coverage, source_length) + + return update_coverage + + +class ActivationCoverage(Coverage): + """ + Implements a coverage mechanism whose updates are performed by a Perceptron with + configurable activation function. + + :param coverage_num_hidden: Number of hidden units for coverage vectors. + :param activation: Type of activation for Perceptron. + """ + + def __init__(self, coverage_num_hidden: int, activation: str, prefix='') -> None: + self.prefix = prefix + self.activation = activation + self.num_hidden = coverage_num_hidden + # input (encoder) to hidden + self.cov_e2h_weight = mx.sym.Variable("%scov_e2h_weight" % self.prefix) + # decoder to hidden + self.cov_dec2h_weight = mx.sym.Variable("%scov_i2h_weight" % self.prefix) + # previous coverage to hidden + self.cov_prev2h_weight = mx.sym.Variable("%scov_prev2h_weight" % self.prefix) + # attention scores to hidden + self.cov_a2h_weight = mx.sym.Variable("%scov_a2h_weight" % self.prefix) + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for updating coverage vectors in a sequence decoder. + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Coverage callable. + """ + + # (batch_size * seq_len, coverage_hidden_num) + source_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=source, + shape=(-3, -1), + name="%scov_flat_source" % self.prefix), + weight=self.cov_e2h_weight, + no_bias=True, + num_hidden=self.num_hidden, + name="%scov_source_hidden_fc" % self.prefix) + + # (batch_size, seq_len, coverage_hidden_num) + source_hidden = mx.sym.reshape(source_hidden, + shape=(-1, source_seq_len, self.num_hidden), + name="%scov_source_hidden" % self.prefix) + + def update_coverage(prev_hidden: mx.sym.Symbol, + attention_prob_scores: mx.sym.Symbol, + prev_coverage: mx.sym.Symbol): + """ + :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden). + :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1). + :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). + :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). + """ + + # (batch_size * seq_len, coverage_hidden_num) + coverage_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=prev_coverage, + shape=(-3, -1), + name="%scov_flat_previous" % self.prefix), + weight=self.cov_prev2h_weight, + no_bias=True, + num_hidden=self.num_hidden, + name="%scov_previous_hidden_fc" % self.prefix) + + # (batch_size, source_seq_len, coverage_hidden_num) + coverage_hidden = mx.sym.reshape(coverage_hidden, + shape=(-1, source_seq_len, self.num_hidden), + name="%scov_previous_hidden" % self.prefix) + + # (batch_size, source_seq_len, 1) + attention_prob_score = mx.sym.expand_dims(attention_prob_scores, axis=2) + + # (batch_size * source_seq_len, coverage_num_hidden) + attention_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(attention_prob_score, + shape=(-3, 0), + name="%scov_reshape_att_probs" % self.prefix), + weight=self.cov_a2h_weight, + no_bias=True, + num_hidden=self.num_hidden, + name="%scov_attention_fc" % self.prefix) + + # (batch_size, source_seq_len, coverage_num_hidden) + attention_hidden = mx.sym.reshape(attention_hidden, + shape=(-1, source_seq_len, self.num_hidden), + name="%scov_reshape_att" % self.prefix) + + # (batch_size, coverage_num_hidden) + prev_hidden = mx.sym.FullyConnected(data=prev_hidden, weight=self.cov_dec2h_weight, no_bias=True, + num_hidden=self.num_hidden, name="%scov_decoder_hidden") + + # (batch_size, 1, coverage_num_hidden) + prev_hidden = mx.sym.expand_dims(data=prev_hidden, axis=1, + name="%scov_input_decoder_hidden_expanded" % self.prefix) + + # (batch_size, source_seq_len, coverage_num_hidden) + intermediate = mx.sym.broadcast_add(lhs=source_hidden, rhs=prev_hidden, + name="%scov_source_plus_hidden" % self.prefix) + + # (batch_size, source_seq_len, coverage_num_hidden) + updated_coverage = intermediate + attention_hidden + coverage_hidden + + # (batch_size, seq_len, coverage_num_hidden) + coverage = mx.sym.Activation(data=updated_coverage, + act_type=self.activation, + name="%scov_activation" % self.prefix) + + return mask_coverage(coverage, source_length) + + return update_coverage + + +def mask_coverage(coverage: mx.sym.Symbol, source_length: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Masks all coverage scores that are outside the actual sequence. + + :param coverage: Input coverage vector. Shape: (batch_size, seq_len, coverage_num_hidden). + :param source_length: Source length. Shape: (batch_size,). + :return: Masked coverage vector. Shape: (batch_size, seq_len, coverage_num_hidden). + """ + coverage = mx.sym.SwapAxis(data=coverage, dim1=0, dim2=1) + coverage = mx.sym.SequenceMask(data=coverage, use_sequence_length=True, sequence_length=source_length) + coverage = mx.sym.SwapAxis(data=coverage, dim1=0, dim2=1) + return coverage diff --git a/sockeye/data_io.py b/sockeye/data_io.py new file mode 100644 index 000000000..436265afa --- /dev/null +++ b/sockeye/data_io.py @@ -0,0 +1,467 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Implements data iterators and I/O related functions for sequence-to-sequence models. +""" +import bisect +import gzip +import logging +import random +from typing import Dict, Iterator, Iterable, List, NamedTuple, Optional, Tuple + +import mxnet as mx +import numpy as np + +import sockeye.constants as C + +logger = logging.getLogger(__name__) + + +def define_buckets(max_seq_len: int, step=10) -> List[int]: + """ + Returns a list of integers defining bucket boundaries. + + :param max_seq_len: Maximum bucket size. + :param step: Distance between buckets. + :return: List of bucket sizes. + """ + step = min(step, max_seq_len) + return [bucket_len for bucket_len in range(step, max_seq_len + step, step)] + + +def define_parallel_buckets(max_seq_len: int, bucket_width=10, length_ratio=1.0) -> List[Tuple[int, int]]: + """ + Returns (src,trg) buckets in steps of 10. + + :param max_seq_len: Maximum bucket size. + :param bucket_width: Width of buckets. + :param length_ratio: Length ratio between source and target data. + """ + step = min(bucket_width, max_seq_len) + return list(zip(define_buckets(max_seq_len, step), define_buckets(max_seq_len, int(step * length_ratio)))) + + +def get_bucket(seq_len: int, buckets: List[int]) -> Optional[int]: + """ + Given sequence length and a list of buckets, return corresponding bucket. + + :param seq_len: Sequence length. + :param buckets: List of buckets. + :return: Chosen bucket. + """ + bucket_idx = bisect.bisect_left(buckets, seq_len) + if bucket_idx == len(buckets): + return None + return buckets[bucket_idx] + + +def get_data_iter(data_source: str, data_target: str, + vocab_source: Dict[str, int], vocab_target: Dict[str, int], + batch_size: int, + fill_up: str, + max_seq_len: int, + bucketing: bool, + bucket_width: int) -> 'ParallelBucketSentenceIter': + """ + Returns a ParallelBucketSentenceIter for bucketed data I/O. + + :param data_source: Path to source data. + :param data_target: Path to target data. + :param vocab_source: Source vocabulary. + :param vocab_target: Target vocabulary. + :param batch_size: Batch size. + :param fill_up: Fill-up strategy for buckets. + :param max_seq_len: Maximum sequence length. + :param bucketing: Whether to use bucketing. + :param bucket_width: Size of buckets. + :return: Data iterator for parallel data. + """ + source_sentences = read_sentences(data_source, vocab_source, add_bos=False) + target_sentences = read_sentences(data_target, vocab_target, add_bos=True) + assert len(source_sentences) == len(target_sentences) + eos_id = vocab_target[C.EOS_SYMBOL] + + length_ratio = sum(len(s) / float(len(t)) for s, t in zip(source_sentences, target_sentences)) / len( + source_sentences) + logger.info("Average length ratio between src & trg: %.2f", length_ratio) + + buckets = define_parallel_buckets(max_seq_len, bucket_width, length_ratio) if bucketing else [ + (max_seq_len, max_seq_len)] + return ParallelBucketSentenceIter(source_sentences, target_sentences, buckets, batch_size, eos_id, C.PAD_ID, + vocab_target[C.UNK_SYMBOL], fill_up=fill_up) + + +def get_training_data_iters(source: str, target: str, + validation_source: str, validation_target: str, + vocab_source: Dict[str, int], vocab_target: Dict[str, int], + batch_size: int, + fill_up: str, + max_seq_len: int, + bucketing: bool, + bucket_width: int) -> Tuple['ParallelBucketSentenceIter', 'ParallelBucketSentenceIter']: + """ + Returns data iterators for training and validation data. + + :param source: Path to source training data. + :param target: Path to target training data. + :param validation_source: Path to source validation data. + :param validation_target: Path to target validation data. + :param vocab_source: Source vocabulary. + :param vocab_target: Target vocabulary. + :param batch_size: Batch size. + :param fill_up: Fill-up strategy for buckets. + :param max_seq_len: Maximum sequence length. + :param bucketing: Whether to use bucketing. + :param bucket_width: Size of buckets. + :return: Data iterators for parallel data. + """ + logger.info("Creating train data iterator") + train_iter = get_data_iter(source, target, vocab_source, vocab_target, batch_size, fill_up, + max_seq_len, bucketing, bucket_width=bucket_width) + logger.info("Creating validation data iterator") + eval_iter = get_data_iter(validation_source, validation_target, vocab_source, vocab_target, batch_size, fill_up, + max_seq_len, bucketing, bucket_width=bucket_width) + return train_iter, eval_iter + + +DataInfo = NamedTuple('DataInfo', [ + ('source', str), + ('target', str), + ('validation_source', str), + ('validation_target', str), + ('vocab_source', str), + ('vocab_target', str), +]) +""" +Tuple to collect data information for training. + +:param source: Path to training source. +:param target: Path to training target. +:param validation_source: Path to validation source. +:param validation_target: Path to validation target. +:param vocab_source: Path to source vocabulary. +:param vocab_target: Path to target vocabulary. +""" + + +def smart_open(filename: str, mode="rt", ftype="auto", errors='replace'): + """ + Returns a file descriptor for filename with UTF-8 encoding. + If mode is "rt", file is opened read-only. + If ftype is "auto", uses gzip iff filename endswith .gz. + If ftype is {"gzip","gz"}, uses gzip. + + Note: encoding error handling defaults to "replace" + + :param filename: The filename to open. + :param mode: Reader mode. + :param ftype: File type. If 'auto' checks filename suffix for gz to try gzip.open + :param errors: Encoding error handling during reading. Defaults to 'replace' + :return: File descriptor + """ + if ftype == 'gzip' or ftype == 'gz' or (ftype == 'auto' and filename.endswith(".gz")): + return gzip.open(filename, mode=mode, encoding='utf-8', errors=errors) + else: + return open(filename, mode=mode, encoding='utf-8', errors=errors) + + +def read_content(path: str, limit=None) -> Iterator[List[str]]: + """ + Returns a list of tokens for each line in path up to a limit. + + :param path: Path to files containing sentences. + :param limit: How many lines to read from path. + :return: Iterator over lists of words. + """ + with smart_open(path) as indata: + for i, line in enumerate(indata): + if limit is not None and i == limit: + break + yield list(get_tokens(line)) + + +def get_tokens(line: str) -> Iterator[str]: + """ + Yields tokens from input string. + + :param line: Input string. + :return: Iterator over tokens. + """ + for token in line.rstrip().split(): + if len(token) > 0: + yield token + + +def tokens2ids(tokens: Iterable[str], vocab: Dict[str, int]) -> List[int]: + """ + Returns sequence of ids given a sequence of tokens and vocab. + + :param tokens: List of tokens. + :param vocab: Vocabulary (containing UNK symbol). + :return: List of word ids. + """ + return [vocab.get(w, vocab[C.UNK_SYMBOL]) for w in tokens] + + +def read_sentences(path: str, vocab: Dict[str, int], add_bos=False, limit=None) -> List[List[int]]: + """ + Reads sentences from path and creates word id sentences. + + :param path: Path to read data from. + :param vocab: Vocabulary mapping. + :param add_bos: Whether to add Beginning-Of-Sentence (BOS) symbol. + :param limit: Read limit. + :return: List of integer sequences. + """ + assert C.UNK_SYMBOL in vocab + assert C.UNK_SYMBOL in vocab + assert vocab[C.PAD_SYMBOL] == C.PAD_ID + assert C.BOS_SYMBOL in vocab + assert C.EOS_SYMBOL in vocab + sentences = [] + for sentence_tokens in read_content(path, limit): + sentence = tokens2ids(sentence_tokens, vocab) + assert len(sentence) > 0, "Empty sentence in file %s" % path + if add_bos: + sentence.insert(0, vocab[C.BOS_SYMBOL]) + sentences.append(sentence) + logger.info("%d sentences loaded from '%s'", len(sentences), path) + return sentences + + +# TODO: consider more memory-efficient data reading (load from disk on demand) +# TODO: consider using HDF5 format for language data +class ParallelBucketSentenceIter(mx.io.DataIter): + """ + A Bucket sentence iterator for parallel data. Randomly shuffles the data after every call to reset(). + Data is stored in NDArrays for each epoch for fast indexing during iteration. + + :param source_sentences: List of source sentences (integer-coded). + :param target_sentences: List of target sentences (integer-coded). + :param buckets: List of buckets. + :param batch_size: Batch_size of generated data batches. + Incomplete batches are discarded if fill_up == None, or filled up according to the fill_up strategy. + :param fill_up: If not None, fill up bucket data to a multiple of batch_size to avoid discarding incomplete batches. + for each bucket. If set to 'replicate', sample examples from the bucket and use them to fill up. + :param eos_id: Word id for end-of-sentence. + :param pad_id: Word id for padding symbols. + :param unk_id: Word id for unknown symbols. + :param dtype: Data type of generated NDArrays. + """ + + def __init__(self, + source_sentences: List[List[int]], + target_sentences: List[List[int]], + buckets: List[Tuple[int, int]], + batch_size: int, + eos_id: int, + pad_id: int, + unk_id: int, + fill_up: Optional[str] = None, + source_data_name=C.SOURCE_NAME, + source_data_length_name=C.SOURCE_LENGTH_NAME, + target_data_name=C.TARGET_NAME, + label_name=C.TARGET_LABEL_NAME, + dtype='float32'): + super(ParallelBucketSentenceIter, self).__init__() + + self.buckets = list(buckets) + self.buckets.sort() + self.default_bucket_key = max(self.buckets) + self.batch_size = batch_size + self.eos_id = eos_id + self.pad_id = pad_id + self.unk_id = unk_id + self.dtype = dtype + self.source_data_name = source_data_name + self.source_data_length_name = source_data_length_name + self.target_data_name = target_data_name + self.label_name = label_name + self.fill_up = fill_up + + # TODO: consider avoiding explicitly creating length and label arrays to save host memory + self.data_source = [[] for _ in self.buckets] + self.data_length = [[] for _ in self.buckets] + self.data_target = [[] for _ in self.buckets] + self.data_label = [[] for _ in self.buckets] + + # assign sentence pairs to buckets + self._assign_to_buckets(source_sentences, target_sentences) + + # convert to single numpy array for each bucket + self._convert_to_array() + + self.provide_data = [ + mx.io.DataDesc(name=source_data_name, shape=(batch_size, self.default_bucket_key[0]), layout=C.BATCH_MAJOR), + mx.io.DataDesc(name=source_data_length_name, shape=(batch_size,), layout=C.BATCH_MAJOR), + mx.io.DataDesc(name=target_data_name, shape=(batch_size, self.default_bucket_key[1]), layout=C.BATCH_MAJOR)] + self.provide_label = [ + mx.io.DataDesc(name=label_name, shape=(self.batch_size, self.default_bucket_key[1]), layout=C.BATCH_MAJOR)] + + self.data_names = [self.source_data_name, self.source_data_length_name, self.target_data_name] + self.label_names = [self.label_name] + + # create index tuples (i,j) into buckets: i := bucket index ; j := row index of bucket array + self.idx = [] + for i, buck in enumerate(self.data_source): + rest = len(buck) % batch_size + if rest > 0: + logger.info("Discarding %d samples from bucket %s due to incomplete batch", rest, self.buckets[i]) + idxs = [(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)] + self.idx.extend(idxs) + self.curr_idx = 0 + + # holds NDArrays + self.nd_source = [] + self.nd_length = [] + self.nd_target = [] + self.nd_label = [] + + self.reset() + + @staticmethod + def _get_bucket(buckets, length_source, length_target): + """ + Determines bucket given source and target length. + """ + bucket = None, None + for j, (source_bkt, target_bkt) in enumerate(buckets): + if source_bkt >= length_source and target_bkt >= length_target: + bucket = j, (source_bkt, target_bkt) + break + return bucket + + def _assign_to_buckets(self, source_sentences, target_sentences): + ndiscard = 0 + tokens_source = 0 + tokens_target = 0 + num_of_unks_source = 0 + num_of_unks_target = 0 + for source, target in zip(source_sentences, target_sentences): + tokens_source += len(source) + tokens_target += len(target) + num_of_unks_source += source.count(self.unk_id) + num_of_unks_source += target.count(self.unk_id) + + buck_idx, buck = self._get_bucket(self.buckets, len(source), len(target)) + if buck is None: + ndiscard += 1 + continue + + buff_source = np.full((buck[0],), self.pad_id, dtype=self.dtype) + buff_target = np.full((buck[1],), self.pad_id, dtype=self.dtype) + buff_label = np.full((buck[1],), self.pad_id, dtype=self.dtype) + buff_source[:len(source)] = source + buff_target[:len(target)] = target + buff_label[:len(target)] = target[1:] + [self.eos_id] + self.data_source[buck_idx].append(buff_source) + self.data_length[buck_idx].append(len(source)) + self.data_target[buck_idx].append(buff_target) + self.data_label[buck_idx].append(buff_label) + + logger.info("Source words: %d", tokens_source) + logger.info("Target words: %d", tokens_target) + logger.info("Vocab coverage source: %.0f%%", (1 - num_of_unks_source / tokens_source) * 100) + logger.info("Vocab coverage target: %.0f%%", (1 - num_of_unks_target / tokens_target) * 100) + logger.info('Total: {0} samples in {1} buckets'.format(len(self.data_source), len(self.buckets))) + nsamples = 0 + for bkt, buck in zip(self.buckets, self.data_length): + logger.info("bucket of {0} : {1} samples".format(bkt, len(buck))) + nsamples += len(buck) + assert nsamples > 0, "0 data points available in the data iterator. " \ + "%d data points have been discarded because they didn't fit into any bucket. Consider " \ + "increasing the --max-seq-len to fit your data." % ndiscard + logger.info("%d sentence pairs out of buckets", ndiscard) + logger.info("fill up mode: %s", self.fill_up) + logger.info("") + + def _convert_to_array(self): + for i in range(len(self.data_source)): + self.data_source[i] = np.asarray(self.data_source[i], dtype=self.dtype) + self.data_length[i] = np.asarray(self.data_length[i], dtype=self.dtype) + self.data_target[i] = np.asarray(self.data_target[i], dtype=self.dtype) + self.data_label[i] = np.asarray(self.data_label[i], dtype=self.dtype) + + n = len(self.data_source[i]) + if n % self.batch_size != 0: + buck_shape = self.buckets[i] + rest = self.batch_size - n % self.batch_size + if self.fill_up == 'pad': + raise NotImplementedError + elif self.fill_up == 'replicate': + logger.info( + "Replicating %d random examples from bucket %s to size it to multiple of batch size %d", rest, + buck_shape, self.batch_size) + random_indices = np.random.randint(self.data_source[i].shape[0], size=rest) + + self.data_source[i] = np.concatenate((self.data_source[i], self.data_source[i][random_indices, :]), + axis=0) + self.data_length[i] = np.concatenate((self.data_length[i], self.data_length[i][random_indices]), + axis=0) + self.data_target[i] = np.concatenate((self.data_target[i], self.data_target[i][random_indices, :]), + axis=0) + self.data_label[i] = np.concatenate((self.data_label[i], self.data_label[i][random_indices, :]), + axis=0) + + def reset(self): + """ + Resets and reshuffles the data. + """ + self.curr_idx = 0 + # shuffle indices + random.shuffle(self.idx) + + self.nd_source = [] + self.nd_length = [] + self.nd_target = [] + self.nd_label = [] + for i in range(len(self.data_source)): + # shuffle indices within each bucket + indices = np.random.permutation(len(self.data_source[i])) + self.nd_source.append(mx.nd.array(self.data_source[i].take(indices, axis=0), dtype=self.dtype)) + self.nd_length.append(mx.nd.array(self.data_length[i].take(indices, axis=0), dtype=self.dtype)) + self.nd_target.append(mx.nd.array(self.data_target[i].take(indices, axis=0), dtype=self.dtype)) + self.nd_label.append(mx.nd.array(self.data_label[i].take(indices, axis=0), dtype=self.dtype)) + + def iter_next(self) -> bool: + """ + True if iterator can return another batch + """ + return self.curr_idx != len(self.idx) + + def next(self) -> mx.io.DataBatch: + """ + Returns the next batch from the data iterator. + """ + if not self.iter_next(): + raise StopIteration + + i, j = self.idx[self.curr_idx] + self.curr_idx += 1 + + source = self.nd_source[i][j:j + self.batch_size] + length = self.nd_length[i][j:j + self.batch_size] + target = self.nd_target[i][j:j + self.batch_size] + data = [source, length, target] + label = [self.nd_label[i][j:j + self.batch_size]] + + provide_data = [mx.io.DataDesc(name=n, shape=x.shape, layout=C.BATCH_MAJOR) for n, x in + zip(self.data_names, data)] + provide_label = [mx.io.DataDesc(name=n, shape=x.shape, layout=C.BATCH_MAJOR) for n, x in + zip(self.label_names, label)] + + # TODO: num pad examples is not set here if fillup strategy would be padding + return mx.io.DataBatch(data, label, + pad=0, index=None, bucket_key=self.buckets[i], + provide_data=provide_data, provide_label=provide_label) diff --git a/sockeye/decoder.py b/sockeye/decoder.py new file mode 100644 index 000000000..d11ac71cc --- /dev/null +++ b/sockeye/decoder.py @@ -0,0 +1,452 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Sequence-to-Sequence Decoders +""" +from typing import Callable, List, NamedTuple, Tuple +from typing import Optional + +import mxnet as mx + +import sockeye.attention +import sockeye.constants as C +import sockeye.coverage +import sockeye.encoder +import sockeye.lexicon +import sockeye.rnn +import sockeye.utils + + +def get_decoder(num_embed: int, + vocab_size: int, + num_layers: int, + rnn_num_hidden: int, + attention: sockeye.attention.Attention, + cell_type: str, residual: bool, + forget_bias: float, + dropout=0., + weight_tying: bool = False, + lexicon: Optional[sockeye.lexicon.Lexicon] = None, + context_gating: bool = False) -> 'Decoder': + """ + Returns a StackedRNNDecoder with the following properties. + + :param num_embed: Target word embedding size. + :param vocab_size: Target vocabulary size. + :param num_layers: Number of RNN layers in the decoder. + :param rnn_num_hidden: Number of hidden units per decoder RNN cell. + :param attention: Attention model. + :param cell_type: RNN cell type. + :param residual: Whether to add residual connections to multi-layer RNNs. + :param forget_bias: Initial value of the RNN forget bias. + :param dropout: Dropout probability for decoder RNN. + :param weight_tying: Whether to share embedding and prediction parameter matrices. + :param lexicon: Optional Lexicon. + :param context_gating: Whether to use context gating. + :return: Decoder instance. + """ + return StackedRNNDecoder(rnn_num_hidden, + attention, + vocab_size, + num_embed, + num_layers, + weight_tying=weight_tying, + dropout=dropout, + cell_type=cell_type, + residual=residual, + forget_bias=forget_bias, + lexicon=lexicon, + context_gating=context_gating) + + +class Decoder: + """ + Generic decoder interface. + """ + + def get_num_hidden(self) -> int: + """ + Returns the representation size of this decoder. + + :raises: NotImplementedError + """ + raise NotImplementedError() + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this decoder. + + :raises: NotImplementedError + """ + raise NotImplementedError() + + +DecoderState = NamedTuple('DecoderState', [ + ('hidden', mx.sym.Symbol), + ('layer_states', List[mx.sym.Symbol]), +]) +""" +Decoder state. + +:param hidden: Hidden state after attention mechanism. Shape: (batch_size, num_hidden). +:param layer_states: Hidden states for RNN layers of StackedRNNDecoder. Shape: List[(batch_size, rnn_num_hidden)] + +""" + + +class StackedRNNDecoder(Decoder): + """ + Class to generate the decoder part of the computation graph in sequence-to-sequence models. + The architecture is based on Luong et al, 2015: Effective Approaches to Attention-based Neural Machine Translation + + :param num_hidden: Number of hidden units in decoder RNN. + :param attention: Attention model. + :param target_vocab_size: Size of target vocabulary. + :param num_target_embed: Size of target word embedding. + :param num_layers: Number of decoder RNN layers. + :param prefix: Decoder symbol prefix. + :param weight_tying: Whether to share embedding and prediction parameter matrices. + :param dropout: Dropout probability for decoder RNN. + :param cell_type: RNN cell type. + :param residual: Whether to add residual connections to multi-layer RNNs. + :param forget_bias: Initial value of the RNN forget bias. + :param lexicon: Optional Lexicon. + :param context_gating: Whether to use context gating. + """ + + def __init__(self, + num_hidden: int, + attention: sockeye.attention.Attention, + target_vocab_size: int, + num_target_embed: int, + num_layers=1, + prefix=C.DECODER_PREFIX, + weight_tying=False, + dropout=0.0, + cell_type: str = C.LSTM_TYPE, + residual: bool = False, + forget_bias: float = 0.0, + lexicon: Optional[sockeye.lexicon.Lexicon] = None, + context_gating: bool = False): + # TODO: implement variant without input feeding + self.num_layers = num_layers + self.prefix = prefix + self.dropout = dropout + self.num_hidden = num_hidden + self.attention = attention + self.target_vocab_size = target_vocab_size + self.num_target_embed = num_target_embed + self.context_gating = context_gating + if self.context_gating: + self.gate_w = mx.sym.Variable("%sgate_weight" % prefix) + self.gate_b = mx.sym.Variable("%sgate_bias" % prefix) + self.mapped_rnn_output_w = mx.sym.Variable("%smapped_rnn_output_weight" % prefix) + self.mapped_rnn_output_b = mx.sym.Variable("%smapped_rnn_output_bias" % prefix) + self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" % prefix) + self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" % prefix) + + # Decoder stacked RNN + self.rnn = sockeye.rnn.get_stacked_rnn(cell_type, num_hidden, num_layers, dropout, prefix, residual, + forget_bias) + + # Decoder parameters + # RNN init state parameters + self._create_layer_parameters() + # Hidden state parameters + self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix) + self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix) + # Embedding & output parameters + self.embedding = sockeye.encoder.Embedding(self.num_target_embed, self.target_vocab_size, + prefix=C.TARGET_EMBEDDING_PREFIX, dropout=0.) # TODO dropout? + if weight_tying: + assert self.num_hidden == self.num_target_embed, \ + "Weight tying requires target embedding size and rnn_num_hidden to be equal" + self.cls_w = self.embedding.embed_weight + else: + self.cls_w = mx.sym.Variable("%scls_weight" % prefix) + self.cls_b = mx.sym.Variable("%scls_bias" % prefix) + + self.lexicon = lexicon + + def get_num_hidden(self) -> int: + """ + Returns the representation size of this decoder. + + :return: Number of hidden units. + """ + return self.num_hidden + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this decoder. + """ + return [self.rnn] + + def _create_layer_parameters(self): + """ + Creates parameters for encoder last state transformation into decoder layer initial states. + """ + self.init_ws, self.init_bs = [], [] + for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): + self.init_ws.append(mx.sym.Variable("%senc2decinit_%d_weight" % (self.prefix, state_idx))) + self.init_bs.append(mx.sym.Variable("%senc2decinit_%d_bias" % (self.prefix, state_idx))) + + def create_layer_input_variables(self, batch_size: int) \ + -> Tuple[List[mx.sym.Symbol], List[mx.io.DataDesc], List[str]]: + """ + Creates RNN layer state variables. Used for inference. + Returns nested list of layer_states variables, flat list of layer shapes (for module binding), + and a flat list of layer names (for BucketingModule's data names) + + :param batch_size: Batch size. + """ + layer_states, layer_shapes, layer_names = [], [], [] + for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): + name = "%senc2decinit_%d" % (self.prefix, state_idx) + layer_states.append(mx.sym.Variable(name)) + layer_shapes.append(mx.io.DataDesc(name=name, shape=(batch_size, init_num_hidden), layout=C.BATCH_MAJOR)) + layer_names.append(name) + return layer_states, layer_shapes, layer_names + + def compute_init_states(self, + source_encoded: mx.sym.Symbol, + source_length: mx.sym.Symbol) -> DecoderState: + """ + Computes initial states of the decoder, hidden state, and one for each RNN layer. + Init states for RNN layers are computed using 1 non-linear FC with the last state of the encoder as input. + + :param source_encoded: Concatenated encoder states. Shape: (source_seq_len, batch_size, encoder_num_hidden). + :param source_length: Lengths of source sequences. Shape: (batch_size,). + :return: Decoder state. + """ + # initial decoder hidden state + hidden = mx.sym.tile(data=mx.sym.expand_dims(data=source_length * 0, axis=1), reps=(1, self.num_hidden)) + # initial states for each layer + layer_states = [] + for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): + init = mx.sym.FullyConnected(data=mx.sym.SequenceLast(data=source_encoded, + sequence_length=source_length, + use_sequence_length=True), + num_hidden=init_num_hidden, + weight=self.init_ws[state_idx], + bias=self.init_bs[state_idx], + name="%senc2decinit_%d" % (self.prefix, state_idx)) + init = mx.sym.Activation(data=init, act_type="tanh", + name="%senc2dec_inittanh_%d" % (self.prefix, state_idx)) + layer_states.append(init) + return DecoderState(hidden, layer_states) + + def _step(self, + word_vec_prev: mx.sym.Symbol, + state: DecoderState, + attention_func: Callable, + attention_state: sockeye.attention.AttentionState, + seq_idx: int = 0) -> Tuple[DecoderState, sockeye.attention.AttentionState]: + + """ + Performs single-time step in the RNN, given previous word vector, previous hidden state, attention function, + and RNN layer states. + + :param word_vec_prev: Embedding of previous target word. Shape: (batch_size, num_target_embed). + :param state: Decoder state consisting of hidden and layer states. + :param attention_func: Attention function to produce context vector. + :param attention_state: Previous attention state. + :param seq_idx: Decoder time step. + :return: (new decoder state, updated attention state). + """ + # (1) RNN step + # concat previous word embedding and previous hidden state + rnn_input = mx.sym.concat(word_vec_prev, state.hidden, dim=1, + name="%sconcat_target_context_t%d" % (self.prefix, seq_idx)) + # rnn_output: (batch_size, rnn_num_hidden) + # next_layer_states: num_layers * [batch_size, rnn_num_hidden] + rnn_output, layer_states = self.rnn(rnn_input, state.layer_states) + + # (2) Attention step + attention_input = self.attention.make_input(seq_idx, word_vec_prev, rnn_output) + attention_state = attention_func(attention_input, attention_state) + + # (3) Combine context with hidden state + if self.context_gating: + # context: (batch_size, encoder_num_hidden) + # gate: (batch_size, rnn_num_hidden) + gate = mx.sym.FullyConnected(data=mx.sym.concat(word_vec_prev, rnn_output, attention_state.context, dim=1), + num_hidden=self.num_hidden, weight=self.gate_w, bias=self.gate_b) + gate = mx.sym.Activation(data=gate, act_type="sigmoid", + name="%sgate_activation_t%d" % (self.prefix, seq_idx)) + + # mapped_rnn_output: (batch_size, rnn_num_hidden) + mapped_rnn_output = mx.sym.FullyConnected(data=rnn_output, + num_hidden=self.num_hidden, + weight=self.mapped_rnn_output_w, + bias=self.mapped_rnn_output_b, + name="%smapped_rnn_output_fc_t%d" % (self.prefix, seq_idx)) + # mapped_context: (batch_size, rnn_num_hidden) + mapped_context = mx.sym.FullyConnected(data=attention_state.context, + num_hidden=self.num_hidden, + weight=self.mapped_context_w, + bias=self.mapped_context_b, + name="%smapped_context_fc_t%d" % (self.prefix, seq_idx)) + + # hidden: (batch_size, rnn_num_hidden) + hidden = mx.sym.Activation(data=gate * mapped_rnn_output + (1 - gate) * mapped_context, + act_type="tanh", + name="%snext_hidden_t%d" % (self.prefix, seq_idx)) + + else: + # hidden: (batch_size, rnn_num_hidden) + hidden = mx.sym.FullyConnected(data=mx.sym.concat(rnn_output, attention_state.context, dim=1), + # use same number of hidden states as RNN + num_hidden=self.num_hidden, + weight=self.hidden_w, + bias=self.hidden_b) + # hidden: (batch_size, rnn_num_hidden) + hidden = mx.sym.Activation(data=hidden, act_type="tanh", + name="%snext_hidden_t%d" % (self.prefix, seq_idx)) + + return DecoderState(hidden, layer_states), attention_state + + def decode(self, + source_encoded: mx.sym.Symbol, + source_seq_len: int, + source_length: mx.sym.Symbol, + target: mx.sym.Symbol, + target_seq_len: int, + source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + """ + Returns decoder logits with batch size and target sequence length collapsed into a single dimension. + + :param source_encoded: Concatenated encoder states. Shape: (source_seq_len, batch_size, encoder_num_hidden). + :param source_seq_len: Maximum source sequence length. + :param source_length: Lengths of source sequences. Shape: (batch_size,). + :param target: Target sequence. Shape: (batch_size, target_seq_len). + :param target_seq_len: Maximum target sequence length. + :param source_lexicon: Lexical biases for current sentence. + Shape: (batch_size, target_vocab_size, source_seq_len) + :return: Logits of next-word predictions for target sequence. + Shape: (batch_size * target_seq_len, target_vocab_size) + """ + # process encoder states + source_encoded_batch_major = mx.sym.swapaxes(source_encoded, dim1=0, dim2=1, name='source_encoded_batch_major') + + # embed and slice target words + # target_embed: (batch_size, target_seq_len, num_target_embed) + target_embed = self.embedding.encode(target, None, target_seq_len) + # target_embed: target_seq_len * (batch_size, num_target_embed) + target_embed = mx.sym.split(data=target_embed, num_outputs=target_seq_len, axis=1, squeeze_axis=True) + + # get recurrent attention function conditioned on source + attention_func = self.attention.on(source_encoded_batch_major, source_length, source_seq_len) + attention_state = self.attention.get_initial_state(source_length, source_seq_len) + + # initialize decoder states + # hidden: (batch_size, rnn_num_hidden) + # layer_states: List[(batch_size, state_num_hidden] + state = self.compute_init_states(source_encoded, source_length) + + # hidden_all: target_seq_len * (batch_size, 1, rnn_num_hidden) + hidden_all = [] + + # TODO: possible alternative: feed back the context vector instead of the hidden (see lamtram) + + lexical_biases = [] + + self.rnn.reset() + + for seq_idx in range(target_seq_len): + # hidden: (batch_size, rnn_num_hidden) + state, attention_state = self._step(target_embed[seq_idx], + state, + attention_func, + attention_state, + seq_idx) + + # hidden_expanded: (batch_size, 1, rnn_num_hidden) + hidden_all.append(mx.sym.expand_dims(data=state.hidden, axis=1)) + + if source_lexicon is not None: + assert self.lexicon is not None, "source_lexicon should not be None if no lexicon available" + lexical_biases.append(self.lexicon.calculate_lex_bias(source_lexicon, attention_state.probs)) + + # concatenate along time axis + # hidden_concat: (batch_size, target_seq_len, rnn_num_hidden) + hidden_concat = mx.sym.concat(*hidden_all, dim=1, name="%shidden_concat" % self.prefix) + # hidden_concat: (batch_size * target_seq_len, rnn_num_hidden) + hidden_concat = mx.sym.reshape(data=hidden_concat, shape=(-1, self.num_hidden)) + + # logits: (batch_size * target_seq_len, target_vocab_size) + logits = mx.sym.FullyConnected(data=hidden_concat, num_hidden=self.target_vocab_size, + weight=self.cls_w, bias=self.cls_b, name=C.LOGITS_NAME) + + if source_lexicon is not None: + # lexical_biases_concat: (batch_size, target_seq_len, target_vocab_size) + lexical_biases_concat = mx.sym.concat(*lexical_biases, dim=1, name='lex_bias_concat') + # lexical_biases_concat: (batch_size * target_seq_len, target_vocab_size) + lexical_biases_concat = mx.sym.reshape(data=lexical_biases_concat, shape=(-1, self.target_vocab_size)) + logits = mx.sym.broadcast_add(lhs=logits, rhs=lexical_biases_concat, + name='%s_plus_lex_bias' % C.LOGITS_NAME) + + return logits + + def predict(self, + word_id_prev: mx.sym.Symbol, + state_prev: DecoderState, + attention_func: Callable, + attention_state_prev: sockeye.attention.AttentionState, + source_lexicon: Optional[mx.sym.Symbol] = None, + softmax_temperature: Optional[float] = None) -> Tuple[mx.sym.Symbol, + DecoderState, + sockeye.attention.AttentionState]: + """ + Given previous word id, attention function, previous hidden state and RNN layer states, + returns Softmax predictions (not a loss symbol), next hidden state, and next layer + states. Used for inference. + + :param word_id_prev: Previous target word id. Shape: (1,). + :param state_prev: Previous decoder state consisting of hidden and layer states. + :param attention_func: Attention function to produce context vector. + :param attention_state_prev: Previous attention state. + :param source_lexicon: Lexical biases for current sentence. + Shape: (batch_size, target_vocab_size, source_seq_len). + :param softmax_temperature: Optional parameter to control steepness of softmax distribution. + :return: (predicted next-word distribution, decoder state, attention state). + """ + # target side embedding + word_vec_prev = self.embedding.encode(word_id_prev, None, 1) + + # state.hidden: (batch_size, rnn_num_hidden) + # attention_state.dynamic_source: (batch_size, source_seq_len, coverage_num_hidden) + # attention_state.probs: (batch_size, source_seq_len) + state, attention_state = self._step(word_vec_prev, + state_prev, + attention_func, + attention_state_prev) + + # logits: (batch_size, target_vocab_size) + logits = mx.sym.FullyConnected(data=state.hidden, num_hidden=self.target_vocab_size, + weight=self.cls_w, bias=self.cls_b, name=C.LOGITS_NAME) + + if source_lexicon is not None: + assert self.lexicon is not None + # lex_bias: (batch_size, 1, target_vocab_size) + lex_bias = self.lexicon.calculate_lex_bias(source_lexicon, attention_state.probs) + # lex_bias: (batch_size, target_vocab_size) + lex_bias = mx.sym.reshape(data=lex_bias, shape=(-1, self.target_vocab_size)) + logits = mx.sym.broadcast_add(lhs=logits, rhs=lex_bias, name='%s_plus_lex_bias' % C.LOGITS_NAME) + + if softmax_temperature is not None: + logits /= softmax_temperature + + softmax_out = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME) + return softmax_out, state, attention_state diff --git a/sockeye/embeddings.py b/sockeye/embeddings.py new file mode 100644 index 000000000..1b54765c8 --- /dev/null +++ b/sockeye/embeddings.py @@ -0,0 +1,116 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Command-line tool to inspect model embeddings. +""" +import argparse +import sys +from typing import List, Tuple + +import mxnet as mx +import numpy as np + +import sockeye.constants as C +import sockeye.translate +import sockeye.utils +import sockeye.vocab +from sockeye.log import setup_main_logger + +logger = setup_main_logger(__name__, file_logging=False) + + +def compute_sims(inputs: mx.nd.NDArray, normalize: bool) -> mx.nd.NDArray: + """ + Returns a matrix with pair-wise similarity scores between inputs. + Similarity score is (normalized) Euclidean distance. 'Similarity with self' is masked + to large negative value. + + :param inputs: NDArray of inputs. + :param normalize: Whether to normalize to unit-length. + :return: NDArray with pairwise similarities of same shape as inputs. + """ + if normalize: + logger.info("Normalizing embeddings to unit length") + inputs = mx.nd.L2Normalization(inputs, mode='instance') + sims = mx.nd.dot(inputs, inputs, transpose_b=True) + sims_np = sims.asnumpy() + np.fill_diagonal(sims_np, -9999999.) + sims = mx.nd.array(sims_np) + return sims + + +def nearest_k(similarity_matrix: mx.nd.NDArray, + query_word_id: int, + k: int, + gamma: float = 1.0) -> List[Tuple[int, float]]: + """ + Returns values and indices of k items with largest similarity. + + :param similarity_matrix: Similarity matrix. + :param query_word_id: Query word id. + :param k: Number of closest items to retrieve. + :param gamma: Parameter to control distribution steepness. + :return: List of indices and values of k nearest elements. + """ + values, indices = mx.nd.topk(mx.nd.softmax(similarity_matrix[query_word_id] / gamma), k=k, ret_typ='both') + return zip(indices.asnumpy(), values.asnumpy()) + + +def main(): + """ + Command-line tool to inspect model embeddings. + """ + params = argparse.ArgumentParser(description='Shows nearest neighbours of input tokens in the embedding space.') + params.add_argument('--params', '-p', required=True, help='params file to read parameters from') + params.add_argument('--vocab', '-v', required=True, help='vocab file') + params.add_argument('--side', '-s', required=True, choices=['source', 'target'], help='what embeddings to look at') + params.add_argument('--norm', '-n', action='store_true', help='normalize embeddings to unit length') + params.add_argument('-k', type=int, default=5, help='Number of neighbours to print') + params.add_argument('--gamma', '-g', type=float, default=1.0, help='Softmax distribution steepness.') + args = params.parse_args() + + logger.info("Arguments: %s", args) + + vocab = sockeye.vocab.vocab_from_pickle(args.vocab) + vocab_inv = sockeye.vocab.reverse_vocab(vocab) + + params, _ = sockeye.utils.load_params(args.params) + weights = params[C.SOURCE_EMBEDDING_PREFIX + "weight"] + if args.side == 'target': + weights = params[C.TARGET_EMBEDDING_PREFIX + "weight"] + logger.info("Embedding size: %d", weights.shape[1]) + + sims = compute_sims(weights, args.norm) + + # weights (vocab, num_target_embed) + assert weights.shape[0] == len(vocab), "vocab and embeddings matrix do not match: %d vs. %d" % ( + weights.shape[0], len(vocab)) + + for line in sys.stdin: + line = line.rstrip() + for token in line.split(): + if token not in vocab: + sys.stdout.write("\n") + logger.error("'%s' not in vocab", token) + continue + sys.stdout.write("Token: %s [%d]: " % (token, vocab[token])) + neighbours = nearest_k(sims, vocab[token], args.k, args.gamma) + for i, (wid, score) in enumerate(neighbours, 1): + sys.stdout.write("%d. %s[%d] %.4f\t" % (i, vocab_inv[wid], wid, score)) + sys.stdout.write("\n") + sys.stdout.flush() + + +if __name__ == '__main__': + main() diff --git a/sockeye/encoder.py b/sockeye/encoder.py new file mode 100644 index 000000000..90514707a --- /dev/null +++ b/sockeye/encoder.py @@ -0,0 +1,413 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Defines Encoder interface and various implementations. +""" +import logging +from typing import List + +import mxnet as mx + +import sockeye.constants as C +import sockeye.rnn +import sockeye.utils + +logger = logging.getLogger(__name__) + + +def get_encoder(num_embed: int, + vocab_size: int, + num_layers: int, + rnn_num_hidden: int, + cell_type: str, + residual: bool, + dropout: float, + forget_bias: float, + fused: bool = False) -> 'Encoder': + """ + Returns an encoder with embedding, batch2time-major conversion, and bidirectional RNN encoder. + If num_layers > 1, adds uni-directional RNNs. + + :param num_embed: Size of embedding layer. + :param vocab_size: Source vocabulary size. + :param num_layers: Number of encoder layers. + :param rnn_num_hidden: Number of hidden units for RNN cells. + :param cell_type: RNN cell type. + :param residual: Whether to add residual connections to multi-layered RNNs. + :param dropout: Dropout probability for encoders (RNN and embedding). + :param forget_bias: Initial value of RNN forget biases. + :param fused: Whether to use FusedRNNCell (CuDNN). Only works with GPU context. + :return: Encoder instance. + """ + # TODO give more control on encoder architecture + encoders = list() + encoders.append(Embedding(num_embed=num_embed, + vocab_size=vocab_size, + prefix=C.SOURCE_EMBEDDING_PREFIX, + dropout=dropout)) + encoders.append(BatchMajor2TimeMajor()) + + EncoderClass = FusedRecurrentEncoder if fused else RecurrentEncoder + encoders.append(BiDirectionalRNNEncoder(num_hidden=rnn_num_hidden, + num_layers=1, + dropout=dropout, + layout=C.TIME_MAJOR, + cell_type=cell_type, + EncoderClass=EncoderClass, + forget_bias=forget_bias)) + + if num_layers > 1: + encoders.append(EncoderClass(num_hidden=rnn_num_hidden, + num_layers=num_layers - 1, + dropout=0., + layout=C.TIME_MAJOR, + cell_type=cell_type, + residual=residual, + forget_bias=forget_bias)) + + return EncoderSequence(encoders) + + +class Encoder: + """ + Generic encoder interface. + """ + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + raise NotImplementedError() + + def get_num_hidden(self) -> int: + """ + Return the representation size of this encoder. + """ + raise NotImplementedError() + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this encoder. + """ + raise NotImplementedError() + + +class BatchMajor2TimeMajor(Encoder): + """ + Converts batch major data to time major + """ + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples (data_length) and maximum sequence length (seq_len). + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + with mx.AttrScope(__layout__=C.TIME_MAJOR): + return mx.sym.swapaxes(data=data, dim1=0, dim2=1) + + def get_num_hidden(self) -> int: + """ + Return the representation size of this encoder. + """ + return 0 + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this encoder. + """ + return [] + + +class Embedding(Encoder): + """ + Thin wrapper around MXNet's Embedding symbol. Works with both time- and batch-major data layouts. + + :param num_embed: Embedding size. + :param vocab_size: Source vocabulary size. + :param prefix: Name prefix for symbols of this encoder. + :param dropout: Dropout probability. + """ + + def __init__(self, num_embed: int, vocab_size: int, prefix: str, dropout: float): + self.num_embed = num_embed + self.vocab_size = vocab_size + self.prefix = prefix + self.dropout = dropout + self.embed_weight = mx.sym.Variable(prefix + "weight") + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + embedding = mx.sym.Embedding(data=data, + input_dim=self.vocab_size, + weight=self.embed_weight, + output_dim=self.num_embed, + name=self.prefix + 'embed') + if self.dropout > 0: + embedding = mx.sym.Dropout(data=embedding, p=self.dropout, name="source_embed_dropout") + return embedding + + def get_num_hidden(self) -> int: + """ + Return the representation size of this encoder. + """ + return self.num_embed + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this encoder. + """ + return [] + + +class EncoderSequence(Encoder): + """ + A sequence of encoders is itself an encoder. + + :param encoders: List of encoders. + """ + + def __init__(self, encoders: List[Encoder]): + self.encoders = encoders + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + for encoder in self.encoders: + data = encoder.encode(data, data_length, seq_len) + return data + + def get_num_hidden(self) -> int: + """ + Return the representation size of this encoder. + """ + return self.encoders[-1].get_num_hidden() + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this encoder. + """ + cells = [] + for encoder in self.encoders: + for cell in encoder.get_rnn_cells(): + cells.append(cell) + return cells + + +class RecurrentEncoder(Encoder): + """ + Uni-directional (multi-layered) recurrent encoder + """ + + def __init__(self, + num_hidden: int, + num_layers: int, + prefix: str = C.STACKEDRNN_PREFIX, + dropout: float = 0., + layout: str = C.TIME_MAJOR, + cell_type: str = C.LSTM_TYPE, + residual: bool = False, + forget_bias=0.0): + self.layout = layout + self.num_hidden = num_hidden + self.rnn = sockeye.rnn.get_stacked_rnn(cell_type, num_hidden, + num_layers, dropout, prefix, + residual, forget_bias) + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + outputs, _ = self.rnn.unroll(seq_len, inputs=data, merge_outputs=True, layout=self.layout) + + return outputs + + def get_rnn_cells(self): + """ + Returns RNNCells used in this encoder. + """ + return [self.rnn] + + def get_num_hidden(self): + """ + Return the representation size of this encoder. + """ + return self.num_hidden + + +class FusedRecurrentEncoder(Encoder): + """ + Uni-directional (multi-layered) recurrent encoder + """ + + def __init__(self, + num_hidden: int, + num_layers: int, + prefix: str = C.STACKEDRNN_PREFIX, + dropout: float = 0., + layout: str = C.TIME_MAJOR, + cell_type: str = C.LSTM_TYPE, + residual: bool = False, + forget_bias=0.0): + self.layout = layout + self.num_hidden = num_hidden + logger.warning("%s: FusedRNNCell uses standard MXNet Orthogonal initializer w/ rand_type=uniform", prefix) + self.rnn = [mx.rnn.FusedRNNCell(num_hidden, + num_layers=num_layers, + mode=cell_type, + bidirectional=False, + forget_bias=forget_bias, + prefix=prefix)] + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + outputs = data + for cell in self.rnn: + outputs, _ = cell.unroll(seq_len, inputs=outputs, merge_outputs=True, layout=self.layout) + + return outputs + + def get_rnn_cells(self): + """ + Returns RNNCells used in this encoder. + """ + return self.rnn + + def get_num_hidden(self): + """ + Return the representation size of this encoder. + """ + return self.num_hidden + + +class BiDirectionalRNNEncoder(Encoder): + """ + An encoder that runs a forward and a reverse RNN over input data. + States from both RNNs are concatenated together. + + :param num_hidden: Number of hidden units for final, concatenated encoder states. Must be a multiple of 2. + :param num_layers: Number of RNN layers. + :param prefix: Name prefix for symbols of this encoder. + :param dropout: Dropout probability. + :param layout: Input data layout. Default: time-major. + :param cell_type: RNN cell type. + :param fused: Whether to use FusedRNNCell (CuDNN). Only works with GPU context. + :param forget_bias: Initial value of RNN forget biases. + """ + + def __init__(self, + num_hidden: int, + num_layers: int, + prefix=C.BIDIRECTIONALRNN_PREFIX, + dropout: float = 0., + layout=C.TIME_MAJOR, + cell_type=C.LSTM_TYPE, + EncoderClass: Encoder = RecurrentEncoder, + forget_bias: float = 0.0): + assert num_hidden % 2 == 0, "num_hidden must be a multiple of 2 for BiDirectionalRNNEncoders." + self.num_hidden = num_hidden + if layout[0] == 'N': + logger.warning("Batch-major layout for encoder input. Consider using time-major layout for faster speed") + + # time-major layout as _encode needs to swap layout for SequenceReverse + self.forward_rnn = EncoderClass(num_hidden=num_hidden // 2, num_layers=num_layers, + prefix=prefix + C.FORWARD_PREFIX, dropout=dropout, + layout=C.TIME_MAJOR, cell_type=cell_type, + forget_bias=forget_bias) + self.reverse_rnn = EncoderClass(num_hidden=num_hidden // 2, num_layers=num_layers, + prefix=prefix + C.REVERSE_PREFIX, dropout=dropout, + layout=C.TIME_MAJOR, cell_type=cell_type, + forget_bias=forget_bias) + self.layout = layout + self.prefix = prefix + + def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param data: Input data. + :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :return: Encoded input data. + """ + if self.layout[0] == 'N': + data = mx.sym.swapaxes(data=data, dim1=0, dim2=1) + data = self._encode(data, data_length, seq_len) + if self.layout[0] == 'N': + data = mx.sym.swapaxes(data=data, dim1=0, dim2=1) + return data + + def _encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, seq_len: int) -> mx.sym.Symbol: + """ + Bidirectionally encodes time-major data. + """ + # (seq_len, batch_size, num_embed) + data_reverse = mx.sym.SequenceReverse(data=data, sequence_length=data_length, + use_sequence_length=True) + # (seq_length, batch, cell_num_hidden) + hidden_forward = self.forward_rnn.encode(data, data_length, seq_len) + # (seq_length, batch, cell_num_hidden) + hidden_reverse = self.reverse_rnn.encode(data_reverse, data_length, seq_len) + # (seq_length, batch, cell_num_hidden) + hidden_reverse = mx.sym.SequenceReverse(data=hidden_reverse, sequence_length=data_length, + use_sequence_length=True) + # (seq_length, batch, 2 * cell_num_hidden) + hidden_concat = mx.sym.concat(hidden_forward, hidden_reverse, dim=2, name="%s_rnn" % self.prefix) + + return hidden_concat + + def get_num_hidden(self) -> int: + """ + Return the representation size of this encoder. + """ + return self.num_hidden + + def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: + """ + Returns a list of RNNCells used by this encoder. + """ + return self.forward_rnn.get_rnn_cells() + self.reverse_rnn.get_rnn_cells() diff --git a/sockeye/inference.py b/sockeye/inference.py new file mode 100644 index 000000000..1c63f0c22 --- /dev/null +++ b/sockeye/inference.py @@ -0,0 +1,646 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Code for inference/translation +""" +import logging +import os +from typing import Dict, List, NamedTuple, Optional, Tuple + +import mxnet as mx +import numpy as np + +import sockeye.bleu +import sockeye.constants as C +import sockeye.data_io +import sockeye.model +import sockeye.utils +import sockeye.vocab +from sockeye.attention import AttentionState +from sockeye.decoder import DecoderState + +logger = logging.getLogger(__name__) + + +class InferenceModel(sockeye.model.SockeyeModel): + """ + InferenceModel is a SockeyeModel that supports three operations used for inference/decoding: + + (1) Encoder forward call: encode source sentence and return initial decoder states, given a bucket_key. + (2) Decoder forward call: single decoder step: predict next word. + (3) Return decoder data shapes, given a bucket key. + + :param model_folder: Folder to load model from. + :param context: MXNet context to bind modules to. + :param fused: Whether to use FusedRNNCell (CuDNN). Only works with GPU context. + :param max_input_len: Maximum input length. + :param beam_size: Beam size. + :param checkpoint: Checkpoint to load. If None, finds best parameters in model_folder. + :param softmax_temperature: Optional parameter to control steepness of softmax distribution. + """ + + def __init__(self, + model_folder: str, + context: mx.context.Context, + fused: bool, + max_input_len: Optional[int], + beam_size: int, + checkpoint: Optional[int] = None, + softmax_temperature: Optional[float] = None): + # load config & determine parameter file + super().__init__(sockeye.model.SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))) + fname_params = os.path.join(model_folder, C.PARAMS_NAME % checkpoint if checkpoint else C.PARAMS_BEST_NAME) + + if max_input_len is None: + max_input_len = self.config.max_seq_len + else: + if max_input_len != self.config.max_seq_len: + logger.warning("Model was trained with max_seq_len=%d, but using max_input_len=%d.", + self.config.max_seq_len, max_input_len) + self.max_input_len = max_input_len + + assert beam_size < self.config.vocab_target_size, 'beam size must be smaller than target vocab size' + + self.beam_size = beam_size + self.softmax_temperature = softmax_temperature + self.encoder_batch_size = 1 + self.context = context + + self._build_model_components(self.max_input_len, fused) + self.encoder_module, self.decoder_module = self._build_modules() + + self.decoder_data_shapes_cache = dict() # bucket_key -> shape cache + max_encoder_data_shapes = self._get_encoder_data_shapes(self.max_input_len) + max_decoder_data_shapes = self._get_decoder_data_shapes(self.max_input_len) + self.encoder_module.bind(data_shapes=max_encoder_data_shapes, for_training=False, grad_req="null") + self.decoder_module.bind(data_shapes=max_decoder_data_shapes, for_training=False, grad_req="null") + + self.load_params_from_file(fname_params) + self.encoder_module.init_params(arg_params=self.params, allow_missing=False) + self.decoder_module.init_params(arg_params=self.params, allow_missing=False) + + def _build_modules(self): + + # Encoder symbol & module + source = mx.sym.Variable(C.SOURCE_NAME) + source_length = mx.sym.Variable(C.SOURCE_LENGTH_NAME) + + def encoder_sym_gen(source_seq_len: int): + source_encoded = self.encoder.encode(source, source_length, seq_len=source_seq_len) + source_encoded_batch_major = mx.sym.swapaxes(source_encoded, dim1=0, dim2=1) + + # initial decoder states + decoder_hidden_init, decoder_init_states = self.decoder.compute_init_states(source_encoded, + source_length) + # initial attention state + attention_state = self.attention.get_initial_state(source_length, source_seq_len) + + data_names = [C.SOURCE_NAME, C.SOURCE_LENGTH_NAME] + label_names = [] + + symbol_group = [source_encoded_batch_major, + attention_state.dynamic_source, + decoder_hidden_init] + decoder_init_states + return mx.sym.Group(symbol_group), data_names, label_names + + encoder_module = mx.mod.BucketingModule(sym_gen=encoder_sym_gen, + default_bucket_key=self.max_input_len, + context=self.context) + + # Decoder symbol & module + source_encoded = mx.sym.Variable(C.SOURCE_ENCODED_NAME) + dynamic_source_prev = mx.sym.Variable(C.SOURCE_DYNAMIC_PREVIOUS_NAME) + word_id_prev = mx.sym.Variable(C.TARGET_PREVIOUS_NAME) + hidden_prev = mx.sym.Variable(C.HIDDEN_PREVIOUS_NAME) + layer_states, self.layer_shapes, layer_names = self.decoder.create_layer_input_variables(self.beam_size) + state = DecoderState(hidden_prev, layer_states) + attention_state = AttentionState(context=None, probs=None, dynamic_source=dynamic_source_prev) + + def decoder_sym_gen(source_seq_len: int): + data_names = [C.SOURCE_ENCODED_NAME, + C.SOURCE_DYNAMIC_PREVIOUS_NAME, + C.SOURCE_LENGTH_NAME, + C.TARGET_PREVIOUS_NAME, + C.HIDDEN_PREVIOUS_NAME] + layer_names + label_names = [] + + attention_func = self.attention.on(source_encoded, source_length, source_seq_len) + + softmax_out, next_state, next_attention_state = \ + self.decoder.predict(word_id_prev, + state, + attention_func, + attention_state, + softmax_temperature=self.softmax_temperature) + + symbol_group = [softmax_out, + next_attention_state.probs, + next_attention_state.dynamic_source, + next_state.hidden] + next_state.layer_states + return mx.sym.Group(symbol_group), data_names, label_names + + decoder_module = mx.mod.BucketingModule(sym_gen=decoder_sym_gen, + default_bucket_key=self.max_input_len, + context=self.context) + + return encoder_module, decoder_module + + @staticmethod + def _get_encoder_data_shapes(max_input_length: int) -> List[mx.io.DataDesc]: + """ + Returns data shapes of the encoder module. + Encoder batch size is always 1. + + Shapes: + source: (1, max_input_len) + length: (1,) + + :param max_input_length: Maximum input length. + :return: List of data descriptions. + """ + return [mx.io.DataDesc(name=C.SOURCE_NAME, shape=(1, max_input_length), layout=C.BATCH_MAJOR), + mx.io.DataDesc(name=C.SOURCE_LENGTH_NAME, shape=(1,), layout=C.BATCH_MAJOR)] + + def _get_decoder_data_shapes(self, input_length) -> List[mx.io.DataDesc]: + """ + Returns data shapes of the decoder module, given a bucket_key (source input length) + Caches results for bucket_keys if called iteratively. + + Shapes: + source_encoded: (beam_size, input_length, encoder_num_hidden) + source_length: (beam_size,) + prev_target_id: (beam_size,) + prev_hidden: (beam_size, decoder_num_hidden) + + :param input_length: Input length. + :return: List of data descriptions. + """ + if input_length in self.decoder_data_shapes_cache: + return self.decoder_data_shapes_cache[input_length] + + shapes = self._get_decoder_variable_shapes(input_length) + self.layer_shapes + self.decoder_data_shapes_cache[input_length] = shapes + return shapes + + def _get_decoder_variable_shapes(self, input_length): + """ + Returns only the data shapes of input variables. Auxiliary method to adjust the computation graph to the + presence or absence of coverage vectors. + + :param input_length: The maximal source sentence length + :return: A list of input shapes + """ + shapes = [mx.io.DataDesc(C.SOURCE_ENCODED_NAME, + (self.beam_size, input_length, self.encoder.get_num_hidden()), + layout=C.BATCH_MAJOR), + mx.io.DataDesc(C.SOURCE_DYNAMIC_PREVIOUS_NAME, + (self.beam_size, input_length, self.attention.dynamic_source_num_hidden), + layout=C.BATCH_MAJOR), + mx.io.DataDesc(C.SOURCE_LENGTH_NAME, + (self.beam_size,), + layout="N"), + mx.io.DataDesc(C.TARGET_PREVIOUS_NAME, + (self.beam_size,), + layout="N"), + mx.io.DataDesc(C.HIDDEN_PREVIOUS_NAME, + (self.beam_size, self.decoder.get_num_hidden()), + layout="NC")] + return shapes + + def run_encoder(self, + source: mx.nd.NDArray, + source_length: mx.nd.NDArray, + bucket_key: int) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, + mx.nd.NDArray, mx.nd.NDArray, + List[mx.nd.NDArray]]: + """ + Runs forward pass of the encoder. + Encodes source given source length and bucket key. + Returns encoder representation of the source, source_length, initial hidden state of decoder RNN, + and initial decoder states tiled to beam size. + + :param source: Integer-coded input tokens. + :param source_length: Length of input sentence. + :param bucket_key: Bucket key. + :return: Encoded source, source length, initial decoder hidden state, initial decoder hidden states. + """ + batch = mx.io.DataBatch(data=[source, source_length], label=None, + bucket_key=bucket_key, + provide_data=[ + mx.io.DataDesc(name=C.SOURCE_NAME, shape=(self.encoder_batch_size, bucket_key), + layout=C.BATCH_MAJOR), + mx.io.DataDesc(name=C.SOURCE_LENGTH_NAME, shape=(self.encoder_batch_size,), + layout=C.BATCH_MAJOR)]) + + self.encoder_module.forward(data_batch=batch, is_train=False) + encoded_source, source_dynamic_init, decoder_hidden_init, *decoder_states = self.encoder_module.get_outputs() + # replicate encoder/init module results beam size times + encoded_source = mx.nd.tile(encoded_source, reps=(self.beam_size, 1, 1)) + source_dynamic_init = mx.nd.tile(source_dynamic_init, reps=(self.beam_size, 1, 1)) + decoder_hidden_init = mx.nd.tile(decoder_hidden_init, reps=(self.beam_size, 1)) + decoder_states = [mx.nd.tile(state, reps=(self.beam_size, 1)) for state in decoder_states] + source_length = mx.nd.tile(source_length, reps=(self.beam_size,)) + + return encoded_source, source_dynamic_init, source_length, decoder_hidden_init, decoder_states + + def run_decoder(self, + encoded_source: mx.nd.NDArray, + dynamic_source: mx.nd.NDArray, + source_length: mx.nd.NDArray, + previous_word_id: mx.nd.NDArray, + previous_hidden: mx.nd.NDArray, + decoder_states: List[mx.nd.NDArray], + bucket_key: int) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, + mx.nd.NDArray, mx.nd.NDArray, + List[mx.nd.NDArray]]: + """ + Runs forward pass of the single-step decoder. + + :param encoded_source: Encoded source sentence. + :param dynamic_source: Dynamic encoding of source sentence. + :param source_length: Source length. + :param previous_word_id: Previous predicted word id. + :param previous_hidden: Previous hidden decoder state. + :param decoder_states: Decoder states. + :param bucket_key: Bucket key. + :return: Probability distribution over next word, attention scores, dynamic source encoding, + next hidden state, next decoder states. + """ + + data = [encoded_source, + dynamic_source, + source_length, + previous_word_id.as_in_context(self.context), + previous_hidden] + decoder_states + + decoder_batch = mx.io.DataBatch( + data=data, + label=None, bucket_key=bucket_key, provide_data=self._get_decoder_data_shapes(bucket_key)) + # run forward pass + self.decoder_module.forward(data_batch=decoder_batch, is_train=False) + # collect outputs + softmax_out, attention_probs, dynamic_source, next_hidden, *next_layer_states = \ + self.decoder_module.get_outputs() + + return softmax_out, attention_probs, dynamic_source, next_hidden, next_layer_states + + +def load_models(context: mx.context.Context, + max_input_len: int, + beam_size: int, + model_folders: List[str], + checkpoints: Optional[List[int]] = None, + softmax_temperature: Optional[float] = None) \ + -> Tuple[List[InferenceModel], Dict[str, int], Dict[str, int]]: + """ + Loads a list of models for inference. + + :param context: MXNet context to bind modules to. + :param max_input_len: Maximum input length. + :param beam_size: Beam size. + :param model_folders: List of model folders to load models from. + :param checkpoints: List of checkpoints to use for each model in model_folders. Use None to load best checkpoint. + :param softmax_temperature: Optional parameter to control steepness of softmax distribution. + :return: List of models, source vocabulary, target vocabulary. + """ + models, source_vocabs, target_vocabs = [], [], [] + if checkpoints is None: + checkpoints = [None] * len(model_folders) + for model_folder, checkpoint in zip(model_folders, checkpoints): + + source_vocabs.append(sockeye.vocab.vocab_from_json_or_pickle(os.path.join(model_folder, C.VOCAB_SRC_NAME))) + target_vocabs.append(sockeye.vocab.vocab_from_json_or_pickle(os.path.join(model_folder, C.VOCAB_TRG_NAME))) + model = InferenceModel(model_folder=model_folder, + context=context, + fused=False, + max_input_len=max_input_len, + beam_size=beam_size, + softmax_temperature=softmax_temperature, + checkpoint=checkpoint) + models.append(model) + + # check vocabulary consistency + assert all(set(vocab.items()) == set(source_vocabs[0].items()) for vocab in + source_vocabs), "Source vocabulary ids do not match" + assert all(set(vocab.items()) == set(target_vocabs[0].items()) for vocab in + target_vocabs), "Target vocabulary ids do not match" + + return models, source_vocabs[0], target_vocabs[0] + + +TranslatorInput = NamedTuple('TranslatorInput', [ + ('id', int), + ('sentence', str), + ('tokens', List[str]), +]) +""" +Required input for Translator. + +:param id: Sentence id. +:param sentence: Input sentence. +:param tokens: List of input tokens. +""" + +TranslatorOutput = NamedTuple('TranslatorOutput', [ + ('id', int), + ('translation', str), + ('tokens', List[str]), + ('attention_matrix', np.ndarray), + ('score', float), +]) +""" +Output structure from Translator. + +:param id: Id of input sentence. +:param translation: Translation string without sentence boundary tokens. +:param tokens: List of translated tokens. +:param attention_matrix: Attention matrix. Shape: (target_length, source_length). +:param score: Negative log probability of generated translation. +""" + + +class Translator: + """ + Translator uses one or several models to translate input. + It holds references to vocabularies to takes care of encoding input strings as word ids and conversion + of target ids into a translation string. + + :param context: MXNet context to bind modules to. + :param ensemble_mode: Ensemble mode: linear or log_linear combination. + :param models: List of models. + :param vocab_source: Source vocabulary. + :param vocab_target: Target vocabulary. + """ + + def __init__(self, + context: mx.context.Context, + ensemble_mode: str, + models: List[InferenceModel], + vocab_source: Dict[str, int], + vocab_target: Dict[str, int]): + self.context = context + self.vocab_source = vocab_source + self.vocab_target = vocab_target + self.vocab_target_inv = sockeye.vocab.reverse_vocab(self.vocab_target) + self.start_id = self.vocab_target[C.BOS_SYMBOL] + self.stop_ids = {self.vocab_target[C.EOS_SYMBOL], C.PAD_ID} + self.models = models + self.interpolation_func = self._get_interpolation_func(ensemble_mode) + self.beam_size = self.models[0].beam_size + self.buckets = sockeye.data_io.define_buckets(self.models[0].max_input_len) + logger.info("Translator (%d model(s) beam_size=%d ensemble_mode=%s)", + len(self.models), self.beam_size, "None" if len(self.models) == 1 else ensemble_mode) + + @staticmethod + def _get_interpolation_func(ensemble_mode): + if ensemble_mode == 'linear': + return Translator._linear_interpolation + elif ensemble_mode == 'log_linear': + return Translator._log_linear_interpolation + else: + raise ValueError("unknown interpolation type") + + @staticmethod + def _linear_interpolation(predictions): + return -mx.nd.log(sockeye.utils.average_arrays(predictions)) + + @staticmethod + def _log_linear_interpolation(predictions): + """ + Returns averaged and re-normalized log probabilities + """ + log_probs = sockeye.utils.average_arrays([mx.nd.log(p) for p in predictions]) + return -mx.nd.log(mx.nd.softmax(log_probs)) + + @staticmethod + def make_input(sentence_id: int, sentence: str) -> TranslatorInput: + """ + Returns TranslatorInput from input_string + + :param sentence_id: Input sentence id. + :param sentence: Input sentence. + :return: Input for translate method. + """ + tokens = list(sockeye.data_io.get_tokens(sentence)) + return TranslatorInput(id=sentence_id, sentence=sentence.rstrip(), tokens=tokens) + + def translate(self, trans_input: TranslatorInput) -> TranslatorOutput: + """ + Translates a TranslatorInput and returns a TranslatorOutput + + :param trans_input: TranslatorInput as returned by make_input(). + :return: translation result. + """ + if not trans_input.tokens: + TranslatorOutput(id=trans_input.id, + translation="", + tokens=[""], + attention_matrix=np.asarray([[0]]), + score=-np.inf) + + return self._make_result(trans_input, *self.translate_nd(*self._get_inference_input(trans_input.tokens))) + + def _get_inference_input(self, tokens: List[str]) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, Optional[int]]: + """ + Returns NDArray of source ids, NDArray of sentence length, and corresponding bucket_key + + :param tokens: List of input tokens. + """ + bucket_key = sockeye.data_io.get_bucket(len(tokens), self.buckets) + if bucket_key is None: + logger.warning("Input (%d) exceeds max bucket size (%d). Stripping", len(tokens), self.buckets[-1]) + bucket_key = self.buckets[-1] + tokens = tokens[:bucket_key] + + source = mx.nd.zeros((1, bucket_key)) + ids = sockeye.data_io.tokens2ids(tokens, self.vocab_source) + for i, wid in enumerate(ids): + source[0, i] = wid + length = mx.nd.array([len(ids)]) + return source, length, bucket_key + + def _make_result(self, + trans_input: TranslatorInput, + target_ids: List[int], + attention_matrix: np.ndarray, + neg_logprob: float) -> TranslatorOutput: + """ + Returns a translator result from generated target-side word ids, attention matrix, and score. + Strips stop ids from translation string. + + :param trans_input: Translator input. + :param target_ids: List of translated ids. + :param attention_matrix: Attention matrix. + :return: TranslatorOutput. + """ + target_tokens = [self.vocab_target_inv[target_id] for target_id in target_ids] + target_string = C.TOKEN_SEPARATOR.join( + target_token for target_id, target_token in zip(target_ids, target_tokens) if + target_id not in self.stop_ids) + attention_matrix = attention_matrix[:, :len(trans_input.tokens)] + return TranslatorOutput(id=trans_input.id, + translation=target_string, + tokens=target_tokens, + attention_matrix=attention_matrix, + score=neg_logprob) + + def translate_nd(self, source: mx.nd.NDArray, source_length: mx.nd.NDArray, bucket_key: int) \ + -> Tuple[List[int], np.ndarray, float]: + """ + Translates source of source_length, given a bucket_key. + + :param source: Source. + :param source_length: Source length. + :param bucket_key: Bucket key. + + :return: Sequence of translated ids, attention matrix, length-normalized negative log probability. + """ + # allow output sentence to be at most 2 times the current bucket_key + # TODO: max_output_length adaptive to source_length + max_output_length = bucket_key * 2 + + return self._beam_search(source, source_length, bucket_key, max_output_length) + + def _combine_predictions(self, + predictions: List[mx.nd.NDArray], + attention_prob_scores: List[mx.nd.NDArray]) -> Tuple[mx.nd.NDArray, np.ndarray]: + """ + Returns combined predictions of models as negative log probabilities, as well as averaged attention prob scores. + """ + # average attention prob scores. TODO: is there a smarter way to do this? + attention_prob_score = sockeye.utils.average_arrays(attention_prob_scores).asnumpy() + + # combine model predictions and convert to neg log probs + if len(self.models) == 1: + neg_logprobs = -mx.nd.log(predictions[0]) + else: + neg_logprobs = self.interpolation_func(predictions) + return neg_logprobs, attention_prob_score + + def _beam_search(self, + source: mx.nd.NDArray, + source_length: mx.nd.NDArray, + bucket_key: int, + max_output_length: int) -> Tuple[List[int], np.ndarray, float]: + """ + Translates a single sentence using beam search. + """ + + # encode source and initialize decoder states for each model + model_encoded_source, model_dynamic_source, model_source_length, model_decoder_states = [], [], [], [] + model_prev_hidden = [] + + for model in self.models: + # encode input sentence and initialize decoder states + # encoded_source: (self.beam_size, bucket_key, rnn_num_hidden) + # decoder_states: [(self.beam_size, rnn_num_hidden),...] + encoded_source, source_dynamic_init, source_length, prev_hidden, decoder_states = \ + model.run_encoder(source, + source_length, + bucket_key) + model_encoded_source.append(encoded_source) + model_dynamic_source.append(source_dynamic_init) + model_source_length.append(source_length) + model_decoder_states.append(decoder_states) + model_prev_hidden.append(prev_hidden) + + # prev_target_word_id(s): (beam_size,) + prev_target_word_id = mx.nd.zeros((self.beam_size,), ctx=self.context) + prev_target_word_id[:] = self.start_id + + accumulated_scores = mx.nd.zeros((self.beam_size,), ctx=self.context) + lengths = mx.nd.zeros((self.beam_size,), ctx=self.context) + finished = [False] * self.beam_size + sequences = [[] for _ in range(self.beam_size)] + # one list of source word attention vectors per hypothesis + attention_lists = [[] for _ in range(self.beam_size)] + prev_hyp_indices = None + + for t in range(0, max_output_length): + + # decode one step for each model + model_probs, model_attention_prob_score, model_next_hidden = [], [], [] + model_next_dynamic_source, model_next_decoder_states = [], [] + for model_index, model in enumerate(self.models): + probs, attention_prob_score, next_dynamic_source, next_hidden, next_decoder_states = model.run_decoder( + model_encoded_source[model_index], + model_dynamic_source[model_index], + model_source_length[model_index], + prev_target_word_id, + model_prev_hidden[model_index], + model_decoder_states[model_index], + bucket_key) + model_probs.append(probs) + model_attention_prob_score.append(attention_prob_score) + model_next_hidden.append(next_hidden) + model_next_dynamic_source.append(next_dynamic_source) + model_next_decoder_states.append(next_decoder_states) + + # combine predictions + hyp_scores, attention_prob_score = self._combine_predictions(model_probs, model_attention_prob_score) + + for hyp_idx in range(self.beam_size): + if not finished[hyp_idx]: + # re-normalize hypothesis score by length + hyp_scores[hyp_idx] = (hyp_scores[hyp_idx] + accumulated_scores[hyp_idx] * lengths[hyp_idx]) / ( + lengths[hyp_idx] + 1) + else: + hyp_scores[hyp_idx][:] = np.inf + hyp_scores[hyp_idx][C.PAD_ID] = accumulated_scores[hyp_idx] + + # get self.beam_size smallest hypotheses + # prev_hyp_indices: row indices in hyp_scores. + # next_word_ids: column indices in hyp_scores + # accumulated_scores: chosen smallest scores in hyp_scores + (prev_hyp_indices, next_word_ids), accumulated_scores = sockeye.utils.smallest_k(hyp_scores.asnumpy(), + self.beam_size, t == 0) + + # select attention according to hypothesis + attention_prob_score = attention_prob_score[prev_hyp_indices, :] + + # list of new hypothesis that are now finished + new_hyp_finished = [word_id in self.stop_ids for word_id in next_word_ids] + new_sequences = [None for _ in range(self.beam_size)] + new_attention_lists = [None for _ in range(self.beam_size)] + for new_hyp_idx, prev_hyp_idx in enumerate(prev_hyp_indices): + if not finished[prev_hyp_idx]: + new_sequences[new_hyp_idx] = sequences[prev_hyp_idx] + [next_word_ids[new_hyp_idx]] + new_attention_lists[new_hyp_idx] = attention_lists[prev_hyp_idx] + [ + attention_prob_score[new_hyp_idx, :]] + else: + new_sequences[new_hyp_idx] = sequences[prev_hyp_idx] + new_attention_lists[new_hyp_idx] = attention_lists[prev_hyp_idx] + lengths[new_hyp_idx] = len(new_sequences[new_hyp_idx]) + + finished = new_hyp_finished + sequences = new_sequences + attention_lists = new_attention_lists + + if all(new_hyp_finished): + break + + # prepare new batch + prev_hyp_indices_nd = mx.nd.array(prev_hyp_indices, ctx=self.context) + prev_target_word_id = mx.nd.array(next_word_ids, ctx=self.context) + model_prev_hidden = [mx.nd.take(next_hidden, prev_hyp_indices_nd) for next_hidden in model_next_hidden] + model_dynamic_source = [mx.nd.take(next_dynamic_source, prev_hyp_indices_nd) for + next_dynamic_source in model_next_dynamic_source] + model_decoder_states = [[mx.nd.take(state, prev_hyp_indices_nd) for state in decoder_states] for + decoder_states in model_next_decoder_states] + + # sequences & accumulated scores are in latest 'k-best order', thus 0th element is best + best = 0 + # attention_matrix: (target_seq_len, source_seq_len) + attention_matrix = np.stack(attention_lists[best], axis=0) + return sequences[best], attention_matrix, accumulated_scores[best] diff --git a/sockeye/initializer.py b/sockeye/initializer.py new file mode 100644 index 000000000..4ca357689 --- /dev/null +++ b/sockeye/initializer.py @@ -0,0 +1,99 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import logging +from typing import Optional + +import mxnet as mx +import numpy as np + +import sockeye.constants as C +from sockeye.lexicon import LexiconInitializer + +logger = logging.getLogger(__name__) + + +def get_initializer(rnn_init_type, lexicon: Optional[mx.nd.NDArray] = None) -> mx.initializer.Initializer: + """ + Returns a mixed MXNet initializer given rnn_init_type and optional lexicon. + + :param rnn_init_type: Initialization type. + :param lexicon: Optional lexicon. + :return: Mixed initializer. + """ + + if rnn_init_type == C.RNN_INIT_ORTHOGONAL: + logger.info("Orthogonal RNN initializer") + h2h_init = mx.initializer.Orthogonal() + elif rnn_init_type == C.RNN_INIT_ORTHOGONAL_STACKED: + logger.info("Stacked orthogonal RNN initializer") + h2h_init = StackedOrthogonalInit(scale=1.0, rand_type="eye") + else: + raise ValueError('unknown rnn initialization type: %s' % rnn_init_type) + + lexicon_init = LexiconInitializer(lexicon) if lexicon is not None else None + + params_init_pairs = [ + (".*h2h.*", h2h_init), + (C.LEXICON_NAME, lexicon_init), + (".*", mx.init.Xavier(factor_type="in", magnitude=2.34)) + ] + return mx.initializer.Mixed(*zip(*params_init_pairs)) + + +class StackedOrthogonalInit(mx.initializer.Initializer): + """ + Initializes weight as Orthogonal matrix. Here we assume that the weight consists of stacked square matrices of + the same size. + For example one could have 3 (2,2) matrices resulting in a (6,2) matrix. This situation arises in RNNs when one + wants to perform multiple h2h transformations in a single matrix multiplication. + + Reference: + Exact solutions to the nonlinear dynamics of learning in deep linear neural networks + arXiv preprint arXiv:1312.6120 (2013). + + :param scale: Scaling factor of weight. + :param rand_type: use "uniform" or "normal" random number to initialize weight. + "eye" simply sets the matrix to an identity matrix. + + """ + + def __init__(self, scale=1.414, rand_type="uniform"): + super().__init__() + self.scale = scale + self.rand_type = rand_type + + def _init_weight(self, sym_name, arr): + assert len(arr.shape) == 2, "Only 2d weight matrices supported." + base_dim = arr.shape[1] + stacked_dim = arr.shape[0] # base_dim * num_sub_matrices + assert stacked_dim % base_dim == 0, \ + "Dim1 must be a multiple of dim2 (as weight = stacked square matrices)." + + num_sub_matrices = stacked_dim // base_dim + logger.info("Initializing weight %s (shape=%s, num_sub_matrices=%d) with an orthogonal weight matrix.", + sym_name, arr.shape, num_sub_matrices) + + for mat_idx in range(0, num_sub_matrices): + if self.rand_type == "uniform": + tmp = np.random.uniform(-1.0, 1.0, (base_dim, base_dim)) + _, __, q = np.linalg.svd(tmp) + elif self.rand_type == "normal": + tmp = np.random.normal(0.0, 1.0, (base_dim, base_dim)) + _, __, q = np.linalg.svd(tmp) + elif self.rand_type == "eye": + q = np.eye(base_dim) + else: + raise ValueError("unknown rand_type %s" % self.rand_type) + q = self.scale * q + arr[mat_idx * base_dim:mat_idx * base_dim + base_dim] = q diff --git a/sockeye/lexicon.py b/sockeye/lexicon.py new file mode 100644 index 000000000..cdf8fb107 --- /dev/null +++ b/sockeye/lexicon.py @@ -0,0 +1,159 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import logging +from typing import Dict + +import mxnet as mx +import numpy as np + +import sockeye.constants as C +from sockeye.data_io import smart_open + +logger = logging.getLogger(__name__) + + +class Lexicon: + """ + Lexicon model component. Stores lexicon and supports two operations: + (1) Given source batch, lookup translation distributions in the lexicon + (2) Given attention score vector and lexicon lookups, compute the lexical bias for the decoder + + :param source_vocab_size: Source vocabulary size. + :param target_vocab_size: Target vocabulary size. + :param learn: Whether to adapt lexical biases during training. + """ + + def __init__(self, source_vocab_size: int, target_vocab_size: int, learn: bool = False): + self.source_vocab_size = source_vocab_size + self.target_vocab_size = target_vocab_size + # TODO: once half-precision works, use float16 for this variable to save memory + self.lexicon = mx.sym.Variable(name=C.LEXICON_NAME, + shape=(self.source_vocab_size, + self.target_vocab_size)) + if not learn: + logger.info("Fixed lexicon bias terms") + self.lexicon = mx.sym.BlockGrad(self.lexicon) + else: + logger.info("Learning lexicon bias terms") + + def lookup(self, source: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Lookup lexicon distributions for source. + + :param source: Input. Shape: (batch_size, source_seq_len). + :return: Lexicon distributions for input. Shape: (batch_size, target_vocab_size, source_seq_len). + """ + return mx.sym.swapaxes(data=mx.sym.Embedding(data=source, + input_dim=self.source_vocab_size, + weight=self.lexicon, + output_dim=self.target_vocab_size, + name=C.LEXICON_NAME + "_lookup"), dim1=1, dim2=2) + + @staticmethod + def calculate_lex_bias(source_lexicon: mx.sym.Symbol, attention_prob_score: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Given attention/alignment scores, calculates a weighted sum over lexical distributions + that serve as a bias for the decoder softmax. + * https://arxiv.org/pdf/1606.02006.pdf + * http://www.aclweb.org/anthology/W/W16/W16-4610.pdf + + :param source_lexicon: Lexical biases for sentence Shape: (batch_size, target_vocab_size, source_seq_len). + :param attention_prob_score: Attention score. Shape: (batch_size, source_seq_len). + :return: Lexical bias. Shape: (batch_size, 1, target_vocab_size). + """ + # attention_prob_score: (batch_size, source_seq_len) -> (batch_size, source_seq_len, 1) + attention_prob_score = mx.sym.expand_dims(attention_prob_score, axis=2) + # lex_bias: (batch_size, target_vocab_size, 1) + lex_bias = mx.sym.batch_dot(source_lexicon, attention_prob_score) + # lex_bias: (batch_size, 1, target_vocab_size) + lex_bias = mx.sym.swapaxes(data=lex_bias, dim1=1, dim2=2) + return lex_bias + + +def initialize_lexicon(cmdline_arg: str, vocab_source: Dict[str, int], vocab_target: Dict[str, int]) -> mx.nd.NDArray: + """ + Reads a probabilistic word lexicon as given by the commandline argument and converts + to log probabilities. + If specified, smooths with custom value, uses 0.001 otherwise. + + :param cmdline_arg: Commandline argument. + :param vocab_source: Source vocabulary. + :param vocab_target: Target vocabulary. + :return: Lexicon array. Shape: (vocab_source_size, vocab_target_size). + """ + fields = cmdline_arg.split(":", 1) + path = fields[0] + lexicon = read_lexicon(path, vocab_source, vocab_target) + assert lexicon.shape == (len(vocab_source), len(vocab_target)), "Invalid lexicon shape" + eps = 0.001 + if len(fields) == 2: + eps = float(fields[1]) + assert eps > 0, "epsilon must be >0" + logger.info("Smoothing lexicon with eps=%.4f", eps) + lexicon = mx.nd.array(np.log(lexicon + eps)) + return lexicon + + +def read_lexicon(path: str, vocab_source: Dict[str, int], vocab_target: Dict[str, int]) -> np.ndarray: + """ + Loads lexical translation probabilities from a translation table of format: src, trg, logprob. + Source words unknown to vocab_source are discarded. + Target words unknown to vocab_target contribute to p(unk|source_word). + See Incorporating Discrete Translation Lexicons into Neural Machine Translation, Section 3.1 & Equation 5 + (https://arxiv.org/pdf/1606.02006.pdf)) + + :param path: Path to lexicon file. + :param vocab_source: Source vocabulary. + :param vocab_target: Target vocabulary. + :return: Lexicon array. Shape: (vocab_source_size, vocab_target_size). + """ + assert C.UNK_SYMBOL in vocab_source + assert C.UNK_SYMBOL in vocab_target + src_unk_id = vocab_source[C.UNK_SYMBOL] + trg_unk_id = vocab_target[C.UNK_SYMBOL] + lexicon = np.zeros((len(vocab_source), len(vocab_target))) + n = 0 + with smart_open(path) as fin: + for line in fin: + src, trg, logprob = line.rstrip('\n').split("\t") + prob = np.exp(float(logprob)) + src_id = vocab_source.get(src, src_unk_id) + trg_id = vocab_target.get(trg, trg_unk_id) + if src_id == src_unk_id: + continue + if trg_id == trg_unk_id: + lexicon[src_id, trg_unk_id] += prob + else: + lexicon[src_id, trg_id] = prob + n += 1 + logger.info("Loaded lexicon from '%s' with %d entries", path, n) + return lexicon + + +class LexiconInitializer(mx.initializer.Initializer): + """ + Given a lexicon NDArray, initialize the variable named C.LEXICON_NAME with it. + + :param lexicon: Lexicon array. + """ + + def __init__(self, lexicon: mx.nd.NDArray): + super().__init__() + self.lexicon = lexicon + + def _init_default(self, sym_name, arr): + assert sym_name == C.LEXICON_NAME, "This initializer should only be used for a lexicon parameter variable" + logger.info("Initializing '%s' with lexicon.", sym_name) + assert len(arr.shape) == 2, "Only 2d weight matrices supported." + self.lexicon.copyto(arr) diff --git a/sockeye/log.py b/sockeye/log.py new file mode 100644 index 000000000..965fd2621 --- /dev/null +++ b/sockeye/log.py @@ -0,0 +1,119 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import logging +import logging.config +from typing import Optional + +FORMATTERS = { + 'verbose': { + 'format': '[%(asctime)s:%(levelname)s:%(name)s:%(funcName)s] %(message)s', + 'datefmt': "%Y-%m-%d:%H:%M:%S", + }, + 'simple': { + 'format': '[%(levelname)s:%(name)s] %(message)s' + }, +} + +FILE_LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': FORMATTERS, + 'handlers': { + 'rotating': { + 'level': 'DEBUG', + 'formatter': 'verbose', + 'class': 'logging.handlers.RotatingFileHandler', + 'maxBytes': 10000000, + 'backupCount': 5, + 'filename': 'sockeye.log', + } + }, + 'root': { + 'handlers': ['rotating'], + 'level': 'DEBUG', + } +} + +CONSOLE_LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': FORMATTERS, + 'handlers': { + 'console': { + 'level': 'INFO', + 'formatter': 'simple', + 'class': 'logging.StreamHandler', + 'stream': None + }, + }, + 'root': { + 'handlers': ['console'], + 'level': 'DEBUG', + } +} + +FILE_CONSOLE_LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': FORMATTERS, + 'handlers': { + 'console': { + 'level': 'INFO', + 'formatter': 'simple', + 'class': 'logging.StreamHandler', + 'stream': None + }, + 'rotating': { + 'level': 'DEBUG', + 'formatter': 'verbose', + 'class': 'logging.handlers.RotatingFileHandler', + 'maxBytes': 10000000, + 'backupCount': 5, + 'filename': 'sockeye.log', + } + }, + 'root': { + 'handlers': ['console', 'rotating'], + 'level': 'DEBUG', + } +} + +LOGGING_CONFIGS = { + "file_only": FILE_LOGGING, + "console_only": CONSOLE_LOGGING, + "file_console": FILE_CONSOLE_LOGGING, +} + + +def setup_main_logger(name: str, file_logging=True, console=True, path: Optional[str] = None) -> logging.Logger: + """ + Return a logger that configures logging for the main application. + + :param name: Name of the returned logger. + :param file_logging: Whether to log to a file. + :param console: Whether to log to the console. + :param path: Optional path to write logfile to. + """ + if file_logging and console: + log_config = LOGGING_CONFIGS["file_console"] + elif file_logging: + log_config = LOGGING_CONFIGS["file_only"] + else: + log_config = LOGGING_CONFIGS["console_only"] + + if path: + log_config["handlers"]["rotating"]["filename"] = path + + logging.config.dictConfig(log_config) + return logging.getLogger(name) diff --git a/sockeye/loss.py b/sockeye/loss.py new file mode 100644 index 000000000..6b94f63c6 --- /dev/null +++ b/sockeye/loss.py @@ -0,0 +1,148 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Functions to generate loss symbols for sequence-to-sequence models. +""" +from typing import Tuple + +import mxnet as mx + +import sockeye.constants as C +import sockeye.model + + +def get_loss(config: sockeye.model.ModelConfig) -> 'Loss': + """ + Returns Loss instance given loss_name. + + :param config: Model configuration. + """ + if config.loss == C.CROSS_ENTROPY: + return CrossEntropyLoss(config.normalize_loss) + elif config.loss == C.SMOOTHED_CROSS_ENTROPY: + return SmoothedCrossEntropyLoss(config.smoothed_cross_entropy_alpha, config.vocab_target_size, + config.normalize_loss) + else: + raise ValueError("unknown loss name") + + +class Loss: + """ + Generic Loss interface. + get_loss() method should return a loss symbol and the softmax outputs. + The softmax outputs (named C.SOFTMAX_NAME) are used by EvalMetrics to compute various metrics, + e.g. perplexity, accuracy. In the special case of cross_entropy, the SoftmaxOutput symbol + provides softmax outputs for forward() AND cross_entropy gradients for backward(). + """ + + def get_loss(self, logits: mx.sym.Symbol, labels: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Returns loss and softmax output symbols given logits and integer-coded labels. + + :param logits: Shape: (batch_size * target_seq_len, target_vocab_size). + :param labels: Shape: (batch_size * target_seq_len,). + :return: Loss and softmax output symbols. + """ + raise NotImplementedError() + + +class CrossEntropyLoss(Loss): + """ + Computes the cross-entropy loss. + + :param normalize: If True normalize the gradient by dividing by the number of non-PAD tokens. + """ + + def __init__(self, normalize: bool = False): + self._normalize = normalize + + def get_loss(self, logits: mx.sym.Symbol, labels: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Returns loss and softmax output symbols given logits and integer-coded labels. + + :param logits: Shape: (batch_size * target_seq_len, target_vocab_size). + :param labels: Shape: (batch_size * target_seq_len,). + :return: Loss and softmax output symbols. + """ + if self._normalize: + normalization = "valid" + else: + normalization = "null" + return mx.sym.SoftmaxOutput(data=logits, + label=labels, + ignore_label=C.PAD_ID, + use_ignore=True, + normalization=normalization, + name=C.SOFTMAX_NAME) + + +def _normalize(loss: mx.sym.Symbol, labels: mx.sym.Symbol): + """ + Normalize loss by the number of non-PAD tokens. + + :param loss: A loss value for each label. + :param labels: A label for each loss entry (potentially containing PAD tokens). + :return: The normalized loss. + """ + return mx.sym.broadcast_div(loss, mx.sym.sum(labels != C.PAD_ID)) + + +class SmoothedCrossEntropyLoss(Loss): + """ + Computes a smoothed cross-entropy loss. Smoothing is defined by alpha which indicates the + amount of probability mass subtracted from the true label probability (1-alpha). + Alpha is then uniformly distributed across other labels. + + :param alpha: Smoothing value. + :param vocab_size: Size of the target vocabulary. + :param normalize: If True normalize the gradient by dividing by the number of non-PAD tokens. + """ + + def __init__(self, alpha: float, vocab_size: int, normalize: bool = False): + assert alpha >= 0, "alpha must be >= 0" + self._alpha = alpha + self._vocab_size = vocab_size + self._normalize = normalize + + def get_loss(self, logits: mx.sym.Symbol, labels: mx.sym.Symbol) -> Tuple[mx.sym.Symbol]: + """ + Returns loss and softmax output symbols given logits and integer-coded labels. + + :param logits: Shape: (batch_size * target_seq_len, target_vocab_size). + :param labels: Shape: (batch_size * target_seq_len,). + :return: Loss and softmax output symbols. + """ + probs = mx.sym.softmax(data=logits) + + on_value = 1.0 - self._alpha + off_value = self._alpha / (self._vocab_size - 1.0) + cross_entropy = mx.sym.one_hot(indices=mx.sym.cast(data=labels, dtype='int32'), + depth=self._vocab_size, + on_value=on_value, + off_value=off_value) + + # zero out pad symbols (0) + cross_entropy = mx.sym.where(labels, cross_entropy, mx.sym.zeros((0, self._vocab_size))) + + # compute cross_entropy + cross_entropy *= - mx.sym.log(data=probs + 1e-10) + cross_entropy = mx.sym.sum(data=cross_entropy, axis=1) + + if self._normalize: + cross_entropy = _normalize(cross_entropy, labels) + + cross_entropy = mx.sym.MakeLoss(cross_entropy, name=C.SMOOTHED_CROSS_ENTROPY) + probs = mx.sym.BlockGrad(probs, name=C.SOFTMAX_NAME) + return cross_entropy, probs + diff --git a/sockeye/lr_scheduler.py b/sockeye/lr_scheduler.py new file mode 100644 index 000000000..29da6edf8 --- /dev/null +++ b/sockeye/lr_scheduler.py @@ -0,0 +1,166 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import logging +from math import sqrt +from typing import Optional + +logger = logging.getLogger(__name__) + + +class LearningRateScheduler: + def new_evaluation_result(self, has_improved: bool): + pass + + def __call__(self, num_updates): + pass + + +class LearningRateSchedulerInvSqrtT(LearningRateScheduler): + """ + Learning rate schedule: lr / sqrt(1 + factor * t). + Note: The factor is calculated from the half life of the learning rate. + + :param updates_per_checkpoint: Number of batches between checkpoints. + :param half_life: Half life of the learning rate in number of checkpoints. + """ + + def __init__(self, updates_per_checkpoint: int, half_life: int) -> None: + assert updates_per_checkpoint > 0, "updates_per_checkpoint needs to be > 0." + assert half_life > 0, "half_life needs to be > 0." + # Note: will be overwritten by optimizer in mxnet + self.base_lr = None + # 0.5 base_lr = base_lr * sqrt(1 + T * factor) + # then factor = 3 ./ T, with T = half_life * updates_per_checkpoint + self.factor = 3. / (half_life * updates_per_checkpoint) + self.t_last_log = -1 + self.log_every_t = int(half_life * updates_per_checkpoint) + + def __call__(self, num_updates: int): + lr = self.base_lr / sqrt(1 + num_updates * self.factor) + + # Note: this method is called once per parameter for the same t. Making sure to just log once. + if num_updates > self.t_last_log and num_updates % self.log_every_t == 0: + logger.info("Learning rate currently at %1.2e", lr) + self.t_last_log = num_updates + + return lr + + +class LearningRateSchedulerInvT(LearningRateScheduler): + """ + Learning rate schedule: lr / (1 + factor * t). + Note: The factor is calculated from the half life of the learning rate. + + :param updates_per_checkpoint: Number of batches between checkpoints. + :param half_life: Half life of the learning rate in number of checkpoints. + """ + + def __init__(self, updates_per_checkpoint: int, half_life: int) -> None: + assert updates_per_checkpoint > 0, "updates_per_checkpoint needs to be > 0." + assert half_life > 0, "half_life needs to be > 0." + # Note: will be overwritten by optimizer + self.base_lr = None + # 0.5 base_lr = base_lr * (1 + T * factor) + # then factor = 1 ./ T, with T = half_life * updates_per_checkpoint + self.factor = 1. / (half_life * updates_per_checkpoint) + self.t_last_log = -1 + self.log_every_t = int(half_life * updates_per_checkpoint) + + def __call__(self, num_updates: int): + lr = self.base_lr / (1 + num_updates * self.factor) + + # Note: this method is called once per parameter for the same t. Making sure to just log once. + if num_updates > self.t_last_log and num_updates % self.log_every_t == 0: + logger.info("Learning rate currently at %1.2e", lr) + self.t_last_log = num_updates + + return lr + + +class LearningRateSchedulerPlateauReduce(LearningRateScheduler): + """ + Lower the learning rate as soon as the validation score plateaus. + + :param reduce_factor: Factor to reduce learning rate with. + :param reduce_num_not_improved: Number of checkpoints with no improvement after which learning rate is reduced. + """ + + def __init__(self, reduce_factor: float, reduce_num_not_improved: int) -> None: + self.reduce_factor = reduce_factor + self.reduce_num_not_improved = reduce_num_not_improved + self.num_not_improved = 0 + self.logger = logging.getLogger("LearningRateSchedulerPlateauReduce") + # Note: will be overwritten by optimizer in mxnet + self.base_lr = None # type: float + self.lr = None # type: float + logger.info("Will reduce the learning rate by a factor of %.2f whenever" + " the validation score doesn't improve %d times.", + reduce_factor, reduce_num_not_improved) + + def new_evaluation_result(self, has_improved: bool): + if self.lr is None: + assert self.base_lr is not None + self.lr = self.base_lr + if has_improved: + self.num_not_improved = 0 + else: + self.num_not_improved += 1 + if self.num_not_improved >= self.reduce_num_not_improved: + self.lr *= self.reduce_factor + self.logger.info("Validation score hasn't improved for %d checkpoints, " + "lowering learning rate to %1.2e", self.num_not_improved, self.lr) + self.num_not_improved = 0 + + def __call__(self, t): + if self.lr is None: + assert self.base_lr is not None + self.lr = self.base_lr + return self.lr + + def __repr__(self): + return "LearningRateSchedulerPlateauReduce(reduce_factor=%.2f, " \ + "reduce_num_not_improved=%d)" % (self.reduce_factor, self.num_not_improved) + + +def get_lr_scheduler(scheduler_type: str, + updates_per_checkpoint: int, + learning_rate_half_life: int, + learning_rate_reduce_factor: float, + learning_rate_reduce_num_not_improved: int) -> Optional[LearningRateScheduler]: + """ + Returns a learning rate scheduler. + + :param scheduler_type: Scheduler type. + :param updates_per_checkpoint: Number of batches between checkpoints. + :param learning_rate_half_life: Half life of the learning rate in number of checkpoints. + :param learning_rate_reduce_factor: Factor to reduce learning rate with. + :param learning_rate_reduce_num_not_improved: Number of checkpoints with no improvement after which learning rate is + reduced. + :raises: ValueError if unknown scheduler_type + :return: Learning rate scheduler. + """ + if scheduler_type is None: + return None + if scheduler_type == "fixed-rate-inv-sqrt-t": + return LearningRateSchedulerInvSqrtT(updates_per_checkpoint, learning_rate_half_life) + elif scheduler_type == "fixed-rate-inv-t": + return LearningRateSchedulerInvT(updates_per_checkpoint, learning_rate_half_life) + elif scheduler_type == "plateau-reduce": + assert learning_rate_reduce_factor is not None, "learning_rate_reduce_factor needed for plateau-reduce " \ + "scheduler" + assert learning_rate_reduce_num_not_improved is not None, "learning_rate_reduce_num_not_improved needed for " \ + "plateau-reduce scheduler" + return LearningRateSchedulerPlateauReduce(learning_rate_reduce_factor, learning_rate_reduce_num_not_improved) + else: + raise ValueError("Unknown learning rate scheduler type %s." % scheduler_type) diff --git a/sockeye/model.py b/sockeye/model.py new file mode 100644 index 000000000..9a4ff925d --- /dev/null +++ b/sockeye/model.py @@ -0,0 +1,183 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import json +import logging +import os + +import sockeye.attention +import sockeye.coverage +import sockeye.data_io +import sockeye.decoder +import sockeye.encoder +import sockeye.lexicon +import sockeye.utils +from sockeye import constants as C + +logger = logging.getLogger(__name__) + +ModelConfig = sockeye.utils.namedtuple_with_defaults('ModelConfig', + [ + "max_seq_len", + "vocab_source_size", + "vocab_target_size", + "num_embed_source", + "num_embed_target", + "attention_type", + "attention_num_hidden", + "attention_coverage_type", + "attention_coverage_num_hidden", + "attention_use_prev_word", + "dropout", + "rnn_cell_type", + "rnn_num_layers", + "rnn_num_hidden", + "rnn_residual_connections", + "weight_tying", + "context_gating", + "lexical_bias", + "learn_lexical_bias", + "data_info", + "loss", + "normalize_loss", + "smoothed_cross_entropy_alpha", + ], + default_values={ + "attention_use_prev_word": False, + "context_gating": False, + "loss": C.CROSS_ENTROPY, + "normalize_loss": False + }) +""" +ModelConfig defines model parameters defined at training time which are relevant to model inference. +Add new model parameters here. If you want backwards compatibility for models trained with code that did not +contain these parameters, provide a reasonable default under default_values. +""" + + +class SockeyeModel: + """ + SockeyeModel shares components needed for both training and inference. + ModelConfig contains parameters and their values that are fixed at training time and must be re-used at inference + time. + + :param config: Model configuration. + """ + + def __init__(self, config: ModelConfig): + self.config = config + logger.info("%s", self.config) + self.encoder = None + self.attention = None + self.decoder = None + self.rnn_cells = [] + self.built = False + self.params = None + + def save_config(self, folder: str): + """ + Saves model configuration to /config + + :param folder: Destination folder. + """ + fname = os.path.join(folder, C.CONFIG_NAME) + with open(fname, "w") as out: + json.dump(self.config._asdict(), out, indent=2, sort_keys=True) + logger.info('Saved config to "%s"', fname) + + @staticmethod + def load_config(fname: str) -> ModelConfig: + """ + Loads model configuration. + + :param fname: Path to load model configuration from. + :return: Model configuration. + """ + with open(fname, "r") as inp: + config = ModelConfig(**json.load(inp)) + logger.info('ModelConfig loaded from "%s"', fname) + return config + + def save_params_to_file(self, fname: str): + """ + Saves model parameters to file. + + :param fname: Path to save parameters to. + """ + assert self.built + params = self.params.copy() + # unpack rnn cell weights + for cell in self.rnn_cells: + params = cell.unpack_weights(params) + sockeye.utils.save_params(params, fname) + logging.info('Saved params to "%s"', fname) + + def load_params_from_file(self, fname: str): + """ + Loads and sets model parameters from file. + + :param fname: Path to load parameters from. + """ + assert self.built + self.params, _ = sockeye.utils.load_params(fname) + # pack rnn cell weights + for cell in self.rnn_cells: + self.params = cell.pack_weights(self.params) + logger.info('Loaded params from "%s"', fname) + + def _build_model_components(self, max_seq_len: int, fused_encoder: bool, rnn_forget_bias: float = 0.0): + """ + Builds and sets model components given maximum sequence length. + + :param max_seq_len: Maximum sequence length supported by the model. + :param fused_encoder: Use FusedRNNCells in encoder. + :param rnn_forget_bias: forget bias initialization for RNNs. + """ + self.encoder = sockeye.encoder.get_encoder(self.config.num_embed_source, + self.config.vocab_source_size, + self.config.rnn_num_layers, + self.config.rnn_num_hidden, + self.config.rnn_cell_type, + self.config.rnn_residual_connections, + self.config.dropout, + rnn_forget_bias, + fused_encoder) + + self.attention = sockeye.attention.get_attention(self.config.attention_use_prev_word, + self.config.attention_type, + self.config.attention_num_hidden, + self.config.rnn_num_hidden, + max_seq_len, + self.config.attention_coverage_type, + self.config.attention_coverage_num_hidden) + + self.lexicon = sockeye.lexicon.Lexicon(self.config.vocab_source_size, + self.config.vocab_target_size, + self.config.learn_lexical_bias) if self.config.lexical_bias else None + + self.decoder = sockeye.decoder.get_decoder(self.config.num_embed_target, + self.config.vocab_target_size, + self.config.rnn_num_layers, + self.config.rnn_num_hidden, + self.attention, + self.config.rnn_cell_type, + self.config.rnn_residual_connections, + rnn_forget_bias, + self.config.dropout, + self.config.weight_tying, + self.lexicon, + self.config.context_gating) + + self.rnn_cells = self.encoder.get_rnn_cells() + self.decoder.get_rnn_cells() + + self.built = True diff --git a/sockeye/output_handler.py b/sockeye/output_handler.py new file mode 100644 index 000000000..303783b72 --- /dev/null +++ b/sockeye/output_handler.py @@ -0,0 +1,142 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import sockeye.inference +from sockeye.utils import plot_attention, print_attention_text, get_alignments + + +def get_output_handler(output_type: str, + output_stream, + align_plot_prefix: str, + sure_align_threshold: float) -> 'OutputHandler': + """ + + :param output_type: Type of output handler. + :param output_stream: Output stream to write to. + :param align_plot_prefix: Prefix for alignment plot files. + :param sure_align_threshold: Threshold to consider an alignment link as 'sure'. + :raises: ValueError for unknown output_type + :return: Output handler. + """ + if output_type == "translation": + return StringOutputHandler(output_stream) + elif output_type == "translation_with_alignments": + return StringWithAlignmentsOutputHandler(output_stream, sure_align_threshold) + elif output_type == "align_plot": + return AlignPlotHandler(plot_prefix=align_plot_prefix) + elif output_type == "align_text": + return AlignTextHandler(sure_align_threshold) + else: + raise ValueError("unknown output type") + + +class OutputHandler: + """ + Abstract output handler interface + """ + + def handle(self, t_input: sockeye.inference.TranslatorInput, t_output: sockeye.inference.TranslatorOutput): + """ + :raises: NotImplementedError + :param t_input: Translator input. + :param t_output: Translator output. + """ + raise NotImplementedError() + + +class StringOutputHandler(OutputHandler): + """ + Output handler to write translation to a stream + + :param stream: Stream to write translations to (e.g. sys.stdout). + """ + + def __init__(self, stream): + self.stream = stream + + def handle(self, t_input: sockeye.inference.TranslatorInput, t_output: sockeye.inference.TranslatorOutput): + """ + :param t_input: Translator input. + :param t_output: Translator output. + """ + self.stream.write("%s\n" % t_output.translation) + self.stream.flush() + + +class StringWithAlignmentsOutputHandler(StringOutputHandler): + """ + Output handler to write translations and alignments to a stream. Translation and alignment string + are separated by a tab. + Alignments are written in the format: + - ... + An alignment link is included if its probability is above the threshold. + + :param stream: Stream to write translations and alignments to. + :param threshold: Threshold for including alignment links. + """ + + def __init__(self, stream, threshold: float): + super().__init__(stream) + self.threshold = threshold + + def handle(self, t_input: sockeye.inference.TranslatorInput, t_output: sockeye.inference.TranslatorOutput): + """ + :param t_input: Translator input. + :param t_output: Translator output. + """ + alignments = " ".join( + ["%d-%d" % (s, t) for s, t in get_alignments(t_output.attention_matrix, threshold=self.threshold)]) + self.stream.write("%s\t%s\n" % (t_output.translation, alignments)) + self.stream.flush() + + +class AlignPlotHandler(OutputHandler): + """ + Output handler to plot alignment matrices to PNG files. + + :param plot_prefix: Prefix for generated PNG files. + """ + + def __init__(self, plot_prefix: str): + self.plot_prefix = plot_prefix + + def handle(self, t_input: sockeye.inference.TranslatorInput, t_output: sockeye.inference.TranslatorOutput): + """ + :param t_input: Translator input. + :param t_output: Translator output. + """ + plot_attention(t_output.attention_matrix, + t_input.tokens, + t_output.tokens, + "%s_%d.png" % (self.plot_prefix, t_input.id)) + + +class AlignTextHandler(OutputHandler): + """ + Output handler to write alignment matrices as ASCII art. + + :param threshold: Threshold for considering alignment links as sure. + """ + + def __init__(self, threshold: float): + self.threshold = threshold + + def handle(self, t_input: sockeye.inference.TranslatorInput, t_output: sockeye.inference.TranslatorOutput): + """ + :param t_input: Translator input. + :param t_output: Translator output. + """ + print_attention_text(t_output.attention_matrix, + t_input.tokens, + t_output.tokens, + self.threshold) diff --git a/sockeye/rnn.py b/sockeye/rnn.py new file mode 100644 index 000000000..8cc38cd38 --- /dev/null +++ b/sockeye/rnn.py @@ -0,0 +1,57 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import mxnet as mx + +from sockeye import constants as C + + +def get_stacked_rnn(cell_type: str, + num_hidden: int, + num_layers: int, + dropout: float, + prefix: str, + residual: bool = False, + forget_bias: float = 0.0) -> mx.rnn.SequentialRNNCell: + """ + Returns (stacked) RNN cell given parameters. + + :param cell_type: RNN cell type. + :param num_hidden: Number of RNN hidden units. + :param num_layers: Number of RNN layers. + :param dropout: Dropout probability on RNN outputs. + :param prefix: Symbol prefix for RNN. + :param residual: Whether to add residual connections between multi-layered RNNs. + :param forget_bias: Initial value of forget biases. + :return: RNN cell. + """ + + rnn = mx.rnn.SequentialRNNCell() + for layer in range(num_layers): + # fhieber: the 'l' in the prefix does NOT stand for 'layer' but for the direction 'l' as in mx.rnn.rnn_cell::517 + # this ensures parameter name compatibility of training w/ FusedRNN and decoding with 'unfused' RNN. + cell_prefix = "%sl%d_" % (prefix, layer) + if cell_type == C.LSTM_TYPE: + cell = mx.rnn.LSTMCell(num_hidden=num_hidden, prefix=cell_prefix, forget_bias=forget_bias) + elif cell_type == C.GRU_TYPE: + cell = mx.rnn.GRUCell(num_hidden=num_hidden, prefix=cell_prefix) + else: + raise NotImplementedError() + if residual and layer > 0: + cell = mx.rnn.ResidualCell(cell) + rnn.add(cell) + + if dropout > 0.: + # TODO(fhieber): add pervasive dropout? + rnn.add(mx.rnn.DropoutCell(dropout, prefix=cell_prefix + "_dropout")) + return rnn diff --git a/sockeye/train.py b/sockeye/train.py new file mode 100644 index 000000000..a42d4d0a9 --- /dev/null +++ b/sockeye/train.py @@ -0,0 +1,218 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Simple Training CLI. +""" +import argparse +import os +import random +import sys +from contextlib import ExitStack +from typing import Optional, Dict + +import mxnet as mx +import numpy as np + +import sockeye.arguments as arguments +import sockeye.attention +import sockeye.constants as C +import sockeye.data_io +import sockeye.decoder +import sockeye.encoder +import sockeye.initializer +import sockeye.lexicon +import sockeye.lr_scheduler +import sockeye.model +import sockeye.training +import sockeye.utils +import sockeye.vocab +from sockeye.log import setup_main_logger +from sockeye.utils import acquire_gpu, get_num_gpus + + +def none_if_negative(val): + return None if val < 0 else val + + +def _build_or_load_vocab(existing_vocab_path: Optional[str], data_path: str, num_words: int, + word_min_count: int) -> Dict: + if existing_vocab_path is None: + vocabulary = sockeye.vocab.build_from_path(data_path, + num_words=num_words, + min_count=word_min_count) + else: + vocabulary = sockeye.vocab.vocab_from_json(existing_vocab_path) + return vocabulary + + +def main(): + params = argparse.ArgumentParser(description='CLI to train sockeye sequence-to-sequence models.') + params = arguments.add_io_args(params) + params = arguments.add_model_parameters(params) + params = arguments.add_training_args(params) + params = arguments.add_device_args(params) + args = params.parse_args() + + # seed the RNGs + np.random.seed(args.seed) + random.seed(args.seed) + mx.random.seed(args.seed) + + if args.use_fused_rnn: + assert not args.use_cpu, "GPU required for FusedRNN cells" + + if args.rnn_residual_connections: + assert args.rnn_num_layers > 2, "Residual connections require at least 3 RNN layers" + + assert args.optimized_metric == C.BLEU or args.optimized_metric in args.metrics, \ + "Must optimize either BLEU or one of tracked metrics (--metrics)" + + output_folder = os.path.abspath(args.output) + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + logger = setup_main_logger(__name__, console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME)) + + logger.info("Command: %s", " ".join(sys.argv)) + logger.info("Arguments: %s", args) + + with ExitStack() as exit_stack: + # context + if args.use_cpu: + context = [mx.cpu()] + else: + num_gpus = get_num_gpus() + assert num_gpus > 0, "No GPUs found, consider running on the CPU with --use-cpu " \ + "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi " \ + "binary isn't on the path)." + context = [] + for gpu_id in args.device_ids: + if gpu_id < 0: + # get an automatic gpu id: + gpu_id = exit_stack.enter_context(acquire_gpu()) + context.append(mx.gpu(gpu_id)) + + # create vocabs + vocab_source = _build_or_load_vocab(args.source_vocab, args.source, args.num_words, args.word_min_count) + sockeye.vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX) + + vocab_target = _build_or_load_vocab(args.target_vocab, args.target, args.num_words, args.word_min_count) + sockeye.vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX) + + vocab_source_size = len(vocab_source) + vocab_target_size = len(vocab_target) + logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size) + + data_info = sockeye.data_io.DataInfo(os.path.abspath(args.source), + os.path.abspath(args.target), + os.path.abspath(args.validation_source), + os.path.abspath(args.validation_target), + args.source_vocab, + args.target_vocab) + + # create data iterators + train_iter, eval_iter = sockeye.data_io.get_training_data_iters(source=data_info.source, + target=data_info.target, + validation_source=data_info.validation_source, + validation_target=data_info.validation_target, + vocab_source=vocab_source, + vocab_target=vocab_target, + batch_size=args.batch_size, + fill_up=args.fill_up, + max_seq_len=args.max_seq_len, + bucketing=not args.no_bucketing, + bucket_width=args.bucket_width) + + # learning rate scheduling + learning_rate_half_life = none_if_negative(args.learning_rate_half_life) + lr_scheduler = sockeye.lr_scheduler.get_lr_scheduler(args.learning_rate_scheduler_type, + args.checkpoint_frequency, + learning_rate_half_life, + args.learning_rate_reduce_factor, + args.learning_rate_reduce_num_not_improved) + + # model configuration + num_embed_source = args.num_embed if args.num_embed_source is None else args.num_embed_source + num_embed_target = args.num_embed if args.num_embed_target is None else args.num_embed_target + attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden + model_config = sockeye.model.ModelConfig(max_seq_len=args.max_seq_len, + vocab_source_size=vocab_source_size, + vocab_target_size=vocab_target_size, + num_embed_source=num_embed_source, + num_embed_target=num_embed_target, + attention_type=args.attention_type, + attention_num_hidden=attention_num_hidden, + attention_coverage_type=args.attention_coverage_type, + attention_coverage_num_hidden=args.attention_coverage_num_hidden, + attention_use_prev_word=args.attention_use_prev_word, + dropout=args.dropout, + rnn_cell_type=args.rnn_cell_type, + rnn_num_layers=args.rnn_num_layers, + rnn_num_hidden=args.rnn_num_hidden, + rnn_residual_connections=args.rnn_residual_connections, + weight_tying=args.weight_tying, + context_gating=args.context_gating, + lexical_bias=args.lexical_bias, + learn_lexical_bias=args.learn_lexical_bias, + data_info=data_info, + loss=args.loss, + normalize_loss=args.normalize_loss, + smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha) + + # create training model + model = sockeye.training.TrainingModel(model_config=model_config, + context=context, + train_iter=train_iter, + fused=args.use_fused_rnn, + bucketing=not args.no_bucketing, + lr_scheduler=lr_scheduler, + rnn_forget_bias=args.rnn_forget_bias) + + if args.params: + model.load_params_from_file(args.params) + logger.info("Training will continue from parameters loaded from '%s'", args.params) + + lexicon = sockeye.lexicon.initialize_lexicon(args.lexical_bias, + vocab_source, vocab_target) if args.lexical_bias else None + + initializer = sockeye.initializer.get_initializer(args.rnn_h2h_init, lexicon=lexicon) + + optimizer = args.optimizer + optimizer_params = {'wd': args.weight_decay, + "learning_rate": args.initial_learning_rate} + if lr_scheduler is not None: + optimizer_params["lr_scheduler"] = lr_scheduler + clip_gradient = none_if_negative(args.clip_gradient) + if clip_gradient is not None: + optimizer_params["clip_gradient"] = clip_gradient + if args.momentum is not None: + optimizer_params["momentum"] = args.momentum + logger.info("Optimizer: %s", optimizer) + logger.info("Optimizer Parameters: %s", optimizer_params) + + model.fit(train_iter, eval_iter, + output_folder=output_folder, + metrics=args.metrics, + initializer=initializer, + max_updates=args.max_updates, + checkpoint_frequency=args.checkpoint_frequency, + optimizer=optimizer, optimizer_params=optimizer_params, + optimized_metric=args.optimized_metric, + max_num_not_improved=args.max_num_checkpoint_not_improved, + monitor_bleu=args.monitor_bleu, + use_tensorboard=args.use_tensorboard) + + +if __name__ == "__main__": + main() diff --git a/sockeye/training.py b/sockeye/training.py new file mode 100644 index 000000000..3c3cb2478 --- /dev/null +++ b/sockeye/training.py @@ -0,0 +1,311 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Code for training +""" +import logging +import os +import time +from typing import List, AnyStr + +import mxnet as mx + +import sockeye.callback +import sockeye.checkpoint_decoder +import sockeye.constants as C +import sockeye.data_io +import sockeye.inference +import sockeye.loss +import sockeye.lr_scheduler +import sockeye.model +import sockeye.utils + +logger = logging.getLogger(__name__) + + +class TrainingModel(sockeye.model.SockeyeModel): + """ + Defines an Encoder/Decoder model (with attention). + RNN configuration (number of hidden units, number of layers, cell type) + is shared between encoder & decoder. + + :param model_config: Configuration object holding details about the model. + :param context: The context(s) that MXNet will be run in (GPU(s)/CPU) + :param train_iter: The iterator over the training data. + :param fused: If True fused RNN cells will be used (should be slightly more efficient, but is only available + on GPUs). + :param bucketing: If True bucketing will be used, if False the computation graph will always be + unrolled to the full length. + :param lr_scheduler: The scheduler that lowers the learning rate during training. + :param rnn_forget_bias: Initial value of the RNN forget biases. + """ + + def __init__(self, + model_config: sockeye.model.ModelConfig, + context: List[mx.context.Context], + train_iter: sockeye.data_io.ParallelBucketSentenceIter, + fused: bool, + bucketing: bool, + lr_scheduler, + rnn_forget_bias: float) -> None: + super().__init__(model_config) + self.context = context + self.lr_scheduler = lr_scheduler + self._build_model_components(self.config.max_seq_len, fused, rnn_forget_bias) + self.module = self._build_module(train_iter, self.config.max_seq_len, bucketing) + + def _build_module(self, + train_iter: sockeye.data_io.ParallelBucketSentenceIter, + max_seq_len: int, + bucketing: bool): + """ + Initializes model components, creates training symbol and module, and binds it. + """ + source = mx.sym.Variable(C.SOURCE_NAME) + source_length = mx.sym.Variable(C.SOURCE_LENGTH_NAME) + target = mx.sym.Variable(C.TARGET_NAME) + labels = mx.sym.reshape(data=mx.sym.Variable(C.TARGET_LABEL_NAME), shape=(-1,)) + + loss = sockeye.loss.get_loss(self.config) + + data_names = [x[0] for x in train_iter.provide_data] + label_names = [x[0] for x in train_iter.provide_label] + + def sym_gen(seq_lens): + """ + Returns a (grouped) loss symbol given source & target input lengths. + Also returns data and label names for the BucketingModule. + """ + source_seq_len, target_seq_len = seq_lens + + source_encoded = self.encoder.encode(source, source_length, seq_len=source_seq_len) + source_lexicon = self.lexicon.lookup(source) if self.lexicon else None + + logits = self.decoder.decode(source_encoded, source_seq_len, source_length, + target, target_seq_len, source_lexicon) + + outputs = loss.get_loss(logits, labels) + + return mx.sym.Group(outputs), data_names, label_names + + if bucketing: + logger.info("Using bucketing. Default max_seq_len=%s", train_iter.default_bucket_key) + return mx.mod.BucketingModule(sym_gen=sym_gen, + logger=logger, + default_bucket_key=train_iter.default_bucket_key, + context=self.context) + else: + logger.info("No bucketing. Unrolled to max_seq_len=%s", max_seq_len) + symbol, _, __ = sym_gen(train_iter.buckets[0]) + return mx.mod.Module(symbol=symbol, + data_names=data_names, + label_names=label_names, + logger=logger, + context=self.context) + + @staticmethod + def _create_eval_metric(metric_names: List[AnyStr]) -> mx.metric.CompositeEvalMetric: + """ + Creates a composite EvalMetric given a list of metric names. + """ + metrics = [] + # output_names refers to the list of outputs this metric should use to update itself, e.g. the softmax output + for metric_name in metric_names: + if metric_name == C.ACCURACY: + metrics.append(sockeye.utils.Accuracy(ignore_label=C.PAD_ID, output_names=[C.SOFTMAX_OUTPUT_NAME])) + elif metric_name == C.PERPLEXITY: + metrics.append(mx.metric.Perplexity(ignore_label=C.PAD_ID, output_names=[C.SOFTMAX_OUTPUT_NAME])) + else: + raise ValueError("unknown metric name") + return mx.metric.create(metrics) + + def fit(self, + train_iter: sockeye.data_io.ParallelBucketSentenceIter, + val_iter: sockeye.data_io.ParallelBucketSentenceIter, + output_folder: str, + metrics: List[AnyStr], + initializer: mx.initializer.Initializer, + max_updates: int, + checkpoint_frequency: int, + optimizer: str, + optimizer_params: dict, + optimized_metric: str = "perplexity", + max_num_not_improved: int = 3, + monitor_bleu: int = 0, + use_tensorboard: bool = False): + """ + Fits model to data given by train_iter using early-stopping w.r.t data given by val_iter. + Saves all intermediate and final output to output_folder + + :param train_iter: The training data iterator. + :param val_iter: The validation data iterator. + :param output_folder: The folder in which all model artifacts will be stored in (parameters, checkpoints, etc.). + :param metrics: The metrics that will be evaluated during training. + :param initializer: The parameter initializer. + :param max_updates: Maximum number of batches to process. + :param checkpoint_frequency: Frequency of checkpointing in number of updates. + :param optimizer: The MXNet optimizer that will update the parameters. + :param optimizer_params: The parameters for the optimizer. + :param optimized_metric: The metric that is tracked for early stopping. + :param max_num_not_improved: Stop training if the optimized_metric does not improve for this many checkpoints. + :param monitor_bleu: Monitor BLEU during training (0: off, >=0: the number of sentences to decode for BLEU + evaluation, -1: decode the full validation set.). + :param use_tensorboard: If True write tensorboard compatible logs for monitoring training and + validation metrics. + :return: Best score on validation data observed during training. + """ + self.save_config(output_folder) + + self.module.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label, + for_training=True, force_rebind=True, grad_req='write') + self.module.symbol.save(os.path.join(output_folder, C.SYMBOL_NAME)) + + self.module.init_params(initializer=initializer, arg_params=self.params, aux_params=None, + allow_missing=False, force_init=False) + + self.module.init_optimizer(kvstore='device', optimizer=optimizer, optimizer_params=optimizer_params) + + checkpoint_decoder = sockeye.checkpoint_decoder.CheckpointDecoder(self.context[-1], + self.config.data_info.validation_source, + self.config.data_info.validation_target, + output_folder, self.config.max_seq_len, + limit=monitor_bleu) \ + if monitor_bleu else None + + logger.info("Training started.") + training_monitor = sockeye.callback.TrainingMonitor(train_iter.batch_size, output_folder, + optimized_metric=optimized_metric, + use_tensorboard=use_tensorboard, + checkpoint_decoder=checkpoint_decoder) + self._fit(train_iter, val_iter, output_folder, + training_monitor, + metrics=metrics, + max_updates=max_updates, + checkpoint_frequency=checkpoint_frequency, + max_num_not_improved=max_num_not_improved) + + logger.info("Training finished. Best checkpoint: %d. Best validation %s: %.6f", + training_monitor.get_best_checkpoint(), + training_monitor.optimized_metric, + training_monitor.get_best_validation_score()) + return training_monitor.get_best_validation_score() + + def _fit(self, + train_iter: sockeye.data_io.ParallelBucketSentenceIter, + val_iter: sockeye.data_io.ParallelBucketSentenceIter, + output_folder: str, + training_monitor: sockeye.callback.TrainingMonitor, + metrics: List[AnyStr], + max_updates: int, + checkpoint_frequency: int, + max_num_not_improved: int): + """ + Internal fit method. Runtime determined by early stopping. + + :param train_iter: Training data iterator. + :param val_iter: Validation data iterator. + :param output_folder: Model output folder. + :param metrics: List of metric names to track on training and validation data. + :param max_updates: Maximum number of batches to process. + :param checkpoint_frequency: Frequency of checkpointing. + :param max_num_not_improved: Maximum number of checkpoints until fitting is stopped if model does not improve. + """ + metric_train = self._create_eval_metric(metrics) + metric_val = self._create_eval_metric(metrics) + num_not_improved = 0 + tic = time.time() + epoch = 0 + checkpoint = 0 + updates = 0 + samples = 0 + next_data_batch = train_iter.next() + while max_updates == -1 or updates < max_updates: + if not train_iter.iter_next(): + epoch += 1 + train_iter.reset() + + # process batch + batch = next_data_batch + self.module.forward_backward(batch) + self.module.update() + + if train_iter.iter_next(): + # pre-fetch next batch + next_data_batch = train_iter.next() + self.module.prepare(next_data_batch) + + self.module.update_metric(metric_train, batch.label) + training_monitor.batch_end_callback(epoch, updates, metric_train) + updates += 1 + samples += train_iter.batch_size + + if updates > 0 and updates % checkpoint_frequency == 0: + checkpoint += 1 + self._checkpoint(checkpoint, output_folder) + training_monitor.checkpoint_callback(checkpoint, metric_train) + + toc = time.time() + logger.info("Checkpoint [%d]\tUpdates=%d Epoch=%d Samples=%d Time-cost=%.3f", + checkpoint, updates, epoch, samples, (toc - tic)) + tic = time.time() + + for name, val in metric_train.get_name_value(): + logger.info('Checkpoint [%d]\tTrain-%s=%f', checkpoint, name, val) + metric_train.reset() + + # evaluation on validation set + has_improved, best_checkpoint = self._evaluate(checkpoint, val_iter, metric_val, training_monitor) + if self.lr_scheduler is not None: + self.lr_scheduler.new_evaluation_result(has_improved) + + if has_improved: + best_path = os.path.join(output_folder, C.PARAMS_BEST_NAME) + if os.path.lexists(best_path): + os.remove(best_path) + actual_best_fname = C.PARAMS_NAME % best_checkpoint + os.symlink(actual_best_fname, best_path) + num_not_improved = 0 + else: + num_not_improved += 1 + + if num_not_improved == max_num_not_improved: + logger.info("Model has not improved for %d checkpoints. Stopping fit.", num_not_improved) + training_monitor.stop_fit_callback() + break + + def _evaluate(self, checkpoint, val_iter, val_metric, training_monitor): + """ + Computes val_metric on val_iter. Returns whether model improved or not. + """ + val_iter.reset() + val_metric.reset() + + for nbatch, eval_batch in enumerate(val_iter): + self.module.forward(eval_batch, is_train=False) + self.module.update_metric(val_metric, eval_batch.label) + + for name, val in val_metric.get_name_value(): + logger.info('Checkpoint [%d]\tValidation-%s=%f', checkpoint, name, val) + + return training_monitor.eval_end_callback(checkpoint, val_metric) + + def _checkpoint(self, checkpoint, output_folder): + """ + Saves checkpoint. + """ + # sync aux params across devices + arg_params, aux_params = self.module.get_params() + self.module.set_params(arg_params, aux_params) + self.params = arg_params + self.save_params_to_file(os.path.join(output_folder, C.PARAMS_NAME % checkpoint)) diff --git a/sockeye/translate.py b/sockeye/translate.py new file mode 100644 index 000000000..55bdbbfb9 --- /dev/null +++ b/sockeye/translate.py @@ -0,0 +1,97 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Translation CLI. +""" +import argparse +import sys +import time +from contextlib import ExitStack + +import mxnet as mx + +import sockeye.arguments as arguments +import sockeye.data_io +import sockeye.inference +import sockeye.output_handler +from sockeye.log import setup_main_logger +from sockeye.utils import acquire_gpu, get_num_gpus + + +def main(): + params = argparse.ArgumentParser(description='Translate from STDIN to STDOUT') + params = arguments.add_inference_args(params) + params = arguments.add_device_args(params) + args = params.parse_args() + + logger = setup_main_logger(__name__, file_logging=False) + + assert args.beam_size > 0, "Beam size must be 1 or greater." + if args.checkpoints is not None: + assert len(args.checkpoints) == len(args.models), "must provide checkpoints for each model" + + logger.info("Command: %s", " ".join(sys.argv)) + logger.info("Arguments: %s", args) + + output_stream = sys.stdout + output_handler = sockeye.output_handler.get_output_handler(args.output_type, + output_stream, + args.align_plot_prefix, + args.sure_align_threshold) + + with ExitStack() as exit_stack: + if args.use_cpu: + context = mx.cpu() + else: + num_gpus = get_num_gpus() + assert num_gpus > 0, "No GPUs found, consider running on the CPU with --use-cpu " \ + "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi " \ + "binary isn't on the path)." + assert len(args.device_ids) == 1, "cannot run on multiple devices for now" + gpu_id = args.device_ids[0] + if gpu_id < 0: + # get a gpu id automatically: + gpu_id = exit_stack.enter_context(acquire_gpu()) + context = mx.gpu(gpu_id) + + translator = sockeye.inference.Translator(context, + args.ensemble_mode, + *sockeye.inference.load_models(context, + args.max_input_len, + args.beam_size, + args.models, + args.checkpoints, + args.softmax_temperature)) + total_time = 0 + i = 0 + for i, line in enumerate(sys.stdin, 1): + trans_input = translator.make_input(i, line) + logger.debug(" IN: %s", trans_input) + + tic = time.time() + trans_output = translator.translate(trans_input) + trans_wall_time = time.time() - tic + total_time += trans_wall_time + + logger.debug("OUT: %s", trans_output) + logger.debug("OUT: time=%.2f", trans_wall_time) + + output_handler.handle(trans_input, trans_output) + + logger.info("Processed %d lines. Total time: %.4f sec/sent: %.4f sent/sec: %.4f", i, total_time, total_time / i, + i / total_time) + + +if __name__ == '__main__': + main() diff --git a/sockeye/utils.py b/sockeye/utils.py new file mode 100644 index 000000000..f1b49f1c1 --- /dev/null +++ b/sockeye/utils.py @@ -0,0 +1,335 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +A set of utility methods. +""" +import collections +import errno +import fcntl +import logging +import os +import shutil +import subprocess +import sys +import time +from contextlib import contextmanager +from typing import Mapping, NamedTuple, Any, List, Iterator, Tuple, Dict, Optional + +import mxnet as mx +import numpy as np + +logger = logging.getLogger(__name__) + + +def save_graph(symbol: mx.sym.Symbol, filename: str, hide_weights: bool = True): + """ + Dumps computation graph visualization to .pdf and .dot file. + + :param symbol: The symbol representing the computation graph. + :param filename: The filename to save the graphic to. + :param hide_weights: If true the weights will not be shown. + """ + dot = mx.viz.plot_network(symbol, hide_weights=hide_weights) + dot.render(filename=filename) + + +def save_params(arg_params: Mapping[str, mx.nd.NDArray], fname: str, + aux_params: Optional[Mapping[str, mx.nd.NDArray]] = None): + """ + Saves the parameters to a file. + + :param arg_params: Mapping from parameter names to the actual parameters. + :param fname: The file name to store the parameters in. + :param aux_params: Optional mapping from parameter names to the auxiliary parameters. + """ + save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()} + if aux_params is not None: + save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()}) + mx.nd.save(fname, save_dict) + + +def load_params(fname: str) -> Tuple[Dict[str, mx.nd.NDArray], Dict[str, mx.nd.NDArray]]: + """ + Loads parameters from a file. + + :param fname: The file containing the parameters. + :return: Mapping from parameter names to the actual parameters for both the arg parameters and the aux parameters. + """ + save_dict = mx.nd.load(fname) + arg_params = {} + aux_params = {} + for k, v in save_dict.items(): + tp, name = k.split(':', 1) + if tp == 'arg': + arg_params[name] = v + if tp == 'aux': + aux_params[name] = v + return arg_params, aux_params + + +class Accuracy(mx.metric.EvalMetric): + """ + Calculates accuracy. Taken from MXNet and adapted to work with batch-major labels + (reshapes (batch_size, time) -> (batch_size * time). + Also allows defining an ignore_label/pad symbol + """ + + def __init__(self, + name='accuracy', + output_names=None, + label_names=None, + ignore_label=None): + super(Accuracy, self).__init__(name=name, + output_names=output_names, + label_names=label_names, + ignore_label=ignore_label) + self.ignore_label = ignore_label + + def update(self, labels, preds): + mx.metric.check_label_shapes(labels, preds) + + for label, pred_label in zip(labels, preds): + if pred_label.shape != label.shape: + pred_label = mx.nd.argmax_channel(pred_label) + pred_label = pred_label.asnumpy().astype('int32') + label = mx.nd.reshape(label, shape=(pred_label.size,)).asnumpy().astype('int32') + + mx.metric.check_label_shapes(label, pred_label) + if self.ignore_label is not None: + correct = ((pred_label.flat == label.flat) * (label.flat != self.ignore_label)).sum() + ignore = (label.flat == self.ignore_label).sum() + n = pred_label.size - ignore + else: + correct = (pred_label.flat == label.flat).sum() + n = pred_label.size + + self.sum_metric += correct + self.num_inst += n + + +def smallest_k(matrix: np.ndarray, k: int, + only_first_row: bool = False) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]: + """ + Find the smallest elements in a numpy matrix. + + :param matrix: Any matrix. + :param k: The number of smallest elements to return. + :param only_first_row: If true the search is constrained to the first row of the matrix. + :return: The row indices, column indices and values of the k smallest items in matrix. + """ + if only_first_row: + flatten = matrix[:1, :].flatten() + else: + flatten = matrix.flatten() + + # args are the indices in flatten of the k smallest elements + args = np.argpartition(flatten, k)[:k] + # args are the indices in flatten of the sorted k smallest elements + args = args[np.argsort(flatten[args])] + # flatten[args] are the values for args + return np.unravel_index(args, matrix.shape), flatten[args] + + +def smallest_k_mx(matrix: mx.nd.NDArray, k: int, + only_first_row: bool = False) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]: + """ + Find the smallest elements in a NDarray. + + :param matrix: Any matrix. + :param k: The number of smallest elements to return. + :param only_first_row: If True the search is constrained to the first row of the matrix. + :return: The row indices, column indices and values of the k smallest items in matrix. + """ + if only_first_row: + matrix = mx.nd.reshape(matrix[0], shape=(1, -1)) + + values, indices = mx.nd.topk(matrix, axis=None, k=k, ret_typ='both', is_ascend=True) + + return np.unravel_index(indices.astype(np.int32).asnumpy(), matrix.shape), values + + +def plot_attention(attention_matrix: np.ndarray, source_tokens: List[str], target_tokens: List[str], filename: str): + """ + Uses matplotlib for creating a visualization of the attention matrix. + + :param attention_matrix: The attention matrix. + :param source_tokens: A list of source tokens. + :param target_tokens: A list of target tokens. + :param filename: The file to which the attention visualization will be written to. + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + assert attention_matrix.shape[0] == len(target_tokens) + + plt.imshow(attention_matrix.transpose(), interpolation="nearest", cmap="Greys") + plt.xlabel("target") + plt.ylabel("source") + plt.gca().set_xticks([i for i in range(0, len(target_tokens))]) + plt.gca().set_yticks([i for i in range(0, len(source_tokens))]) + plt.gca().set_xticklabels(target_tokens, rotation='vertical') + plt.gca().set_yticklabels(source_tokens) + plt.tight_layout() + plt.savefig(filename) + logger.info("Saved alignment visualization to " + filename) + + +def print_attention_text(attention_matrix: np.ndarray, source_tokens: List[str], target_tokens: List[str], + threshold: float): + """ + Prints the attention matrix to standard out. + + :param attention_matrix: The attention matrix. + :param source_tokens: A list of source tokens. + :param target_tokens: A list of target tokens. + :param threshold: The threshold for including an alignment link in the result. + """ + sys.stdout.write(" ") + for j in target_tokens: + sys.stdout.write("---") + sys.stdout.write("\n") + for (i, f_i) in enumerate(source_tokens): + sys.stdout.write(" |") + for (j, _) in enumerate(target_tokens): + align_prob = attention_matrix[j, i] + if align_prob > threshold: + sys.stdout.write("(*)") + elif align_prob > 0.4: + sys.stdout.write("(?)") + else: + sys.stdout.write(" ") + sys.stdout.write(" | %s\n" % f_i) + sys.stdout.write(" ") + for j in target_tokens: + sys.stdout.write("---") + sys.stdout.write("\n") + for k in range(max(map(len, target_tokens))): + sys.stdout.write(" ") + for word in target_tokens: + letter = word[k] if len(word) > k else " " + sys.stdout.write(" %s " % letter) + sys.stdout.write("\n") + sys.stdout.write("\n") + + +def get_alignments(attention_matrix: np.ndarray, threshold: float = .9) -> Iterator[Tuple[int, int]]: + """ + Yields hard alignments from an attention_matrix (target_length, source_length) + given a threshold. + + :param attention_matrix: The attention matrix. + :param threshold: The threshold for including an alignment link in the result. + :return: Generator yielding strings of the form 0-0, 0-1, 2-1, 2-2, 3-4... + """ + for src_idx in range(attention_matrix.shape[1]): + for trg_idx in range(attention_matrix.shape[0]): + if attention_matrix[trg_idx, src_idx] > threshold: + yield (src_idx, trg_idx) + + +def average_arrays(arrays: List[mx.sym.NDArray]) -> mx.sym.NDArray: + """ + Take a list of arrays of the same shape and take the element wise average. + + :param arrays: A list of NDArrays with the same shape that will be averaged. + :return: The average of the NDArrays in the same context as arrays[0]. + """ + if len(arrays) == 1: + return arrays[0] + assert all(arrays[0].shape == a.shape for a in arrays), "nd array shapes do not match" + new_array = mx.nd.zeros(arrays[0].shape, dtype=arrays[0].dtype, ctx=arrays[0].context) + for a in arrays: + new_array += a.as_in_context(new_array.context) + new_array /= len(arrays) + return new_array + + +def get_num_gpus() -> int: + """ + Gets the number of GPUs available on the host (depends on nvidia-smi). + + :return: The number of GPUs on the system. + """ + if shutil.which("nvidia-smi") is None: + logger.warning("Couldn't find nvidia-smi, therefore we assume no GPUs are available.") + return 0 + sp = subprocess.Popen(['nvidia-smi', '-L'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out_str = sp.communicate()[0].decode("utf-8") + num_gpus = len(out_str.rstrip("\n").split("\n")) + return num_gpus + + +@contextmanager +def acquire_gpu(lock_dir: str = "/var/lock", retry_wait: int = 10): + """ + Acquires a gpu by locking a file (therefore this assumes that everyone using gpus calls this method and shares the + lock directory). + + :param lock_dir: The directory for storing the lock file. + :param retry_wait: The number of seconds to wait between retries. + """ + num_gpus = get_num_gpus() + + logger.info("Trying to acquire one of the %d gpus", num_gpus) + + # try to acquire a GPU lock + while True: + for gpu_id in range(num_gpus): + # try to acquire a lock + lockfile_path = os.path.join(lock_dir, "sockeye.gpu%d.lock" % gpu_id) + with open(lockfile_path, 'w') as lock_file: + try: + # exclusive non-blocking lock + fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB) + # got the lock, let's write our PID into it: + lock_file.write("%d\n" % os.getpid()) + lock_file.flush() + logger.info("Acquired GPU %d." % gpu_id) + + yield gpu_id + + logger.info("Releasing GPU %d." % gpu_id) + # release lock: + fcntl.flock(lock_file, fcntl.LOCK_UN) + os.remove(lockfile_path) + return + except IOError as e: + # raise on unrelated IOErrors + if e.errno != errno.EAGAIN: + logger.error("Failed acquiring gpu lock.", exc_info=True) + raise + else: + logger.info("GPU %d is currently locked." % gpu_id, + exc_info=True) + logger.info("No GPU available will try again in %ss." % retry_wait) + time.sleep(retry_wait) + + +def namedtuple_with_defaults(typename, field_names, default_values: Mapping[str, Any] = ()) -> NamedTuple: + """ + Create a named tuple with default values. + + :param typename: The name of the new type. + :param field_names: The fields the type will have. + :param default_values: A mapping from field names to default values. + :return: The new named tuple with default values. + """ + T = collections.namedtuple(typename, field_names) + T.__new__.__defaults__ = (None,) * len(T._fields) + if isinstance(default_values, collections.Mapping): + prototype = T(**default_values) + else: + prototype = T(*default_values) + T.__new__.__defaults__ = tuple(prototype) + return T diff --git a/sockeye/vocab.py b/sockeye/vocab.py new file mode 100644 index 000000000..aa7934490 --- /dev/null +++ b/sockeye/vocab.py @@ -0,0 +1,147 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import json +import logging +import os +import pickle +from collections import Counter +from itertools import chain, islice +from typing import Dict, Iterable, Mapping + +import sockeye.constants as C +from sockeye.data_io import get_tokens, smart_open + +logger = logging.getLogger(__name__) + + +def build_from_path(path: str, num_words: int = 50000, min_count: int = 1) -> Dict[str, int]: + """ + Creates vocabulary from path to a file in sentence-per-line format. A sentence is just a whitespace delimited + list of tokens. Note that special symbols like the beginning of sentence (BOS) symbol will be added to the + vocabulary. + + :param path: Path to file with one sentence per line. + :param num_words: Maximum number of words in the vocabulary. + :param min_count: Minimum occurrences of words to be included in the vocabulary. + :return: Word-to-id mapping. + """ + with smart_open(path) as data: + logger.info("Building vocabulary from dataset: %s", path) + return build_vocab(data, num_words, min_count) + + +def build_vocab(data: Iterable[str], num_words: int = 50000, min_count: int = 1) -> Dict[str, int]: + """ + Creates a vocabulary mapping from words to ids. Increasing integer ids are assigned by word frequency, + using lexical sorting as a tie breaker. The only exception to this are special symbols such as the padding symbol + (PAD). + + :param data: Sequence of sentences containing whitespace delimited tokens. + :param num_words: Maximum number of words in the vocabulary. + :param min_count: Minimum occurrences of words to be included in the vocabulary. + :return: Word-to-id mapping. + """ + vocab_symbols_set = set(C.VOCAB_SYMBOLS) + raw_vocab = Counter(token for line in data for token in get_tokens(line) + if token not in vocab_symbols_set) + logger.info("Initial vocabulary: %d types" % len(raw_vocab)) + + # For words with the same count, they will be ordered reverse alphabetically. + # Not an issue since we only care for consistency + pruned_vocab = sorted(((c, w) for w, c in raw_vocab.items() if c >= min_count), reverse=True) + logger.info("Pruned vocabulary: %d types (min frequency %d)", len(pruned_vocab), min_count) + + vocab = islice((w for c, w in pruned_vocab), num_words) + + word_to_id = {word: idx for idx, word in enumerate(chain(C.VOCAB_SYMBOLS, vocab))} + logger.info("Final vocabulary: %d types (min frequency %d, top %d types)", + len(word_to_id), min_count, num_words) + + # Important: pad symbol becomes index 0 + assert word_to_id[C.PAD_SYMBOL] == C.PAD_ID + return word_to_id + + +def vocab_to_pickle(vocab: Mapping, path: str): + """ + Saves vocabulary in pickle format. + + :param vocab: Vocabulary mapping. + :param path: Output file path. + """ + with open(path, 'wb') as out: + pickle.dump(vocab, out) + logger.info('Vocabulary saved to "%s"', path) + + +def vocab_to_json(vocab: Mapping, path: str): + """ + Saves vocabulary in human-readable json. + + :param vocab: Vocabulary mapping. + :param path: Output file path. + """ + with open(path, "w") as out: + json.dump(vocab, out, indent=4) + logger.info('Vocabulary saved to "%s"', path) + + +def vocab_from_json_or_pickle(path) -> Dict: + """ + Try loading the json version of the vocab and fall back to pickle for backwards compatibility. + + :param path: Path to vocab without the json suffix. If it exists the `path` + '.json' will be loaded as a JSON + object and otherwise `path` is loaded as a pickle object. + :return: The loaded vocabulary. + """ + if os.path.exists(path + C.JSON_SUFFIX): + return vocab_from_json(path + C.JSON_SUFFIX) + else: + return vocab_from_pickle(path) + + +def vocab_from_pickle(path: str) -> Dict: + """ + Saves vocabulary in pickle format. + + :param path: Path to pickle file containing the vocabulary. + :return: The loaded vocabulary. + """ + with open(path, 'rb') as inp: + vocab = pickle.load(inp) + logger.info('Vocabulary (%d words) loaded from "%s"', len(vocab), path) + return vocab + + +def vocab_from_json(path: str) -> Dict: + """ + Saves vocabulary in json format. + + :param path: Path to json file containing the vocabulary. + :return: The loaded vocabulary. + """ + with open(path) as inp: + vocab = json.load(inp) + logger.info('Vocabulary (%d words) loaded from "%s"', len(vocab), path) + return vocab + + +def reverse_vocab(vocab: Mapping) -> Dict: + """ + Returns value-to-key mapping from key-to-value-mapping. + + :param vocab: Key to value mapping. + :return: A mapping from values to keys. + """ + return {v: k for k, v in vocab.items()} diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 000000000..214e3177f --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + diff --git a/test/test_attention.py b/test/test_attention.py new file mode 100644 index 000000000..ee6f8a19c --- /dev/null +++ b/test/test_attention.py @@ -0,0 +1,191 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import mxnet as mx +import numpy as np +import pytest +import sockeye.attention +import sockeye.coverage +from test.test_utils import gaussian_vector, integer_vector + +attention_types = ['bilinear', 'dot', 'location', 'mlp'] + + +@pytest.mark.parametrize("attention_type", attention_types) +def test_attention(attention_type, + batch_size=1, + encoder_num_hidden=2, + decoder_num_hidden=2): + # source: (batch_size, seq_len, encoder_num_hidden) + source = mx.sym.Variable("source") + # source_length: (batch_size,) + source_length = mx.sym.Variable("source_length") + source_seq_len = 3 + + attention = sockeye.attention.get_attention(input_previous_word=False, + attention_type=attention_type, + attention_num_hidden=2, + rnn_num_hidden=2, + max_seq_len=source_seq_len, + attention_coverage_type="", + attention_coverage_num_hidden=2) + attention_state = attention.get_initial_state(source_length, source_seq_len) + attention_func = attention.on(source, source_length, source_seq_len) + attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"), mx.sym.Variable("decoder_state")) + attention_state = attention_func(attention_input, attention_state) + sym = mx.sym.Group([attention_state.context, attention_state.probs]) + + executor = sym.simple_bind(ctx=mx.cpu(), + source=(batch_size, source_seq_len, encoder_num_hidden), + source_length=(batch_size,), + decoder_state=(batch_size, decoder_num_hidden)) + + # TODO: test for other inputs (that are not equal at each source position) + executor.arg_dict["source"][:] = np.asarray([[[1., 2.], [1., 2.], [3., 4.]]]) + executor.arg_dict["source_length"][:] = np.asarray([2.0]) + executor.arg_dict["decoder_state"][:] = np.asarray([[5, 6]]) + exec_output = executor.forward() + context_result = exec_output[0].asnumpy() + attention_prob_result = exec_output[1].asnumpy() + + # expecting uniform attention_weights of 0.5: 0.5 * seq1 + 0.5 * seq2 + assert np.isclose(context_result, np.asarray([[1., 2.]])).all() + # equal attention to first two and no attention to third + assert np.isclose(attention_prob_result, np.asarray([[0.5, 0.5, 0.]])).all() + + +coverage_cases = [("gru", 10), ("tanh", 4), ("count", 1), ("sigmoid", 1), ("relu", 30)] + + +@pytest.mark.parametrize("attention_coverage_type,attention_coverage_num_hidden", coverage_cases) +def test_coverage_attention(attention_coverage_type, + attention_coverage_num_hidden, + batch_size=3, + encoder_num_hidden=2, + decoder_num_hidden=2): + # source: (batch_size, seq_len, encoder_num_hidden) + source = mx.sym.Variable("source") + # source_length: (batch_size, ) + source_length = mx.sym.Variable("source_length") + source_seq_len = 10 + + attention = sockeye.attention.get_attention(input_previous_word=False, + attention_type="coverage", + attention_num_hidden=5, + rnn_num_hidden=0, + max_seq_len=source_seq_len, + attention_coverage_type=attention_coverage_type, + attention_coverage_num_hidden=attention_coverage_num_hidden) + attention_state = attention.get_initial_state(source_length, source_seq_len) + attention_func = attention.on(source, source_length, source_seq_len) + attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"), mx.sym.Variable("decoder_state")) + attention_state = attention_func(attention_input, attention_state) + sym = mx.sym.Group([attention_state.context, attention_state.probs, attention_state.dynamic_source]) + + source_shape = (batch_size, source_seq_len, encoder_num_hidden) + source_length_shape = (batch_size,) + decoder_state_shape = (batch_size, decoder_num_hidden) + + executor = sym.simple_bind(ctx=mx.cpu(), + source=source_shape, + source_length=source_length_shape, + decoder_state=decoder_state_shape) + + source_length_vector = integer_vector(shape=source_length_shape, max_value=source_seq_len) + executor.arg_dict["source"][:] = gaussian_vector(shape=source_shape) + executor.arg_dict["source_length"][:] = source_length_vector + executor.arg_dict["decoder_state"][:] = gaussian_vector(shape=decoder_state_shape) + exec_output = executor.forward() + context_result = exec_output[0].asnumpy() + attention_prob_result = exec_output[1].asnumpy() + dynamic_source_result = exec_output[2].asnumpy() + + expected_probs = (1 / source_length_vector).reshape((batch_size, 1)) + expected_dynamic_source = (1 / source_length_vector).reshape((batch_size, 1)) + + assert context_result.shape == (batch_size, encoder_num_hidden) + assert attention_prob_result.shape == (batch_size, source_seq_len) + assert dynamic_source_result.shape == (batch_size, source_seq_len, attention_coverage_num_hidden) + assert (np.sum(np.isclose(attention_prob_result, expected_probs), axis=1) == source_length_vector).all() + + +def test_last_state_attention(batch_size=1, + encoder_num_hidden=2): + """ + EncoderLastStateAttention is a bit different from other attention mechanisms as it doesn't take a query argument + and doesn't return a probability distribution over the inputs (aka alignment). + """ + # source: (batch_size, seq_len, encoder_num_hidden) + source = mx.sym.Variable("source") + # source_length: (batch_size,) + source_length = mx.sym.Variable("source_length") + source_seq_len = 3 + + attention = sockeye.attention.get_attention(input_previous_word=False, + attention_type="fixed", + attention_num_hidden=0, + rnn_num_hidden=0, + max_seq_len=source_seq_len, + attention_coverage_type="", + attention_coverage_num_hidden=0) + attention_state = attention.get_initial_state(source_length, source_seq_len) + attention_func = attention.on(source, source_length, source_seq_len) + attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"), mx.sym.Variable("decoder_state")) + attention_state = attention_func(attention_input, attention_state) + sym = mx.sym.Group([attention_state.context, attention_state.probs]) + + executor = sym.simple_bind(ctx=mx.cpu(), + source=(batch_size, source_seq_len, encoder_num_hidden), + source_length=(batch_size,)) + + # TODO: test for other inputs (that are not equal at each source position) + executor.arg_dict["source"][:] = np.asarray([[[1., 2.], [1., 2.], [3., 4.]]]) + executor.arg_dict["source_length"][:] = np.asarray([2.0]) + exec_output = executor.forward() + context_result = exec_output[0].asnumpy() + attention_prob_result = exec_output[1].asnumpy() + + # expecting attention on last state based on source_length + assert np.isclose(context_result, np.asarray([[1., 2.]])).all() + assert np.isclose(attention_prob_result, np.asarray([[0., 1.0, 0.]])).all() + + +def test_get_context_and_attention_probs(): + source = mx.sym.Variable('source') + source_length = mx.sym.Variable('source_length') + attention_scores = mx.sym.Variable('scores') + context, att_probs = sockeye.attention.get_context_and_attention_probs(source, source_length, attention_scores) + sym = mx.sym.Group([context, att_probs]) + assert len(sym.list_arguments()) == 3 + + batch_size, seq_len, num_hidden = 32, 50, 100 + + # data + source_nd = mx.nd.random_normal(shape=(batch_size, seq_len, num_hidden)) + source_length_np = np.random.random_integers(1, seq_len, (batch_size,)) + source_length_nd = mx.nd.array(source_length_np) + scores_nd = mx.nd.zeros((batch_size, seq_len, 1)) + + in_shapes, out_shapes, _ = sym.infer_shape(source=source_nd.shape, + source_length=source_length_nd.shape, + scores=scores_nd.shape) + + assert in_shapes == [(batch_size, seq_len, num_hidden), (batch_size, seq_len, 1), (batch_size,)] + assert out_shapes == [(batch_size, num_hidden), (batch_size, seq_len)] + + context, probs = sym.eval(source=source_nd, + source_length=source_length_nd, + scores=scores_nd) + + expected_probs = (1. / source_length_nd).reshape((batch_size, 1)).asnumpy() + assert (np.sum(np.isclose(probs.asnumpy(), expected_probs), axis=1) == source_length_np).all() diff --git a/test/test_bleu.py b/test/test_bleu.py new file mode 100644 index 000000000..9c4281645 --- /dev/null +++ b/test/test_bleu.py @@ -0,0 +1,26 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest +import sockeye.bleu + + +test_cases = [(["this is a test", "another test"], ["ref1", "ref2"], 0.003799178428257963), + (["this is a test"], ["this is a test"], 1.0), + (["this is a fest"], ["this is a test"], 0.223606797749979)] + + +@pytest.mark.parametrize("hypotheses, references, expected_bleu", test_cases) +def test_bleu(hypotheses, references, expected_bleu): + bleu = sockeye.bleu.corpus_bleu(hypotheses, references) + assert abs(bleu - expected_bleu) < 1e-8 \ No newline at end of file diff --git a/test/test_callback.py b/test/test_callback.py new file mode 100644 index 000000000..cfa6bfe41 --- /dev/null +++ b/test/test_callback.py @@ -0,0 +1,69 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Tests sockeye.callback.TrainingMonitor optimization logic +""" +import pytest +import numpy as np +import sockeye.callback +import tempfile +import os + +test_constants = [('perplexity', np.inf, True, + [{'perplexity': 100.0, '_': 42}, {'perplexity': 50.0}, {'perplexity': 60.0}, {'perplexity': 80.0}], + [{'perplexity': 200.0}, {'perplexity': 100.0}, {'perplexity': 100.001}, {'perplexity': 99.99}], + [True, True, False, True]), + ('accuracy', -np.inf, False, + [{'accuracy': 100.0}, {'accuracy': 50.0}, {'accuracy': 60.0}, {'accuracy': 80.0}], + [{'accuracy': 200.0}, {'accuracy': 100.0}, {'accuracy': 100.001}, {'accuracy': 99.99}], + [True, False, False, False])] + + +class DummyMetric(object): + def __init__(self, metric_dict): + self.metric_dict = metric_dict + + def get_name_value(self): + for metric_name, value in self.metric_dict.items(): + yield metric_name, value + + +@pytest.mark.parametrize("optimized_metric, initial_best, minimize, train_metrics, eval_metrics, improved_seq", + test_constants) +def test_callback(optimized_metric, initial_best, minimize, train_metrics, eval_metrics, improved_seq): + with tempfile.TemporaryDirectory() as tmpdir: + batch_size = 32 + monitor = sockeye.callback.TrainingMonitor(batch_size=batch_size, + output_folder=tmpdir, + optimized_metric=optimized_metric) + assert monitor.optimized_metric == optimized_metric + assert monitor.get_best_validation_score() == initial_best + assert monitor.minimize == minimize + + for checkpoint, (train_metric, eval_metric, expected_improved) in enumerate( + zip(train_metrics, eval_metrics, improved_seq), 1): + monitor.checkpoint_callback(checkpoint, DummyMetric(train_metric)) + assert len(monitor.metrics) == checkpoint + assert monitor.metrics[-1] == {k + "-train": v for k, v in train_metric.items()} + improved, best_checkpoint = monitor.eval_end_callback(checkpoint, DummyMetric(eval_metric)) + assert {k + "-val" for k in eval_metric.keys()} <= monitor.metrics[-1].keys() + assert improved == expected_improved + + +def test_bleu_requires_checkpoint_decoder(): + with pytest.raises(AssertionError), tempfile.TemporaryDirectory() as tmpdir: + sockeye.callback.TrainingMonitor(batch_size=1, + output_folder=tmpdir, + optimized_metric='bleu', + checkpoint_decoder=None) diff --git a/test/test_coverage.py b/test/test_coverage.py new file mode 100644 index 000000000..990546592 --- /dev/null +++ b/test/test_coverage.py @@ -0,0 +1,138 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import mxnet as mx +import numpy as np +import pytest +import sockeye.coverage +from test.test_utils import gaussian_vector, integer_vector, uniform_vector + +activation_types = ["tanh", "sigmoid", "relu", "softrelu"] + + +@pytest.mark.parametrize("act_type", activation_types) +def test_activation_coverage(act_type): + encoder_num_hidden, decoder_num_hidden, coverage_num_hidden, source_seq_len, batch_size = 5, 5, 2, 10, 4 + + # source: (batch_size, source_seq_len, encoder_num_hidden) + source = mx.sym.Variable("source") + # source_length: (batch_size,) + source_length = mx.sym.Variable("source_length") + # prev_hidden: (batch_size, decoder_num_hidden) + prev_hidden = mx.sym.Variable("prev_hidden") + # prev_coverage: (batch_size, source_seq_len, coverage_num_hidden) + prev_coverage = mx.sym.Variable("prev_coverage") + # attention_scores: (batch_size, source_seq_len) + attention_scores = mx.sym.Variable("attention_scores") + + source_shape = (batch_size, source_seq_len, encoder_num_hidden) + source_length_shape = (batch_size,) + prev_hidden_shape = (batch_size, decoder_num_hidden) + attention_scores_shape = (batch_size, source_seq_len, 1) + prev_coverage_shape = (batch_size, source_seq_len, coverage_num_hidden) + + source_data = gaussian_vector(shape=source_shape) + source_length_data = integer_vector(shape=source_length_shape, max_value=source_seq_len) + prev_hidden_data = gaussian_vector(shape=prev_hidden_shape) + prev_coverage_data = gaussian_vector(shape=prev_coverage_shape) + attention_scores_data = uniform_vector(shape=attention_scores_shape) + attention_scores_data = attention_scores_data / np.sum(attention_scores_data) + + coverage = sockeye.coverage.get_coverage(coverage_type=act_type, coverage_num_hidden=coverage_num_hidden) + coverage_func = coverage.on(source, source_length, source_seq_len) + updated_coverage = coverage_func(prev_hidden, attention_scores, prev_coverage) + + executor = updated_coverage.simple_bind(ctx=mx.cpu(), + source=source_shape, + source_length=source_length_shape, + prev_hidden=prev_hidden_shape, + prev_coverage=prev_coverage_shape, + attention_scores=attention_scores_shape) + + executor.arg_dict["source"][:] = source_data + executor.arg_dict["source_length"][:] = source_length_data + executor.arg_dict["prev_hidden"][:] = prev_hidden_data + executor.arg_dict["prev_coverage"][:] = prev_coverage_data + executor.arg_dict["attention_scores"][:] = attention_scores_data + + result = executor.forward() + + # this is needed to modulate the 0 input. The output changes according to the activation type used. + activation = mx.sym.Activation(name="activation", act_type=act_type) + modulated = activation.eval(ctx=mx.cpu(), activation_data=mx.nd.zeros((1,)))[0].asnumpy() + + new_coverage = result[0].asnumpy() + + assert new_coverage.shape == prev_coverage_shape + # For this to work the mask value in sockeye.coverage.mask_coverage needs to be set to something != 0 as + # it will otherwise be identical to the output of the coverage_func for some activations (e.g. tanh). + # What this test does is that finds all words for which the coverage is 0 (i.e. all words that have + # not been masked). It then checks whether the number of these words equals the sentence length + # TODO: at the moment I have set the value of the mask to 1 -> this is not ideal + # assert (np.sum(np.sum(new_coverage == modulated, axis=2) != 0, axis=1) == source_length_data).all() + + +def test_gru_coverage(): + encoder_num_hidden, decoder_num_hidden, coverage_num_hidden, source_seq_len, batch_size = 5, 5, 2, 10, 4 + + # source: (batch_size, source_seq_len, encoder_num_hidden) + source = mx.sym.Variable("source") + # source_length: (batch_size,) + source_length = mx.sym.Variable("source_length") + # prev_hidden: (batch_size, decoder_num_hidden) + prev_hidden = mx.sym.Variable("prev_hidden") + # prev_coverage: (batch_size, source_seq_len, coverage_num_hidden) + prev_coverage = mx.sym.Variable("prev_coverage") + # attention_scores: (batch_size, source_seq_len) + attention_scores = mx.sym.Variable("attention_scores") + + source_shape = (batch_size, source_seq_len, encoder_num_hidden) + source_length_shape = (batch_size,) + prev_hidden_shape = (batch_size, decoder_num_hidden) + attention_scores_shape = (batch_size, source_seq_len) + prev_coverage_shape = (batch_size, source_seq_len, coverage_num_hidden) + + source_data = gaussian_vector(shape=source_shape) + source_length_data = integer_vector(shape=source_length_shape, max_value=source_seq_len) + prev_hidden_data = gaussian_vector(shape=prev_hidden_shape) + prev_coverage_data = gaussian_vector(shape=prev_coverage_shape) + attention_scores_data = uniform_vector(shape=attention_scores_shape) + attention_scores_data = attention_scores_data / np.sum(attention_scores_data) + + coverage = sockeye.coverage.get_coverage(coverage_type="gru", coverage_num_hidden=coverage_num_hidden) + coverage_func = coverage.on(source, source_length, source_seq_len) + updated_coverage = coverage_func(prev_hidden, attention_scores, prev_coverage) + + executor = updated_coverage.simple_bind(ctx=mx.cpu(), + source=source_shape, + source_length=source_length_shape, + prev_hidden=prev_hidden_shape, + prev_coverage=prev_coverage_shape, + attention_scores=attention_scores_shape) + + executor.arg_dict["source"][:] = source_data + executor.arg_dict["source_length"][:] = source_length_data + executor.arg_dict["prev_hidden"][:] = prev_hidden_data + executor.arg_dict["prev_coverage"][:] = prev_coverage_data + executor.arg_dict["attention_scores"][:] = attention_scores_data + + result = executor.forward() + new_coverage = result[0].asnumpy() + + assert new_coverage.shape == prev_coverage_shape + # For this to work the mask value in sockeye.coverage.mask_coverage needs to be set to something != 0 as + # it will otherwise be identical to the output of the coverage_func for some activations (e.g. tanh). + # What this test does is that finds all words for which the coverage is 0 (i.e. all words that have + # not been masked). It then checks whether the number of these words equals the sentence length + # TODO: at the moment I have set the value of the mask to 1 -> this is not ideal + # assert (np.sum(np.sum(new_coverage != 1, axis=2) != 0, axis=1) == source_length_data).all() diff --git a/test/test_data_io.py b/test/test_data_io.py new file mode 100644 index 000000000..373e74ceb --- /dev/null +++ b/test/test_data_io.py @@ -0,0 +1,72 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest +import sockeye.data_io +import sockeye.constants as C + +define_bucket_tests = [(50, 10, [10, 20, 30, 40, 50]), + (50, 50, [50]), + (5, 10, [5]), + (11, 10, [10, 20])] + + +@pytest.mark.parametrize("max_seq_len, step, expected_buckets", define_bucket_tests) +def test_define_buckets(max_seq_len, step, expected_buckets): + buckets = sockeye.data_io.define_buckets(max_seq_len, step=step) + assert buckets == expected_buckets + + +define_parallel_bucket_tests = [(50, 10, 1.0, [(10, 10), (20, 20), (30, 30), (40, 40), (50, 50)]), + (50, 10, 0.5, [(10, 5), (20, 10), (30, 15), (40, 20), (50, 25)]), + (50, 50, 0.5, [(50, 25)]), + (50, 50, 1.5, [(50, 50)]), + (75, 50, 1.5, [(50, 75)]), + (10, 2, 0.567, [(2, 1), (4, 2), (6, 3), (8, 4), (10, 5)])] + + +@pytest.mark.parametrize("max_seq_len, bucket_width, length_ratio, expected_buckets", define_parallel_bucket_tests) +def test_define_parallel_buckets(max_seq_len, bucket_width, length_ratio, expected_buckets): + buckets = sockeye.data_io.define_parallel_buckets(max_seq_len, bucket_width=bucket_width, length_ratio=length_ratio) + assert buckets == expected_buckets + + +get_bucket_tests = [(10, [10, 20, 30, 40, 50], 10), + (11, [10], None), + (2, [1, 4, 8], 4)] + + +@pytest.mark.parametrize("seq_len, buckets, expected_bucket", get_bucket_tests) +def test_get_bucket(seq_len, buckets, expected_bucket): + bucket = sockeye.data_io.get_bucket(seq_len, buckets) + assert bucket == expected_bucket + + +get_tokens_tests = [("this is a line \n", ["this", "is", "a", "line"]), + (" a \tb \r \n", ["a", "b"])] + + +@pytest.mark.parametrize("line, expected_tokens", get_tokens_tests) +def test_get_tokens(line, expected_tokens): + tokens = list(sockeye.data_io.get_tokens(line)) + assert tokens == expected_tokens + + +tokens2ids_tests = [(["a", "b", "c"], {"a": 1, "b": 0, "c": 300, C.UNK_SYMBOL: 12}, [1, 0, 300]), + (["a", "x", "c"], {"a": 1, "b": 0, "c": 300, C.UNK_SYMBOL: 12}, [1, 12, 300])] + + +@pytest.mark.parametrize("tokens, vocab, expected_ids", tokens2ids_tests) +def test_tokens2ids(tokens, vocab, expected_ids): + ids = sockeye.data_io.tokens2ids(tokens, vocab) + assert ids == expected_ids diff --git a/test/test_decoder.py b/test/test_decoder.py new file mode 100644 index 000000000..12ab33d18 --- /dev/null +++ b/test/test_decoder.py @@ -0,0 +1,97 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import mxnet as mx +import pytest +import sockeye.decoder +import sockeye.attention +import sockeye.constants as C +from test.test_utils import gaussian_vector, integer_vector + + +step_tests = [(C.GRU_TYPE, True), (C.LSTM_TYPE, False)] + + +@pytest.mark.parametrize("cell_type, context_gating", step_tests) +def test_step(cell_type, context_gating, + num_embed=2, + encoder_num_hidden=5, + decoder_num_hidden=5): + + attention_num_hidden, vocab_size, num_layers, \ + batch_size, source_seq_len, coverage_num_hidden = 2, 10, 1, 10, 7, 2 + + # (batch_size, source_seq_len, encoder_num_hidden) + source = mx.sym.Variable("source") + source_shape = (batch_size, source_seq_len, encoder_num_hidden) + # (batch_size,) + source_length = mx.sym.Variable("source_length") + source_length_shape = (batch_size,) + # (batch_size, num_embed) + word_vec_prev = mx.sym.Variable("word_vec_prev") + word_vec_prev_shape = (batch_size, num_embed) + # (batch_size, decoder_num_hidden) + hidden_prev = mx.sym.Variable("hidden_prev") + hidden_prev_shape = (batch_size, decoder_num_hidden) + # List(mx.sym.Symbol(batch_size, decoder_num_hidden) + states_shape = (batch_size, decoder_num_hidden) + + attention = sockeye.attention.get_attention(input_previous_word=False, + attention_type="coverage", + attention_num_hidden=attention_num_hidden, + rnn_num_hidden=decoder_num_hidden, + max_seq_len=source_seq_len, + attention_coverage_type="tanh", + attention_coverage_num_hidden=coverage_num_hidden) + attention_state = attention.get_initial_state(source_length, source_seq_len) + attention_func = attention.on(source, source_length, source_seq_len) + + decoder = sockeye.decoder.get_decoder(num_embed=num_embed, + vocab_size=vocab_size, + num_layers=num_layers, + rnn_num_hidden=decoder_num_hidden, + attention=attention, + cell_type=cell_type, + residual=False, + forget_bias=0., + dropout=0., + weight_tying=False, + lexicon=None, + context_gating=context_gating) + + if cell_type == C.GRU_TYPE: + layer_states = [gaussian_vector(shape=states_shape, return_symbol=True) for _ in range(num_layers)] + elif cell_type == C.LSTM_TYPE: + layer_states = [gaussian_vector(shape=states_shape, return_symbol=True) for _ in range(num_layers*2)] + + state, attention_state = decoder._step(word_vec_prev=word_vec_prev, + state=sockeye.decoder.DecoderState(hidden_prev, layer_states), + attention_func=attention_func, + attention_state=attention_state) + sym = mx.sym.Group([state.hidden, attention_state.probs, attention_state.dynamic_source]) + + executor = sym.simple_bind(ctx=mx.cpu(), + source=source_shape, + source_length=source_length_shape, + word_vec_prev=word_vec_prev_shape, + hidden_prev=hidden_prev_shape) + executor.arg_dict["source"][:] = gaussian_vector(source_shape) + executor.arg_dict["source_length"][:] = integer_vector(source_length_shape, source_seq_len) + executor.arg_dict["word_vec_prev"][:] = gaussian_vector(word_vec_prev_shape) + executor.arg_dict["hidden_prev"][:] = gaussian_vector(hidden_prev_shape) + executor.arg_dict["states"] = layer_states + hidden_result, attention_probs_result, attention_dynamic_source_result = executor.forward() + + assert hidden_result.shape == hidden_prev_shape + assert attention_probs_result.shape == (batch_size, source_seq_len) + assert attention_dynamic_source_result.shape == (batch_size, source_seq_len, coverage_num_hidden) diff --git a/test/test_loss.py b/test/test_loss.py new file mode 100644 index 000000000..15fcea1cd --- /dev/null +++ b/test/test_loss.py @@ -0,0 +1,135 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import mxnet as mx +import numpy as np + +import sockeye.constants as C +import sockeye.loss +import sockeye.model + + +def test_cross_entropy_loss(): + loss = sockeye.loss.get_loss(sockeye.model.ModelConfig(loss=C.CROSS_ENTROPY)) + assert isinstance(loss, sockeye.loss.CrossEntropyLoss) + + logits = mx.sym.Variable("logits") + labels = mx.sym.Variable("labels") + sym = mx.sym.Group(loss.get_loss(logits, labels)) + + assert sym.list_arguments() == ['logits', 'labels'] + assert sym.list_outputs() == [C.SOFTMAX_NAME + "_output"] + + logits_np = mx.nd.array([[1, 2, 3, 4], + [4, 2, 2, 2], + [3, 3, 3, 3], + [4, 4, 4, 4]]) + labels_np = mx.nd.array([1, 0, 2, 3]) # C.PAD_ID == 0 + + expected_softmax = np.asarray([[0.0320586, 0.08714432, 0.23688284, 0.64391428], + [0.71123451, 0.09625512, 0.09625512, 0.09625512], + [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25]]) + expected_grads = np.asarray([[0.0320586, -0.91285568, 0.23688284, 0.64391428], + [0., 0., 0., 0.], + [0.25, 0.25, -0.75, 0.25], + [0.25, 0.25, 0.25, -0.75]]) + + _, out_shapes, _ = (sym.infer_shape(logits=logits_np.shape, labels=labels_np.shape)) + assert out_shapes[0] == logits_np.shape + + executor = sym.simple_bind(ctx=mx.cpu(), + logits=logits_np.shape, + labels=labels_np.shape) + executor.arg_dict["logits"][:] = logits_np + executor.arg_dict["labels"][:] = labels_np + softmax = executor.forward(is_train=True)[0].asnumpy() + assert np.isclose(softmax, expected_softmax).all() + + executor.backward() + grads = executor.grad_dict["logits"].asnumpy() + assert np.isclose(grads, expected_grads).all() + label_grad_sum = executor.grad_dict["labels"].asnumpy().sum() + assert label_grad_sum == 0 + + +def test_smoothed_cross_entropy_loss(): + alpha = 0.5 + vocab_target_size = 4 + loss = sockeye.loss.get_loss(sockeye.model.ModelConfig(loss=C.SMOOTHED_CROSS_ENTROPY, + vocab_target_size=vocab_target_size, + smoothed_cross_entropy_alpha=alpha)) + assert isinstance(loss, sockeye.loss.SmoothedCrossEntropyLoss) + + logits = mx.sym.Variable("logits") + labels = mx.sym.Variable("labels") + sym = mx.sym.Group(loss.get_loss(logits, labels)) + + assert sym.list_arguments() == ['labels', 'logits'] + assert sym.list_outputs() == [C.SMOOTHED_CROSS_ENTROPY + "_output", C.SOFTMAX_NAME + "_output"] + + logits_np = mx.nd.array([[1, 2, 3, 4], + [4, 2, 2, 2], + [3, 3, 3, 3], + [4, 4, 4, 4]]) + labels_np = mx.nd.array([1, 0, 2, 3]) # C.PAD_ID == 0 + + expected_softmax = np.asarray([[0.0320586, 0.08714432, 0.23688284, 0.64391428], + [0.71123451, 0.09625512, 0.09625512, 0.09625512], + [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25]]) + expected_cross_entropy = np.asarray([2.10685635, 0., 1.38629436, 1.38629436]) + expected_grads = np.asarray([[-0.13460806, -0.41285568, 0.07021617, 0.4772476], + [0., 0., 0., 0.], + [0.08333333, 0.08333333, -0.25, 0.08333333], + [0.08333333, 0.08333333, 0.08333333, -0.25]]) + + _, out_shapes, _ = (sym.infer_shape(logits=logits_np.shape, labels=labels_np.shape)) + assert len(out_shapes) == 2 + assert out_shapes[0] == (4,) + assert out_shapes[1] == logits_np.shape + + executor = sym.simple_bind(ctx=mx.cpu(), + logits=logits_np.shape, + labels=labels_np.shape) + executor.arg_dict["logits"][:] = logits_np + executor.arg_dict["labels"][:] = labels_np + outputs = executor.forward(is_train=True) + smoothed_cross_entropy = outputs[0].asnumpy() + softmax = outputs[1].asnumpy() + assert np.isclose(softmax, expected_softmax).all() + assert np.isclose(smoothed_cross_entropy, expected_cross_entropy).all() + + executor.backward() + grads = executor.grad_dict["logits"].asnumpy() + assert np.isclose(grads, expected_grads).all() + label_grad_sum = executor.grad_dict["labels"].asnumpy().sum() + assert label_grad_sum == 0 + + +def test_normalize(): + loss = mx.sym.Variable("loss") + labels = mx.sym.Variable("labels") + + normalized_loss = sockeye.loss._normalize(loss, labels) + executor = normalized_loss.simple_bind(loss=(2, 2), labels=(2, 2), ctx=mx.cpu()) + executor.arg_dict["loss"][:] = np.asarray([[0., 2.], [0., 4.]]) + executor.arg_dict["labels"][:] = np.asarray([[0, 4], [0, 5]]) + executor.forward() + normalized_loss_np = executor.outputs[0].asnumpy() + + expected_normalized_loss = np.asarray([[0.0, 1.0], [0.0, 2.0]]) + + assert np.isclose(normalized_loss_np, expected_normalized_loss).all() + + diff --git a/test/test_lr_scheduler.py b/test/test_lr_scheduler.py new file mode 100644 index 000000000..d94a55e85 --- /dev/null +++ b/test/test_lr_scheduler.py @@ -0,0 +1,28 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from sockeye.lr_scheduler import LearningRateSchedulerInvSqrtT, LearningRateSchedulerInvT +import pytest + + +def test_lr_scheduler(): + updates_per_epoch = 13 + half_life_num_epochs = 3 + + schedulers = [LearningRateSchedulerInvT(updates_per_epoch, half_life_num_epochs), + LearningRateSchedulerInvSqrtT(updates_per_epoch, half_life_num_epochs)] + for scheduler in schedulers: + scheduler.base_lr = 1.0 + # test correct half-life: + + assert scheduler(updates_per_epoch * half_life_num_epochs) == pytest.approx(0.5) diff --git a/test/test_output_handler.py b/test/test_output_handler.py new file mode 100644 index 000000000..4837c2df6 --- /dev/null +++ b/test/test_output_handler.py @@ -0,0 +1,53 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest +import io +import numpy as np +from sockeye.inference import TranslatorInput, TranslatorOutput +import sockeye.output_handler + +stream_handler_tests = [(sockeye.output_handler.StringOutputHandler(io.StringIO()), + TranslatorInput(id=0, sentence="a test", tokens=None), + TranslatorOutput(id=0, translation="ein Test", tokens=None, + attention_matrix=None, + score=0.), + "ein Test\n"), + (sockeye.output_handler.StringOutputHandler(io.StringIO()), + TranslatorInput(id=0, sentence="", tokens=None), + TranslatorOutput(id=0, translation="", tokens=None, + attention_matrix=None, + score=0.), + "\n"), + (sockeye.output_handler.StringWithAlignmentsOutputHandler(io.StringIO(), threshold=0.5), + TranslatorInput(id=0, sentence="a test", tokens=None), + TranslatorOutput(id=0, translation="ein Test", tokens=None, + attention_matrix=np.asarray([[1, 0], + [0, 1]]), + score=0.), + "ein Test\t0-0 1-1\n"), + (sockeye.output_handler.StringWithAlignmentsOutputHandler(io.StringIO(), threshold=0.5), + TranslatorInput(id=0, sentence="a test", tokens=None), + TranslatorOutput(id=0, translation="ein Test !", tokens=None, + attention_matrix=np.asarray([[0.4, 0.6], + [0.8, 0.2], + [0.5, 0.5]]), + score=0.), + "ein Test !\t0-1 1-0\n"), + ] + + +@pytest.mark.parametrize("handler, translation_input, translation_output, expected_string", stream_handler_tests) +def test_stream_output_handler(handler, translation_input, translation_output, expected_string): + handler.handle(translation_input, translation_output) + assert handler.stream.getvalue() == expected_string diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 000000000..22f471eef --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,68 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import sockeye.utils +import numpy as np +import mxnet as mx +import numpy as np + + +def test_get_alignments(): + attention_matrix = np.asarray([[0.1, 0.4, 0.5], + [0.2, 0.8, 0.0], + [0.4, 0.4, 0.2]]) + test_cases = [(0.5, [(1, 1)]), + (0.8, []), + (0.1, [(0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 2)])] + + for threshold, expected_alignment in test_cases: + alignment = list(sockeye.utils.get_alignments(attention_matrix, threshold=threshold)) + assert alignment == expected_alignment + + +def gaussian_vector(shape, return_symbol=False): + """ + Generates random normal tensors (diagonal covariance) + + :param shape: shape of the tensor. + :param return_symbol: True if the result should be a Symbol, False if it should be an Numpy array. + :return: A gaussian tensor. + """ + return mx.sym.random_normal(shape=shape) if return_symbol else np.random.normal(size=shape) + + +def integer_vector(shape, max_value, return_symbol=False): + """ + Generates a random positive integer tensor + + :param shape: shape of the tensor. + :param max_value: maximum integer value. + :param return_symbol: True if the result should be a Symbol, False if it should be an Numpy array. + :return: A random integer tensor. + """ + return mx.sym.round(mx.sym.random_uniform(shape=shape) * max_value) if return_symbol \ + else np.round(np.random.uniform(size=shape) * max_value) + + +def uniform_vector(shape, min_value=0, max_value=1, return_symbol=False): + """ + Generates a uniformly random tensor + + :param shape: shape of the tensor + :param min_value: minimum possible value + :param max_value: maximum possible value (exclusive) + :param return_symbol: True if the result should be a mx.sym.Symbol, False if it should be a Numpy array + :return: + """ + return mx.sym.random_uniform(low=min_value, high=max_value, shape=shape) if return_symbol \ + else np.random.uniform(low=min_value, high=max_value, size=shape) diff --git a/test/test_vocab.py b/test/test_vocab.py new file mode 100644 index 000000000..df36d53d5 --- /dev/null +++ b/test/test_vocab.py @@ -0,0 +1,54 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest + +import sockeye.constants as C +from sockeye.vocab import build_vocab + +test_vocab = [ + # Example 1 + (["one two three", "one two three"], 3, 1, {"": 0, "": 1, "": 2, "": 3, "two": 4, "three": 5, "one": 6}), + (["one two three", "one two three"], 3, 2, {"": 0, "": 1, "": 2, "": 3, "two": 4, "three": 5, "one": 6}), + (["one two three", "one two three"], 2, 2, {"": 0, "": 1, "": 2, "": 3, "two": 4, "three": 5}), + # Example 2 + (["one one two three ", "one two three"], 3, 1, {"": 0, "": 1, "": 2, "": 3, "one": 4, "two": 5, "three": 6}), + (["one one two three ", "one two three"], 3, 2, {"": 0, "": 1, "": 2, "": 3, "one": 4, "two": 5, "three": 6}), + (["one one two three ", "one two three"], 3, 3, {"": 0, "": 1, "": 2, "": 3, "one": 4}), + (["one one two three ", "one two three"], 2, 1, {"": 0, "": 1, "": 2, "": 3, "one": 4, "two": 5}), + ] + + +@pytest.mark.parametrize("data,size,min_count,expected", test_vocab) +def test_build_vocab(data, size, min_count, expected): + vocab = build_vocab(data, size, min_count) + assert vocab == expected + +test_constants = [ + # Example 1 + (["one two three", "one two three"], 3, 1, C.VOCAB_SYMBOLS), + (["one two three", "one two three"], 3, 2, C.VOCAB_SYMBOLS), + (["one two three", "one two three"], 2, 2, C.VOCAB_SYMBOLS), + # Example 2 + (["one one two three ", "one two three"], 3, 1, C.VOCAB_SYMBOLS), + (["one one two three ", "one two three"], 3, 2, C.VOCAB_SYMBOLS), + (["one one two three ", "one two three"], 3, 3, C.VOCAB_SYMBOLS), + (["one one two three ", "one two three"], 2, 1, C.VOCAB_SYMBOLS), + ] + + +@pytest.mark.parametrize("data,size,min_count,constants", test_constants) +def test_constants_in_vocab(data, size, min_count, constants): + vocab = build_vocab(data, size, min_count) + for const in constants: + assert const in vocab diff --git a/typechecked-files b/typechecked-files new file mode 100644 index 000000000..5c6a87678 --- /dev/null +++ b/typechecked-files @@ -0,0 +1,5 @@ +sockeye/bleu.py +sockeye/constants.py +sockeye/rnn.py +sockeye/attention.py +sockeye/lr_scheduler.py