Skip to content

Commit

Permalink
Replacing upper case types with lower case types for Python 3.9+
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Dec 1, 2024
1 parent f3b0805 commit da7b4de
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 56 deletions.
37 changes: 20 additions & 17 deletions src/tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pprint import pformat
from shlex import quote, split
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
from typing_inspect import is_literal_type

from tap.utils import (
Expand Down Expand Up @@ -53,9 +53,9 @@ def __init__(
*args,
underscores_to_dashes: bool = False,
explicit_bool: bool = False,
config_files: Optional[List[PathLike]] = None,
config_files: Optional[list[PathLike]] = None,
**kwargs,
):
) -> None:
"""Initializes the Tap instance.
:param args: Arguments passed to the super class ArgumentParser.
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
self.argument_buffer = {}

# Create a place to put the subparsers
self._subparser_buffer: List[Tuple[str, type, Dict[str, Any]]] = []
self._subparser_buffer: list[tuple[str, type, dict[str, Any]]] = []

# Get class variables help strings from the comments
self.class_variables = self._get_class_variables()
Expand Down Expand Up @@ -369,7 +369,7 @@ def configure(self) -> None:
pass

@staticmethod
def get_reproducibility_info(repo_path: Optional[PathLike] = None) -> Dict[str, str]:
def get_reproducibility_info(repo_path: Optional[PathLike] = None) -> dict[str, str]:
"""Gets a dictionary of reproducibility information.
Reproducibility information always includes:
Expand Down Expand Up @@ -405,7 +405,7 @@ def get_reproducibility_info(repo_path: Optional[PathLike] = None) -> Dict[str,

return reproducibility

def _log_all(self, repo_path: Optional[PathLike] = None) -> Dict[str, Any]:
def _log_all(self, repo_path: Optional[PathLike] = None) -> dict[str, Any]:
"""Gets all arguments along with reproducibility information.
:param repo_path: Path to the git repo to examine for reproducibility info.
Expand All @@ -418,7 +418,10 @@ def _log_all(self, repo_path: Optional[PathLike] = None) -> Dict[str, Any]:
return arg_log

def parse_args(
self: TapType, args: Optional[Sequence[str]] = None, known_only: bool = False, legacy_config_parsing=False
self: TapType,
args: Optional[Sequence[str]] = None,
known_only: bool = False,
legacy_config_parsing: bool = False,
) -> TapType:
"""Parses arguments, sets attributes of self equal to the parsed arguments, and processes arguments.
Expand Down Expand Up @@ -483,7 +486,7 @@ def parse_args(
return self

@classmethod
def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union[Dict[str, Any], Dict]:
def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union[dict[str, Any], dict]:
"""Returns a dictionary mapping variable names to values.
Variables and values are extracted from classes using key starting
Expand Down Expand Up @@ -518,7 +521,7 @@ def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union

return dictionary

def _get_class_dict(self) -> Dict[str, Any]:
def _get_class_dict(self) -> dict[str, Any]:
"""Returns a dictionary mapping class variable names to values from the class dict."""
class_dict = self._get_from_self_and_super(
extract_func=lambda super_class: dict(getattr(super_class, "__dict__", dict()))
Expand All @@ -531,7 +534,7 @@ def _get_class_dict(self) -> Dict[str, Any]:

return class_dict

def _get_annotations(self) -> Dict[str, Any]:
def _get_annotations(self) -> dict[str, Any]:
"""Returns a dictionary mapping variable names to their type annotations."""
return self._get_from_self_and_super(extract_func=lambda super_class: dict(get_type_hints(super_class)))

Expand Down Expand Up @@ -559,15 +562,15 @@ def _get_class_variables(self) -> dict:

return class_variables

def _get_argument_names(self) -> Set[str]:
def _get_argument_names(self) -> set[str]:
"""Returns a list of variable names corresponding to the arguments."""
return (
{get_dest(*name_or_flags, **kwargs) for name_or_flags, kwargs in self.argument_buffer.values()}
| set(self._get_class_dict().keys())
| set(self._annotations.keys())
) - {"help"}

def as_dict(self) -> Dict[str, Any]:
def as_dict(self) -> dict[str, Any]:
"""Returns the member variables corresponding to the parsed arguments.
Note: This does not include attributes set directly on an instance
Expand Down Expand Up @@ -596,7 +599,7 @@ def as_dict(self) -> Dict[str, Any]:

return stored_dict

def from_dict(self, args_dict: Dict[str, Any], skip_unsettable: bool = False) -> TapType:
def from_dict(self, args_dict: dict[str, Any], skip_unsettable: bool = False) -> TapType:
"""Loads arguments from a dictionary, ensuring all required arguments are set.
:param args_dict: A dictionary from argument names to the values of the arguments.
Expand Down Expand Up @@ -682,7 +685,7 @@ def load(

return self

def _load_from_config_files(self, config_files: Optional[List[str]]) -> List[str]:
def _load_from_config_files(self, config_files: Optional[list[str]]) -> list[str]:
"""Loads arguments from a list of configuration files containing command line arguments.
:param config_files: A list of paths to configuration files containing the command line arguments
Expand All @@ -708,7 +711,7 @@ def __str__(self) -> str:
"""
return pformat(self.as_dict())

def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:
def __deepcopy__(self, memo: dict[int, Any] = None) -> TapType:
"""Deepcopy the Tap object."""
copied = type(self).__new__(type(self))

Expand All @@ -722,11 +725,11 @@ def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:

return copied

def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
"""Gets the state of the object for pickling."""
return self.as_dict()

def __setstate__(self, d: Dict[str, Any]) -> None:
def __setstate__(self, d: dict[str, Any]) -> None:
"""
Initializes the object with the provided dictionary of arguments for unpickling.
Expand Down
20 changes: 10 additions & 10 deletions src/tap/tapify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import dataclasses
import inspect
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union

from docstring_parser import Docstring, parse
from packaging.version import Version
Expand Down Expand Up @@ -68,7 +68,7 @@ class _TapData:
Data about a class' or function's arguments which are sufficient to inform a Tap class.
"""

args_data: List[_ArgData]
args_data: list[_ArgData]
"List of data about each argument in the class or function"

has_kwargs: bool
Expand All @@ -79,7 +79,7 @@ class _TapData:


def _is_pydantic_base_model(obj: Union[Type[Any], Any]) -> bool:
if inspect.isclass(obj): # issublcass requires that obj is a class
if inspect.isclass(obj): # issubclass requires that obj is a class
return issubclass(obj, BaseModel)
else:
return isinstance(obj, BaseModel)
Expand All @@ -94,7 +94,7 @@ def _is_pydantic_dataclass(obj: Union[Type[Any], Any]) -> bool:


def _tap_data_from_data_model(
data_model: Any, func_kwargs: Dict[str, Any], param_to_description: Dict[str, str] = None
data_model: Any, func_kwargs: dict[str, Any], param_to_description: dict[str, str] = None
) -> _TapData:
"""
Currently only works when `data_model` is a:
Expand Down Expand Up @@ -153,7 +153,7 @@ def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optiona
# dataclass fields in a pydantic BaseModel. It's also possible to use (builtin) dataclass fields and pydantic Fields
# in the same data model. Therefore, the type of the data model doesn't determine the type of each field. The
# solution is to iterate through the fields and check each type.
args_data: List[_ArgData] = []
args_data: list[_ArgData] = []
for name, field in name_to_field.items():
if isinstance(field, dataclasses.Field):
# Idiosyncrasy: if a pydantic Field is used in a pydantic dataclass, then field.default is a FieldInfo
Expand All @@ -177,7 +177,7 @@ def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optiona


def _tap_data_from_class_or_function(
class_or_function: _ClassOrFunction, func_kwargs: Dict[str, Any], param_to_description: Dict[str, str]
class_or_function: _ClassOrFunction, func_kwargs: dict[str, Any], param_to_description: dict[str, str]
) -> _TapData:
"""
Extract data by inspecting the signature of `class_or_function`.
Expand All @@ -186,7 +186,7 @@ def _tap_data_from_class_or_function(
----
Deletes redundant keys from `func_kwargs`
"""
args_data: List[_ArgData] = []
args_data: list[_ArgData] = []
has_kwargs = False
known_only = False

Expand Down Expand Up @@ -240,7 +240,7 @@ def _docstring(class_or_function) -> Docstring:
return parse(doc)


def _tap_data(class_or_function: _ClassOrFunction, param_to_description: Dict[str, str], func_kwargs) -> _TapData:
def _tap_data(class_or_function: _ClassOrFunction, param_to_description: dict[str, str], func_kwargs) -> _TapData:
"""
Controls how :class:`_TapData` is extracted from `class_or_function`.
"""
Expand Down Expand Up @@ -298,7 +298,7 @@ def to_tap_class(class_or_function: _ClassOrFunction) -> Type[Tap]:
def tapify(
class_or_function: Union[Callable[[InputType], OutputType], Type[OutputType]],
known_only: bool = False,
command_line_args: Optional[List[str]] = None,
command_line_args: Optional[list[str]] = None,
explicit_bool: bool = False,
description: Optional[str] = None,
**func_kwargs,
Expand Down Expand Up @@ -339,7 +339,7 @@ def tapify(

# Prepare command line arguments for class_or_function, respecting positional-only args
class_or_function_args: list[Any] = []
class_or_function_kwargs: Dict[str, Any] = {}
class_or_function_kwargs: dict[str, Any] = {}
command_line_args_dict = command_line_args.as_dict()
for arg_data in tap_data.args_data:
arg_value = command_line_args_dict[arg_data.name]
Expand Down
22 changes: 8 additions & 14 deletions src/tap/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from argparse import ArgumentParser, ArgumentTypeError
import ast
from base64 import b64encode, b64decode
import copy
from functools import wraps
import inspect
from io import StringIO
from json import JSONEncoder
Expand All @@ -16,15 +14,11 @@
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
Expand All @@ -38,7 +32,7 @@
PathLike = Union[str, os.PathLike]


def check_output(command: List[str], suppress_stderr: bool = True, **kwargs) -> str:
def check_output(command: list[str], suppress_stderr: bool = True, **kwargs) -> str:
"""Runs subprocess.check_output and returns the result as a string.
:param command: A list of strings representing the command to run on the command line.
Expand Down Expand Up @@ -225,7 +219,7 @@ def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
raise ValueError("Could not find any class variables in the class.")


def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> dict[int, list[dict[str, Union[str, int]]]]:
"""Extract a map from each line number to list of mappings providing information about each token."""
line_to_tokens = {}
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
Expand All @@ -244,7 +238,7 @@ def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, Lis
return line_to_tokens


def get_subsequent_assign_lines(source_cls: str) -> Tuple[Set[int], Set[int]]:
def get_subsequent_assign_lines(source_cls: str) -> tuple[set[int], set[int]]:
"""For all multiline assign statements, get the line numbers after the first line in the assignment.
:param source_cls: The source code of the class.
Expand Down Expand Up @@ -301,7 +295,7 @@ def get_subsequent_assign_lines(source_cls: str) -> Tuple[Set[int], Set[int]]:
return intermediate_assign_lines, final_assign_lines


def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
def get_class_variables(cls: type) -> dict[str, dict[str, str]]:
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
# Get the source code and tokens of the class
source_cls = inspect.getsource(cls)
Expand Down Expand Up @@ -387,7 +381,7 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
return variable_to_comment


def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[type]]:
def get_literals(literal: Literal, variable: str) -> tuple[Callable[[str], Any], list[type]]:
"""Extracts the values from a Literal type and ensures that the values are all primitive types."""
literals = list(get_args(literal))

Expand Down Expand Up @@ -424,7 +418,7 @@ def boolean_type(flag_value: str) -> bool:
class TupleTypeEnforcer:
"""The type argument to argparse for checking and applying types to Tuples."""

def __init__(self, types: List[type], loop: bool = False):
def __init__(self, types: list[type], loop: bool = False):
self.types = [boolean_type if t == bool else t for t in types]
self.loop = loop
self.index = 0
Expand Down Expand Up @@ -545,7 +539,7 @@ def as_python_object(dct: Any) -> Any:


def enforce_reproducibility(
saved_reproducibility_data: Optional[Dict[str, str]], current_reproducibility_data: Dict[str, str], path: PathLike
saved_reproducibility_data: Optional[dict[str, str]], current_reproducibility_data: dict[str, str], path: PathLike
) -> None:
"""Checks if reproducibility has failed and raises the appropriate error.
Expand Down Expand Up @@ -597,7 +591,7 @@ def get_origin(tp: Any) -> Any:


# TODO: remove this once typing_inspect.get_args is fixed for Python 3.10 union types
def get_args(tp: Any) -> Tuple[type, ...]:
def get_args(tp: Any) -> tuple[type, ...]:
"""Same as typing_inspect.get_args but fixes Python 3.10 union types."""
if sys.version_info >= (3, 10) and isinstance(tp, UnionType):
return tp.__args__
Expand Down
1 change: 0 additions & 1 deletion tests/test_actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from typing import List, Literal
import unittest
from unittest import TestCase
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def write(self, msg):
sys.stderr = DevNull()


def stringify(arg_list: Iterable[Any]) -> List[str]:
def stringify(arg_list: Iterable[Any]) -> list[str]:
"""Converts an iterable of arguments of any type to a list of strings.
:param arg_list: An iterable of arguments of any type.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_tapify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
import io
import sys
from typing import Dict, List, Optional, Tuple, Any
from typing import List, Optional, Tuple, Any
import unittest
from unittest import TestCase

Expand Down Expand Up @@ -237,21 +237,21 @@ def __eq__(self, other: str) -> bool:
self.assertEqual(output, "1 simple 3.14 -0.3 True wee")

def test_tapify_complex_types(self):
def concat(complexity: List[str], requires: Tuple[int, int], intelligence: Person) -> str:
def concat(complexity: list[str], requires: tuple[int, int], intelligence: Person) -> str:
return f'{" ".join(complexity)} {requires[0]} {requires[1]} {intelligence}'

def concat_with_positionals(complexity: List[str], /, requires: Tuple[int, int], intelligence: Person) -> str:
def concat_with_positionals(complexity: list[str], /, requires: tuple[int, int], intelligence: Person) -> str:
return f'{" ".join(complexity)} {requires[0]} {requires[1]} {intelligence}'

class Concat:
def __init__(self, complexity: List[str], requires: Tuple[int, int], intelligence: Person):
def __init__(self, complexity: list[str], requires: tuple[int, int], intelligence: Person):
self.kwargs = {"complexity": complexity, "requires": requires, "intelligence": intelligence}

def __eq__(self, other: str) -> bool:
return other == concat(**self.kwargs)

class ConcatWithPositionals:
def __init__(self, complexity: List[str], /, requires: Tuple[int, int], intelligence: Person):
def __init__(self, complexity: list[str], /, requires: tuple[int, int], intelligence: Person):
self.kwargs = {"complexity": complexity, "requires": requires, "intelligence": intelligence}

def __eq__(self, other: str) -> bool:
Expand Down Expand Up @@ -1468,7 +1468,7 @@ def __eq__(self, other: str) -> bool:
pydantic_data_models = []

class Concat:
def __init__(self, a: int, b: int = 2, **kwargs: Dict[str, str]):
def __init__(self, a: int, b: int = 2, **kwargs: dict[str, str]):
"""Concatenate three numbers.
:param a: The first number.
Expand All @@ -1482,7 +1482,7 @@ def __eq__(self, other: str) -> bool:
return other == concat(self.a, self.b, **self.kwargs)

class ConcatWithPositionals:
def __init__(self, a: int, /, b: int = 2, **kwargs: Dict[str, str]):
def __init__(self, a: int, /, b: int = 2, **kwargs: dict[str, str]):
"""Concatenate three numbers.
:param a: The first number.
Expand Down
Loading

0 comments on commit da7b4de

Please sign in to comment.