diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a11f4a5..8c50050 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # This file is managed by 'repo_helper'. Don't edit it directly. --- -exclude: ^$ +exclude: ^snapshottest/ ci: autoupdate_schedule: quarterly diff --git a/repo_helper.yml b/repo_helper.yml index 98f9018..9b44b36 100644 --- a/repo_helper.yml +++ b/repo_helper.yml @@ -13,6 +13,7 @@ use_whey: true min_coverage: 95 tox_testenv_extras: all standalone_contrib_guide: true +pre_commit_exclude: "^snapshottest/" conda_channels: - conda-forge diff --git a/snapshottest/LICENSE b/snapshottest/LICENSE new file mode 100644 index 0000000..fb89122 --- /dev/null +++ b/snapshottest/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017-Present Syrus Akbary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/snapshottest/README.md b/snapshottest/README.md new file mode 100644 index 0000000..b190395 --- /dev/null +++ b/snapshottest/README.md @@ -0,0 +1,141 @@ +# SnapshotTest [![travis][travis-image]][travis-url] [![pypi][pypi-image]][pypi-url] + +[travis-image]: https://img.shields.io/travis/syrusakbary/snapshottest.svg?style=flat +[travis-url]: https://travis-ci.org/syrusakbary/snapshottest +[pypi-image]: https://img.shields.io/pypi/v/snapshottest.svg?style=flat +[pypi-url]: https://pypi.python.org/pypi/snapshottest + + +Snapshot testing is a way to test your APIs without writing actual test cases. + +1. A snapshot is a single state of your API, saved in a file. +2. You have a set of snapshots for your API endpoints. +3. Once you add a new feature, you can generate *automatically* new snapshots for the updated API. + +## Installation + + $ pip install snapshottest + + +## Usage with unittest/nose + +```python +from snapshottest import TestCase + +class APITestCase(TestCase): + def test_api_me(self): + """Testing the API for /me""" + my_api_response = api.client.get('/me') + self.assertMatchSnapshot(my_api_response) + + # Set custom snapshot name: `gpg_response` + my_gpg_response = api.client.get('/me?gpg_key') + self.assertMatchSnapshot(my_gpg_response, 'gpg_response') +``` + +If you want to update the snapshots automatically you can use the `nosetests --snapshot-update`. + +Check the [Unittest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/unittest). + +## Usage with pytest + +```python +def test_mything(snapshot): + """Testing the API for /me""" + my_api_response = api.client.get('/me') + snapshot.assert_match(my_api_response) + + # Set custom snapshot name: `gpg_response` + my_gpg_response = api.client.get('/me?gpg_key') + snapshot.assert_match(my_gpg_response, 'gpg_response') +``` + +If you want to update the snapshots automatically you can use the `--snapshot-update` config. + +Check the [Pytest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/pytest). + +## Usage with django +Add to your settings: +```python +TEST_RUNNER = 'snapshottest.django.TestRunner' +``` +To create your snapshottest: +```python +from snapshottest.django import TestCase + +class APITestCase(TestCase): + def test_api_me(self): + """Testing the API for /me""" + my_api_response = api.client.get('/me') + self.assertMatchSnapshot(my_api_response) +``` +If you want to update the snapshots automatically you can use the `python manage.py test --snapshot-update`. +Check the [Django example](https://github.com/syrusakbary/snapshottest/tree/master/examples/django_project). + +## Disabling terminal colors + +Set the environment variable `ANSI_COLORS_DISABLED` (to any value), e.g. + + ANSI_COLORS_DISABLED=1 pytest + + +# Contributing + +After cloning this repo and configuring a virtualenv for snapshottest (optional, but highly recommended), ensure dependencies are installed by running: + +```sh +make develop +``` + +After developing, ensure your code is formatted properly by running: + +```sh +make format-fix +``` + +and then run the full test suite with: + +```sh +make lint +# and +make test +``` + +To test locally on all supported Python versions, you can use +[tox](https://tox.readthedocs.io/): + +```sh +pip install tox # (if you haven't before) +tox +``` + +# Notes + +This package is heavily inspired in [jest snapshot testing](https://facebook.github.io/jest/docs/snapshot-testing.html). + +# Reasons to use this package + +> Most of this content is taken from the [Jest snapshot blogpost](https://facebook.github.io/jest/blog/2016/07/27/jest-14.html). + +We want to make it as frictionless as possible to write good tests that are useful. +We observed that when engineers are provided with ready-to-use tools, they end up writing more tests, which in turn results in stable and healthy code bases. + +However engineers frequently spend more time writing a test than the component itself. As a result many people stopped writing tests altogether which eventually led to instabilities. + +A typical snapshot test case for a mobile app renders a UI component, takes a screenshot, then compares it to a reference image stored alongside the test. The test will fail if the two images do not match: either the change is unexpected, or the screenshot needs to be updated to the new version of the UI component. + + +## Snapshot Testing with SnapshotTest + +A similar approach can be taken when it comes to testing your APIs. +Instead of rendering the graphical UI, which would require building the entire app, you can use a test renderer to quickly generate a serializable value for your API response. + + +## License + +[MIT License](https://github.com/syrusakbary/snapshottest/blob/master/LICENSE) + +[![coveralls][coveralls-image]][coveralls-url] + +[coveralls-image]: https://coveralls.io/repos/syrusakbary/snapshottest/badge.svg?branch=master&service=github +[coveralls-url]: https://coveralls.io/github/syrusakbary/snapshottest?branch=master diff --git a/snapshottest/__init__.py b/snapshottest/__init__.py new file mode 100644 index 0000000..8db737d --- /dev/null +++ b/snapshottest/__init__.py @@ -0,0 +1,7 @@ +from .snapshot import Snapshot +from .generic_repr import GenericRepr +from .module import assert_match_snapshot +from .unittest import TestCase + + +__all__ = ["Snapshot", "GenericRepr", "assert_match_snapshot", "TestCase"] diff --git a/snapshottest/diff.py b/snapshottest/diff.py new file mode 100644 index 0000000..3fde50a --- /dev/null +++ b/snapshottest/diff.py @@ -0,0 +1,39 @@ +from termcolor import colored +from fastdiff import compare + +from .sorted_dict import SortedDict +from .formatter import Formatter + + +def format_line(line): + line = line.rstrip("\n") + if line.startswith("-"): + return colored(line, "green", attrs=["bold"]) + elif line.startswith("+"): + return colored(line, "red", attrs=["bold"]) + elif line.startswith("?"): + return colored("") + colored(line, "yellow", attrs=["bold"]) + + return colored("") + colored(line, "white", attrs=["dark"]) + + +class PrettyDiff(object): + def __init__(self, obj, snapshottest): + self.pretty = Formatter() + self.snapshottest = snapshottest + if isinstance(obj, dict): + obj = SortedDict(obj) + self.obj = self.pretty(obj) + + def __eq__(self, other): + return isinstance(other, PrettyDiff) and self.obj == other.obj + + def __repr__(self): + return repr(self.obj) + + def get_diff(self, other): + text1 = "Received \n\n" + self.pretty(self.obj) + text2 = "Snapshot \n\n" + self.pretty(other) + + lines = list(compare(text2, text1)) + return [format_line(line) for line in lines] diff --git a/snapshottest/django.py b/snapshottest/django.py new file mode 100644 index 0000000..9d20b9c --- /dev/null +++ b/snapshottest/django.py @@ -0,0 +1,61 @@ +from django.test import TestCase as dTestCase +from django.test import SimpleTestCase as dSimpleTestCase +from django.test.runner import DiscoverRunner + +from snapshottest.reporting import reporting_lines +from .unittest import TestCase as uTestCase +from .module import SnapshotModule + + +class TestRunnerMixin(object): + separator1 = "=" * 70 + separator2 = "-" * 70 + + def __init__(self, snapshot_update=False, **kwargs): + super(TestRunnerMixin, self).__init__(**kwargs) + uTestCase.snapshot_should_update = snapshot_update + + @classmethod + def add_arguments(cls, parser): + super(TestRunnerMixin, cls).add_arguments(parser) + parser.add_argument( + "--snapshot-update", + default=False, + action="store_true", + dest="snapshot_update", + help="Update the snapshots automatically.", + ) + + def run_tests(self, test_labels, extra_tests=None, **kwargs): + result = super(TestRunnerMixin, self).run_tests( + test_labels=test_labels, extra_tests=extra_tests, **kwargs + ) + self.print_report() + if TestCase.snapshot_should_update: + for module in SnapshotModule.get_modules(): + module.delete_unvisited() + module.save() + + return result + + def print_report(self): + lines = list(reporting_lines("python manage.py test")) + if lines: + print("\n" + self.separator1) + print("SnapshotTest summary") + print(self.separator2) + for line in lines: + print(line) + print(self.separator1) + + +class TestRunner(TestRunnerMixin, DiscoverRunner): + pass + + +class TestCase(uTestCase, dTestCase): + pass + + +class SimpleTestCase(uTestCase, dSimpleTestCase): + pass diff --git a/snapshottest/error.py b/snapshottest/error.py new file mode 100644 index 0000000..da0ff8a --- /dev/null +++ b/snapshottest/error.py @@ -0,0 +1,11 @@ +class SnapshotError(Exception): + pass + + +class SnapshotNotFound(SnapshotError): + def __init__(self, module, test_name): + super(SnapshotNotFound, self).__init__( + "Snapshot '{snapshot_id!s}' not found in {snapshot_file!s}".format( + snapshot_id=test_name, snapshot_file=module.filepath + ) + ) diff --git a/snapshottest/file.py b/snapshottest/file.py new file mode 100644 index 0000000..3a5c494 --- /dev/null +++ b/snapshottest/file.py @@ -0,0 +1,73 @@ +import os +import shutil +import filecmp + +from .formatter import Formatter +from .formatters import BaseFormatter + + +class FileSnapshot(object): + def __init__(self, path): + """ + Create a file snapshot pointing to the specified `path`. In a snapshot, `path` + is considered to be relative to the test module's "snapshots" folder. (This is + done to prevent ugly path manipulations inside the snapshot file.) + """ + self.path = path + + def __repr__(self): + return "FileSnapshot({})".format(repr(self.path)) + + def __eq__(self, other): + return self.path == other.path + + +class FileSnapshotFormatter(BaseFormatter): + def can_format(self, value): + return isinstance(value, FileSnapshot) + + def store(self, test, value): + """ + Copy the file from the test location to the snapshot location. + + If the original test file has an extension, the snapshot file will + use the same extension. + """ + + file_snapshot_dir = self.get_file_snapshot_dir(test) + if not os.path.exists(file_snapshot_dir): + os.makedirs(file_snapshot_dir, 0o0700) + extension = os.path.splitext(value.path)[1] + snapshot_file = os.path.join(file_snapshot_dir, test.test_name) + extension + shutil.copy(value.path, snapshot_file) + relative_snapshot_filename = os.path.relpath( + snapshot_file, test.module.snapshot_dir + ) + return FileSnapshot(relative_snapshot_filename) + + def get_imports(self): + return (("snapshottest.file", "FileSnapshot"),) + + def format(self, value, indent, formatter): + return repr(value) + + def assert_value_matches_snapshot( + self, test, test_value, snapshot_value, formatter + ): + snapshot_path = os.path.join(test.module.snapshot_dir, snapshot_value.path) + files_identical = filecmp.cmp(test_value.path, snapshot_path, shallow=False) + assert files_identical, "Stored file differs from test file" + + @staticmethod + def get_file_snapshot_dir(test): + """ + Get the directory for storing file snapshots for `test`. + Snapshot files are stored under: + snapshots/snap_/ + Right next to where the snapshot module is stored: + snapshots/snap_.py + """ + return os.path.join(test.module.snapshot_dir, test.module.module) + + +Formatter.register_formatter(FileSnapshotFormatter()) diff --git a/snapshottest/formatter.py b/snapshottest/formatter.py new file mode 100644 index 0000000..3b7281d --- /dev/null +++ b/snapshottest/formatter.py @@ -0,0 +1,37 @@ +from .formatters import default_formatters + + +class Formatter(object): + formatters = default_formatters() + + def __init__(self, imports=None): + self.htchar = " " * 4 + self.lfchar = "\n" + self.indent = 0 + self.imports = imports + + def __call__(self, value, **args): + return self.format(value, self.indent) + + def format(self, value, indent): + formatter = self.get_formatter(value) + for module, import_name in formatter.get_imports(): + self.imports[module].add(import_name) + return formatter.format(value, indent, self) + + def normalize(self, value): + formatter = self.get_formatter(value) + return formatter.normalize(value, self) + + @staticmethod + def get_formatter(value): + for formatter in Formatter.formatters: + if formatter.can_format(value): + return formatter + + # This should never happen as GenericFormatter is registered by default. + raise RuntimeError("No formatter found for value") + + @staticmethod + def register_formatter(formatter): + Formatter.formatters.insert(0, formatter) diff --git a/snapshottest/formatters.py b/snapshottest/formatters.py new file mode 100644 index 0000000..39a0644 --- /dev/null +++ b/snapshottest/formatters.py @@ -0,0 +1,174 @@ +import math +from collections import defaultdict + +from .sorted_dict import SortedDict +from .generic_repr import GenericRepr + + +class BaseFormatter(object): + def can_format(self, value): + raise NotImplementedError() + + def format(self, value, indent, formatter): + raise NotImplementedError() + + def get_imports(self): + return () + + def assert_value_matches_snapshot( + self, test, test_value, snapshot_value, formatter + ): + test.assert_equals(formatter.normalize(test_value), snapshot_value) + + def store(self, test, value): + return value + + def normalize(self, value, formatter): + return value + + +class TypeFormatter(BaseFormatter): + def __init__(self, types, format_func): + self.types = types + self.format_func = format_func + + def can_format(self, value): + return isinstance(value, self.types) + + def format(self, value, indent, formatter): + return self.format_func(value, indent, formatter) + + +class CollectionFormatter(TypeFormatter): + def normalize(self, value, formatter): + iterator = iter(value.items()) if isinstance(value, dict) else iter(value) + # https://github.com/syrusakbary/snapshottest/issues/115 + # Normally we shouldn't need to turn this into a list, but some iterable + # constructors need a list not an iterator (e.g. unittest.mock.call). + return value.__class__([formatter.normalize(item) for item in iterator]) + + +class DefaultDictFormatter(TypeFormatter): + def normalize(self, value, formatter): + return defaultdict( + value.default_factory, (formatter.normalize(item) for item in value.items()) + ) + + +def trepr(s): + text = "\n".join([repr(line).lstrip("u")[1:-1] for line in s.split("\n")]) + quotes, dquotes = "'''", '"""' + if quotes in text: + if dquotes in text: + text = text.replace(quotes, "\\'\\'\\'") + else: + quotes = dquotes + return "%s%s%s" % (quotes, text, quotes) + + +def format_none(value, indent, formatter): + return "None" + + +def format_str(value, indent, formatter): + if "\n" in value: + # Is a multiline string, so we use '''{}''' for the repr + return trepr(value) + + # Snapshots are saved with `from __future__ import unicode_literals`, + # so the `u'...'` repr is unnecessary, even on Python 2 + return repr(value).lstrip("u") + + +def format_float(value, indent, formatter): + if math.isinf(value) or math.isnan(value): + return 'float("%s")' % repr(value) + return repr(value) + + +def format_std_type(value, indent, formatter): + return repr(value) + + +def format_dict(value, indent, formatter): + value = SortedDict(value) + items = [ + formatter.lfchar + + formatter.htchar * (indent + 1) + + formatter.format(key, indent) + + ": " + + formatter.format(value[key], indent + 1) + for key in value + ] + return "{%s}" % (",".join(items) + formatter.lfchar + formatter.htchar * indent) + + +def format_list(value, indent, formatter): + return "[%s]" % format_sequence(value, indent, formatter) + + +def format_sequence(value, indent, formatter): + items = [ + formatter.lfchar + + formatter.htchar * (indent + 1) + + formatter.format(item, indent + 1) + for item in value + ] + return ",".join(items) + formatter.lfchar + formatter.htchar * indent + + +def format_tuple(value, indent, formatter): + return "(%s%s" % ( + format_sequence(value, indent, formatter), + ",)" if len(value) == 1 else ")", + ) + + +def format_set(value, indent, formatter): + return "set([%s])" % format_sequence(value, indent, formatter) + + +def format_frozenset(value, indent, formatter): + return "frozenset([%s])" % format_sequence(value, indent, formatter) + + +class GenericFormatter(BaseFormatter): + def can_format(self, value): + return True + + def store(self, test, value): + return GenericRepr.from_value(value) + + def normalize(self, value, formatter): + return GenericRepr.from_value(value) + + def format(self, value, indent, formatter): + if not isinstance(value, GenericRepr): + value = GenericRepr.from_value(value) + return repr(value) + + def get_imports(self): + return [("snapshottest", "GenericRepr")] + + def assert_value_matches_snapshot( + self, test, test_value, snapshot_value, formatter + ): + test_value = GenericRepr.from_value(test_value) + # Assert equality between the representations to provide a nice textual diff. + test.assert_equals(test_value.representation, snapshot_value.representation) + + +def default_formatters(): + return [ + TypeFormatter(type(None), format_none), + DefaultDictFormatter(defaultdict, format_dict), + CollectionFormatter(dict, format_dict), + CollectionFormatter(tuple, format_tuple), + CollectionFormatter(list, format_list), + CollectionFormatter(set, format_set), + CollectionFormatter(frozenset, format_frozenset), + TypeFormatter((str,), format_str), + TypeFormatter((float,), format_float), + TypeFormatter((int, complex, bool, bytes), format_std_type), + GenericFormatter(), + ] diff --git a/snapshottest/generic_repr.py b/snapshottest/generic_repr.py new file mode 100644 index 0000000..0bf2287 --- /dev/null +++ b/snapshottest/generic_repr.py @@ -0,0 +1,22 @@ +class GenericRepr(object): + def __init__(self, representation): + self.representation = representation + + def __repr__(self): + return "GenericRepr({})".format(repr(self.representation)) + + def __eq__(self, other): + return ( + isinstance(other, GenericRepr) + and self.representation == other.representation + ) + + def __hash__(self): + return hash(self.representation) + + @staticmethod + def from_value(value): + representation = repr(value) + # Remove the hex id, if found. + representation = representation.replace(hex(id(value)), "0x100000000") + return GenericRepr(representation) diff --git a/snapshottest/module.py b/snapshottest/module.py new file mode 100644 index 0000000..3dc43a2 --- /dev/null +++ b/snapshottest/module.py @@ -0,0 +1,308 @@ +import codecs +import errno +import os +import sys +import importlib.util +from collections import defaultdict +import logging + +from .snapshot import Snapshot +from .formatter import Formatter +from .error import SnapshotNotFound + + +logger = logging.getLogger(__name__) + + +def _escape_quotes(text): + return text.replace("'", "\\'") + + +def _load_source(module_name, filepath): + """ + Replaces old imp.load_source() call. + + The imp module was dropped in Python 3.12 in favor of the importlib. + See: https://docs.python.org/3.11/library/imp.html#imp.load_module + + Following code was inspired by the importlib documentation example: + https://docs.python.org/3.12/library/importlib.html#importing-a-source-file-directly + + This approach has been also encouraged in the official mailing lists: + https://discuss.python.org/t/how-do-i-migrate-from-imp/27885 + """ + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + # As a performance optimization, store loaded module for further use. + # https://docs.python.org/3.11/library/sys.html#sys.modules + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +class SnapshotModule(object): + _snapshot_modules = {} + + def __init__(self, module, filepath): + self._original_snapshot = None + self._snapshots = None + self.module = module + self.filepath = filepath + self.imports = defaultdict(set) + self.visited_snapshots = set() + self.new_snapshots = set() + self.failed_snapshots = set() + self.imports["snapshottest"].add("Snapshot") + + def load_snapshots(self): + try: + source = _load_source(self.module, self.filepath) + # except FileNotFoundError: # Python 3 + except (IOError, OSError) as err: + if err.errno == errno.ENOENT: + return Snapshot() + else: + raise + else: + assert isinstance(source.snapshots, Snapshot) + return source.snapshots + + def visit(self, snapshot_name): + self.visited_snapshots.add(snapshot_name) + + def delete_unvisited(self): + for unvisited in self.unvisited_snapshots: + del self.snapshots[unvisited] + + @property + def unvisited_snapshots(self): + return set(self.snapshots.keys()) - self.visited_snapshots + + @classmethod + def total_unvisited_snapshots(cls): + unvisited_snapshots = 0 + unvisited_modules = 0 + for module in cls.get_modules(): + unvisited_snapshot_len = len(module.unvisited_snapshots) + unvisited_snapshots += unvisited_snapshot_len + unvisited_modules += min(unvisited_snapshot_len, 1) + + return unvisited_snapshots, unvisited_modules + + @classmethod + def get_modules(cls): + return SnapshotModule._snapshot_modules.values() + + @classmethod + def stats_for_module(cls, getter): + count_snapshots = 0 + count_modules = 0 + for module in SnapshotModule._snapshot_modules.values(): + length = getter(module) + count_snapshots += length + count_modules += min(length, 1) + + return count_snapshots, count_modules + + @classmethod + def stats_unvisited_snapshots(cls): + return cls.stats_for_module(lambda module: len(module.unvisited_snapshots)) + + @classmethod + def stats_visited_snapshots(cls): + return cls.stats_for_module(lambda module: len(module.visited_snapshots)) + + @classmethod + def stats_new_snapshots(cls): + return cls.stats_for_module(lambda module: len(module.new_snapshots)) + + @classmethod + def stats_failed_snapshots(cls): + return cls.stats_for_module(lambda module: len(module.failed_snapshots)) + + @classmethod + def stats_successful_snapshots(cls): + stats_visited = cls.stats_visited_snapshots() + stats_failed = cls.stats_failed_snapshots() + return stats_visited[0] - stats_failed[0] + + @classmethod + def has_snapshots(cls): + return cls.stats_visited_snapshots()[0] > 0 + + @property + def original_snapshot(self): + if not self._original_snapshot: + self._original_snapshot = self.load_snapshots() + return self._original_snapshot + + @property + def snapshots(self): + if not self._snapshots: + self._snapshots = Snapshot(self.original_snapshot) + return self._snapshots + + def __getitem__(self, test_name): + try: + return self.snapshots[test_name] + except KeyError: + raise SnapshotNotFound(self, test_name) + + def __setitem__(self, key, value): + if key not in self.snapshots: + # It's a new test + self.new_snapshots.add(key) + self.snapshots[key] = value + + def mark_failed(self, key): + return self.failed_snapshots.add(key) + + @property + def snapshot_dir(self): + return os.path.dirname(self.filepath) + + def save(self): + if self.original_snapshot == self.snapshots: + # If there are no changes, we do nothing + return + + # Create the snapshot dir in case doesn't exist + try: + os.makedirs(self.snapshot_dir, 0o0700) + except (IOError, OSError): + pass + + # Create __init__.py in case doesn't exist + open(os.path.join(self.snapshot_dir, "__init__.py"), "a").close() + + pretty = Formatter(self.imports) + + with codecs.open(self.filepath, "w", encoding="utf-8") as snapshot_file: + snapshots_declarations = [ + """snapshots['{}'] = {}""".format( + _escape_quotes(key), pretty(self.snapshots[key]) + ) + for key in sorted(self.snapshots.keys()) + ] + + imports = "\n".join( + [ + "from {} import {}".format( + module, ", ".join(sorted(module_imports)) + ) + for module, module_imports in sorted(self.imports.items()) + ] + ) + snapshot_file.write( + """# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +{} + + +snapshots = Snapshot() + +{} +""".format( + imports, "\n\n".join(snapshots_declarations) + ) + ) + + @classmethod + def get_module_for_testpath(cls, test_filepath): + if test_filepath not in cls._snapshot_modules: + dirname = os.path.dirname(test_filepath) + snapshot_dir = os.path.join(dirname, "snapshots") + + snapshot_basename = "snap_{}.py".format( + os.path.splitext(os.path.basename(test_filepath))[0] + ) + snapshot_filename = os.path.join(snapshot_dir, snapshot_basename) + snapshot_module = "{}".format(os.path.splitext(snapshot_basename)[0]) + + cls._snapshot_modules[test_filepath] = SnapshotModule( + snapshot_module, snapshot_filename + ) + + return cls._snapshot_modules[test_filepath] + + +class SnapshotTest(object): + _current_tester = None + + def __init__(self): + self.curr_snapshot = "" + self.snapshot_counter = 1 + + @property + def module(self): + raise NotImplementedError("module property needs to be implemented") + + @property + def update(self): + return False + + @property + def test_name(self): + raise NotImplementedError("test_name property needs to be implemented") + + def __enter__(self): + SnapshotTest._current_tester = self + return self + + def __exit__(self, type, value, tb): + self.save_changes() + SnapshotTest._current_tester = None + + def visit(self): + self.module.visit(self.test_name) + + def fail(self): + self.module.mark_failed(self.test_name) + + def store(self, data): + formatter = Formatter.get_formatter(data) + data = formatter.store(self, data) + self.module[self.test_name] = data + + def assert_value_matches_snapshot(self, test_value, snapshot_value): + formatter = Formatter.get_formatter(test_value) + formatter.assert_value_matches_snapshot( + self, test_value, snapshot_value, Formatter() + ) + + def assert_equals(self, value, snapshot): + assert value == snapshot + + def assert_match(self, value, name=""): + self.curr_snapshot = name or self.snapshot_counter + self.visit() + if self.update: + self.store(value) + else: + try: + prev_snapshot = self.module[self.test_name] + except SnapshotNotFound: + self.store(value) # first time this test has been seen + else: + try: + self.assert_value_matches_snapshot(value, prev_snapshot) + except AssertionError: + self.fail() + raise + + if not name: + self.snapshot_counter += 1 + + def save_changes(self): + self.module.save() + + +def assert_match_snapshot(value, name=""): + if not SnapshotTest._current_tester: + raise Exception( + "You need to use assert_match_snapshot in the SnapshotTest context." + ) + + SnapshotTest._current_tester.assert_match(value, name) diff --git a/snapshottest/nose.py b/snapshottest/nose.py new file mode 100644 index 0000000..9d0e6b4 --- /dev/null +++ b/snapshottest/nose.py @@ -0,0 +1,61 @@ +import logging +import os + +from nose.plugins import Plugin + +from .module import SnapshotModule +from .reporting import reporting_lines +from .unittest import TestCase + +log = logging.getLogger("nose.plugins.snapshottest") + + +class SnapshotTestPlugin(Plugin): + name = "snapshottest" + enabled = True + + separator1 = "=" * 70 + separator2 = "-" * 70 + + def options(self, parser, env=os.environ): + super(SnapshotTestPlugin, self).options(parser, env=env) + parser.add_option( + "--snapshot-update", + action="store_true", + default=False, + dest="snapshot_update", + help="Update the snapshots.", + ) + parser.add_option( + "--snapshot-disable", + action="store_true", + dest="snapshot_disable", + default=False, + help="Disable special SnapshotTest", + ) + + def configure(self, options, conf): + super(SnapshotTestPlugin, self).configure(options, conf) + self.snapshot_update = options.snapshot_update + self.enabled = not options.snapshot_disable + + def wantClass(self, cls): + if issubclass(cls, TestCase): + cls.snapshot_should_update = self.snapshot_update + + def afterContext(self): + if self.snapshot_update: + for module in SnapshotModule.get_modules(): + module.delete_unvisited() + module.save() + + def report(self, stream): + if not SnapshotModule.has_snapshots(): + return + + stream.writeln(self.separator1) + stream.writeln("SnapshotTest summary") + stream.writeln(self.separator2) + for line in reporting_lines("nosetests"): + stream.writeln(line) + stream.writeln(self.separator1) diff --git a/snapshottest/pytest.py b/snapshottest/pytest.py new file mode 100644 index 0000000..b820e1a --- /dev/null +++ b/snapshottest/pytest.py @@ -0,0 +1,92 @@ +import pytest +import re + +from .module import SnapshotModule, SnapshotTest +from .diff import PrettyDiff +from .reporting import reporting_lines, diff_report + + +def pytest_addoption(parser): + group = parser.getgroup("snapshottest") + group.addoption( + "--snapshot-update", + action="store_true", + default=False, + dest="snapshot_update", + help="Update the snapshots.", + ) + group.addoption( + "--snapshot-verbose", + action="store_true", + default=False, + help="Dump diagnostic and progress information.", + ) + + +class PyTestSnapshotTest(SnapshotTest): + def __init__(self, request=None): + self.request = request + super(PyTestSnapshotTest, self).__init__() + + @property + def module(self): + return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath) + + @property + def update(self): + return self.request.config.option.snapshot_update + + @property + def test_name(self): + cls_name = getattr(self.request.node.cls, "__name__", "") + flattened_node_name = re.sub( + r"\s+", " ", self.request.node.name.replace(r"\n", " ") + ) + return "{}{} {}".format( + "{}.".format(cls_name) if cls_name else "", + flattened_node_name, + self.curr_snapshot, + ) + + +class SnapshotSession(object): + def __init__(self, config): + self.verbose = config.getoption("snapshot_verbose") + self.config = config + + def display(self, tr): + if not SnapshotModule.has_snapshots(): + return + + tr.write_sep("=", "SnapshotTest summary") + + for line in reporting_lines("pytest"): + tr.write_line(line) + + +def pytest_assertrepr_compare(op, left, right): + if isinstance(left, PrettyDiff) and op == "==": + return diff_report(left, right) + + +@pytest.fixture +def snapshot(request): + with PyTestSnapshotTest(request) as snapshot_test: + yield snapshot_test + + +def pytest_terminal_summary(terminalreporter): + if terminalreporter.config.option.snapshot_update: + for module in SnapshotModule.get_modules(): + module.delete_unvisited() + module.save() + + terminalreporter.config._snapshotsession.display(terminalreporter) + + +# force the other plugins to initialise first +# (fixes issue with capture not being properly initialised) +@pytest.mark.trylast +def pytest_configure(config): + config._snapshotsession = SnapshotSession(config) + # config.pluginmanager.register(bs, "snapshottest") diff --git a/snapshottest/reporting.py b/snapshottest/reporting.py new file mode 100644 index 0000000..26ca51f --- /dev/null +++ b/snapshottest/reporting.py @@ -0,0 +1,60 @@ +import os +from termcolor import colored + +from .module import SnapshotModule + + +def reporting_lines(testing_cli): + successful_snapshots = SnapshotModule.stats_successful_snapshots() + bold = ["bold"] + if successful_snapshots: + yield (colored("{} snapshots passed", attrs=bold) + ".").format( + successful_snapshots + ) + new_snapshots = SnapshotModule.stats_new_snapshots() + if new_snapshots[0]: + yield ( + colored("{} snapshots written", "green", attrs=bold) + " in {} test suites." + ).format(*new_snapshots) + inspect_str = colored( + "Inspect your code or run with `{} --snapshot-update` to update them.".format( + testing_cli + ), + attrs=["dark"], + ) + failed_snapshots = SnapshotModule.stats_failed_snapshots() + if failed_snapshots[0]: + yield ( + colored("{} snapshots failed", "red", attrs=bold) + + " in {} test suites. " + + inspect_str + ).format(*failed_snapshots) + unvisited_snapshots = SnapshotModule.stats_unvisited_snapshots() + if unvisited_snapshots[0]: + yield ( + colored("{} snapshots deprecated", "yellow", attrs=bold) + + " in {} test suites. " + + inspect_str + ).format(*unvisited_snapshots) + + +def diff_report(left, right): + return [ + "stored snapshot should match the received value", + "", + colored("> ") + + colored("Received value", "red", attrs=["bold"]) + + colored(" does not match ", attrs=["bold"]) + + colored( + "stored snapshot `{}`".format( + left.snapshottest.test_name, + ), + "green", + attrs=["bold"], + ) + + colored(".", attrs=["bold"]), + colored("") + + "> " + + os.path.relpath(left.snapshottest.module.filepath, os.getcwd()), + "", + ] + left.get_diff(right) diff --git a/snapshottest/snapshot.py b/snapshottest/snapshot.py new file mode 100644 index 0000000..200227f --- /dev/null +++ b/snapshottest/snapshot.py @@ -0,0 +1,5 @@ +from collections import OrderedDict + + +class Snapshot(OrderedDict): + pass diff --git a/snapshottest/sorted_dict.py b/snapshottest/sorted_dict.py new file mode 100644 index 0000000..9e27c62 --- /dev/null +++ b/snapshottest/sorted_dict.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + + +class SortedDict(OrderedDict): + def __init__(self, values): + super(SortedDict, self).__init__() + + try: + sorted_items = sorted(values.items()) + except TypeError: + # Enums are not sortable + sorted_items = values.items() + for key, value in sorted_items: + if isinstance(value, dict): + self[key] = SortedDict(value) + elif isinstance(value, list): + self[key] = self._sort_list(value) + else: + self[key] = value + + def _sort_list(self, value): + def sort(val): + if isinstance(val, dict): + return SortedDict(val) + elif isinstance(val, list): + return self._sort_list(val) + else: + return val + + return [sort(item) for item in value] diff --git a/snapshottest/unittest.py b/snapshottest/unittest.py new file mode 100644 index 0000000..535b24a --- /dev/null +++ b/snapshottest/unittest.py @@ -0,0 +1,101 @@ +import unittest +import inspect + +from .module import SnapshotModule, SnapshotTest +from .diff import PrettyDiff +from .reporting import diff_report + + +class UnitTestSnapshotTest(SnapshotTest): + def __init__(self, test_class, test_id, test_filepath, should_update, assertEqual): + self.test_class = test_class + self.test_id = test_id + self.test_filepath = test_filepath + self.assertEqual = assertEqual + self.should_update = should_update + super(UnitTestSnapshotTest, self).__init__() + + @property + def module(self): + return SnapshotModule.get_module_for_testpath(self.test_filepath) + + @property + def update(self): + return self.should_update + + def assert_equals(self, value, snapshot): + self.assertEqual(value, snapshot) + + @property + def test_name(self): + class_name = self.test_class.__name__ + test_name = self.test_id.split(".")[-1] + return "{}::{} {}".format(class_name, test_name, self.curr_snapshot) + + +# Inspired by https://gist.github.com/twolfson/13f5f5784f67fd49b245 +class TestCase(unittest.TestCase): + + snapshot_should_update = False + + @classmethod + def setUpClass(cls): + """On inherited classes, run our `setUp` method""" + cls._snapshot_tests = [] + cls._snapshot_file = inspect.getfile(cls) + + if cls is not TestCase and cls.setUp is not TestCase.setUp: + orig_setUp = cls.setUp + orig_tearDown = cls.tearDown + + def setUpOverride(self, *args, **kwargs): + TestCase.setUp(self) + return orig_setUp(self, *args, **kwargs) + + def tearDownOverride(self, *args, **kwargs): + TestCase.tearDown(self) + return orig_tearDown(self, *args, **kwargs) + + cls.setUp = setUpOverride + cls.tearDown = tearDownOverride + + super(TestCase, cls).setUpClass() + + def comparePrettyDifs(self, obj1, obj2, msg): + # self + # assert obj1 == obj2 + if not (obj1 == obj2): + raise self.failureException("\n".join(diff_report(obj1, obj2))) + # raise self.failureException("DIFF") + + @classmethod + def tearDownClass(cls): + if cls._snapshot_tests: + module = SnapshotModule.get_module_for_testpath(cls._snapshot_file) + module.save() + super(TestCase, cls).tearDownClass() + + def setUp(self): + """Do some custom setup""" + # print dir(self.__module__) + self.addTypeEqualityFunc(PrettyDiff, self.comparePrettyDifs) + self._snapshot = UnitTestSnapshotTest( + test_class=self.__class__, + test_id=self.id(), + test_filepath=self._snapshot_file, + should_update=self.snapshot_should_update, + assertEqual=self.assertEqual, + ) + self._snapshot_tests.append(self._snapshot) + SnapshotTest._current_tester = self._snapshot + + def tearDown(self): + """Do some custom setup""" + # print dir(self.__module__) + SnapshotTest._current_tester = None + self._snapshot = None + + def assert_match_snapshot(self, value, name=""): + self._snapshot.assert_match(value, name=name) + + assertMatchSnapshot = assert_match_snapshot diff --git a/tests/conftest.py b/tests/conftest.py index f3658c5..2b6c221 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import asyncio # noqa: F401 # Fixes intermittent error where the import causes a KeyError import pathlib -pytest_plugins = ("coincidence", "sphinx_toolbox.testing") +pytest_plugins = ("coincidence", "sphinx_toolbox.testing", "snapshottest.pytest") repo_root = pathlib.Path(__file__).parent.parent diff --git a/tests/requirements.txt b/tests/requirements.txt index 74e7447..c71c915 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,8 +1,10 @@ # snapshottest>=0.6.0 +# snapshottest@ git+https://github.com/MarcellPerger1/snapshottest.git@dont-use-imp alabaster<=0.7.13 coincidence>=0.2.0 coverage>=5.1 coverage-pyver-pragma>=0.2.1 +fastdiff<1,>=0.1.4 importlib-metadata>=3.6.0 mypy<1.8.0; platform_python_implementation == "CPython" pytest<7.2.0,>=6.0.0 @@ -14,10 +16,10 @@ pytest-randomly>=3.7.0 pytest-timeout>=1.4.2 sdjson>=0.3.0 setuptools>=59.6.0 -snapshottest@ git+https://github.com/MarcellPerger1/snapshottest.git@dont-use-imp sphinxcontrib-applehelp<=1.0.4 sphinxcontrib-devhelp<=1.0.2 sphinxcontrib-htmlhelp<=2.0.1 sphinxcontrib-jsmath<=1.0.1 sphinxcontrib-qthelp<=1.0.3 sphinxcontrib-serializinghtml<=1.1.5 +termcolor>=2.4.0