Skip to content

Commit

Permalink
Appease mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
domdfcoding committed Jul 8, 2024
1 parent 63a0fe3 commit 722840c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion snapshottest/diff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from termcolor import colored
from fastdiff import compare
from fastdiff import compare # type: ignore[import]

from .sorted_dict import SortedDict
from .formatter import Formatter
Expand Down
9 changes: 5 additions & 4 deletions snapshottest/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import importlib.util
from collections import defaultdict
import logging
from typing import Dict

from .snapshot import Snapshot
from .formatter import Formatter
Expand Down Expand Up @@ -32,16 +33,16 @@ def _load_source(module_name, filepath):
https://discuss.python.org/t/how-do-i-migrate-from-imp/27885
"""
spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
# As a performance optimization, store loaded module for further use.
# https://docs.python.org/3.11/library/sys.html#sys.modules
sys.modules[module_name] = module
spec.loader.exec_module(module)
spec.loader.exec_module(module) # type: ignore[union-attr]
return module


class SnapshotModule(object):
_snapshot_modules = {}
_snapshot_modules: Dict = {}

def __init__(self, module, filepath):
self._original_snapshot = None
Expand Down Expand Up @@ -276,7 +277,7 @@ def assert_equals(self, value, snapshot):
assert value == snapshot

def assert_match(self, value, name=""):
self.curr_snapshot = name or self.snapshot_counter
self.curr_snapshot = name or self.snapshot_counter # type: ignore[assignment]
self.visit()
if self.update:
self.store(value)
Expand Down
8 changes: 4 additions & 4 deletions snapshottest/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ def reporting_lines(testing_cli):
successful_snapshots = SnapshotModule.stats_successful_snapshots()
bold = ["bold"]
if successful_snapshots:
yield (colored("{} snapshots passed", attrs=bold) + ".").format(
yield (colored("{} snapshots passed", attrs=bold) + ".").format( # type: ignore[arg-type]
successful_snapshots
)
new_snapshots = SnapshotModule.stats_new_snapshots()
if new_snapshots[0]:
yield (
colored("{} snapshots written", "green", attrs=bold) + " in {} test suites."
colored("{} snapshots written", "green", attrs=bold) + " in {} test suites." # type: ignore[arg-type]
).format(*new_snapshots)
inspect_str = colored(
"Inspect your code or run with `{} --snapshot-update` to update them.".format(
Expand All @@ -25,14 +25,14 @@ def reporting_lines(testing_cli):
failed_snapshots = SnapshotModule.stats_failed_snapshots()
if failed_snapshots[0]:
yield (
colored("{} snapshots failed", "red", attrs=bold)
colored("{} snapshots failed", "red", attrs=bold) # type: ignore[arg-type]
+ " in {} test suites. "
+ inspect_str
).format(*failed_snapshots)
unvisited_snapshots = SnapshotModule.stats_unvisited_snapshots()
if unvisited_snapshots[0]:
yield (
colored("{} snapshots deprecated", "yellow", attrs=bold)
colored("{} snapshots deprecated", "yellow", attrs=bold) # type: ignore[arg-type]
+ " in {} test suites. "
+ inspect_str
).format(*unvisited_snapshots)
Expand Down
18 changes: 9 additions & 9 deletions snapshottest/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class TestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""On inherited classes, run our `setUp` method"""
cls._snapshot_tests = []
cls._snapshot_file = inspect.getfile(cls)
cls._snapshot_tests = [] # type: ignore[attr-defined]
cls._snapshot_file = inspect.getfile(cls) # type: ignore[attr-defined]

if cls is not TestCase and cls.setUp is not TestCase.setUp:
orig_setUp = cls.setUp
Expand All @@ -56,8 +56,8 @@ def tearDownOverride(self, *args, **kwargs):
TestCase.tearDown(self)
return orig_tearDown(self, *args, **kwargs)

cls.setUp = setUpOverride
cls.tearDown = tearDownOverride
cls.setUp = setUpOverride # type: ignore[assignment]
cls.tearDown = tearDownOverride # type: ignore[assignment]

super(TestCase, cls).setUpClass()

Expand All @@ -70,8 +70,8 @@ def comparePrettyDifs(self, obj1, obj2, msg):

@classmethod
def tearDownClass(cls):
if cls._snapshot_tests:
module = SnapshotModule.get_module_for_testpath(cls._snapshot_file)
if cls._snapshot_tests: # type: ignore[attr-defined]
module = SnapshotModule.get_module_for_testpath(cls._snapshot_file) # type: ignore[attr-defined]
module.save()
super(TestCase, cls).tearDownClass()

Expand All @@ -82,18 +82,18 @@ def setUp(self):
self._snapshot = UnitTestSnapshotTest(
test_class=self.__class__,
test_id=self.id(),
test_filepath=self._snapshot_file,
test_filepath=self._snapshot_file, # type: ignore[attr-defined]
should_update=self.snapshot_should_update,
assertEqual=self.assertEqual,
)
self._snapshot_tests.append(self._snapshot)
self._snapshot_tests.append(self._snapshot) # type: ignore[attr-defined]
SnapshotTest._current_tester = self._snapshot

def tearDown(self):
"""Do some custom setup"""
# print dir(self.__module__)
SnapshotTest._current_tester = None
self._snapshot = None
self._snapshot = None # type: ignore[assignment]

def assert_match_snapshot(self, value, name=""):
self._snapshot.assert_match(value, name=name)
Expand Down
2 changes: 1 addition & 1 deletion tests/attrs_serde_tests/snapshots/snap_test_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# snapshottest: v1 - https://goo.gl/zC4yUc

# 3rd party
from snapshottest import Snapshot # type: ignore
from snapshottest import Snapshot

snapshots = Snapshot()

Expand Down

0 comments on commit 722840c

Please sign in to comment.