From 531417f4f9e07be8c687747729029186bd9c4f86 Mon Sep 17 00:00:00 2001 From: David Vo Date: Sat, 16 Sep 2023 00:10:32 +1000 Subject: [PATCH] Tighten Talon marks type hints (#1848) This removes a bunch of `Any`s from the Talon code related to marks. Ref: #725 ## Checklist - [ ] ~I have added [tests](https://www.cursorless.org/docs/contributing/test-case-recorder/)~ - [ ] ~I have updated the [docs](https://github.com/cursorless-dev/cursorless/tree/main/docs) and [cheatsheet](https://github.com/cursorless-dev/cursorless/tree/main/cursorless-talon/src/cheatsheet)~ - [x] I have not broken the cheatsheet --- src/marks/decorated_mark.py | 11 ++++++----- src/marks/lines_number.py | 22 ++++++++++------------ src/marks/mark.py | 6 +++--- src/marks/mark_types.py | 31 +++++++++++++++++++++++++++++++ src/marks/simple_mark.py | 4 +++- src/targets/target_types.py | 8 +++++--- 6 files changed, 58 insertions(+), 24 deletions(-) create mode 100644 src/marks/mark_types.py diff --git a/src/marks/decorated_mark.py b/src/marks/decorated_mark.py index 2a92f77c..75675ee8 100644 --- a/src/marks/decorated_mark.py +++ b/src/marks/decorated_mark.py @@ -4,6 +4,7 @@ from talon import Module, actions, cron, fs from ..csv_overrides import init_csv_and_watch_changes +from .mark_types import DecoratedSymbol mod = Module() @@ -28,9 +29,9 @@ def cursorless_grapheme(m) -> str: @mod.capture( rule="[{user.cursorless_hat_color}] [{user.cursorless_hat_shape}] " ) -def cursorless_decorated_symbol(m) -> dict[str, Any]: +def cursorless_decorated_symbol(m) -> DecoratedSymbol: """A decorated symbol""" - hat_color = getattr(m, "cursorless_hat_color", "default") + hat_color: str = getattr(m, "cursorless_hat_color", "default") try: hat_style_name = f"{hat_color}-{m.cursorless_hat_shape}" except AttributeError: @@ -82,10 +83,10 @@ def cursorless_decorated_symbol(m) -> dict[str, Any]: } FALLBACK_COLOR_ENABLEMENT = DEFAULT_COLOR_ENABLEMENT -unsubscribe_hat_styles = None +unsubscribe_hat_styles: Any = None -def setup_hat_styles_csv(hat_colors: dict, hat_shapes: dict): +def setup_hat_styles_csv(hat_colors: dict[str, str], hat_shapes: dict[str, str]): global unsubscribe_hat_styles ( @@ -149,7 +150,7 @@ def setup_hat_styles_csv(hat_colors: dict, hat_shapes: dict): slow_reload_job = None -def init_hats(hat_colors: dict, hat_shapes: dict): +def init_hats(hat_colors: dict[str, str], hat_shapes: dict[str, str]): setup_hat_styles_csv(hat_colors, hat_shapes) vscode_settings_path: Path = actions.user.vscode_settings_path().resolve() diff --git a/src/marks/lines_number.py b/src/marks/lines_number.py index 4f892cd1..a1ff66fe 100644 --- a/src/marks/lines_number.py +++ b/src/marks/lines_number.py @@ -1,10 +1,10 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any from talon import Module from ..targets.range_target import RangeConnective +from .mark_types import LineNumber, LineNumberMark, LineNumberType mod = Module() @@ -14,8 +14,8 @@ @dataclass class CustomizableTerm: cursorlessIdentifier: str - type: str - formatter: Callable + type: LineNumberType + formatter: Callable[[int], int] # NOTE: Please do not change these dicts. Use the CSVs for customization. @@ -35,15 +35,13 @@ class CustomizableTerm: "[ ]" ) ) -def cursorless_line_number(m) -> dict[str, Any]: +def cursorless_line_number(m) -> LineNumber: direction = directions_map[m.cursorless_line_direction] - anchor = create_line_number_mark( - direction.type, direction.formatter(m.private_cursorless_number_small_list[0]) - ) - if len(m.private_cursorless_number_small_list) > 1: + numbers: list[int] = m.private_cursorless_number_small_list + anchor = create_line_number_mark(direction.type, direction.formatter(numbers[0])) + if len(numbers) > 1: active = create_line_number_mark( - direction.type, - direction.formatter(m.private_cursorless_number_small_list[1]), + direction.type, direction.formatter(numbers[1]) ) range_connective: RangeConnective = m.cursorless_range_connective return { @@ -56,9 +54,9 @@ def cursorless_line_number(m) -> dict[str, Any]: return anchor -def create_line_number_mark(line_number_type: str, line_number: int) -> dict[str, Any]: +def create_line_number_mark(type: LineNumberType, line_number: int) -> LineNumberMark: return { "type": "lineNumber", - "lineNumberType": line_number_type, + "lineNumberType": type, "lineNumber": line_number, } diff --git a/src/marks/mark.py b/src/marks/mark.py index 0231ebb9..18e21abb 100644 --- a/src/marks/mark.py +++ b/src/marks/mark.py @@ -1,7 +1,7 @@ -from typing import Any - from talon import Module +from .mark_types import Mark + mod = Module() @@ -12,5 +12,5 @@ "" # row (ie absolute mod 100), up, down ) ) -def cursorless_mark(m) -> dict[str, Any]: +def cursorless_mark(m) -> Mark: return m[0] diff --git a/src/marks/mark_types.py b/src/marks/mark_types.py new file mode 100644 index 00000000..3985b7d4 --- /dev/null +++ b/src/marks/mark_types.py @@ -0,0 +1,31 @@ +from typing import Literal, TypedDict, Union + + +class DecoratedSymbol(TypedDict): + type: Literal["decoratedSymbol"] + symbolColor: str + character: str + + +SimpleMark = dict[Literal["type"], str] + +LineNumberType = Literal["modulo100", "relative"] + + +class LineNumberMark(TypedDict): + type: Literal["lineNumber"] + lineNumberType: LineNumberType + lineNumber: int + + +class LineNumberRange(TypedDict): + type: Literal["range"] + anchor: LineNumberMark + active: LineNumberMark + excludeAnchor: bool + excludeActive: bool + + +LineNumber = Union[LineNumberMark, LineNumberRange] + +Mark = Union[DecoratedSymbol, SimpleMark, LineNumber] diff --git a/src/marks/simple_mark.py b/src/marks/simple_mark.py index 4359d05d..a98d3d84 100644 --- a/src/marks/simple_mark.py +++ b/src/marks/simple_mark.py @@ -1,5 +1,7 @@ from talon import Module +from .mark_types import SimpleMark + mod = Module() mod.list("cursorless_simple_mark", desc="Cursorless simple marks") @@ -15,7 +17,7 @@ @mod.capture(rule="{user.cursorless_simple_mark}") -def cursorless_simple_mark(m) -> dict[str, str]: +def cursorless_simple_mark(m) -> SimpleMark: return { "type": simple_marks[m.cursorless_simple_mark], } diff --git a/src/targets/target_types.py b/src/targets/target_types.py index 6e130ce9..49e02832 100644 --- a/src/targets/target_types.py +++ b/src/targets/target_types.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union + +from ..marks.mark_types import Mark RangeTargetType = Literal["vertical"] @@ -7,8 +9,8 @@ @dataclass class PrimitiveTarget: type = "primitive" - mark: Optional[dict] - modifiers: Optional[list[dict]] + mark: Optional[Mark] + modifiers: Optional[list[dict[str, Any]]] @dataclass