diff --git a/.github/weaviate-compose.yml b/.github/weaviate-compose.yml index 8bbedb7b23..8d715c758f 100644 --- a/.github/weaviate-compose.yml +++ b/.github/weaviate-compose.yml @@ -11,8 +11,6 @@ services: image: semitechnologies/weaviate:1.21.1 ports: - 8080:8080 - volumes: - - weaviate_data restart: on-failure:0 environment: QUERY_DEFAULTS_LIMIT: 25 diff --git a/.github/workflows/test_destination_clickhouse.yml b/.github/workflows/test_destination_clickhouse.yml index d834df6b28..5b6848f2fe 100644 --- a/.github/workflows/test_destination_clickhouse.yml +++ b/.github/workflows/test_destination_clickhouse.yml @@ -1,4 +1,3 @@ - name: test | clickhouse on: @@ -8,7 +7,7 @@ on: - devel workflow_dispatch: schedule: - - cron: '0 2 * * *' + - cron: '0 2 * * *' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -20,7 +19,7 @@ env: DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} ACTIVE_DESTINATIONS: "[\"clickhouse\"]" - ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" jobs: get_docs_changes: @@ -67,12 +66,51 @@ jobs: - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + # OSS ClickHouse + - run: | + docker compose -f "tests/load/clickhouse/clickhouse-compose.yml" up -d + echo "Waiting for ClickHouse to be healthy..." + timeout 30s bash -c 'until docker compose -f "tests/load/clickhouse/clickhouse-compose.yml" ps | grep -q "healthy"; do sleep 1; done' + echo "ClickHouse is up and running" + name: Start ClickHouse OSS + + + - run: poetry run pytest tests/load -m "essential" + name: Run essential tests Linux (ClickHouse OSS) + if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} + env: + DESTINATION__CLICKHOUSE__CREDENTIALS__HOST: localhost + DESTINATION__CLICKHOUSE__CREDENTIALS__DATABASE: dlt_data + DESTINATION__CLICKHOUSE__CREDENTIALS__USERNAME: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PASSWORD: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PORT: 9000 + DESTINATION__CLICKHOUSE__CREDENTIALS__HTTP_PORT: 8123 + DESTINATION__CLICKHOUSE__CREDENTIALS__SECURE: 0 + + - run: poetry run pytest tests/load + name: Run all tests Linux (ClickHouse OSS) + if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} + env: + DESTINATION__CLICKHOUSE__CREDENTIALS__HOST: localhost + DESTINATION__CLICKHOUSE__CREDENTIALS__DATABASE: dlt_data + DESTINATION__CLICKHOUSE__CREDENTIALS__USERNAME: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PASSWORD: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PORT: 9000 + DESTINATION__CLICKHOUSE__CREDENTIALS__HTTP_PORT: 8123 + DESTINATION__CLICKHOUSE__CREDENTIALS__SECURE: 0 + + - name: Stop ClickHouse OSS + if: always() + run: docker compose -f "tests/load/clickhouse/clickhouse-compose.yml" down -v + + # ClickHouse Cloud - run: | poetry run pytest tests/load -m "essential" - name: Run essential tests Linux + name: Run essential tests Linux (ClickHouse Cloud) if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | poetry run pytest tests/load - name: Run all tests Linux + name: Run all tests Linux (ClickHouse Cloud) if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} + diff --git a/.github/workflows/test_destination_dremio.yml b/.github/workflows/test_destination_dremio.yml index 1b47268b59..7ec6c4f697 100644 --- a/.github/workflows/test_destination_dremio.yml +++ b/.github/workflows/test_destination_dremio.yml @@ -43,7 +43,7 @@ jobs: uses: actions/checkout@master - name: Start dremio - run: docker-compose -f "tests/load/dremio/docker-compose.yml" up -d + run: docker compose -f "tests/load/dremio/docker-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 @@ -87,4 +87,4 @@ jobs: - name: Stop dremio if: always() - run: docker-compose -f "tests/load/dremio/docker-compose.yml" down -v + run: docker compose -f "tests/load/dremio/docker-compose.yml" down -v diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index b140935d4c..6094f2c0ac 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -60,7 +60,7 @@ jobs: uses: actions/checkout@master - name: Start weaviate - run: docker-compose -f ".github/weaviate-compose.yml" up -d + run: docker compose -f ".github/weaviate-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index f1bf6016bc..78ea23ec1c 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -73,7 +73,7 @@ jobs: uses: actions/checkout@master - name: Start weaviate - run: docker-compose -f ".github/weaviate-compose.yml" up -d + run: docker compose -f ".github/weaviate-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 @@ -109,4 +109,4 @@ jobs: - name: Stop weaviate if: always() - run: docker-compose -f ".github/weaviate-compose.yml" down -v + run: docker compose -f ".github/weaviate-compose.yml" down -v diff --git a/.github/workflows/test_pyarrow17.yml b/.github/workflows/test_pyarrow17.yml new file mode 100644 index 0000000000..dd48c2af9d --- /dev/null +++ b/.github/workflows/test_pyarrow17.yml @@ -0,0 +1,77 @@ + +name: tests marked as needspyarrow17 + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + schedule: + - cron: '0 2 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + + DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + + ACTIVE_DESTINATIONS: "[\"filesystem\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} + + run_pyarrow17: + name: needspyarrow17 tests + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + steps: + + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-pyarrow17 + + - name: Install dependencies + run: poetry install --no-interaction --with sentry-sdk --with pipeline -E deltalake -E gs -E s3 -E az + + - name: Upgrade pyarrow + run: poetry run pip install pyarrow==17.0.0 + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + - run: | + poetry run pytest tests/libs tests/load -m needspyarrow17 + name: Run needspyarrow17 tests Linux diff --git a/deploy/dlt/Dockerfile b/deploy/dlt/Dockerfile index f3d4f9d707..3f9f6a2341 100644 --- a/deploy/dlt/Dockerfile +++ b/deploy/dlt/Dockerfile @@ -31,7 +31,7 @@ RUN apk update &&\ # add build labels and envs ARG COMMIT_SHA="" ARG IMAGE_VERSION="" -LABEL commit_sha = ${COMMIT_SHA} +LABEL commit_sha=${COMMIT_SHA} LABEL version=${IMAGE_VERSION} ENV COMMIT_SHA=${COMMIT_SHA} ENV IMAGE_VERSION=${IMAGE_VERSION} diff --git a/deploy/dlt/Dockerfile.airflow b/deploy/dlt/Dockerfile.airflow index 43adf5ea95..620b72da0e 100644 --- a/deploy/dlt/Dockerfile.airflow +++ b/deploy/dlt/Dockerfile.airflow @@ -14,7 +14,7 @@ WORKDIR /tmp/pydlt # add build labels and envs ARG COMMIT_SHA="" ARG IMAGE_VERSION="" -LABEL commit_sha = ${COMMIT_SHA} +LABEL commit_sha=${COMMIT_SHA} LABEL version=${IMAGE_VERSION} ENV COMMIT_SHA=${COMMIT_SHA} ENV IMAGE_VERSION=${IMAGE_VERSION} diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index 522b3a6712..a1434133f0 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -1,6 +1,7 @@ import os import ast import shutil +import tomlkit from types import ModuleType from typing import Dict, List, Sequence, Tuple from importlib.metadata import version as pkg_version @@ -488,16 +489,19 @@ def init_command( # generate tomls with comments secrets_prov = SecretsTomlProvider() - # print(secrets_prov._toml) - write_values(secrets_prov._toml, required_secrets.values(), overwrite_existing=False) + secrets_toml = tomlkit.document() + write_values(secrets_toml, required_secrets.values(), overwrite_existing=False) + secrets_prov._config_doc = secrets_toml + config_prov = ConfigTomlProvider() - write_values(config_prov._toml, required_config.values(), overwrite_existing=False) + config_toml = tomlkit.document() + write_values(config_toml, required_config.values(), overwrite_existing=False) + config_prov._config_doc = config_toml + # write toml files secrets_prov.write_toml() config_prov.write_toml() - # telemetry_status_command() - # if there's no dependency system write the requirements file if dependency_system is None: requirements_txt = "\n".join(source_files.requirements.compiled()) diff --git a/dlt/cli/telemetry_command.py b/dlt/cli/telemetry_command.py index 3285c7cdf2..45e9c270f9 100644 --- a/dlt/cli/telemetry_command.py +++ b/dlt/cli/telemetry_command.py @@ -1,6 +1,6 @@ import os +import tomlkit -from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.providers.toml import ConfigTomlProvider from dlt.common.configuration.specs import RunConfiguration @@ -29,15 +29,21 @@ def change_telemetry_status_command(enabled: bool) -> None: ] # write local config config = ConfigTomlProvider(add_global_config=False) + config_toml = tomlkit.document() if not config.is_empty: - write_values(config._toml, telemetry_value, overwrite_existing=True) + write_values(config_toml, telemetry_value, overwrite_existing=True) + config._config_doc = config_toml config.write_toml() + # write global config global_path = ConfigTomlProvider.global_config_path() os.makedirs(global_path, exist_ok=True) config = ConfigTomlProvider(project_dir=global_path, add_global_config=False) - write_values(config._toml, telemetry_value, overwrite_existing=True) + config_toml = tomlkit.document() + write_values(config_toml, telemetry_value, overwrite_existing=True) + config._config_doc = config_toml config.write_toml() + if enabled: fmt.echo("Telemetry switched %s" % fmt.bold("ON")) else: diff --git a/dlt/common/configuration/accessors.py b/dlt/common/configuration/accessors.py index 1b32ae96f4..733a4b3016 100644 --- a/dlt/common/configuration/accessors.py +++ b/dlt/common/configuration/accessors.py @@ -80,6 +80,13 @@ def _get_value(self, field: str, type_hint: Type[Any] = None) -> Tuple[Any, List break return value, traces + @staticmethod + def register_provider(provider: ConfigProvider) -> None: + """Registers `provider` to participate in the configuration resolution. `provider` + is added after all existing providers and will be used if all others do not resolve. + """ + Container()[ConfigProvidersContext].add_provider(provider) + class _ConfigAccessor(_Accessor): """Provides direct access to configured values that are not secrets.""" diff --git a/dlt/common/configuration/providers/__init__.py b/dlt/common/configuration/providers/__init__.py index 3f5bc20cdc..7338b82b7c 100644 --- a/dlt/common/configuration/providers/__init__.py +++ b/dlt/common/configuration/providers/__init__.py @@ -4,12 +4,13 @@ from .toml import ( SecretsTomlProvider, ConfigTomlProvider, - TomlFileProvider, + ProjectDocProvider, CONFIG_TOML, SECRETS_TOML, StringTomlProvider, - SECRETS_TOML_KEY, + CustomLoaderDocProvider, ) +from .vault import SECRETS_TOML_KEY from .google_secrets import GoogleSecretsProvider from .context import ContextProvider @@ -19,11 +20,12 @@ "DictionaryProvider", "SecretsTomlProvider", "ConfigTomlProvider", - "TomlFileProvider", + "ProjectDocProvider", "CONFIG_TOML", "SECRETS_TOML", "StringTomlProvider", "SECRETS_TOML_KEY", "GoogleSecretsProvider", "ContextProvider", + "CustomLoaderDocProvider", ] diff --git a/dlt/common/configuration/providers/airflow.py b/dlt/common/configuration/providers/airflow.py index edd02c3487..113593b4da 100644 --- a/dlt/common/configuration/providers/airflow.py +++ b/dlt/common/configuration/providers/airflow.py @@ -1,10 +1,10 @@ import io import contextlib -from .toml import VaultTomlProvider +from .vault import VaultDocProvider -class AirflowSecretsTomlProvider(VaultTomlProvider): +class AirflowSecretsTomlProvider(VaultDocProvider): def __init__(self, only_secrets: bool = False, only_toml_fragments: bool = False) -> None: super().__init__(only_secrets, only_toml_fragments) diff --git a/dlt/common/configuration/providers/dictionary.py b/dlt/common/configuration/providers/dictionary.py index dffe5f0c71..5358d80be3 100644 --- a/dlt/common/configuration/providers/dictionary.py +++ b/dlt/common/configuration/providers/dictionary.py @@ -1,38 +1,26 @@ from contextlib import contextmanager -from typing import Any, ClassVar, Iterator, Optional, Type, Tuple +from typing import ClassVar, Iterator -from dlt.common.typing import StrAny +from dlt.common.typing import DictStrAny -from .provider import ConfigProvider, get_key_name +from .provider import get_key_name +from .toml import BaseDocProvider -class DictionaryProvider(ConfigProvider): +class DictionaryProvider(BaseDocProvider): NAME: ClassVar[str] = "Dictionary Provider" def __init__(self) -> None: - self._values: StrAny = {} + super().__init__({}) + + @staticmethod + def get_key_name(key: str, *sections: str) -> str: + return get_key_name(key, "__", *sections) @property def name(self) -> str: return self.NAME - def get_value( - self, key: str, hint: Type[Any], pipeline_name: str, *sections: str - ) -> Tuple[Optional[Any], str]: - full_path = sections + (key,) - if pipeline_name: - full_path = (pipeline_name,) + full_path - full_key = get_key_name(key, "__", pipeline_name, *sections) - node = self._values - try: - for k in full_path: - if not isinstance(node, dict): - raise KeyError(k) - node = node[k] - return node, full_key - except KeyError: - return None, full_key - @property def supports_secrets(self) -> bool: return True @@ -42,8 +30,8 @@ def supports_sections(self) -> bool: return True @contextmanager - def values(self, v: StrAny) -> Iterator[None]: - p_values = self._values - self._values = v + def values(self, v: DictStrAny) -> Iterator[None]: + p_values = self._config_doc + self._config_doc = v yield - self._values = p_values + self._config_doc = p_values diff --git a/dlt/common/configuration/providers/google_secrets.py b/dlt/common/configuration/providers/google_secrets.py index 43a284c67c..55cc35e02c 100644 --- a/dlt/common/configuration/providers/google_secrets.py +++ b/dlt/common/configuration/providers/google_secrets.py @@ -5,7 +5,7 @@ from dlt.common.json import json from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.exceptions import MissingDependencyException -from .toml import VaultTomlProvider +from .vault import VaultDocProvider from .provider import get_key_name # Create a translation table to replace punctuation with "" @@ -33,7 +33,7 @@ def normalize_key(in_string: str) -> str: return stripped_whitespace -class GoogleSecretsProvider(VaultTomlProvider): +class GoogleSecretsProvider(VaultDocProvider): def __init__( self, credentials: GcpServiceAccountCredentials, diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index 10e0b470de..c13d1f8454 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -1,30 +1,24 @@ import os -import abc import tomlkit -import contextlib +import yaml +import functools from tomlkit.items import Item as TOMLItem from tomlkit.container import Container as TOMLContainer -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type -from dlt.common.pendulum import pendulum from dlt.common.configuration.paths import get_dlt_settings_dir, get_dlt_data_dir -from dlt.common.configuration.utils import auto_cast -from dlt.common.configuration.specs import known_sections -from dlt.common.configuration.specs.base_configuration import is_secret_hint +from dlt.common.configuration.utils import auto_cast, auto_config_fragment from dlt.common.utils import update_dict_nested -from dlt.common.typing import AnyType - from .provider import ConfigProvider, ConfigProviderException, get_key_name CONFIG_TOML = "config.toml" SECRETS_TOML = "secrets.toml" -SECRETS_TOML_KEY = "dlt_secrets_toml" -class BaseTomlProvider(ConfigProvider): - def __init__(self, toml_document: TOMLContainer) -> None: - self._toml = toml_document +class BaseDocProvider(ConfigProvider): + def __init__(self, config_doc: Dict[str, Any]) -> None: + self._config_doc = config_doc @staticmethod def get_key_name(key: str, *sections: str) -> str: @@ -37,49 +31,74 @@ def get_value( if pipeline_name: full_path = (pipeline_name,) + full_path full_key = self.get_key_name(key, pipeline_name, *sections) - node: Union[TOMLContainer, TOMLItem] = self._toml + node = self._config_doc try: for k in full_path: if not isinstance(node, dict): raise KeyError(k) node = node[k] - rv = node.unwrap() if isinstance(node, (TOMLContainer, TOMLItem)) else node - return rv, full_key + return node, full_key except KeyError: return None, full_key - def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) -> None: + def set_value(self, key: str, value: Any, pipeline_name: Optional[str], *sections: str) -> None: + """Sets `value` under `key` in `sections` and optionally for `pipeline_name` + + If key already has value of type dict and value to set is also of type dict, the new value + is merged with old value. + """ if pipeline_name: sections = (pipeline_name,) + sections + if key is None: + raise ValueError("dlt_secrets_toml must contain toml document") + + master: Dict[str, Any] + # descend from root, create tables if necessary + master = self._config_doc + for k in sections: + if not isinstance(master, dict): + raise KeyError(k) + if k not in master: + master[k] = {} + master = master[k] + if isinstance(value, dict): + # remove none values, TODO: we need recursive None removal + value = {k: v for k, v in value.items() if v is not None} + # if target is also dict then merge recursively + if isinstance(master.get(key), dict): + update_dict_nested(master[key], value) + return + master[key] = value + + def set_fragment( + self, key: Optional[str], value_or_fragment: str, pipeline_name: str, *sections: str + ) -> None: + """Tries to interpret `value_or_fragment` as a fragment of toml, yaml or json string and replace/merge into config doc. - if isinstance(value, TOMLContainer): + If `key` is not provided, fragment is considered a full document and will replace internal config doc. Otherwise + fragment is merged with config doc from the root element and not from the element under `key`! + + For simple values it falls back to `set_value` method. + """ + fragment = auto_config_fragment(value_or_fragment) + if fragment is not None: + # always update the top document if key is None: - self._toml = value + self._config_doc = fragment else: - # always update the top document # TODO: verify that value contains only the elements under key - update_dict_nested(self._toml, value) + update_dict_nested(self._config_doc, fragment) else: - if key is None: - raise ValueError("dlt_secrets_toml must contain toml document") + # set value using auto_cast + self.set_value(key, auto_cast(value_or_fragment), pipeline_name, *sections) - master: TOMLContainer - # descend from root, create tables if necessary - master = self._toml - for k in sections: - if not isinstance(master, dict): - raise KeyError(k) - if k not in master: - master[k] = tomlkit.table() - master = master[k] # type: ignore - if isinstance(value, dict): - # remove none values, TODO: we need recursive None removal - value = {k: v for k, v in value.items() if v is not None} - # if target is also dict then merge recursively - if isinstance(master.get(key), dict): - update_dict_nested(master[key], value) # type: ignore - return - master[key] = value + def to_toml(self) -> str: + return tomlkit.dumps(self._config_doc) + + def to_yaml(self) -> str: + return yaml.dump( + self._config_doc, allow_unicode=True, default_flow_style=False, sort_keys=False + ) @property def supports_sections(self) -> bool: @@ -87,18 +106,18 @@ def supports_sections(self) -> bool: @property def is_empty(self) -> bool: - return len(self._toml.body) == 0 + return len(self._config_doc) == 0 -class StringTomlProvider(BaseTomlProvider): +class StringTomlProvider(BaseDocProvider): def __init__(self, toml_string: str) -> None: - super().__init__(StringTomlProvider.loads(toml_string)) + super().__init__(StringTomlProvider.loads(toml_string).unwrap()) - def update(self, toml_string: str) -> None: - self._toml = self.loads(toml_string) + # def update(self, toml_string: str) -> None: + # self._config_doc = StringTomlProvider.loads(toml_string).unwrap() def dumps(self) -> str: - return tomlkit.dumps(self._toml) + return tomlkit.dumps(self._config_doc) @staticmethod def loads(toml_string: str) -> tomlkit.TOMLDocument: @@ -113,124 +132,49 @@ def name(self) -> str: return "memory" -class VaultTomlProvider(BaseTomlProvider): - """A toml-backed Vault abstract config provider. - - This provider allows implementation of providers that store secrets in external vaults: like Hashicorp, Google Secrets or Airflow Metadata. - The basic working principle is obtain config and secrets values from Vault keys and reconstitute a `secrets.toml` like document that is then used - as a cache. - - The implemented must provide `_look_vault` method that returns a value from external vault from external key. - - To reduce number of calls to external vaults the provider is searching for a known configuration fragments which should be toml documents and merging - them with the - - only keys with secret type hint (CredentialsConfiguration, TSecretValue) will be looked up by default. - - provider gathers `toml` document fragments that contain source and destination credentials in path specified below - - single values will not be retrieved, only toml fragments by default - - """ - - def __init__(self, only_secrets: bool, only_toml_fragments: bool) -> None: - """Initializes the toml backed Vault provider by loading a toml fragment from `dlt_secrets_toml` key and using it as initial configuration. +class CustomLoaderDocProvider(BaseDocProvider): + def __init__( + self, name: str, loader: Callable[[], Dict[str, Any]], supports_secrets: bool = True + ) -> None: + """Provider that calls `loader` function to get a Python dict with config/secret values to be queried. + The `loader` function typically loads a string (ie. from file), parses it (ie. as toml or yaml), does additional + processing and returns a Python dict to be queried. - _extended_summary_ + Instance of CustomLoaderDocProvider must be registered for the returned dict to be used to resolve config values. + >>> import dlt + >>> dlt.config.register_provider(provider) Args: - only_secrets (bool): Only looks for secret values (CredentialsConfiguration, TSecretValue) by returning None (not found) - only_toml_fragments (bool): Only load the known toml fragments and ignore any other lookups by returning None (not found) - """ - self.only_secrets = only_secrets - self.only_toml_fragments = only_toml_fragments - self._vault_lookups: Dict[str, pendulum.DateTime] = {} + name(str): name of the provider that will be visible ie. in exceptions + loader(Callable[[], Dict[str, Any]]): user-supplied function that will load the document with config/secret values + supports_secrets(bool): allows to store secret values in this provider - super().__init__(tomlkit.document()) - self._update_from_vault(SECRETS_TOML_KEY, None, AnyType, None, ()) - - def get_value( - self, key: str, hint: type, pipeline_name: str, *sections: str - ) -> Tuple[Optional[Any], str]: - full_key = self.get_key_name(key, pipeline_name, *sections) + """ + self._name = name + self._supports_secrets = supports_secrets + super().__init__(loader()) - value, _ = super().get_value(key, hint, pipeline_name, *sections) - if value is None: - # only secrets hints are handled - if self.only_secrets and not is_secret_hint(hint) and hint is not AnyType: - return None, full_key - - if pipeline_name: - # loads dlt_secrets_toml for particular pipeline - lookup_fk = self.get_key_name(SECRETS_TOML_KEY, pipeline_name) - self._update_from_vault(lookup_fk, "", AnyType, pipeline_name, ()) - - # generate auxiliary paths to get from vault - for known_section in [known_sections.SOURCES, known_sections.DESTINATION]: - - def _look_at_idx(idx: int, full_path: Tuple[str, ...], pipeline_name: str) -> None: - lookup_key = full_path[idx] - lookup_sections = full_path[:idx] - lookup_fk = self.get_key_name(lookup_key, *lookup_sections) - self._update_from_vault( - lookup_fk, lookup_key, AnyType, pipeline_name, lookup_sections - ) - - def _lookup_paths(pipeline_name_: str, known_section_: str) -> None: - with contextlib.suppress(ValueError): - full_path = sections + (key,) - if pipeline_name_: - full_path = (pipeline_name_,) + full_path - idx = full_path.index(known_section_) - _look_at_idx(idx, full_path, pipeline_name_) - # if there's element after index then also try it (destination name / source name) - if len(full_path) - 1 > idx: - _look_at_idx(idx + 1, full_path, pipeline_name_) - - # first query the shortest paths so the longer paths can override it - _lookup_paths(None, known_section) # check sources and sources. - if pipeline_name: - _lookup_paths( - pipeline_name, known_section - ) # check .sources and .sources. - - value, _ = super().get_value(key, hint, pipeline_name, *sections) - # skip checking the exact path if we check only toml fragments - if value is None and not self.only_toml_fragments: - # look for key in the vault and update the toml document - self._update_from_vault(full_key, key, hint, pipeline_name, sections) - value, _ = super().get_value(key, hint, pipeline_name, *sections) - - # if value: - # print(f"GSM got value for {key} {pipeline_name}-{sections}") - # else: - # print(f"GSM FAILED value for {key} {pipeline_name}-{sections}") - return value, full_key + @property + def name(self) -> str: + return self._name @property def supports_secrets(self) -> bool: - return True - - @abc.abstractmethod - def _look_vault(self, full_key: str, hint: type) -> str: - pass - - def _update_from_vault( - self, full_key: str, key: str, hint: type, pipeline_name: str, sections: Tuple[str, ...] - ) -> None: - if full_key in self._vault_lookups: - return - # print(f"tries '{key}' {pipeline_name} | {sections} at '{full_key}'") - secret = self._look_vault(full_key, hint) - self._vault_lookups[full_key] = pendulum.now() - if secret is not None: - self.set_value(key, auto_cast(secret), pipeline_name, *sections) + return self._supports_secrets @property - def is_empty(self) -> bool: - return False + def is_writable(self) -> bool: + return True -class TomlFileProvider(BaseTomlProvider): +class ProjectDocProvider(CustomLoaderDocProvider): def __init__( - self, file_name: str, project_dir: str = None, add_global_config: bool = False + self, + name: str, + supports_secrets: bool, + file_name: str, + project_dir: str = None, + add_global_config: bool = False, ) -> None: """Creates config provider from a `toml` file @@ -240,6 +184,8 @@ def __init__( If none of the files exist, an empty provider is created. Args: + name(str): name of the provider when registering in context + supports_secrets(bool): allows to store secret values in this provider file_name (str): The name of `toml` file to load project_dir (str, optional): The location of `file_name`. If not specified, defaults to $cwd/.dlt add_global_config (bool, optional): Looks for `file_name` in `dlt` home directory which in most cases is $HOME/.dlt @@ -247,23 +193,16 @@ def __init__( Raises: TomlProviderReadException: File could not be read, most probably `toml` parsing error """ - toml_document = self._read_toml_file(file_name, project_dir, add_global_config) - super().__init__(toml_document) - - def _read_toml_file( - self, file_name: str, project_dir: str = None, add_global_config: bool = False - ) -> tomlkit.TOMLDocument: - self._file_name = file_name self._toml_path = os.path.join(project_dir or get_dlt_settings_dir(), file_name) self._add_global_config = add_global_config - try: - project_toml = self._read_toml(self._toml_path) - if add_global_config: - global_toml = self._read_toml(os.path.join(self.global_config_path(), file_name)) - project_toml = update_dict_nested(global_toml, project_toml) - return project_toml - except Exception as ex: - raise TomlProviderReadException(self.name, file_name, self._toml_path, str(ex)) + + super().__init__( + name, + functools.partial( + self._read_toml_files, name, file_name, self._toml_path, add_global_config + ), + supports_secrets, + ) @staticmethod def global_config_path() -> str: @@ -274,7 +213,22 @@ def write_toml(self) -> None: not self._add_global_config ), "Will not write configs when `add_global_config` flag was set" with open(self._toml_path, "w", encoding="utf-8") as f: - tomlkit.dump(self._toml, f) + tomlkit.dump(self._config_doc, f) + + @staticmethod + def _read_toml_files( + name: str, file_name: str, toml_path: str, add_global_config: bool + ) -> Dict[str, Any]: + try: + project_toml = ProjectDocProvider._read_toml(toml_path).unwrap() + if add_global_config: + global_toml = ProjectDocProvider._read_toml( + os.path.join(ProjectDocProvider.global_config_path(), file_name) + ).unwrap() + project_toml = update_dict_nested(global_toml, project_toml) + return project_toml + except Exception as ex: + raise TomlProviderReadException(name, file_name, toml_path, str(ex)) @staticmethod def _read_toml(toml_path: str) -> tomlkit.TOMLDocument: @@ -286,34 +240,30 @@ def _read_toml(toml_path: str) -> tomlkit.TOMLDocument: return tomlkit.document() -class ConfigTomlProvider(TomlFileProvider): +class ConfigTomlProvider(ProjectDocProvider): def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: - super().__init__(CONFIG_TOML, project_dir=project_dir, add_global_config=add_global_config) - - @property - def name(self) -> str: - return CONFIG_TOML - - @property - def supports_secrets(self) -> bool: - return False + super().__init__( + CONFIG_TOML, + False, + CONFIG_TOML, + project_dir=project_dir, + add_global_config=add_global_config, + ) @property def is_writable(self) -> bool: return True -class SecretsTomlProvider(TomlFileProvider): +class SecretsTomlProvider(ProjectDocProvider): def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: - super().__init__(SECRETS_TOML, project_dir=project_dir, add_global_config=add_global_config) - - @property - def name(self) -> str: - return SECRETS_TOML - - @property - def supports_secrets(self) -> bool: - return True + super().__init__( + SECRETS_TOML, + True, + SECRETS_TOML, + project_dir=project_dir, + add_global_config=add_global_config, + ) @property def is_writable(self) -> bool: diff --git a/dlt/common/configuration/providers/vault.py b/dlt/common/configuration/providers/vault.py new file mode 100644 index 0000000000..0dcaa1b5c4 --- /dev/null +++ b/dlt/common/configuration/providers/vault.py @@ -0,0 +1,126 @@ +import abc +import contextlib +from typing import Any, Dict, Optional, Tuple + +from dlt.common.typing import AnyType +from dlt.common.pendulum import pendulum +from dlt.common.configuration.specs import known_sections +from dlt.common.configuration.specs.base_configuration import is_secret_hint + +from .toml import BaseDocProvider + +SECRETS_TOML_KEY = "dlt_secrets_toml" + + +class VaultDocProvider(BaseDocProvider): + """A toml-backed Vault abstract config provider. + + This provider allows implementation of providers that store secrets in external vaults: like Hashicorp, Google Secrets or Airflow Metadata. + The basic working principle is obtain config and secrets values from Vault keys and reconstitute a `secrets.toml` like document that is then used + as a cache. + + The implemented must provide `_look_vault` method that returns a value from external vault from external key. + + To reduce number of calls to external vaults the provider is searching for a known configuration fragments which should be toml documents and merging + them with the + - only keys with secret type hint (CredentialsConfiguration, TSecretValue) will be looked up by default. + - provider gathers `toml` document fragments that contain source and destination credentials in path specified below + - single values will not be retrieved, only toml fragments by default + + """ + + def __init__(self, only_secrets: bool, only_toml_fragments: bool) -> None: + """Initializes the toml backed Vault provider by loading a toml fragment from `dlt_secrets_toml` key and using it as initial configuration. + + _extended_summary_ + + Args: + only_secrets (bool): Only looks for secret values (CredentialsConfiguration, TSecretValue) by returning None (not found) + only_toml_fragments (bool): Only load the known toml fragments and ignore any other lookups by returning None (not found) + """ + self.only_secrets = only_secrets + self.only_toml_fragments = only_toml_fragments + self._vault_lookups: Dict[str, pendulum.DateTime] = {} + + super().__init__({}) + self._update_from_vault(SECRETS_TOML_KEY, None, AnyType, None, ()) + + def get_value( + self, key: str, hint: type, pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: + full_key = self.get_key_name(key, pipeline_name, *sections) + + value, _ = super().get_value(key, hint, pipeline_name, *sections) + if value is None: + # only secrets hints are handled + if self.only_secrets and not is_secret_hint(hint) and hint is not AnyType: + return None, full_key + + if pipeline_name: + # loads dlt_secrets_toml for particular pipeline + lookup_fk = self.get_key_name(SECRETS_TOML_KEY, pipeline_name) + self._update_from_vault(lookup_fk, "", AnyType, pipeline_name, ()) + + # generate auxiliary paths to get from vault + for known_section in [known_sections.SOURCES, known_sections.DESTINATION]: + + def _look_at_idx(idx: int, full_path: Tuple[str, ...], pipeline_name: str) -> None: + lookup_key = full_path[idx] + lookup_sections = full_path[:idx] + lookup_fk = self.get_key_name(lookup_key, *lookup_sections) + self._update_from_vault( + lookup_fk, lookup_key, AnyType, pipeline_name, lookup_sections + ) + + def _lookup_paths(pipeline_name_: str, known_section_: str) -> None: + with contextlib.suppress(ValueError): + full_path = sections + (key,) + if pipeline_name_: + full_path = (pipeline_name_,) + full_path + idx = full_path.index(known_section_) + _look_at_idx(idx, full_path, pipeline_name_) + # if there's element after index then also try it (destination name / source name) + if len(full_path) - 1 > idx: + _look_at_idx(idx + 1, full_path, pipeline_name_) + + # first query the shortest paths so the longer paths can override it + _lookup_paths(None, known_section) # check sources and sources. + if pipeline_name: + _lookup_paths( + pipeline_name, known_section + ) # check .sources and .sources. + + value, _ = super().get_value(key, hint, pipeline_name, *sections) + # skip checking the exact path if we check only toml fragments + if value is None and not self.only_toml_fragments: + # look for key in the vault and update the toml document + self._update_from_vault(full_key, key, hint, pipeline_name, sections) + value, _ = super().get_value(key, hint, pipeline_name, *sections) + + return value, full_key + + @property + def supports_secrets(self) -> bool: + return True + + def clear_lookup_cache(self) -> None: + self._vault_lookups.clear() + + @abc.abstractmethod + def _look_vault(self, full_key: str, hint: type) -> str: + pass + + def _update_from_vault( + self, full_key: str, key: str, hint: type, pipeline_name: str, sections: Tuple[str, ...] + ) -> None: + if full_key in self._vault_lookups: + return + # print(f"tries '{key}' {pipeline_name} | {sections} at '{full_key}'") + secret = self._look_vault(full_key, hint) + self._vault_lookups[full_key] = pendulum.now() + if secret is not None: + self.set_fragment(key, secret, pipeline_name, *sections) + + @property + def is_empty(self) -> bool: + return False diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index 97803a60e3..dd40d3b775 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -1,5 +1,6 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, cast +from dlt.common.utils import without_none from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue, DictStrAny from dlt.common.configuration.specs import ( @@ -7,7 +8,10 @@ CredentialsWithDefault, configspec, ) -from dlt.common.configuration.specs.exceptions import InvalidBoto3Session +from dlt.common.configuration.specs.exceptions import ( + InvalidBoto3Session, + ObjectStoreRsCredentialsException, +) from dlt import version @@ -47,11 +51,29 @@ def to_session_credentials(self) -> Dict[str, str]: def to_object_store_rs_credentials(self) -> Dict[str, str]: # https://docs.rs/object_store/latest/object_store/aws - assert self.region_name is not None, "`object_store` Rust crate requires AWS region." - creds = self.to_session_credentials() - if creds["aws_session_token"] is None: - creds.pop("aws_session_token") - return {**creds, **{"region": self.region_name}} + creds = cast( + Dict[str, str], + without_none( + dict( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + region=self.region_name, + endpoint_url=self.endpoint_url, + ) + ), + ) + + if "endpoint_url" not in creds: # AWS S3 + if "region" not in creds: + raise ObjectStoreRsCredentialsException( + "`object_store` Rust crate requires AWS region when using AWS S3." + ) + else: # S3-compatible, e.g. MinIO + if self.endpoint_url.startswith("http://"): + creds["aws_allow_http"] = "true" + + return creds @configspec diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 642634fb0a..d77d97cee8 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -146,7 +146,7 @@ def _airflow_providers() -> List[ConfigProvider]: from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider # probe if Airflow variable containing all secrets is present - from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY + from dlt.common.configuration.providers.vault import SECRETS_TOML_KEY secrets_toml_var = Variable.get(SECRETS_TOML_KEY, default_var=None) diff --git a/dlt/common/configuration/specs/exceptions.py b/dlt/common/configuration/specs/exceptions.py index 7a0b283630..928e46a8a0 100644 --- a/dlt/common/configuration/specs/exceptions.py +++ b/dlt/common/configuration/specs/exceptions.py @@ -68,3 +68,7 @@ def __init__(self, spec: Type[Any], native_value: Any): " containing credentials" ) super().__init__(spec, native_value, msg) + + +class ObjectStoreRsCredentialsException(ConfigurationException): + pass diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index a1d82fc577..ca5bd076f1 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -33,12 +33,6 @@ class GcpCredentials(CredentialsConfiguration): project_id: str = None - location: ( - str - ) = ( # DEPRECATED! and present only for backward compatibility. please set bigquery location in BigQuery configuration - "US" - ) - def parse_native_representation(self, native_value: Any) -> None: if not isinstance(native_value, str): raise InvalidGoogleNativeCredentialsType(self.__class__, native_value) diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 74190a87de..bc52241a26 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -14,12 +14,13 @@ get_args, Literal, get_origin, - List, ) from collections.abc import Mapping as C_Mapping +import yaml + from dlt.common.json import json -from dlt.common.typing import AnyType, TAny +from dlt.common.typing import AnyType, DictStrAny, TAny from dlt.common.data_types import coerce_value, py_type_to_sc_type from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.exceptions import ConfigValueCannotBeCoercedException, LookupTrace @@ -118,7 +119,10 @@ def serialize_value(value: Any) -> str: def auto_cast(value: str) -> Any: - # try to cast to bool, int, float and complex (via JSON) + """Parse and cast str `value` to bool, int, float and complex (via JSON) + + F[f]alse and T[t]rue strings are cast to bool values + """ if value.lower() == "true": return True if value.lower() == "false": @@ -132,11 +136,31 @@ def auto_cast(value: str) -> Any: # only lists and dictionaries count if isinstance(c_v, (list, dict)): return c_v - with contextlib.suppress(ValueError): - return tomlkit.parse(value) return value +def auto_config_fragment(value: str) -> Optional[DictStrAny]: + """Tries to parse config fragment assuming toml, yaml and json formats + + Only dicts are considered valid fragments. + None is returned when not a fragment + """ + try: + return tomlkit.parse(value).unwrap() + except ValueError: + pass + with contextlib.suppress(Exception): + c_v = yaml.safe_load(value) + if isinstance(c_v, dict): + return c_v + with contextlib.suppress(ValueError): + c_v = json.loads(value) + # only lists and dictionaries count + if isinstance(c_v, dict): + return c_v + return None + + def log_traces( config: Optional[BaseConfiguration], key: str, diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index a4835a8188..be71cb50e9 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -88,6 +88,9 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): max_table_nesting: Optional[int] = None """Allows a destination to overwrite max_table_nesting from source""" + supported_merge_strategies: Sequence["TLoaderMergeStrategy"] = None # type: ignore[name-defined] # noqa: F821 + # TODO: also add `supported_replace_strategies` capability + # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False @@ -107,6 +110,7 @@ def generic_capabilities( naming_convention: TNamingConventionReferenceArg = None, loader_file_format_adapter: LoaderFileFormatAdapter = None, supported_table_formats: Sequence["TTableFormat"] = None, # type: ignore[name-defined] # noqa: F821 + supported_merge_strategies: Sequence["TLoaderMergeStrategy"] = None, # type: ignore[name-defined] # noqa: F821 ) -> "DestinationCapabilitiesContext": from dlt.common.data_writers.escape import format_datetime_literal @@ -134,6 +138,7 @@ def generic_capabilities( caps.supports_ddl_transactions = True caps.supports_transactions = True caps.supports_multiple_statements = True + caps.supported_merge_strategies = supported_merge_strategies or [] return caps diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index c5f30401df..49c9b822e3 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -126,6 +126,10 @@ def __init__(self, schema_name: str, version_hash: str, stored_version_hash: str ) +class DestinationCapabilitiesException(DestinationException): + pass + + class DestinationInvalidFileFormat(DestinationTerminalException): def __init__( self, destination_type: str, file_format: str, file_name: str, message: str diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index dd4fbc8e13..ca9d6a2d94 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -33,6 +33,7 @@ get_file_format, get_write_disposition, get_table_format, + get_merge_strategy, ) from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration @@ -446,6 +447,8 @@ def prepare_load_table( # add write disposition if not specified - in child tables if "write_disposition" not in table: table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) + if "x-merge-strategy" not in table: + table["x-merge-strategy"] = get_merge_strategy(self.schema.tables, table_name) # type: ignore[typeddict-unknown-key] if "table_format" not in table: table["table_format"] = get_table_format(self.schema.tables, table_name) if "file_format" not in table: diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index 82906b0144..6a86ab5fbe 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -143,6 +143,25 @@ def _to_pip_install(self) -> str: return "\n".join([f'pip install "{d}"' for d in self.dependencies]) +class DependencyVersionException(DltException): + def __init__( + self, pkg_name: str, version_found: str, version_required: str, appendix: str = "" + ) -> None: + self.pkg_name = pkg_name + self.version_found = version_found + self.version_required = version_required + super().__init__(self._get_msg(appendix)) + + def _get_msg(self, appendix: str) -> str: + msg = ( + f"Found `{self.pkg_name}=={self.version_found}`, while" + f" `{self.pkg_name}{self.version_required}` is required." + ) + if appendix: + msg = msg + "\n" + appendix + return msg + + class SystemConfigurationException(DltException): pass diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index 00d8dcc430..72ab453cbf 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -144,7 +144,7 @@ def custom_pua_encode(obj: Any) -> str: elif dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) # type: ignore elif PydanticBaseModel and isinstance(obj, PydanticBaseModel): - return obj.dict() # type: ignore[return-value] + return obj.dict(by_alias=True) # type: ignore[return-value] elif isinstance(obj, Enum): # Enum value is just int or str return obj.value # type: ignore[no-any-return] diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index 32847303f8..04100b0c6c 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -1,15 +1,17 @@ from typing import Optional, Dict, Union +from pathlib import Path from dlt import version from dlt.common import logger from dlt.common.libs.pyarrow import pyarrow as pa -from dlt.common.libs.pyarrow import dataset_to_table, cast_arrow_schema_types +from dlt.common.libs.pyarrow import cast_arrow_schema_types from dlt.common.schema.typing import TWriteDisposition from dlt.common.exceptions import MissingDependencyException from dlt.common.storages import FilesystemConfiguration try: - from deltalake import write_deltalake + from deltalake import write_deltalake, DeltaTable + from deltalake.writer import try_get_deltatable except ModuleNotFoundError: raise MissingDependencyException( "dlt deltalake helpers", @@ -18,10 +20,10 @@ ) -def ensure_delta_compatible_arrow_table(table: pa.table) -> pa.Table: - """Returns Arrow table compatible with Delta table format. +def ensure_delta_compatible_arrow_schema(schema: pa.Schema) -> pa.Schema: + """Returns Arrow schema compatible with Delta table format. - Casts table schema to replace data types not supported by Delta. + Casts schema to replace data types not supported by Delta. """ ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP = { # maps type check function to type factory function @@ -29,10 +31,18 @@ def ensure_delta_compatible_arrow_table(table: pa.table) -> pa.Table: pa.types.is_time: pa.string(), pa.types.is_decimal256: pa.string(), # pyarrow does not allow downcasting to decimal128 } - adjusted_schema = cast_arrow_schema_types( - table.schema, ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP - ) - return table.cast(adjusted_schema) + return cast_arrow_schema_types(schema, ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP) + + +def ensure_delta_compatible_arrow_data( + data: Union[pa.Table, pa.RecordBatchReader] +) -> Union[pa.Table, pa.RecordBatchReader]: + """Returns Arrow data compatible with Delta table format. + + Casts `data` schema to replace data types not supported by Delta. + """ + schema = ensure_delta_compatible_arrow_schema(data.schema) + return data.cast(schema) def get_delta_write_mode(write_disposition: TWriteDisposition) -> str: @@ -49,21 +59,19 @@ def get_delta_write_mode(write_disposition: TWriteDisposition) -> str: def write_delta_table( - path: str, - data: Union[pa.Table, pa.dataset.Dataset], + table_or_uri: Union[str, Path, DeltaTable], + data: Union[pa.Table, pa.RecordBatchReader], write_disposition: TWriteDisposition, storage_options: Optional[Dict[str, str]] = None, ) -> None: """Writes in-memory Arrow table to on-disk Delta table.""" - table = dataset_to_table(data) - # throws warning for `s3` protocol: https://github.com/delta-io/delta-rs/issues/2460 # TODO: upgrade `deltalake` lib after https://github.com/delta-io/delta-rs/pull/2500 # is released write_deltalake( # type: ignore[call-overload] - table_or_uri=path, - data=ensure_delta_compatible_arrow_table(table), + table_or_uri=table_or_uri, + data=ensure_delta_compatible_arrow_data(data), mode=get_delta_write_mode(write_disposition), schema_mode="merge", # enable schema evolution (adding new columns) storage_options=storage_options, diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index ee249b111c..9d3e97421c 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -474,10 +474,6 @@ def pq_stream_with_new_columns( yield tbl -def dataset_to_table(data: Union[pyarrow.Table, pyarrow.dataset.Dataset]) -> pyarrow.Table: - return data.to_table() if isinstance(data, pyarrow.dataset.Dataset) else data - - def cast_arrow_schema_types( schema: pyarrow.Schema, type_map: Dict[Callable[[pyarrow.DataType], bool], Callable[..., pyarrow.DataType]], diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index 774a1641a7..15e3e53409 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -4,6 +4,7 @@ from typing import ( Dict, Generic, + Optional, Set, TypedDict, List, @@ -298,14 +299,16 @@ def create_list_model( ) -def validate_items( +def validate_and_filter_items( table_name: str, list_model: Type[ListModel[_TPydanticModel]], items: List[TDataItem], column_mode: TSchemaEvolutionMode, data_mode: TSchemaEvolutionMode, ) -> List[_TPydanticModel]: - """Validates list of `item` with `list_model` and returns parsed Pydantic models + """Validates list of `item` with `list_model` and returns parsed Pydantic models. If `column_mode` and `data_mode` are set + this function will remove non validating items (`discard_row`) or raise on the first non-validating items (`freeze`). Note + that the model itself may be configured to remove non validating or extra items as well. `list_model` should be created with `create_list_model` and have `items` field which this function returns. """ @@ -332,6 +335,7 @@ def validate_items( list_model, {"columns": "freeze"}, items, + err["msg"], ) from e # raise on freeze if err["type"] == "extra_forbidden": @@ -345,6 +349,7 @@ def validate_items( list_model, {"columns": "freeze"}, err_item, + err["msg"], ) from e elif column_mode == "discard_row": # pop at the right index @@ -366,6 +371,7 @@ def validate_items( list_model, {"data_type": "freeze"}, err_item, + err["msg"], ) from e elif data_mode == "discard_row": items.pop(err_idx - len(deleted)) @@ -376,17 +382,19 @@ def validate_items( ) # validate again with error items removed - return validate_items(table_name, list_model, items, column_mode, data_mode) + return validate_and_filter_items(table_name, list_model, items, column_mode, data_mode) -def validate_item( +def validate_and_filter_item( table_name: str, model: Type[_TPydanticModel], item: TDataItems, column_mode: TSchemaEvolutionMode, data_mode: TSchemaEvolutionMode, -) -> _TPydanticModel: - """Validates `item` against model `model` and returns an instance of it""" +) -> Optional[_TPydanticModel]: + """Validates `item` against model `model` and returns an instance of it. If `column_mode` and `data_mode` are set + this function will return None (`discard_row`) or raise on non-validating items (`freeze`). Note + that the model itself may be configured to remove non validating or extra items as well.""" try: return model.parse_obj(item) except ValidationError as e: @@ -403,6 +411,7 @@ def validate_item( model, {"columns": "freeze"}, item, + err["msg"], ) from e elif column_mode == "discard_row": return None @@ -420,6 +429,7 @@ def validate_item( model, {"data_type": "freeze"}, item, + err["msg"], ) from e elif data_mode == "discard_row": return None diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 88abd575b0..45ae26e8be 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -34,11 +34,11 @@ def metrics(name: str, extra: Mapping[str, Any], stacklevel: int = 1) -> None: @contextlib.contextmanager -def suppress_and_warn() -> Iterator[None]: +def suppress_and_warn(msg: str) -> Iterator[None]: try: yield except Exception: - LOGGER.warning("Suppressed exception", exc_info=True) + LOGGER.warning(msg, exc_info=True) def is_logging() -> bool: diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 91af42a6c5..8e296445eb 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -2,18 +2,25 @@ from typing import Dict, List, Mapping, Optional, Sequence, Tuple, cast, TypedDict, Any from dlt.common.json import json from dlt.common.normalizers.exceptions import InvalidJsonNormalizer -from dlt.common.normalizers.typing import TJSONNormalizer +from dlt.common.normalizers.typing import TJSONNormalizer, TRowIdType from dlt.common.normalizers.utils import generate_dlt_id, DLT_ID_LENGTH_BYTES from dlt.common.typing import DictStrAny, TDataItem, StrAny from dlt.common.schema import Schema from dlt.common.schema.typing import ( + TLoaderMergeStrategy, TColumnSchema, TColumnName, TSimpleRegex, DLT_NAME_PREFIX, ) -from dlt.common.schema.utils import column_name_validator, get_validity_column_names +from dlt.common.schema.utils import ( + column_name_validator, + get_validity_column_names, + get_columns_names_with_prop, + get_first_column_name_with_prop, + get_merge_strategy, +) from dlt.common.schema.exceptions import ColumnNameConflictException from dlt.common.utils import digest128, update_dict_nested from dlt.common.normalizers.json import ( @@ -158,7 +165,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - return out_rec_row, out_rec_list @staticmethod - def get_row_hash(row: Dict[str, Any]) -> str: + def get_row_hash(row: Dict[str, Any], subset: Optional[List[str]] = None) -> str: """Returns hash of row. Hash includes column names and values and is ordered by column name. @@ -166,6 +173,8 @@ def get_row_hash(row: Dict[str, Any]) -> str: Can be used as deterministic row identifier. """ row_filtered = {k: v for k, v in row.items() if not k.startswith(DLT_NAME_PREFIX)} + if subset is not None: + row_filtered = {k: v for k, v in row.items() if k in subset} row_str = json.dumps(row_filtered, sort_keys=True) return digest128(row_str, DLT_ID_LENGTH_BYTES) @@ -188,18 +197,39 @@ def _extend_row(extend: DictStrAny, row: DictStrAny) -> None: row.update(extend) def _add_row_id( - self, table: str, row: DictStrAny, parent_row_id: str, pos: int, _r_lvl: int + self, + table: str, + dict_row: DictStrAny, + flattened_row: DictStrAny, + parent_row_id: str, + pos: int, + _r_lvl: int, ) -> str: - # row_id is always random, no matter if primary_key is present or not - row_id = generate_dlt_id() - if _r_lvl > 0: - primary_key = self.schema.filter_row_with_hint(table, "primary_key", row) - if not primary_key: - # child table row deterministic hash - row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) - # link to parent table - DataItemNormalizer._link_row(row, parent_row_id, pos) - row[self.c_dlt_id] = row_id + primary_key = False + if _r_lvl > 0: # child table + primary_key = bool( + self.schema.filter_row_with_hint(table, "primary_key", flattened_row) + ) + row_id_type = self._get_row_id_type(self.schema, table, primary_key, _r_lvl) + + if row_id_type == "random": + row_id = generate_dlt_id() + else: + if _r_lvl == 0: # root table + if row_id_type in ("key_hash", "row_hash"): + subset = None + if row_id_type == "key_hash": + subset = self._get_primary_key(self.schema, table) + # base hash on `dict_row` instead of `flattened_row` + # so changes in child tables lead to new row id + row_id = self.get_row_hash(dict_row, subset=subset) + elif _r_lvl > 0: # child table + if row_id_type == "row_hash": + row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) + # link to parent table + DataItemNormalizer._link_row(flattened_row, parent_row_id, pos) + + flattened_row[self.c_dlt_id] = row_id return row_id def _get_propagated_values(self, table: str, row: DictStrAny, _r_lvl: int) -> StrAny: @@ -268,14 +298,9 @@ def _normalize_row( parent_row_id: Optional[str] = None, pos: Optional[int] = None, _r_lvl: int = 0, - row_hash: bool = False, ) -> TNormalizedRowIterator: schema = self.schema table = schema.naming.shorten_fragments(*parent_path, *ident_path) - # compute row hash and set as row id - if row_hash: - row_id = self.get_row_hash(dict_row) - dict_row[self.c_dlt_id] = row_id # flatten current row and extract all lists to recur into flattened_row, lists = self._flatten(table, dict_row, _r_lvl) # always extend row @@ -283,7 +308,7 @@ def _normalize_row( # infer record hash or leave existing primary key if present row_id = flattened_row.get(self.c_dlt_id, None) if not row_id: - row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl) + row_id = self._add_row_id(table, dict_row, flattened_row, parent_row_id, pos, _r_lvl) # find fields to propagate to child tables in config extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) @@ -369,11 +394,7 @@ def normalize_data_item( row = cast(DictStrAny, item) # identify load id if loaded data must be processed after loading incrementally row[self.c_dlt_load_id] = load_id - - # determine if row hash should be used as dlt id - row_hash = False - if self._is_scd2_table(self.schema, table_name): - row_hash = self._dlt_id_is_row_hash(self.schema, table_name, self.c_dlt_id) + if self._get_merge_strategy(self.schema, table_name) == "scd2": self._validate_validity_column_names( self.schema.name, self._get_validity_column_names(self.schema, table_name), item ) @@ -382,7 +403,6 @@ def normalize_data_item( row, {}, (self.schema.naming.normalize_table_identifier(table_name),), - row_hash=row_hash, ) @classmethod @@ -450,11 +470,16 @@ def _get_table_nesting_level(schema: Schema, table_name: str) -> Optional[int]: @staticmethod @lru_cache(maxsize=None) - def _is_scd2_table(schema: Schema, table_name: str) -> bool: - if table_name in schema.data_table_names(): - if schema.get_table(table_name).get("x-merge-strategy") == "scd2": - return True - return False + def _get_merge_strategy(schema: Schema, table_name: str) -> Optional[TLoaderMergeStrategy]: + return get_merge_strategy(schema.tables, table_name) + + @staticmethod + @lru_cache(maxsize=None) + def _get_primary_key(schema: Schema, table_name: str) -> List[str]: + if table_name not in schema.tables: + return [] + table = schema.get_table(table_name) + return get_columns_names_with_prop(table, "primary_key", include_incomplete=True) @staticmethod @lru_cache(maxsize=None) @@ -463,12 +488,29 @@ def _get_validity_column_names(schema: Schema, table_name: str) -> List[Optional @staticmethod @lru_cache(maxsize=None) - def _dlt_id_is_row_hash(schema: Schema, table_name: str, c_dlt_id: str) -> bool: - return ( - schema.get_table(table_name)["columns"] # type: ignore[return-value] - .get(c_dlt_id, {}) - .get("x-row-version", False) - ) + def _get_row_id_type( + schema: Schema, table_name: str, primary_key: bool, _r_lvl: int + ) -> TRowIdType: + if _r_lvl == 0: # root table + merge_strategy = DataItemNormalizer._get_merge_strategy(schema, table_name) + if merge_strategy == "upsert": + return "key_hash" + elif merge_strategy == "scd2": + x_row_version_col = get_first_column_name_with_prop( + schema.get_table(table_name), + "x-row-version", + include_incomplete=True, + ) + if x_row_version_col == DataItemNormalizer.C_DLT_ID: + return "row_hash" + elif _r_lvl > 0: # child table + merge_strategy = DataItemNormalizer._get_merge_strategy(schema, table_name) + if merge_strategy in ("upsert", "scd2"): + # these merge strategies rely on deterministic child row hash + return "row_hash" + if not primary_key: + return "row_hash" + return "random" @staticmethod def _validate_validity_column_names( diff --git a/dlt/common/normalizers/typing.py b/dlt/common/normalizers/typing.py index 9ea6f3cf11..9840f3a4d2 100644 --- a/dlt/common/normalizers/typing.py +++ b/dlt/common/normalizers/typing.py @@ -1,5 +1,5 @@ +from typing import List, Optional, Type, TypedDict, Literal, Union from types import ModuleType -from typing import List, Optional, Type, TypedDict, Union from dlt.common.typing import StrAny from dlt.common.normalizers.naming import NamingConvention @@ -7,6 +7,9 @@ TNamingConventionReferenceArg = Union[str, Type[NamingConvention], ModuleType] +TRowIdType = Literal["random", "row_hash", "key_hash"] + + class TJSONNormalizer(TypedDict, total=False): module: str config: Optional[StrAny] # config is a free form and is validated by `module` diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index c6ee27e58b..1e1416eb53 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -602,6 +602,14 @@ def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> self._deferred_pipeline = deferred_pipeline +def current_pipeline() -> SupportsPipeline: + """Gets active pipeline context or None if not found""" + proxy = Container()[PipelineContext] + if not proxy.is_active(): + return None + return proxy.pipeline() + + @configspec class StateInjectableContext(ContainerInjectableContext): state: TPipelineState = None diff --git a/dlt/common/runtime/telemetry.py b/dlt/common/runtime/telemetry.py index 28dde0206c..6b783483cc 100644 --- a/dlt/common/runtime/telemetry.py +++ b/dlt/common/runtime/telemetry.py @@ -14,7 +14,6 @@ disable_anon_tracker, track, ) -from dlt.pipeline.platform import disable_platform_tracker, init_platform_tracker _TELEMETRY_STARTED = False @@ -36,6 +35,10 @@ def start_telemetry(config: RunConfiguration) -> None: init_anon_tracker(config) if config.dlthub_dsn: + # TODO: we need pluggable modules for tracing so import into + # concrete modules is not needed + from dlt.pipeline.platform import init_platform_tracker + init_platform_tracker() _TELEMETRY_STARTED = True @@ -55,6 +58,9 @@ def stop_telemetry() -> None: pass disable_anon_tracker() + + from dlt.pipeline.platform import disable_platform_tracker + disable_platform_tracker() _TELEMETRY_STARTED = False diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index b5081c5ff4..9a4dd51d4b 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -66,7 +66,6 @@ ] """Known hints of a column used to declare hint regexes.""" -TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg", "delta"] TFileFormat = Literal[Literal["preferred"], TLoaderFileFormat] TTypeDetections = Literal[ @@ -168,7 +167,10 @@ class NormalizerInfo(TypedDict, total=True): total=False, ) -TLoaderMergeStrategy = Literal["delete-insert", "scd2"] + +TWriteDisposition = Literal["skip", "append", "replace", "merge"] +TLoaderMergeStrategy = Literal["delete-insert", "scd2", "upsert"] + WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) MERGE_STRATEGIES: Set[TLoaderMergeStrategy] = set(get_args(TLoaderMergeStrategy)) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index cd0cc5aa63..aa5de9611c 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -685,6 +685,13 @@ def get_file_format(tables: TSchemaTables, table_name: str) -> TFileFormat: ) +def get_merge_strategy(tables: TSchemaTables, table_name: str) -> TLoaderMergeStrategy: + return cast( + TLoaderMergeStrategy, + get_inherited_table_hint(tables, table_name, "x-merge-strategy", allow_none=True), + ) + + def fill_hints_from_parent_and_clone_table( tables: TSchemaTables, table: TTableSchema ) -> TTableSchema: diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 7109daf497..c1d130e477 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -10,6 +10,8 @@ from types import ModuleType import traceback import zlib +from importlib.metadata import version as pkg_version +from packaging.version import Version from typing import ( Any, @@ -29,7 +31,12 @@ Iterable, ) -from dlt.common.exceptions import DltException, ExceptionTrace, TerminalException +from dlt.common.exceptions import ( + DltException, + ExceptionTrace, + TerminalException, + DependencyVersionException, +) from dlt.common.typing import AnyFun, StrAny, DictStrAny, StrStr, TAny, TFun @@ -565,3 +572,14 @@ def order_deduped(lst: List[Any]) -> List[Any]: Only works for lists with hashable elements. """ return list(dict.fromkeys(lst)) + + +def assert_min_pkg_version(pkg_name: str, version: str, msg: str = "") -> None: + version_found = pkg_version(pkg_name) + if Version(version_found) < Version(version): + raise DependencyVersionException( + pkg_name=pkg_name, + version_found=version_found, + version_required=">=" + version, + appendix=msg, + ) diff --git a/dlt/destinations/impl/athena/athena_adapter.py b/dlt/destinations/impl/athena/athena_adapter.py index cb600335c0..50f7abc54a 100644 --- a/dlt/destinations/impl/athena/athena_adapter.py +++ b/dlt/destinations/impl/athena/athena_adapter.py @@ -4,7 +4,7 @@ from dlt.common.pendulum import timezone from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TColumnSchema -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate @@ -89,7 +89,7 @@ def athena_adapter( >>> athena_adapter(data, partition=["department", athena_partition.year("date_hired"), athena_partition.bucket(8, "name")]) [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if partition: diff --git a/dlt/destinations/impl/athena/factory.py b/dlt/destinations/impl/athena/factory.py index d4c29a641f..07d784ed49 100644 --- a/dlt/destinations/impl/athena/factory.py +++ b/dlt/destinations/impl/athena/factory.py @@ -46,6 +46,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.schema_supports_numeric_precision = False caps.timestamp_precision = 3 caps.supports_truncate_command = False + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] return caps @property diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 0f6b8f4838..095974d186 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -24,6 +24,7 @@ from dlt.common.schema.utils import get_inherited_table_hint from dlt.common.schema.utils import table_schema_has_type from dlt.common.storages.file_storage import FileStorage +from dlt.common.storages.load_package import destination_state from dlt.common.typing import DictStrAny from dlt.destinations.job_impl import DestinationJsonlLoadJob, DestinationParquetLoadJob from dlt.destinations.sql_client import SqlClientBase @@ -36,6 +37,7 @@ LoadJobTerminalException, ) from dlt.destinations.impl.bigquery.bigquery_adapter import ( + AUTODETECT_SCHEMA_HINT, PARTITION_HINT, CLUSTER_HINT, TABLE_DESCRIPTION_HINT, @@ -50,7 +52,6 @@ from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.utils import parse_db_data_type_str_with_precision -from dlt.pipeline.current import destination_state class BigQueryTypeMapper(TypeMapper): @@ -290,6 +291,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: + # return empty columns which will skip table CREATE or ALTER + # to let BigQuery autodetect table from data + if self._should_autodetect_schema(table_name): + return [] + table: Optional[TTableSchema] = self.prepare_load_table(table_name) sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) canonical_name = self.sql_client.make_qualified_table_name(table_name) @@ -447,12 +453,6 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON decimal_target_types: Optional[List[str]] = None if ext == "parquet": - # if table contains complex types, we cannot load with parquet - if table_schema_has_type(table, "complex"): - raise LoadJobTerminalException( - file_path, - "Bigquery cannot load into JSON data type from parquet. Use jsonl instead.", - ) source_format = bigquery.SourceFormat.PARQUET # parquet needs NUMERIC type auto-detection decimal_target_types = ["NUMERIC", "BIGNUMERIC"] @@ -467,6 +467,19 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load ignore_unknown_values=False, max_bad_records=0, ) + if self._should_autodetect_schema(table_name): + # allow BigQuery to infer and evolve the schema, note that dlt is not + # creating such tables at all + job_config.autodetect = True + job_config.schema_update_options = bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION + job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED + elif ext == "parquet" and table_schema_has_type(table, "complex"): + # if table contains complex types, we cannot load with parquet + raise LoadJobTerminalException( + file_path, + "Bigquery cannot load into JSON data type from parquet. Enable autodetect_schema in" + " config or via BigQuery adapter or use jsonl format instead.", + ) if bucket_path: return self.sql_client.native_connection.load_table_from_uri( @@ -495,6 +508,11 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) + def _should_autodetect_schema(self, table_name: str) -> bool: + return get_inherited_table_hint( + self.schema._schema_tables, table_name, AUTODETECT_SCHEMA_HINT, allow_none=True + ) or (self.config.autodetect_schema and table_name not in self.schema.dlt_table_names()) + def _streaming_load( sql_client: SqlClientBase[BigQueryClient], items: List[Dict[Any, Any]], table: Dict[str, Any] diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 38828249ff..55fe1b6b74 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -7,7 +7,7 @@ TColumnNames, TTableSchemaColumns, ) -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate @@ -20,6 +20,7 @@ ROUND_HALF_EVEN_HINT: Literal["x-bigquery-round-half-even"] = "x-bigquery-round-half-even" TABLE_EXPIRATION_HINT: Literal["x-bigquery-table-expiration"] = "x-bigquery-table-expiration" TABLE_DESCRIPTION_HINT: Literal["x-bigquery-table-description"] = "x-bigquery-table-description" +AUTODETECT_SCHEMA_HINT: Literal["x-bigquery-autodetect-schema"] = "x-bigquery-autodetect-schema" def bigquery_adapter( @@ -31,6 +32,7 @@ def bigquery_adapter( table_description: Optional[str] = None, table_expiration_datetime: Optional[str] = None, insert_api: Optional[Literal["streaming", "default"]] = None, + autodetect_schema: Optional[bool] = None, ) -> DltResource: """ Prepares data for loading into BigQuery. @@ -62,6 +64,8 @@ def bigquery_adapter( If "streaming" is chosen, the streaming API (https://cloud.google.com/bigquery/docs/streaming-data-into-bigquery) is used. NOTE: due to BigQuery features, streaming insert is only available for `append` write_disposition. + autodetect_schema (bool, optional): If set to True, BigQuery schema autodetection will be used to create data tables. This + allows to create structured types from nested data. Returns: A `DltResource` object that is ready to be loaded into BigQuery. @@ -74,7 +78,7 @@ def bigquery_adapter( >>> bigquery_adapter(data, partition="date_hired", table_expiration_datetime="2024-01-30", table_description="Employee Data") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} column_hints: TTableSchemaColumns = {} @@ -136,6 +140,9 @@ def bigquery_adapter( ) additional_table_hints[TABLE_DESCRIPTION_HINT] = table_description + if autodetect_schema: + additional_table_hints[AUTODETECT_SCHEMA_HINT] = autodetect_schema + if table_expiration_datetime: if not isinstance(table_expiration_datetime, str): raise ValueError( diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index ef4e63ca12..47cc997a4a 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -19,25 +19,21 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): should_set_case_sensitivity_on_new_dataset: bool = False """If True, dlt will set case sensitivity flag on created datasets that corresponds to naming convention""" - http_timeout: float = 15.0 # connection timeout for http request to BigQuery api - file_upload_timeout: float = 30 * 60.0 # a timeout for file upload when loading local files - retry_deadline: float = ( - 60.0 # how long to retry the operation in case of error, the backoff 60 s. - ) + http_timeout: float = 15.0 + """connection timeout for http request to BigQuery api""" + file_upload_timeout: float = 30 * 60.0 + """a timeout for file upload when loading local files""" + retry_deadline: float = 60.0 + """How long to retry the operation in case of error, the backoff 60 s.""" batch_size: int = 500 + """Number of rows in streaming insert batch""" + autodetect_schema: bool = False + """Allow BigQuery to autodetect schemas and create data tables""" __config_gen_annotations__: ClassVar[List[str]] = ["location"] def get_location(self) -> str: - if self.location != "US": - return self.location - # default was changed in credentials, emit deprecation message - if self.credentials.location != "US": - warnings.warn( - "Setting BigQuery location in the credentials is deprecated. Please set the" - " location directly in bigquery section ie. destinations.bigquery.location='EU'" - ) - return self.credentials.location + return self.location def fingerprint(self) -> str: """Returns a fingerprint of project_id""" diff --git a/dlt/destinations/impl/bigquery/factory.py b/dlt/destinations/impl/bigquery/factory.py index b3096e9312..34dd1790ae 100644 --- a/dlt/destinations/impl/bigquery/factory.py +++ b/dlt/destinations/impl/bigquery/factory.py @@ -42,6 +42,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = False caps.supports_clone_table = True caps.schema_supports_numeric_precision = False # no precision information in BigQuery + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] return caps diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index d08e91758a..148fca3f1e 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -2,21 +2,18 @@ import re from copy import deepcopy from textwrap import dedent -from typing import ClassVar, Optional, Dict, List, Sequence, cast, Tuple +from typing import Optional, List, Sequence, cast from urllib.parse import urlparse import clickhouse_connect from clickhouse_connect.driver.tools import insert_file -import dlt from dlt import config from dlt.common.configuration.specs import ( CredentialsConfiguration, AzureCredentialsWithoutDefaults, - GcpCredentials, AwsCredentialsWithoutDefaults, ) -from dlt.destinations.exceptions import DestinationTransientException from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( SupportsStagingDestination, @@ -29,26 +26,27 @@ from dlt.common.schema.typing import ( TTableFormat, TTableSchema, - TColumnHint, TColumnType, - TTableSchemaColumns, - TColumnSchemaBase, ) from dlt.common.storages import FileStorage from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.clickhouse.clickhouse_adapter import ( - TTableEngineType, - TABLE_ENGINE_TYPE_HINT, -) from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseClientConfiguration, ) from dlt.destinations.impl.clickhouse.sql_client import ClickHouseSqlClient -from dlt.destinations.impl.clickhouse.utils import ( - convert_storage_to_http_scheme, +from dlt.destinations.impl.clickhouse.typing import ( + HINT_TO_CLICKHOUSE_ATTR, + TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR, +) +from dlt.destinations.impl.clickhouse.typing import ( + TTableEngineType, + TABLE_ENGINE_TYPE_HINT, FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING, SUPPORTED_FILE_FORMATS, ) +from dlt.destinations.impl.clickhouse.utils import ( + convert_storage_to_http_scheme, +) from dlt.destinations.job_client_impl import ( SqlJobClientBase, SqlJobClientWithStaging, @@ -58,18 +56,6 @@ from dlt.destinations.type_mapping import TypeMapper -HINT_TO_CLICKHOUSE_ATTR: Dict[TColumnHint, str] = { - "primary_key": "PRIMARY KEY", - "unique": "", # No unique constraints available in ClickHouse. - "foreign_key": "", # No foreign key constraints support in ClickHouse. -} - -TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR: Dict[TTableEngineType, str] = { - "merge_tree": "MergeTree", - "replicated_merge_tree": "ReplicatedMergeTree", -} - - class ClickHouseTypeMapper(TypeMapper): sct_to_unbound_dbt = { "complex": "String", @@ -113,7 +99,8 @@ def from_db_type( if db_type == "DateTime('UTC')": db_type = "DateTime" if datetime_match := re.match( - r"DateTime64(?:\((?P\d+)(?:,?\s*'(?PUTC)')?\))?", db_type + r"DateTime64(?:\((?P\d+)(?:,?\s*'(?PUTC)')?\))?", + db_type, ): if datetime_match["precision"]: precision = int(datetime_match["precision"]) @@ -131,7 +118,7 @@ def from_db_type( db_type = "Decimal" if db_type == "Decimal" and (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") + return cast(TColumnType, dict(data_type="wei")) return super().from_db_type(db_type, precision, scale) @@ -161,7 +148,7 @@ def __init__( compression = "auto" - # Don't use dbapi driver for local files. + # Don't use the DBAPI driver for local files. if not bucket_path: # Local filesystem. if ext == "jsonl": @@ -182,8 +169,8 @@ def __init__( fmt=clickhouse_format, settings={ "allow_experimental_lightweight_delete": 1, - # "allow_experimental_object_type": 1, "enable_http_compression": 1, + "date_time_input_format": "best_effort", }, compression=compression, ) @@ -201,13 +188,7 @@ def __init__( compression = "none" if config.get("data_writer.disable_compression") else "gz" if bucket_scheme in ("s3", "gs", "gcs"): - if isinstance(staging_credentials, AwsCredentialsWithoutDefaults): - bucket_http_url = convert_storage_to_http_scheme( - bucket_url, endpoint=staging_credentials.endpoint_url - ) - access_key_id = staging_credentials.aws_access_key_id - secret_access_key = staging_credentials.aws_secret_access_key - else: + if not isinstance(staging_credentials, AwsCredentialsWithoutDefaults): raise LoadJobTerminalException( file_path, dedent( @@ -219,6 +200,11 @@ def __init__( ).strip(), ) + bucket_http_url = convert_storage_to_http_scheme( + bucket_url, endpoint=staging_credentials.endpoint_url + ) + access_key_id = staging_credentials.aws_access_key_id + secret_access_key = staging_credentials.aws_secret_access_key auth = "NOSIGN" if access_key_id and secret_access_key: auth = f"'{access_key_id}','{secret_access_key}'" @@ -299,6 +285,7 @@ def __init__( config.normalize_staging_dataset_name(schema), config.credentials, capabilities, + config, ) super().__init__(schema, config, self.sql_client) self.config: ClickHouseClientConfiguration = config @@ -311,10 +298,10 @@ def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> Li def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: # Build column definition. # The primary key and sort order definition is defined outside column specification. - hints_str = " ".join( + hints_ = " ".join( self.active_hints.get(hint) for hint in self.active_hints.keys() - if c.get(hint, False) is True + if c.get(cast(str, hint), False) is True and hint not in ("primary_key", "sort") and hint in self.active_hints ) @@ -328,7 +315,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) return ( - f"{self.sql_client.escape_column_name(c['name'])} {type_with_nullability_modifier} {hints_str}" + f"{self.sql_client.escape_column_name(c['name'])} {type_with_nullability_modifier} {hints_}" .strip() ) @@ -343,7 +330,10 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> ) def _get_table_update_sql( - self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, ) -> List[str]: table: TTableSchema = self.prepare_load_table(table_name, self.in_staging_mode) sql = SqlJobClientBase._get_table_update_sql(self, table_name, new_columns, generate_alter) @@ -351,9 +341,15 @@ def _get_table_update_sql( if generate_alter: return sql - # Default to 'ReplicatedMergeTree' if user didn't explicitly set a table engine hint. + # Default to 'MergeTree' if the user didn't explicitly set a table engine hint. + # Clickhouse Cloud will automatically pick `SharedMergeTree` for this option, + # so it will work on both local and cloud instances of CH. table_type = cast( - TTableEngineType, table.get(TABLE_ENGINE_TYPE_HINT, "replicated_merge_tree") + TTableEngineType, + table.get( + cast(str, TABLE_ENGINE_TYPE_HINT), + self.config.table_engine_type, + ), ) sql[0] = f"{sql[0]}\nENGINE = {TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR.get(table_type)}" diff --git a/dlt/destinations/impl/clickhouse/clickhouse_adapter.py b/dlt/destinations/impl/clickhouse/clickhouse_adapter.py index 1bbde8e45d..41be531b71 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse_adapter.py +++ b/dlt/destinations/impl/clickhouse/clickhouse_adapter.py @@ -1,12 +1,15 @@ -from typing import Any, Literal, Set, get_args, Dict - -from dlt.destinations.utils import ensure_resource +from typing import Any, Dict + +from dlt.destinations.impl.clickhouse.configuration import TTableEngineType +from dlt.destinations.impl.clickhouse.typing import ( + TABLE_ENGINE_TYPES, + TABLE_ENGINE_TYPE_HINT, +) +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate -TTableEngineType = Literal["merge_tree", "replicated_merge_tree"] - """ The table engine (type of table) determines: @@ -19,8 +22,6 @@ See https://clickhouse.com/docs/en/engines/table-engines. """ -TABLE_ENGINE_TYPES: Set[TTableEngineType] = set(get_args(TTableEngineType)) -TABLE_ENGINE_TYPE_HINT: Literal["x-table-engine-type"] = "x-table-engine-type" def clickhouse_adapter(data: Any, table_engine_type: TTableEngineType = None) -> DltResource: @@ -45,7 +46,7 @@ def clickhouse_adapter(data: Any, table_engine_type: TTableEngineType = None) -> >>> clickhouse_adapter(data, table_engine_type="merge_tree") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if table_engine_type is not None: diff --git a/dlt/destinations/impl/clickhouse/configuration.py b/dlt/destinations/impl/clickhouse/configuration.py index 483356f9f9..fbda58abc7 100644 --- a/dlt/destinations/impl/clickhouse/configuration.py +++ b/dlt/destinations/impl/clickhouse/configuration.py @@ -1,16 +1,13 @@ import dataclasses -from typing import ClassVar, Dict, List, Any, Final, Literal, cast, Optional +from typing import ClassVar, Dict, List, Any, Final, cast, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.destination.reference import ( DestinationClientDwhWithStagingConfiguration, ) -from dlt.common.libs.sql_alchemy import URL from dlt.common.utils import digest128 - - -TSecureConnection = Literal[0, 1] +from dlt.destinations.impl.clickhouse.typing import TSecureConnection, TTableEngineType @configspec(init=False) @@ -34,10 +31,6 @@ class ClickHouseCredentials(ConnectionStringCredentials): """Timeout for establishing connection. Defaults to 10 seconds.""" send_receive_timeout: int = 300 """Timeout for sending and receiving data. Defaults to 300 seconds.""" - dataset_table_separator: str = "___" - """Separator for dataset table names, defaults to '___', i.e. 'database.dataset___table'.""" - dataset_sentinel_table_name: str = "dlt_sentinel_table" - """Special table to mark dataset as existing""" gcp_access_key_id: Optional[str] = None """When loading from a gcp bucket, you need to provide gcp interoperable keys""" gcp_secret_access_key: Optional[str] = None @@ -67,10 +60,9 @@ def get_query(self) -> Dict[str, Any]: "connect_timeout": str(self.connect_timeout), "send_receive_timeout": str(self.send_receive_timeout), "secure": 1 if self.secure else 0, - # Toggle experimental settings. These are necessary for certain datatypes and not optional. "allow_experimental_lightweight_delete": 1, - # "allow_experimental_object_type": 1, "enable_http_compression": 1, + "date_time_input_format": "best_effort", } ) return query @@ -78,16 +70,26 @@ def get_query(self) -> Dict[str, Any]: @configspec class ClickHouseClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = dataclasses.field(default="clickhouse", init=False, repr=False, compare=False) # type: ignore[misc] + destination_type: Final[str] = dataclasses.field( # type: ignore[misc] + default="clickhouse", init=False, repr=False, compare=False + ) credentials: ClickHouseCredentials = None - # Primary key columns are used to build a sparse primary index which allows for efficient data retrieval, - # but they do not enforce uniqueness constraints. It permits duplicate values even for the primary key - # columns within the same granule. - # See: https://clickhouse.com/docs/en/optimize/sparse-primary-indexes + dataset_table_separator: str = "___" + """Separator for dataset table names, defaults to '___', i.e. 'database.dataset___table'.""" + table_engine_type: Optional[TTableEngineType] = "merge_tree" + """The default table engine to use. Defaults to 'merge_tree'. Other implemented options are 'shared_merge_tree' and 'replicated_merge_tree'.""" + dataset_sentinel_table_name: str = "dlt_sentinel_table" + """Special table to mark dataset as existing""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "dataset_table_separator", + "dataset_sentinel_table_name", + "table_engine_type", + ] def fingerprint(self) -> str: - """Returns a fingerprint of host part of a connection string.""" + """Returns a fingerprint of the host part of a connection string.""" if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" diff --git a/dlt/destinations/impl/clickhouse/factory.py b/dlt/destinations/impl/clickhouse/factory.py index 52a1694dee..93da6c866a 100644 --- a/dlt/destinations/impl/clickhouse/factory.py +++ b/dlt/destinations/impl/clickhouse/factory.py @@ -1,14 +1,13 @@ import sys import typing as t -from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.data_writers.escape import ( escape_clickhouse_identifier, escape_clickhouse_literal, format_clickhouse_datetime_literal, ) - +from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseClientConfiguration, ClickHouseCredentials, @@ -67,6 +66,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_truncate_command = True + caps.supported_merge_strategies = ["delete-insert", "scd2"] + return caps @property diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index 244db578b1..25914e4093 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -1,3 +1,5 @@ +import datetime # noqa: I251 +from clickhouse_driver import dbapi as clickhouse_dbapi # type: ignore[import-untyped] from contextlib import contextmanager from typing import ( Iterator, @@ -7,21 +9,32 @@ Optional, Sequence, ClassVar, + Literal, Tuple, + cast, ) -import clickhouse_driver # type: ignore[import-untyped] +import clickhouse_driver import clickhouse_driver.errors # type: ignore[import-untyped] from clickhouse_driver.dbapi import OperationalError # type: ignore[import-untyped] from clickhouse_driver.dbapi.extras import DictCursor # type: ignore[import-untyped] +from pendulum import DateTime # noqa: I251 from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.typing import DictStrAny from dlt.destinations.exceptions import ( DatabaseUndefinedRelation, DatabaseTransientException, DatabaseTerminalException, ) -from dlt.destinations.impl.clickhouse.configuration import ClickHouseCredentials +from dlt.destinations.impl.clickhouse.configuration import ( + ClickHouseCredentials, + ClickHouseClientConfiguration, +) +from dlt.destinations.impl.clickhouse.typing import ( + TTableEngineType, + TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR, +) from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, @@ -32,6 +45,7 @@ from dlt.destinations.utils import _convert_to_old_pyformat +TDeployment = Literal["ClickHouseOSS", "ClickHouseCloud"] TRANSACTIONS_UNSUPPORTED_WARNING_MESSAGE = ( "ClickHouse does not support transactions! Each statement is auto-committed separately." ) @@ -44,7 +58,7 @@ class ClickHouseDBApiCursorImpl(DBApiCursorImpl): class ClickHouseSqlClient( SqlClientBase[clickhouse_driver.dbapi.connection.Connection], DBTransaction ): - dbapi: ClassVar[DBApi] = clickhouse_driver.dbapi + dbapi: ClassVar[DBApi] = clickhouse_dbapi def __init__( self, @@ -52,17 +66,19 @@ def __init__( staging_dataset_name: str, credentials: ClickHouseCredentials, capabilities: DestinationCapabilitiesContext, + config: ClickHouseClientConfiguration, ) -> None: super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: clickhouse_driver.dbapi.connection = None self.credentials = credentials self.database_name = credentials.database + self.config = config def has_dataset(self) -> bool: - # we do not need to normalize dataset_sentinel_table_name - sentinel_table = self.credentials.dataset_sentinel_table_name + # we do not need to normalize dataset_sentinel_table_name. + sentinel_table = self.config.dataset_sentinel_table_name return sentinel_table in [ - t.split(self.credentials.dataset_table_separator)[1] for t in self._list_tables() + t.split(self.config.dataset_table_separator)[1] for t in self._list_tables() ] def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection: @@ -99,18 +115,21 @@ def execute_sql( return None if curr.description is None else curr.fetchall() def create_dataset(self) -> None: - # We create a sentinel table which defines wether we consider the dataset created + # We create a sentinel table which defines whether we consider the dataset created. sentinel_table_name = self.make_qualified_table_name( - self.credentials.dataset_sentinel_table_name - ) - self.execute_sql( - f"""CREATE TABLE {sentinel_table_name} (_dlt_id String NOT NULL PRIMARY KEY) ENGINE=ReplicatedMergeTree COMMENT 'internal dlt sentinel table'""" + self.config.dataset_sentinel_table_name ) + sentinel_table_type = cast(TTableEngineType, self.config.table_engine_type) + self.execute_sql(f""" + CREATE TABLE {sentinel_table_name} + (_dlt_id String NOT NULL PRIMARY KEY) + ENGINE={TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR.get(sentinel_table_type)} + COMMENT 'internal dlt sentinel table'""") def drop_dataset(self) -> None: - # always try to drop sentinel table + # always try to drop the sentinel table. sentinel_table_name = self.make_qualified_table_name( - self.credentials.dataset_sentinel_table_name + self.config.dataset_sentinel_table_name ) # drop a sentinel table self.execute_sql(f"DROP TABLE {sentinel_table_name} SYNC") @@ -151,6 +170,15 @@ def _list_tables(self) -> List[str]: ) return [row[0] for row in rows] + @staticmethod + def _sanitise_dbargs(db_args: DictStrAny) -> DictStrAny: + """For ClickHouse OSS, the DBapi driver doesn't parse datetime types. + We remove timezone specifications in this case.""" + for key, value in db_args.items(): + if isinstance(value, (DateTime, datetime.datetime)): + db_args[key] = str(value.replace(microsecond=0, tzinfo=None)) + return db_args + @contextmanager @raise_database_error def execute_query( @@ -158,12 +186,14 @@ def execute_query( ) -> Iterator[ClickHouseDBApiCursorImpl]: assert isinstance(query, str), "Query must be a string." - db_args = kwargs.copy() + db_args: DictStrAny = kwargs.copy() if args: query, db_args = _convert_to_old_pyformat(query, args, OperationalError) db_args.update(kwargs) + db_args = self._sanitise_dbargs(db_args) + with self._conn.cursor() as cursor: for query_line in query.split(";"): if query_line := query_line.strip(): @@ -188,7 +218,7 @@ def make_qualified_table_name_path( if table_name: # table name combines dataset name and table name table_name = self.capabilities.casefold_identifier( - f"{self.dataset_name}{self.credentials.dataset_table_separator}{table_name}" + f"{self.dataset_name}{self.config.dataset_table_separator}{table_name}" ) if escape: table_name = self.capabilities.escape_identifier(table_name) diff --git a/dlt/destinations/impl/clickhouse/typing.py b/dlt/destinations/impl/clickhouse/typing.py new file mode 100644 index 0000000000..658822149c --- /dev/null +++ b/dlt/destinations/impl/clickhouse/typing.py @@ -0,0 +1,32 @@ +from typing import Literal, Dict, get_args, Set + +from dlt.common.schema import TColumnHint + +TSecureConnection = Literal[0, 1] +TTableEngineType = Literal[ + "merge_tree", + "shared_merge_tree", + "replicated_merge_tree", +] + +HINT_TO_CLICKHOUSE_ATTR: Dict[TColumnHint, str] = { + "primary_key": "PRIMARY KEY", + "unique": "", # No unique constraints available in ClickHouse. + "foreign_key": "", # No foreign key constraints support in ClickHouse. +} + +TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR: Dict[TTableEngineType, str] = { + "merge_tree": "MergeTree", + "shared_merge_tree": "SharedMergeTree", + "replicated_merge_tree": "ReplicatedMergeTree", +} + +TDeployment = Literal["ClickHouseOSS", "ClickHouseCloud"] + +SUPPORTED_FILE_FORMATS = Literal["jsonl", "parquet"] +FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING: Dict[SUPPORTED_FILE_FORMATS, str] = { + "jsonl": "JSONEachRow", + "parquet": "Parquet", +} +TABLE_ENGINE_TYPES: Set[TTableEngineType] = set(get_args(TTableEngineType)) +TABLE_ENGINE_TYPE_HINT: Literal["x-table-engine-type"] = "x-table-engine-type" diff --git a/dlt/destinations/impl/clickhouse/utils.py b/dlt/destinations/impl/clickhouse/utils.py index 0e2fa3db00..02e4e93943 100644 --- a/dlt/destinations/impl/clickhouse/utils.py +++ b/dlt/destinations/impl/clickhouse/utils.py @@ -1,16 +1,12 @@ -from typing import Union, Literal, Dict +from typing import Union from urllib.parse import urlparse, ParseResult -SUPPORTED_FILE_FORMATS = Literal["jsonl", "parquet"] -FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING: Dict[SUPPORTED_FILE_FORMATS, str] = { - "jsonl": "JSONEachRow", - "parquet": "Parquet", -} - - def convert_storage_to_http_scheme( - url: Union[str, ParseResult], use_https: bool = False, endpoint: str = None, region: str = None + url: Union[str, ParseResult], + use_https: bool = False, + endpoint: str = None, + region: str = None, ) -> str: try: if isinstance(url, str): diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 56462714c1..409d3bc4be 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -42,6 +42,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.alter_add_multi_column = True caps.supports_multiple_statements = False caps.supports_clone_table = True + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] return caps @property diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index c44fd3cca1..976dfa4fb5 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -3,9 +3,8 @@ from typing import ClassVar, Optional, Type, Iterable, cast, List from dlt.common.destination.reference import LoadJob -from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import AnyFun -from dlt.pipeline.current import destination_state +from dlt.common.storages.load_package import destination_state from dlt.common.configuration import create_resolved_partial from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -15,6 +14,8 @@ DoNothingJob, JobClientBase, ) + +from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration from dlt.destinations.job_impl import ( DestinationJsonlLoadJob, diff --git a/dlt/destinations/impl/dremio/factory.py b/dlt/destinations/impl/dremio/factory.py index 29a4937c69..b8c7e1b746 100644 --- a/dlt/destinations/impl/dremio/factory.py +++ b/dlt/destinations/impl/dremio/factory.py @@ -40,6 +40,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_clone_table = False caps.supports_multiple_statements = False caps.timestamp_precision = 3 + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps @property diff --git a/dlt/destinations/impl/duckdb/factory.py b/dlt/destinations/impl/duckdb/factory.py index 388f914479..2c4df2cb58 100644 --- a/dlt/destinations/impl/duckdb/factory.py +++ b/dlt/destinations/impl/duckdb/factory.py @@ -35,6 +35,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = True caps.alter_add_multi_column = False caps.supports_truncate_command = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py index c68bc36ca9..c2792fc432 100644 --- a/dlt/destinations/impl/dummy/factory.py +++ b/dlt/destinations/impl/dummy/factory.py @@ -27,6 +27,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 65536 caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = False + caps.supported_merge_strategies = ["delete-insert", "upsert"] return caps diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 1e6eec5cce..31b61c6cb1 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -32,6 +32,10 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: preferred_loader_file_format="jsonl", loader_file_format_adapter=loader_file_format_adapter, supported_table_formats=["delta"], + # TODO: make `supported_merge_strategies` depend on configured + # `table_format` (perhaps with adapter similar to how we handle + # loader file format) + supported_merge_strategies=["upsert"], ) @property diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index bf443e061f..ef4702b17d 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -9,11 +9,18 @@ import dlt from dlt.common import logger, time, json, pendulum +from dlt.common.utils import assert_min_pkg_version from dlt.common.storages.fsspec_filesystem import glob_files from dlt.common.typing import DictStrAny from dlt.common.schema import Schema, TSchemaTables, TTableSchema +from dlt.common.schema.utils import get_first_column_name_with_prop, get_columns_names_with_prop from dlt.common.storages import FileStorage, fsspec_from_config -from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName, TPipelineStateDoc +from dlt.common.storages.load_package import ( + LoadJobInfo, + ParsedLoadJobFileName, + TPipelineStateDoc, + load_package as current_load_package, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( NewLoadJob, @@ -116,18 +123,73 @@ def __init__( def write(self) -> None: from dlt.common.libs.pyarrow import pyarrow as pa from dlt.common.libs.deltalake import ( + DeltaTable, write_delta_table, + ensure_delta_compatible_arrow_schema, _deltalake_storage_options, + try_get_deltatable, + ) + + assert_min_pkg_version( + pkg_name="pyarrow", + version="17.0.0", + msg="`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination.", ) + # create Arrow dataset from Parquet files file_paths = [job.file_path for job in self.table_jobs] + arrow_ds = pa.dataset.dataset(file_paths) + + # create Delta table object + dt_path = self.client.make_remote_uri(self.make_remote_path()) + storage_options = _deltalake_storage_options(self.client.config) + dt = try_get_deltatable(dt_path, storage_options=storage_options) + + # explicitly check if there is data + # (https://github.com/delta-io/delta-rs/issues/2686) + if arrow_ds.head(1).num_rows == 0: + if dt is None: + # create new empty Delta table with schema from Arrow table + DeltaTable.create( + table_uri=dt_path, + schema=ensure_delta_compatible_arrow_schema(arrow_ds.schema), + mode="overwrite", + ) + return - write_delta_table( - path=self.client.make_remote_uri(self.make_remote_path()), - data=pa.dataset.dataset(file_paths), - write_disposition=self.table["write_disposition"], - storage_options=_deltalake_storage_options(self.client.config), - ) + arrow_rbr = arrow_ds.scanner().to_reader() # RecordBatchReader + + if self.table["write_disposition"] == "merge" and dt is not None: + assert self.table["x-merge-strategy"] in self.client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] + + if self.table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item] + if "parent" in self.table: + unique_column = get_first_column_name_with_prop(self.table, "unique") + predicate = f"target.{unique_column} = source.{unique_column}" + else: + primary_keys = get_columns_names_with_prop(self.table, "primary_key") + predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys]) + + qry = ( + dt.merge( + source=arrow_rbr, + predicate=predicate, + source_alias="source", + target_alias="target", + ) + .when_matched_update_all() + .when_not_matched_insert_all() + ) + + qry.execute() + + else: + write_delta_table( + table_or_uri=dt_path if dt is None else dt, + data=arrow_rbr, + write_disposition=self.table["write_disposition"], + storage_options=storage_options, + ) def make_remote_path(self) -> str: # directory path, not file path @@ -424,11 +486,9 @@ def _store_current_state(self, load_id: str) -> None: # don't save the state this way when used as staging if self.config.as_staging: return - # get state doc from current pipeline - from dlt.pipeline.current import load_package - - pipeline_state_doc = load_package()["state"].get("pipeline_state") + # get state doc from current pipeline + pipeline_state_doc = current_load_package()["state"].get("pipeline_state") if not pipeline_state_doc: return @@ -555,7 +615,9 @@ def get_table_jobs( if table_format == "delta": delta_jobs = [ DeltaLoadFilesystemJob( - self, table, get_table_jobs(completed_table_chain_jobs, table["name"]) + self, + table=self.prepare_load_table(table["name"]), + table_jobs=get_table_jobs(completed_table_chain_jobs, table["name"]), ) for table in table_chain ] diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index bb33632b48..99d5ef43c6 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -1,7 +1,7 @@ from typing import Any from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource @@ -32,7 +32,7 @@ def lancedb_adapter( >>> lancedb_adapter(data, embed="description") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) column_hints: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/motherduck/factory.py b/dlt/destinations/impl/motherduck/factory.py index df7418b9db..a9bab96d08 100644 --- a/dlt/destinations/impl/motherduck/factory.py +++ b/dlt/destinations/impl/motherduck/factory.py @@ -36,6 +36,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = False caps.alter_add_multi_column = False caps.supports_truncate_command = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py index 6912510995..85c94c21b7 100644 --- a/dlt/destinations/impl/mssql/factory.py +++ b/dlt/destinations/impl/mssql/factory.py @@ -39,6 +39,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = True caps.max_rows_per_insert = 1000 caps.timestamp_precision = 7 + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] return caps diff --git a/dlt/destinations/impl/postgres/factory.py b/dlt/destinations/impl/postgres/factory.py index 0fe8c6d13e..e14aa61465 100644 --- a/dlt/destinations/impl/postgres/factory.py +++ b/dlt/destinations/impl/postgres/factory.py @@ -41,6 +41,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 1024 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = True + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] return caps diff --git a/dlt/destinations/impl/qdrant/factory.py b/dlt/destinations/impl/qdrant/factory.py index 2bface0938..f994948d91 100644 --- a/dlt/destinations/impl/qdrant/factory.py +++ b/dlt/destinations/impl/qdrant/factory.py @@ -7,7 +7,7 @@ from dlt.destinations.impl.qdrant.configuration import QdrantCredentials, QdrantClientConfiguration if t.TYPE_CHECKING: - from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient + from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient class qdrant(Destination[QdrantClientConfiguration, "QdrantClient"]): @@ -44,7 +44,7 @@ def adjust_capabilities( @property def client_class(self) -> t.Type["QdrantClient"]: - from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient + from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient return QdrantClient diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index 215d87a920..e39d3e3644 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -2,7 +2,7 @@ from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract import DltResource, resource as make_resource -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter VECTORIZE_HINT = "x-qdrant-embed" @@ -32,7 +32,7 @@ def qdrant_adapter( >>> qdrant_adapter(data, embed="description") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) column_hints: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/qdrant/qdrant_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py similarity index 99% rename from dlt/destinations/impl/qdrant/qdrant_client.py rename to dlt/destinations/impl/qdrant/qdrant_job_client.py index 080c277edd..28d7388701 100644 --- a/dlt/destinations/impl/qdrant/qdrant_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -317,8 +317,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") p_created_at = self.schema.naming.normalize_identifier("created_at") - limit = 100 - offset = None + limit = 10 + start_from = None while True: try: scroll_table_name = self._make_qualified_collection_name( @@ -337,14 +337,15 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: order_by=models.OrderBy( key=p_created_at, direction=models.Direction.DESC, + start_from=start_from, ), limit=limit, - offset=offset, ) if len(state_records) == 0: return None for state_record in state_records: state = state_record.payload + start_from = state[p_created_at] load_id = state[p_dlt_load_id] scroll_table_name = self._make_qualified_collection_name( self.schema.loads_table_name @@ -361,7 +362,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ), ) if load_records.count == 0: - return None + continue return StateInfo.from_normalized_mapping(state, self.schema.naming) except UnexpectedResponse as e: if e.status_code == 404: diff --git a/dlt/destinations/impl/redshift/factory.py b/dlt/destinations/impl/redshift/factory.py index 7e6638be1e..ef1ee6b754 100644 --- a/dlt/destinations/impl/redshift/factory.py +++ b/dlt/destinations/impl/redshift/factory.py @@ -40,6 +40,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = True caps.alter_add_multi_column = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 1211b78672..08fc132fc3 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -83,9 +83,6 @@ def parse_native_representation(self, native_value: Any) -> None: if param in self.query: setattr(self, param, self.query.get(param)) - # if not self.is_partial() and (self.password or self.private_key): - # self.resolve() - def on_resolved(self) -> None: if not self.password and not self.private_key and not self.authenticator: raise ConfigurationValueError( @@ -139,6 +136,9 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) csv_format: Optional[CsvFormatConfiguration] = None """Optional csv format configuration""" + query_tag: Optional[str] = None + """A tag with placeholders to tag sessions executing jobs""" + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/snowflake/factory.py b/dlt/destinations/impl/snowflake/factory.py index f531b8704e..c5fbd8600b 100644 --- a/dlt/destinations/impl/snowflake/factory.py +++ b/dlt/destinations/impl/snowflake/factory.py @@ -41,6 +41,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = True caps.alter_add_multi_column = True caps.supports_clone_table = True + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] return caps @property diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 532ff404ae..bf175ba911 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -15,11 +15,14 @@ AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults, ) +from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.storages.load_package import ParsedLoadJobFileName +from dlt.common.typing import TLoaderFileFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException @@ -90,38 +93,83 @@ def __init__( ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) + # resolve reference + is_local_file = not NewReferenceJob.is_reference_job(file_path) + file_url = file_path if is_local_file else NewReferenceJob.resolve_reference(file_path) + # take file name + file_name = FileStorage.get_file_name_from_file_path(file_url) + file_format = file_name.rsplit(".", 1)[-1] qualified_table_name = client.make_qualified_table_name(table_name) + # this means we have a local file + stage_file_path: str = "" + if is_local_file: + if not stage_name: + # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" + stage_name = client.make_qualified_table_name("%" + table_name) + stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' - # extract and prepare some vars - bucket_path = ( - NewReferenceJob.resolve_reference(file_path) - if NewReferenceJob.is_reference_job(file_path) - else "" + copy_sql = self.gen_copy_sql( + file_url, + qualified_table_name, + file_format, # type: ignore[arg-type] + client.capabilities.generates_case_sensitive_identifiers(), + stage_name, + stage_file_path, + staging_credentials, + config.csv_format, ) - file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + + with client.begin_transaction(): + # PUT and COPY in one tx if local file, otherwise only copy + if is_local_file: + client.execute_sql( + f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' + " AUTO_COMPRESS = FALSE" + ) + client.execute_sql(copy_sql) + if stage_file_path and not keep_staged_files: + client.execute_sql(f"REMOVE {stage_file_path}") + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() + + @classmethod + def gen_copy_sql( + cls, + file_url: str, + qualified_table_name: str, + loader_file_format: TLoaderFileFormat, + is_case_sensitive: bool, + stage_name: Optional[str] = None, + local_stage_file_path: Optional[str] = None, + staging_credentials: Optional[CredentialsConfiguration] = None, + csv_format: Optional[CsvFormatConfiguration] = None, + ) -> str: + parsed_file_url = urlparse(file_url) + # check if local filesystem (file scheme or just a local file in native form) + is_local = parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path( + file_url ) + # file_name = FileStorage.get_file_name_from_file_path(file_url) + from_clause = "" credentials_clause = "" files_clause = "" - stage_file_path = "" on_error_clause = "" - case_folding = ( - "CASE_SENSITIVE" - if client.capabilities.generates_case_sensitive_identifiers() - else "CASE_INSENSITIVE" - ) + case_folding = "CASE_SENSITIVE" if is_case_sensitive else "CASE_INSENSITIVE" column_match_clause = f"MATCH_BY_COLUMN_NAME='{case_folding}'" - if bucket_path: - bucket_url = urlparse(bucket_path) - bucket_scheme = bucket_url.scheme + if not is_local: + bucket_scheme = parsed_file_url.scheme # referencing an external s3/azure stage does not require explicit AWS credentials if bucket_scheme in ["s3", "az", "abfs"] and stage_name: from_clause = f"FROM '@{stage_name}'" - files_clause = f"FILES = ('{bucket_url.path.lstrip('/')}')" + files_clause = f"FILES = ('{parsed_file_url.path.lstrip('/')}')" # referencing an staged files via a bucket URL requires explicit AWS credentials elif ( bucket_scheme == "s3" @@ -129,7 +177,7 @@ def __init__( and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) ): credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" - from_clause = f"FROM '{bucket_path}'" + from_clause = f"FROM '{file_url}'" elif ( bucket_scheme in ["az", "abfs"] and staging_credentials @@ -139,48 +187,43 @@ def __init__( credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" # Converts an az:/// to azure://.blob.core.windows.net// # as required by snowflake - _path = "/" + bucket_url.netloc + bucket_url.path - bucket_path = urlunparse( - bucket_url._replace( + _path = "/" + parsed_file_url.netloc + parsed_file_url.path + file_url = urlunparse( + parsed_file_url._replace( scheme="azure", netloc=f"{staging_credentials.azure_storage_account_name}.blob.core.windows.net", path=_path, ) ) - from_clause = f"FROM '{bucket_path}'" + from_clause = f"FROM '{file_url}'" else: # ensure that gcs bucket path starts with gcs://, this is a requirement of snowflake - bucket_path = bucket_path.replace("gs://", "gcs://") + file_url = file_url.replace("gs://", "gcs://") if not stage_name: # when loading from bucket stage must be given raise LoadJobTerminalException( - file_path, - f"Cannot load from bucket path {bucket_path} without a stage name. See" + file_url, + f"Cannot load from bucket path {file_url} without a stage name. See" " https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for" " instructions on setting up the `stage_name`", ) from_clause = f"FROM @{stage_name}/" - files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" + files_clause = f"FILES = ('{urlparse(file_url).path.lstrip('/')}')" else: - # this means we have a local file - if not stage_name: - # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name("%" + table_name) - stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' - from_clause = f"FROM {stage_file_path}" + from_clause = f"FROM {local_stage_file_path}" # decide on source format, stage_file_path will either be a local file or a bucket path - if file_name.endswith("jsonl"): + if loader_file_format == "jsonl": source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" - elif file_name.endswith("parquet"): + elif loader_file_format == "parquet": source_format = ( "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" # TODO: USE_VECTORIZED_SCANNER inserts null strings into VARIANT JSON # " USE_VECTORIZED_SCANNER = TRUE)" ) - elif file_name.endswith("csv"): + elif loader_file_format == "csv": # empty strings are NULL, no data is NULL, missing columns (ERROR_ON_COLUMN_COUNT_MISMATCH) are NULL - csv_format = config.csv_format or CsvFormatConfiguration() + csv_format = csv_format or CsvFormatConfiguration() source_format = ( "(TYPE = 'CSV', BINARY_FORMAT = 'UTF-8', PARSE_HEADER =" f" {csv_format.include_header}, FIELD_OPTIONALLY_ENCLOSED_BY = '\"', NULL_IF =" @@ -193,31 +236,16 @@ def __init__( if csv_format.on_error_continue: on_error_clause = "ON_ERROR = CONTINUE" else: - raise ValueError(file_name) + raise ValueError(f"{loader_file_format} not supported for Snowflake COPY command.") - with client.begin_transaction(): - # PUT and COPY in one tx if local file, otherwise only copy - if not bucket_path: - client.execute_sql( - f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' - " AUTO_COMPRESS = FALSE" - ) - client.execute_sql(f"""COPY INTO {qualified_table_name} - {from_clause} - {files_clause} - {credentials_clause} - FILE_FORMAT = {source_format} - {column_match_clause} - {on_error_clause} - """) - if stage_file_path and not keep_staged_files: - client.execute_sql(f"REMOVE {stage_file_path}") - - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() + return f"""COPY INTO {qualified_table_name} + {from_clause} + {files_clause} + {credentials_clause} + FILE_FORMAT = {source_format} + {column_match_clause} + {on_error_clause} + """ class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): @@ -232,6 +260,7 @@ def __init__( config.normalize_staging_dataset_name(schema), config.credentials, capabilities, + config.query_tag, ) super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index fbc80b7b6c..8d11c23363 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List +from typing import Any, AnyStr, ClassVar, Dict, Iterator, Optional, Sequence, List import snowflake.connector as snowflake_lib @@ -12,6 +12,7 @@ from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, + TJobQueryTags, raise_database_error, raise_open_connection_error, ) @@ -37,10 +38,12 @@ def __init__( staging_dataset_name: str, credentials: SnowflakeCredentials, capabilities: DestinationCapabilitiesContext, + query_tag: Optional[str] = None, ) -> None: super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: snowflake_lib.SnowflakeConnection = None self.credentials = credentials + self.query_tag = query_tag def open_connection(self) -> snowflake_lib.SnowflakeConnection: conn_params = self.credentials.to_connector_params() @@ -120,6 +123,20 @@ def _reset_connection(self) -> None: self._conn.rollback() self._conn.autocommit(True) + def set_query_tags(self, tags: TJobQueryTags) -> None: + super().set_query_tags(tags) + if self.query_tag: + self._tag_session() + + def _tag_session(self) -> None: + """Wraps query with Snowflake query tag""" + if self._query_tags: + tag = self.query_tag.format(**self._query_tags) + tag_query = f"ALTER SESSION SET QUERY_TAG = '{tag}'" + else: + tag_query = "ALTER SESSION UNSET QUERY_TAG" + self.execute_sql(tag_query) + @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, snowflake_lib.errors.ProgrammingError): diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index 4820056e66..bb117e48d2 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -73,6 +73,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 caps.timestamp_precision = 7 + caps.supported_merge_strategies = ["delete-insert", "scd2"] + return caps @property diff --git a/dlt/destinations/impl/synapse/synapse_adapter.py b/dlt/destinations/impl/synapse/synapse_adapter.py index 8b262f3621..e12823c7bf 100644 --- a/dlt/destinations/impl/synapse/synapse_adapter.py +++ b/dlt/destinations/impl/synapse/synapse_adapter.py @@ -3,7 +3,7 @@ from dlt.extract import DltResource, resource as make_resource from dlt.extract.items import TTableHintTemplate from dlt.extract.hints import TResourceHints -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter TTableIndexType = Literal["heap", "clustered_columnstore_index"] """ @@ -37,7 +37,7 @@ def synapse_adapter(data: Any, table_index_type: TTableIndexType = None) -> DltR >>> synapse_adapter(data, table_index_type="clustered_columnstore_index") [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if table_index_type is not None: diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index a290ac65b4..9bd0b41783 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -2,7 +2,7 @@ from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract import DltResource, resource as make_resource -from dlt.destinations.utils import ensure_resource +from dlt.destinations.utils import get_resource_for_adapter TTokenizationTMethod = Literal["word", "lowercase", "whitespace", "field"] TOKENIZATION_METHODS: Set[TTokenizationTMethod] = set(get_args(TTokenizationTMethod)) @@ -54,7 +54,7 @@ def weaviate_adapter( >>> weaviate_adapter(data, vectorize="description", tokenization={"description": "word"}) [DltResource with hints applied] """ - resource = ensure_resource(data) + resource = get_resource_for_adapter(data) column_hints: TTableSchemaColumns = {} if vectorize: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index e00b7ebb05..dd0e783414 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -28,6 +28,7 @@ TTableFormat, ) from dlt.common.schema.utils import ( + get_inherited_table_hint, loads_table, normalize_table_identifiers, version_table, @@ -262,6 +263,7 @@ def create_table_chain_completed_followup_jobs( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" + self._set_query_tags_for_job(load_id, table) if SqlLoadJob.is_sql_job(file_path): # execute sql load job return SqlLoadJob(file_path, self.sql_client) @@ -676,6 +678,27 @@ def _verify_schema(self) -> None: logger.error(str(exception)) raise exceptions[0] + def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: + """Sets query tags in sql_client for a job in package `load_id`, starting for a particular `table`""" + from dlt.common.pipeline import current_pipeline + + pipeline = current_pipeline() + pipeline_name = pipeline.pipeline_name if pipeline else "" + self.sql_client.set_query_tags( + { + "source": self.schema.name, + "resource": ( + get_inherited_table_hint( + self.schema._schema_tables, table["name"], "resource", allow_none=True + ) + or "" + ), + "table": table["name"], + "load_id": load_id, + "pipeline_name": pipeline_name, + } + ) + class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): in_staging_mode: bool = False diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index a4e4b998af..9a8f7277b7 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -5,6 +5,7 @@ from dlt.common.json import json from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob +from dlt.common.storages.load_package import commit_load_package_state from dlt.common.schema import Schema, TTableSchema from dlt.common.storages import FileStorage from dlt.common.typing import TDataItems @@ -14,8 +15,6 @@ TDestinationCallable, ) -from dlt.pipeline.current import commit_load_package_state - class EmptyLoadJobWithoutFollowup(LoadJob): def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index fbe2b17fc2..27d1bc7ce5 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -7,6 +7,7 @@ Any, ClassVar, ContextManager, + Dict, Generic, Iterator, Optional, @@ -15,6 +16,7 @@ Type, AnyStr, List, + TypedDict, ) from dlt.common.typing import TFun @@ -28,6 +30,16 @@ from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction +class TJobQueryTags(TypedDict): + """Applied to sql client when a job using it starts. Using to tag queries""" + + source: str + resource: str + table: str + load_id: str + pipeline_name: str + + class SqlClientBase(ABC, Generic[TNativeConn]): dbapi: ClassVar[DBApi] = None @@ -53,6 +65,7 @@ def __init__( self.staging_dataset_name = staging_dataset_name self.database_name = database_name self.capabilities = capabilities + self._query_tags: TJobQueryTags = None @abstractmethod def open_connection(self) -> TNativeConn: @@ -162,8 +175,13 @@ def catalog_name(self, escape: bool = True) -> Optional[str]: # connection is scoped to a current database return None - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ".".join(self.make_qualified_table_name_path(None, escape=escape)) + def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = False) -> str: + if staging: + with self.with_staging_dataset(): + path = self.make_qualified_table_name_path(None, escape=escape) + else: + path = self.make_qualified_table_name_path(None, escape=escape) + return ".".join(path) def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: return ".".join(self.make_qualified_table_name_path(table_name, escape=escape)) @@ -188,6 +206,12 @@ def make_qualified_table_name_path( path.append(table_name) return path + def get_qualified_table_names(self, table_name: str, escape: bool = True) -> Tuple[str, str]: + """Returns qualified names for table and corresponding staging table as tuple.""" + with self.with_staging_dataset(): + staging_table_name = self.make_qualified_table_name(table_name, escape) + return self.make_qualified_table_name(table_name, escape), staging_table_name + def escape_column_name(self, column_name: str, escape: bool = True) -> str: column_name = self.capabilities.casefold_identifier(column_name) if escape: @@ -210,6 +234,10 @@ def with_alternative_dataset_name( def with_staging_dataset(self) -> ContextManager["SqlClientBase[TNativeConn]"]: return self.with_alternative_dataset_name(self.staging_dataset_name) + def set_query_tags(self, tags: TJobQueryTags) -> None: + """Sets current schema (source), resource, load_id and table name when a job starts""" + self._query_tags = tags + def _ensure_native_conn(self) -> None: if not self.native_connection: raise LoadClientNotConnected(type(self).__name__, self.dataset_name) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 1715389e17..e67be049ab 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional +from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional, Callable, Union import yaml from dlt.common.logger import pretty_format_exception @@ -6,6 +6,7 @@ from dlt.common.schema.typing import ( TTableSchema, TSortOrder, + TColumnProp, ) from dlt.common.schema.utils import ( get_columns_names_with_prop, @@ -16,12 +17,12 @@ DEFAULT_MERGE_STRATEGY, ) from dlt.common.storages.load_storage import ParsedLoadJobFileName +from dlt.common.storages.load_package import load_package as current_load_package from dlt.common.utils import uniq_id from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase -from dlt.pipeline.current import load_package as current_load_package class SqlJobParams(TypedDict, total=False): @@ -158,6 +159,8 @@ def generate_sql( # type: ignore[return] merge_strategy = table_chain[0].get("x-merge-strategy", DEFAULT_MERGE_STRATEGY) if merge_strategy == "delete-insert": return cls.gen_merge_sql(table_chain, sql_client) + elif merge_strategy == "upsert": + return cls.gen_upsert_sql(table_chain, sql_client) elif merge_strategy == "scd2": return cls.gen_scd2_sql(table_chain, sql_client) @@ -342,6 +345,107 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: """ return f"CREATE TEMP TABLE {temp_table_name} AS {select_sql};" + @classmethod + def gen_update_table_prefix(cls, table_name: str) -> str: + return f"UPDATE {table_name} SET" + + @classmethod + def requires_temp_table_for_delete(cls) -> bool: + """Whether a temporary table is required to delete records. + + Must be `True` for destinations that don't support correlated subqueries. + """ + return False + + @classmethod + def _escape_list(cls, list_: List[str], escape_id: Callable[[str], str]) -> List[str]: + return list(map(escape_id, list_)) + + @classmethod + def _get_hard_delete_col_and_cond( + cls, + table: TTableSchema, + escape_id: Callable[[str], str], + escape_lit: Callable[[Any], Any], + invert: bool = False, + ) -> Tuple[Optional[str], Optional[str]]: + """Returns tuple of hard delete column name and SQL condition statement. + + Returns tuple of `None` values if no column has `hard_delete` hint. + Condition statement can be used to filter deleted records. + Set `invert=True` to filter non-deleted records instead. + """ + + col = get_first_column_name_with_prop(table, "hard_delete") + if col is None: + return (None, None) + cond = f"{escape_id(col)} IS NOT NULL" + if invert: + cond = f"{escape_id(col)} IS NULL" + if table["columns"][col]["data_type"] == "bool": + if invert: + cond += f" OR {escape_id(col)} = {escape_lit(False)}" + else: + cond = f"{escape_id(col)} = {escape_lit(True)}" + return (col, cond) + + @classmethod + def _get_unique_col( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + table: TTableSchema, + ) -> str: + """Returns name of first column in `table` with `unique` property. + + Raises `MergeDispositionException` if no such column exists. + """ + return cls._get_prop_col_or_raise( + table, + "unique", + MergeDispositionException( + sql_client.fully_qualified_dataset_name(), + sql_client.fully_qualified_dataset_name(staging=True), + [t["name"] for t in table_chain], + f"No `unique` column (e.g. `_dlt_id`) in table `{table['name']}`.", + ), + ) + + @classmethod + def _get_root_key_col( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + table: TTableSchema, + ) -> str: + """Returns name of first column in `table` with `root_key` property. + + Raises `MergeDispositionException` if no such column exists. + """ + return cls._get_prop_col_or_raise( + table, + "root_key", + MergeDispositionException( + sql_client.fully_qualified_dataset_name(), + sql_client.fully_qualified_dataset_name(staging=True), + [t["name"] for t in table_chain], + f"No `root_key` column (e.g. `_dlt_root_id`) in table `{table['name']}`.", + ), + ) + + @classmethod + def _get_prop_col_or_raise( + cls, table: TTableSchema, prop: Union[TColumnProp, str], exception: Exception + ) -> str: + """Returns name of first column in `table` with `prop` property. + + Raises `exception` if no such column exists. + """ + col = get_first_column_name_with_prop(table, prop) + if col is None: + raise exception + return col + @classmethod def gen_merge_sql( cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] @@ -367,22 +471,18 @@ def gen_merge_sql( escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal # get top level table full identifiers - root_table_name = sql_client.make_qualified_table_name(root_table["name"]) - with sql_client.with_staging_dataset(): - staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + root_table_name, staging_root_table_name = sql_client.get_qualified_table_names( + root_table["name"] + ) # get merge and primary keys from top level - primary_keys = list( - map( - escape_column_id, - get_columns_names_with_prop(root_table, "primary_key"), - ) + primary_keys = cls._escape_list( + get_columns_names_with_prop(root_table, "primary_key"), + escape_column_id, ) - merge_keys = list( - map( - escape_column_id, - get_columns_names_with_prop(root_table, "merge_key"), - ) + merge_keys = cls._escape_list( + get_columns_names_with_prop(root_table, "merge_key"), + escape_column_id, ) # if we do not have any merge keys to select from, we will fall back to a staged append, i.E. @@ -407,18 +507,9 @@ def gen_merge_sql( root_table_name, staging_root_table_name, key_clauses, for_delete=False ) # use unique hint to create temp table with all identifiers to delete - unique_columns = get_columns_names_with_prop(root_table, "unique") - if not unique_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - "There is no unique column (ie _dlt_id) in top table" - f" {root_table['name']} so it is not possible to link child tables to it.", - ) - # get first unique column - unique_column = escape_column_id(unique_columns[0]) - # create temp table with unique identifier + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, root_table) + ) create_delete_temp_table_sql, delete_temp_table_name = ( cls.gen_delete_temp_table_sql( root_table["name"], unique_column, key_table_clauses, sql_client @@ -430,17 +521,9 @@ def gen_merge_sql( # but uses temporary views instead for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) - root_key_columns = get_columns_names_with_prop(table, "root_key") - if not root_key_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - "There is no root foreign key (ie _dlt_root_id) in child table" - f" {table['name']} so it is not possible to refer to top level table" - f" {root_table['name']} unique column {unique_column}", - ) - root_key_column = escape_column_id(root_key_columns[0]) + root_key_column = escape_column_id( + cls._get_root_key_col(table_chain, sql_client, table) + ) sql.append( cls.gen_delete_from_sql( table_name, root_key_column, delete_temp_table_name, unique_column @@ -454,15 +537,13 @@ def gen_merge_sql( ) ) - # get name of column with hard_delete hint, if specified - not_deleted_cond: str = None - hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") - if hard_delete_col is not None: - # any value indicates a delete for non-boolean columns - not_deleted_cond = f"{escape_column_id(hard_delete_col)} IS NULL" - if root_table["columns"][hard_delete_col]["data_type"] == "bool": - # only True values indicate a delete for boolean columns - not_deleted_cond += f" OR {escape_column_id(hard_delete_col)} = {escape_lit(False)}" + # get hard delete information + hard_delete_col, not_deleted_cond = cls._get_hard_delete_col_and_cond( + root_table, + escape_column_id, + escape_lit, + invert=True, + ) # get dedup sort information dedup_sort = get_dedup_sort_tuple(root_table) @@ -470,7 +551,8 @@ def gen_merge_sql( insert_temp_table_name: str = None if len(table_chain) > 1: if len(primary_keys) > 0 or hard_delete_col is not None: - condition_columns = [hard_delete_col] if not_deleted_cond is not None else None + # condition_columns = [hard_delete_col] if not_deleted_cond is not None else None + condition_columns = None if hard_delete_col is None else [hard_delete_col] ( create_insert_temp_table_sql, insert_temp_table_name, @@ -488,9 +570,7 @@ def gen_merge_sql( # insert from staging to dataset for table in table_chain: - table_name = sql_client.make_qualified_table_name(table["name"]) - with sql_client.with_staging_dataset(): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" if (len(primary_keys) > 0 and len(table_chain) > 1) or ( @@ -513,6 +593,97 @@ def gen_merge_sql( sql.append(f"INSERT INTO {table_name}({col_str}) {select_sql};") return sql + @classmethod + def gen_upsert_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: + sql: List[str] = [] + root_table = table_chain[0] + root_table_name, staging_root_table_name = sql_client.get_qualified_table_names( + root_table["name"] + ) + escape_column_id = sql_client.escape_column_name + escape_lit = sql_client.capabilities.escape_literal + if escape_lit is None: + escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal + + # process table hints + primary_keys = cls._escape_list( + get_columns_names_with_prop(root_table, "primary_key"), + escape_column_id, + ) + hard_delete_col, deleted_cond = cls._get_hard_delete_col_and_cond( + root_table, + escape_column_id, + escape_lit, + ) + + # generate merge statement for root table + on_str = " AND ".join([f"d.{c} = s.{c}" for c in primary_keys]) + root_table_column_names = list(map(escape_column_id, root_table["columns"])) + update_str = ", ".join([c + " = " + "s." + c for c in root_table_column_names]) + col_str = ", ".join(["{alias}" + c for c in root_table_column_names]) + delete_str = ( + "" if hard_delete_col is None else f"WHEN MATCHED AND s.{deleted_cond} THEN DELETE" + ) + + sql.append(f""" + MERGE INTO {root_table_name} d USING {staging_root_table_name} s + ON {on_str} + {delete_str} + WHEN MATCHED + THEN UPDATE SET {update_str} + WHEN NOT MATCHED + THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")}); + """) + + # generate statements for child tables if they exist + child_tables = table_chain[1:] + if child_tables: + root_unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, root_table) + ) + for table in child_tables: + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, table) + ) + root_key_column = escape_column_id( + cls._get_root_key_col(table_chain, sql_client, table) + ) + table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) + + # delete records for elements no longer in the list + sql.append(f""" + DELETE FROM {table_name} + WHERE {root_key_column} IN (SELECT {root_unique_column} FROM {staging_root_table_name}) + AND {unique_column} NOT IN (SELECT {unique_column} FROM {staging_table_name}); + """) + + # insert records for new elements in the list + table_column_names = list(map(escape_column_id, table["columns"])) + update_str = ", ".join([c + " = " + "s." + c for c in table_column_names]) + col_str = ", ".join(["{alias}" + c for c in table_column_names]) + sql.append(f""" + MERGE INTO {table_name} d USING {staging_table_name} s + ON d.{unique_column} = s.{unique_column} + WHEN MATCHED + THEN UPDATE SET {update_str} + WHEN NOT MATCHED + THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")}); + """) + + # delete hard-deleted records + if hard_delete_col is not None: + sql.append(f""" + DELETE FROM {table_name} + WHERE {root_key_column} IN ( + SELECT {root_unique_column} + FROM {staging_root_table_name} + WHERE {deleted_cond} + ); + """) + return sql + @classmethod def gen_scd2_sql( cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] @@ -526,9 +697,9 @@ def gen_scd2_sql( """ sql: List[str] = [] root_table = table_chain[0] - root_table_name = sql_client.make_qualified_table_name(root_table["name"]) - with sql_client.with_staging_dataset(): - staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + root_table_name, staging_root_table_name = sql_client.get_qualified_table_names( + root_table["name"] + ) # get column names caps = sql_client.capabilities @@ -580,27 +751,15 @@ def gen_scd2_sql( # insert list elements for new active records in child tables child_tables = table_chain[1:] if child_tables: - unique_column: str = None - # use unique hint to create temp table with all identifiers to delete - unique_columns = get_columns_names_with_prop(root_table, "unique") - if not unique_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so" - " it is not possible to link child tables to it.", - ) - # get first unique column - unique_column = escape_column_id(unique_columns[0]) + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, root_table) + ) # TODO: - based on deterministic child hashes (OK) # - if row hash changes all is right # - if it does not we only capture new records, while we should replace existing with those in stage # - this write disposition is way more similar to regular merge (how root tables are handled is different, other tables handled same) for table in child_tables: - table_name = sql_client.make_qualified_table_name(table["name"]) - with sql_client.with_staging_dataset(): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) sql.append(f""" INSERT INTO {table_name} SELECT * @@ -609,15 +768,3 @@ def gen_scd2_sql( """) return sql - - @classmethod - def gen_update_table_prefix(cls, table_name: str) -> str: - return f"UPDATE {table_name} SET" - - @classmethod - def requires_temp_table_for_delete(cls) -> bool: - """Whether a temporary table is required to delete records. - - Must be `True` for destinations that don't support correlated subqueries. - """ - return False diff --git a/dlt/destinations/utils.py b/dlt/destinations/utils.py index d24ad7c5a7..fcc2c4fd16 100644 --- a/dlt/destinations/utils.py +++ b/dlt/destinations/utils.py @@ -1,4 +1,6 @@ import re +import inspect + from typing import Any, List, Optional, Tuple from dlt.common import logger @@ -14,16 +16,35 @@ from typing import Any, cast, Tuple, Dict, Type from dlt.destinations.exceptions import DatabaseTransientException -from dlt.extract import DltResource, resource as make_resource +from dlt.extract import DltResource, resource as make_resource, DltSource RE_DATA_TYPE = re.compile(r"([A-Z]+)\((\d+)(?:,\s?(\d+))?\)") -def ensure_resource(data: Any) -> DltResource: - """Wraps `data` in a DltResource if it's not a DltResource already.""" +def get_resource_for_adapter(data: Any) -> DltResource: + """ + Helper function for adapters. Wraps `data` in a DltResource if it's not a DltResource already. + Alternatively if `data` is a DltSource, throws an error if there are multiple resource in the source + or returns the single resource if available. + """ if isinstance(data, DltResource): return data - resource_name = None if hasattr(data, "__name__") else "content" + # prevent accidentally wrapping sources with adapters + if isinstance(data, DltSource): + if len(data.selected_resources.keys()) == 1: + return list(data.selected_resources.values())[0] + else: + raise ValueError( + "You are trying to use an adapter on a DltSource with multiple resources. You can" + " only use adapters on pure data, direclty on a DltResouce or a DltSource" + " containing a single DltResource." + ) + + resource_name = None + if not hasattr(data, "__name__"): + logger.info("Setting default resource name to `content` for adapted resource.") + resource_name = "content" + return cast(DltResource, make_resource(data, name=resource_name)) @@ -77,17 +98,32 @@ def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[ f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""", ) ) - if ( - table.get("x-merge-strategy") == "delete-insert" - and not has_column_with_prop(table, "primary_key") - and not has_column_with_prop(table, "merge_key") - ): - log( - f"Table {table_name} has `write_disposition` set to `merge`" - " and `merge_strategy` set to `delete-insert`, but no primary or" - " merge keys defined." - " dlt will fall back to `append` for this table." - ) + if table.get("x-merge-strategy") == "delete-insert": + if not has_column_with_prop(table, "primary_key") and not has_column_with_prop( + table, "merge_key" + ): + log( + f"Table {table_name} has `write_disposition` set to `merge`" + " and `merge_strategy` set to `delete-insert`, but no primary or" + " merge keys defined." + " dlt will fall back to `append` for this table." + ) + elif table.get("x-merge-strategy") == "upsert": + if not has_column_with_prop(table, "primary_key"): + exception_log.append( + SchemaCorruptedException( + schema.name, + f"No primary key defined for table `{table['name']}`." + " `primary_key` needs to be set when using the `upsert`" + " merge strategy.", + ) + ) + if has_column_with_prop(table, "merge_key"): + log( + f"Found `merge_key` for table `{table['name']}` with" + " `upsert` merge strategy. Merge key is not supported" + " for this strategy and will be ignored." + ) if has_column_with_prop(table, "hard_delete"): if len(get_columns_names_with_prop(table, "hard_delete")) > 1: exception_log.append( diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 7a24b7f225..485a01eb99 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -374,7 +374,7 @@ def extract( ) -> str: # generate load package to be able to commit all the sources together later load_id = self.extract_storage.create_load_package( - source.discover_schema(), reuse_exiting_package=True + source.schema, reuse_exiting_package=True ) with Container().injectable_context( SourceSchemaInjectableContext(source.schema) diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index bc10177223..dce375afb0 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -15,6 +15,7 @@ TTableFormat, TSchemaContract, DEFAULT_VALIDITY_COLUMN_NAMES, + MERGE_STRATEGIES, ) from dlt.common.schema.utils import ( DEFAULT_WRITE_DISPOSITION, @@ -26,6 +27,7 @@ ) from dlt.common.typing import TDataItem from dlt.common.utils import clone_dict_nested +from dlt.common.normalizers.json.relational import DataItemNormalizer from dlt.common.validation import validate_dict_ignoring_xkeys from dlt.extract.exceptions import ( DataItemRequiredForDynamicTableHints, @@ -342,6 +344,7 @@ def _set_hints( self, hints_template: TResourceHints, create_table_variant: bool = False ) -> None: DltResourceHints.validate_dynamic_hints(hints_template) + DltResourceHints.validate_write_disposition_hint(hints_template.get("write_disposition")) if create_table_variant: table_name: str = hints_template["name"] # type: ignore[assignment] # incremental cannot be specified in variant @@ -436,13 +439,11 @@ def _merge_write_disposition_dict(dict_: Dict[str, Any]) -> None: @staticmethod def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: - """Merges merge disposition dict into x-hints on in place.""" + """Merges merge disposition dict into x-hints in place.""" mddict: TMergeDispositionDict = deepcopy(dict_["write_disposition"]) if mddict is not None: - dict_["x-merge-strategy"] = ( - mddict["strategy"] if "strategy" in mddict else DEFAULT_MERGE_STRATEGY - ) + dict_["x-merge-strategy"] = mddict.get("strategy", DEFAULT_MERGE_STRATEGY) # add columns for `scd2` merge strategy if dict_.get("x-merge-strategy") == "scd2": if mddict.get("validity_column_names") is None: @@ -464,7 +465,7 @@ def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: "x-valid-to": True, "x-active-record-timestamp": mddict.get("active_record_timestamp"), } - hash_ = mddict.get("row_version_column_name", "_dlt_id") + hash_ = mddict.get("row_version_column_name", DataItemNormalizer.C_DLT_ID) dict_["columns"][hash_] = { "name": hash_, "nullable": False, @@ -496,3 +497,13 @@ def validate_dynamic_hints(template: TResourceHints) -> None: raise InconsistentTableTemplate( f"Table name {table_name} must be a function if any other table hint is a function" ) + + @staticmethod + def validate_write_disposition_hint(wd: TTableHintTemplate[TWriteDispositionConfig]) -> None: + if isinstance(wd, dict) and wd["disposition"] == "merge": + wd = cast(TMergeDispositionDict, wd) + if "strategy" in wd and wd["strategy"] not in MERGE_STRATEGIES: + raise ValueError( + f'`{wd["strategy"]}` is not a valid merge strategy. ' + f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""" + ) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 11f989e0b2..c1117370b5 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -107,6 +107,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa # incremental acting as empty EMPTY: ClassVar["Incremental[Any]"] = None + placement_affinity: ClassVar[float] = 1 # stick to end def __init__( self, @@ -496,6 +497,8 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: class IncrementalResourceWrapper(ItemTransform[TDataItem]): + placement_affinity: ClassVar[float] = 1 # stick to end + _incremental: Optional[Incremental[Any]] = None """Keeps the injectable incremental""" _from_hints: bool = False diff --git a/dlt/extract/items.py b/dlt/extract/items.py index 4cf8d2191f..d721e8094e 100644 --- a/dlt/extract/items.py +++ b/dlt/extract/items.py @@ -3,6 +3,7 @@ from typing import ( Any, Callable, + ClassVar, Generic, Iterator, Iterable, @@ -135,6 +136,9 @@ class ItemTransform(ABC, Generic[TAny]): _f_meta: ItemTransformFunctionWithMeta[TAny] = None _f: ItemTransformFunctionNoMeta[TAny] = None + placement_affinity: ClassVar[float] = 0 + """Tell how strongly an item sticks to start (-1) or end (+1) of pipe.""" + def __init__(self, transform_f: ItemTransformFunc[TAny]) -> None: # inspect the signature sig = inspect.signature(transform_f) @@ -223,6 +227,8 @@ class ValidateItem(ItemTransform[TDataItem]): See `PydanticValidator` for possible implementation. """ + placement_affinity: ClassVar[float] = 0.9 # stick to end but less than incremental + table_name: str def bind(self, pipe: SupportsPipe) -> ItemTransform[TDataItem]: diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 6517273db5..02b52c4623 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -1,7 +1,18 @@ import inspect import makefun from copy import copy -from typing import Any, AsyncIterator, Optional, Union, Callable, Iterable, Iterator, List, Tuple +from typing import ( + Any, + AsyncIterator, + ClassVar, + Optional, + Union, + Callable, + Iterable, + Iterator, + List, + Tuple, +) from dlt.common.typing import AnyFun, AnyType, TDataItems from dlt.common.utils import get_callable_name @@ -31,7 +42,9 @@ ) -class ForkPipe: +class ForkPipe(ItemTransform[ResolvablePipeItem]): + placement_affinity: ClassVar[float] = 2 + def __init__(self, pipe: "Pipe", step: int = -1, copy_on_fork: bool = False) -> None: """A transformer that forks the `pipe` and sends the data items to forks added via `add_pipe` method.""" self._pipes: List[Tuple["Pipe", int]] = [] @@ -46,7 +59,7 @@ def add_pipe(self, pipe: "Pipe", step: int = -1) -> None: def has_pipe(self, pipe: "Pipe") -> bool: return pipe in [p[0] for p in self._pipes] - def __call__(self, item: TDataItems, meta: Any) -> Iterator[ResolvablePipeItem]: + def __call__(self, item: TDataItems, meta: Any = None) -> Iterator[ResolvablePipeItem]: for i, (pipe, step) in enumerate(self._pipes): if i == 0 or not self.copy_on_fork: _it = item @@ -65,8 +78,8 @@ def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = No self.parent = parent # add the steps, this will check and mod transformations if steps: - for step in steps: - self.append_step(step) + for index, step in enumerate(steps): + self.insert_step(step, index) @classmethod def from_data( @@ -123,7 +136,8 @@ def fork(self, child_pipe: "Pipe", child_step: int = -1, copy_on_fork: bool = Fa fork_step = self.tail if not isinstance(fork_step, ForkPipe): fork_step = ForkPipe(child_pipe, child_step, copy_on_fork) - self.append_step(fork_step) + # always add this at the end + self.insert_step(fork_step, len(self)) else: if not fork_step.has_pipe(child_pipe): fork_step.add_pipe(child_pipe, child_step) @@ -131,19 +145,34 @@ def fork(self, child_pipe: "Pipe", child_step: int = -1, copy_on_fork: bool = Fa def append_step(self, step: TPipeStep) -> "Pipe": """Appends pipeline step. On first added step performs additional verification if step is a valid data generator""" - step_no = len(self._steps) - if step_no == 0 and not self.has_parent: + steps_count = len(self._steps) + + if steps_count == 0 and not self.has_parent: self._verify_head_step(step) else: - step = self._wrap_transform_step_meta(step_no, step) - - self._steps.append(step) + step = self._wrap_transform_step_meta(steps_count, step) + + # find the insert position using particular + if steps_count > 0: + affinity = step.placement_affinity if isinstance(step, ItemTransform) else 0 + for index in reversed(range(0, steps_count)): + step_at_idx = self._steps[index] + affinity_at_idx = ( + step_at_idx.placement_affinity if isinstance(step_at_idx, ItemTransform) else 0 + ) + if affinity_at_idx <= affinity: + self._insert_at_pos(step, index + 1) + return self + # insert at the start due to strong affinity + self._insert_at_pos(step, 0) + else: + self._steps.append(step) return self def insert_step(self, step: TPipeStep, index: int) -> "Pipe": """Inserts step at a given index in the pipeline. Allows prepending only for transformers""" - step_no = len(self._steps) - if step_no == 0: + steps_count = len(self._steps) + if steps_count == 0: return self.append_step(step) if index == 0: if not self.has_parent: @@ -153,11 +182,7 @@ def insert_step(self, step: TPipeStep, index: int) -> "Pipe": " transformer", ) step = self._wrap_transform_step_meta(index, step) - # actually insert in the list - self._steps.insert(index, step) - # increase the _gen_idx if added before generator - if index <= self._gen_idx: - self._gen_idx += 1 + self._insert_at_pos(step, index) return self def remove_step(self, index: int) -> None: @@ -327,15 +352,6 @@ def _wrap_transform_step_meta(self, step_no: int, step: TPipeStep) -> TPipeStep: # check the signature sig = inspect.signature(step) meta_arg = check_compat_transformer(self.name, step, sig) - # sig_arg_count = len(sig.parameters) - # callable_name = get_callable_name(step) - # if sig_arg_count == 0: - # raise InvalidStepFunctionArguments(self.name, callable_name, sig, "Function takes no arguments") - # # see if meta is present in kwargs - # meta_arg = next((p for p in sig.parameters.values() if p.name == "meta"), None) - # if meta_arg is not None: - # if meta_arg.kind not in (meta_arg.KEYWORD_ONLY, meta_arg.POSITIONAL_OR_KEYWORD): - # raise InvalidStepFunctionArguments(self.name, callable_name, sig, "'meta' cannot be pos only argument '") if meta_arg is None: # add meta parameter when not present orig_step = step @@ -408,6 +424,17 @@ def _clone(self, new_name: str = None, with_parent: bool = False) -> "Pipe": p._steps = self._steps.copy() return p + def _insert_at_pos(self, step: Any, index: int) -> None: + # shift right if no parent + if index == 0 and not self.has_parent: + # put after gen + index += 1 + # actually insert in the list + self._steps.insert(index, step) + # increase the _gen_idx if added before generator + if index <= self._gen_idx: + self._gen_idx += 1 + def __repr__(self) -> str: if self.has_parent: bound_str = " data bound to " + repr(self.parent) diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 93eb9d1189..55c0bd728f 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -338,10 +338,14 @@ def add_filter( return self def add_limit(self: TDltResourceImpl, max_items: int) -> TDltResourceImpl: # noqa: A003 - """Adds a limit `max_items` to the resource pipe + """Adds a limit `max_items` to the resource pipe. - This mutates the encapsulated generator to stop after `max_items` items are yielded. This is useful for testing and debugging. It is - a no-op for transformers. Those should be limited by their input data. + This mutates the encapsulated generator to stop after `max_items` items are yielded. This is useful for testing and debugging. + + Notes: + 1. Transformers won't be limited. They should process all the data they receive fully to avoid inconsistencies in generated datasets. + 2. Each yielded item may contain several records. `add_limit` only limits the "number of yields", not the total number of records. + 3. Async resources with a limit added may occasionally produce one item more than the limit on some runs. This behavior is not deterministic. Args: max_items (int): The maximum number of items to yield diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 7732c4f056..88a98e14f3 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -277,7 +277,7 @@ def schema_contract(self, settings: TSchemaContract) -> None: @property def exhausted(self) -> bool: - """check all selected pipes wether one of them has started. if so, the source is exhausted.""" + """Check all selected pipes whether one of them has started. if so, the source is exhausted.""" for resource in self._resources.extracted.values(): item = resource._pipe.gen if inspect.isgenerator(item): @@ -345,6 +345,10 @@ def add_limit(self, max_items: int) -> "DltSource": # noqa: A003 This is useful for testing, debugging and generating sample datasets for experimentation. You can easily get your test dataset in a few minutes, when otherwise you'd need to wait hours for the full loading to complete. + Notes: + 1. Transformers resources won't be limited. They should process all the data they receive fully to avoid inconsistencies in generated datasets. + 2. Each yielded item may contain several records. `add_limit` only limits the "number of yields", not the total number of records. + Args: max_items (int): The maximum number of items to yield Returns: diff --git a/dlt/extract/validation.py b/dlt/extract/validation.py index 504eee1bfc..4cd321b88c 100644 --- a/dlt/extract/validation.py +++ b/dlt/extract/validation.py @@ -30,20 +30,26 @@ def __init__( self.model = apply_schema_contract_to_model(model, column_mode, data_mode) self.list_model = create_list_model(self.model, data_mode) - def __call__( - self, item: TDataItems, meta: Any = None - ) -> Union[_TPydanticModel, List[_TPydanticModel]]: + def __call__(self, item: TDataItems, meta: Any = None) -> TDataItems: """Validate a data item against the pydantic model""" if item is None: return None - from dlt.common.libs.pydantic import validate_item, validate_items + from dlt.common.libs.pydantic import validate_and_filter_item, validate_and_filter_items if isinstance(item, list): - return validate_items( - self.table_name, self.list_model, item, self.column_mode, self.data_mode - ) - return validate_item(self.table_name, self.model, item, self.column_mode, self.data_mode) + return [ + model.dict(by_alias=True) + for model in validate_and_filter_items( + self.table_name, self.list_model, item, self.column_mode, self.data_mode + ) + ] + item = validate_and_filter_item( + self.table_name, self.model, item, self.column_mode, self.data_mode + ) + if item is not None: + item = item.dict(by_alias=True) + return item def __str__(self, *args: Any, **kwargs: Any) -> str: return f"PydanticValidator(model={self.model.__qualname__})" diff --git a/dlt/helpers/streamlit_app/utils.py b/dlt/helpers/streamlit_app/utils.py index e3f2069c3c..00ebe8d137 100644 --- a/dlt/helpers/streamlit_app/utils.py +++ b/dlt/helpers/streamlit_app/utils.py @@ -52,7 +52,7 @@ def do_query( # type: ignore[return] except SqlClientNotAvailable: st.error("🚨 Cannot load data - SqlClient not available") - return do_query # type: ignore + return do_query # type: ignore[unused-ignore, no-any-return] def query_data( diff --git a/dlt/load/load.py b/dlt/load/load.py index a8dfb7002e..2290d40a1e 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,19 +1,15 @@ import contextlib from functools import reduce -import datetime # noqa: 251 from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Sequence from concurrent.futures import Executor import os -from copy import deepcopy from dlt.common import logger from dlt.common.runtime.signals import sleep from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo from dlt.common.schema.utils import get_top_level_table -from dlt.common.schema.typing import TTableSchema from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.storages.load_package import ( LoadPackageStateInjectableContext, diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 235ba3485a..723e0ded83 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -1,11 +1,13 @@ from typing import Any, Optional +import dlt from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration from dlt.common.typing import AnyFun, TSecretValue from dlt.common.utils import digest256 from dlt.common.destination import TLoaderFileFormat from dlt.common.pipeline import TRefreshMode +from dlt.common.configuration.exceptions import ConfigurationValueError @configspec @@ -18,6 +20,8 @@ class PipelineConfiguration(BaseConfiguration): staging_name: Optional[str] = None loader_file_format: Optional[TLoaderFileFormat] = None dataset_name: Optional[str] = None + dataset_name_layout: Optional[str] = None + """Layout for dataset_name, where %s is replaced with dataset_name. For example: 'prefix_%s'""" pipeline_salt: Optional[TSecretValue] = None restore_from_destination: bool = True """Enables the `run` method of the `Pipeline` object to restore the pipeline state and schemas from the destination""" @@ -41,6 +45,11 @@ def on_resolved(self) -> None: self.runtime.pipeline_name = self.pipeline_name if not self.pipeline_salt: self.pipeline_salt = TSecretValue(digest256(self.pipeline_name)) + if self.dataset_name_layout and "%s" not in self.dataset_name_layout: + raise ConfigurationValueError( + "The dataset_name_layout must contain a '%s' placeholder for dataset_name. For" + " example: 'prefix_%s'" + ) def ensure_correct_pipeline_kwargs(f: AnyFun, **kwargs: Any) -> None: diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index 25fd398623..2ae74e2532 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -1,7 +1,7 @@ """Easy access to active pipelines, state, sources and schemas""" from dlt.common.pipeline import source_state as _state, resource_state, get_current_pipe_name -from dlt.pipeline import pipeline as _pipeline +from dlt.pipeline.pipeline import Pipeline from dlt.extract.decorators import get_source_schema from dlt.common.storages.load_package import ( load_package, @@ -11,10 +11,15 @@ ) from dlt.extract.decorators import get_source_schema, get_source -pipeline = _pipeline -"""Alias for dlt.pipeline""" + +def pipeline() -> Pipeline: + """Currently active pipeline ie. the most recently created or run""" + from dlt import _pipeline + + return _pipeline() + + state = source_state = _state -"""Alias for dlt.state""" source_schema = get_source_schema source = get_source pipe_name = get_current_pipe_name diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index ac5d3b90e4..4f29ca4c87 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -12,7 +12,6 @@ Optional, Sequence, Tuple, - Type, cast, get_type_hints, ContextManager, @@ -29,12 +28,14 @@ from dlt.common.configuration.exceptions import ( ConfigFieldMissingException, ContextDefaultCannotBeCreated, + ConfigurationValueError, ) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.destination.exceptions import ( DestinationIncompatibleLoaderFileFormatException, DestinationNoStagingMode, DestinationUndefinedEntity, + DestinationCapabilitiesException, ) from dlt.common.exceptions import MissingDependencyException from dlt.common.runtime import signals, initialize_runtime @@ -44,7 +45,6 @@ TWriteDispositionConfig, TAnySchemaColumns, TSchemaContract, - TTableSchema, ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound @@ -458,6 +458,37 @@ def extract( step_info, ) from exc + def _verify_destination_capabilities( + self, + caps: DestinationCapabilitiesContext, + loader_file_format: TLoaderFileFormat, + ) -> None: + # verify loader file format + if loader_file_format and loader_file_format not in caps.supported_loader_file_formats: + raise DestinationIncompatibleLoaderFileFormatException( + self.destination.destination_name, + (self.staging.destination_name if self.staging else None), + loader_file_format, + set(caps.supported_loader_file_formats), + ) + + # verify merge strategy + for table in self.default_schema.data_tables(include_incomplete=True): + if ( + "x-merge-strategy" in table + and caps.supported_merge_strategies + and table["x-merge-strategy"] not in caps.supported_merge_strategies # type: ignore[typeddict-item] + ): + if self.destination.destination_name == "filesystem" and table["x-merge-strategy"] == "delete-insert": # type: ignore[typeddict-item] + # `filesystem` does not support `delete-insert`, but no + # error should be raised because it falls back to `append` + pass + else: + raise DestinationCapabilitiesException( + f"`{table.get('x-merge-strategy')}` merge strategy not supported" + f" for `{self.destination.destination_name}` destination." + ) + @with_runtime_trace() @with_schemas_sync @with_config_section((known_sections.NORMALIZE,)) @@ -487,13 +518,8 @@ def normalize( ) # run with destination context with self._maybe_destination_capabilities() as caps: - if loader_file_format and loader_file_format not in caps.supported_loader_file_formats: - raise DestinationIncompatibleLoaderFileFormatException( - self.destination.destination_name, - (self.staging.destination_name if self.staging else None), - loader_file_format, - set(caps.supported_loader_file_formats), - ) + self._verify_destination_capabilities(caps, loader_file_format) + # shares schema storage with the pipeline so we do not need to install normalize_step: Normalize = Normalize( collector=self.collector, @@ -1400,6 +1426,10 @@ def _set_dataset_name(self, new_dataset_name: str) -> None: new_dataset_name += self._pipeline_instance_id self.dataset_name = new_dataset_name + # normalizes the dataset name using the dataset_name_layout + if self.config.dataset_name_layout: + self.dataset_name = self.config.dataset_name_layout % self.dataset_name + def _set_default_schema_name(self, schema: Schema) -> None: assert self.default_schema_name is None self.default_schema_name = schema.name diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index fc15654949..29770966a6 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -3,7 +3,6 @@ import os import pickle import datetime # noqa: 251 -import dataclasses from typing import Any, List, NamedTuple, Optional, Protocol, Sequence import humanize @@ -26,6 +25,7 @@ SupportsPipeline, ) from dlt.common.source import get_current_pipe_name +from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import DictStrAny, StrAny, SupportsHumanize from dlt.common.utils import uniq_id, get_exception_trace_chain @@ -103,7 +103,8 @@ def asdict(self) -> DictStrAny: d["step_info"] = {} # take only the base keys for prop in self.step_info._astuple()._asdict(): - d["step_info"][prop] = step_info_dict.pop(prop) + if prop in step_info_dict: + d["step_info"][prop] = step_info_dict.pop(prop) # replace the attributes in exception traces with json dumps if self.exception_traces: # do not modify original traces @@ -232,7 +233,7 @@ def start_trace(step: TPipelineStep, pipeline: SupportsPipeline) -> PipelineTrac resolved_config_values=[], ) for module in TRACKING_MODULES: - with suppress_and_warn(): + with suppress_and_warn(f"on_start_trace on module {module} failed"): module.on_start_trace(trace, step, pipeline) return trace @@ -242,7 +243,7 @@ def start_trace_step( ) -> PipelineStepTrace: trace_step = PipelineStepTrace(uniq_id(), step, pendulum.now()) for module in TRACKING_MODULES: - with suppress_and_warn(): + with suppress_and_warn(f"start_trace_step on module {module} failed"): module.on_start_trace_step(trace, step, pipeline) return trace_step @@ -292,7 +293,7 @@ def end_trace_step( trace.resolved_config_values[:] = list(resolved_values) trace.steps.append(step) for module in TRACKING_MODULES: - with suppress_and_warn(): + with suppress_and_warn(f"end_trace_step on module {module} failed"): module.on_end_trace_step(trace, step, pipeline, step_info, send_state) return trace @@ -304,7 +305,7 @@ def end_trace( if trace_path: save_trace(trace_path, trace) for module in TRACKING_MODULES: - with suppress_and_warn(): + with suppress_and_warn(f"end_trace on module {module} failed"): module.on_end_trace(trace, pipeline, send_state) return trace @@ -324,8 +325,13 @@ def merge_traces(last_trace: PipelineTrace, new_trace: PipelineTrace) -> Pipelin def save_trace(trace_path: str, trace: PipelineTrace) -> None: - with open(os.path.join(trace_path, TRACE_FILE_NAME), mode="bw") as f: - f.write(pickle.dumps(trace)) + # remove previous file, we do not want to keep old trace even if we fail later + trace_dump_path = os.path.join(trace_path, TRACE_FILE_NAME) + if os.path.isfile(trace_dump_path): + os.unlink(trace_dump_path) + with suppress_and_warn("Failed to create trace dump via pickle"): + trace_dump = pickle.dumps(trace) + FileStorage.save_atomic(trace_path, TRACE_FILE_NAME, trace_dump, file_type="b") def load_trace(trace_path: str) -> PipelineTrace: diff --git a/dlt/sources/helpers/requests/session.py b/dlt/sources/helpers/requests/session.py index 0a4d277848..5ba4d9b611 100644 --- a/dlt/sources/helpers/requests/session.py +++ b/dlt/sources/helpers/requests/session.py @@ -1,6 +1,5 @@ from requests import Session as BaseSession -from tenacity import Retrying, retry_if_exception_type -from typing import Optional, TYPE_CHECKING, Sequence, Union, Tuple, Type, TypeVar +from typing import Optional, TYPE_CHECKING, Union, Tuple, TypeVar from dlt.sources.helpers.requests.typing import TRequestTimeout from dlt.common.typing import TimedeltaSeconds @@ -56,3 +55,7 @@ def request(self, *args, **kwargs): # type: ignore[no-untyped-def,no-redef] if self.raise_for_status: resp.raise_for_status() return resp + + def send(self, request, **kwargs): # type: ignore[no-untyped-def] + kwargs.setdefault("timeout", self.timeout) + return super().send(request, **kwargs) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 2cc19f6624..73ae064299 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -116,7 +116,7 @@ def _create_request( hooks=hooks, ) - def _send_request(self, request: Request) -> Response: + def _send_request(self, request: Request, **kwargs: Any) -> Response: logger.info( f"Making {request.method.upper()} request to {request.url}" f" with params={request.params}, json={request.json}" @@ -125,18 +125,26 @@ def _send_request(self, request: Request) -> Response: prepared_request = self.session.prepare_request(request) send_kwargs = self.session.merge_environment_settings( - prepared_request.url, {}, None, None, None + prepared_request.url, + kwargs.pop("proxies", {}), + kwargs.pop("stream", None), + kwargs.pop("verify", None), + kwargs.pop("cert", None), ) + send_kwargs.update(**kwargs) # type: ignore[call-arg] return self.session.send(prepared_request, **send_kwargs) def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> Response: prepared_request = self._create_request( path=path, method=method, - **kwargs, + params=kwargs.pop("params", None), + json=kwargs.pop("json", None), + auth=kwargs.pop("auth", None), + hooks=kwargs.pop("hooks", None), ) - return self._send_request(prepared_request) + return self._send_request(prepared_request, **kwargs) def get(self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response: return self.request(path, method="GET", params=params, **kwargs) @@ -154,6 +162,7 @@ def paginate( paginator: Optional[BasePaginator] = None, data_selector: Optional[jsonpath.TJsonPath] = None, hooks: Optional[Hooks] = None, + **kwargs: Any, ) -> Iterator[PageData[Any]]: """Iterates over paginated API responses, yielding pages of data. @@ -170,6 +179,9 @@ def paginate( hooks (Optional[Hooks]): Hooks to modify request/response objects. Note that when hooks are not provided, the default behavior is to raise an exception on error status codes. + **kwargs (Any): Optional arguments to that the Request library accepts, such as + `stream`, `verify`, `proxies`, `cert`, `timeout`, and `allow_redirects`. + Yields: PageData[Any]: A page of data from the paginated API response, along with request and response context. @@ -183,7 +195,6 @@ def paginate( >>> for page in client.paginate("/search", method="post", json={"query": "foo"}): >>> print(page) """ - paginator = paginator if paginator else copy.deepcopy(self.paginator) auth = auth or self.auth data_selector = data_selector or self.data_selector @@ -204,7 +215,7 @@ def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: while True: try: - response = self._send_request(request) + response = self._send_request(request, **kwargs) except IgnoreResponseException: break diff --git a/dlt/sources/helpers/rest_client/detector.py b/dlt/sources/helpers/rest_client/detector.py index 19a1e83a82..511c9ce981 100644 --- a/dlt/sources/helpers/rest_client/detector.py +++ b/dlt/sources/helpers/rest_client/detector.py @@ -8,7 +8,7 @@ from .paginators import ( BasePaginator, HeaderLinkPaginator, - JSONResponsePaginator, + JSONLinkPaginator, JSONResponseCursorPaginator, SinglePagePaginator, PageNumberPaginator, @@ -156,7 +156,7 @@ def header_links_detector(response: Response) -> Tuple[HeaderLinkPaginator, floa return None, None -def json_links_detector(response: Response) -> Tuple[JSONResponsePaginator, float]: +def json_links_detector(response: Response) -> Tuple[JSONLinkPaginator, float]: dictionary = response.json() next_path_parts, next_href = find_next_page_path(dictionary) @@ -166,7 +166,7 @@ def json_links_detector(response: Response) -> Tuple[JSONResponsePaginator, floa try: urlparse(next_href) if next_href.startswith("http") or next_href.startswith("/"): - return JSONResponsePaginator(next_url_path=".".join(next_path_parts)), 1.0 + return JSONLinkPaginator(next_url_path=".".join(next_path_parts)), 1.0 except Exception: pass diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index 701f0c914b..4c8ce70bb2 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -1,3 +1,4 @@ +import warnings from abc import ABC, abstractmethod from typing import Optional, Dict, Any from urllib.parse import urlparse, urljoin @@ -422,7 +423,7 @@ class BaseNextUrlPaginator(BaseReferencePaginator): Subclasses should implement the `update_state` method to extract the next page URL and set the `_next_reference` attribute accordingly. - See `HeaderLinkPaginator` and `JSONResponsePaginator` for examples. + See `HeaderLinkPaginator` and `JSONLinkPaginator` for examples. """ def update_request(self, request: Request) -> None: @@ -491,7 +492,7 @@ def __str__(self) -> str: return super().__str__() + f": links_next_key: {self.links_next_key}" -class JSONResponsePaginator(BaseNextUrlPaginator): +class JSONLinkPaginator(BaseNextUrlPaginator): """Locates the next page URL within the JSON response body. The key containing the URL can be specified using a JSON path. @@ -511,12 +512,12 @@ class JSONResponsePaginator(BaseNextUrlPaginator): The link to the next page (`https://api.example.com/items?page=2`) is located in the 'next' key of the 'pagination' object. You can use - `JSONResponsePaginator` to paginate through the API endpoint: + `JSONLinkPaginator` to paginate through the API endpoint: from dlt.sources.helpers.rest_client import RESTClient client = RESTClient( base_url="https://api.example.com", - paginator=JSONResponsePaginator(next_url_path="pagination.next") + paginator=JSONLinkPaginator(next_url_path="pagination.next") ) @dlt.resource @@ -547,6 +548,20 @@ def __str__(self) -> str: return super().__str__() + f": next_url_path: {self.next_url_path}" +class JSONResponsePaginator(JSONLinkPaginator): + def __init__( + self, + next_url_path: jsonpath.TJsonPath = "next", + ) -> None: + warnings.warn( + "JSONResponsePaginator is deprecated and will be removed in version 1.0.0. Use" + " JSONLinkPaginator instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(next_url_path) + + class JSONResponseCursorPaginator(BaseReferencePaginator): """Uses a cursor parameter for pagination, with the cursor value found in the JSON response body. diff --git a/docs/examples/custom_config_provider/.dlt/config.toml b/docs/examples/custom_config_provider/.dlt/config.toml new file mode 100644 index 0000000000..8b99f2d496 --- /dev/null +++ b/docs/examples/custom_config_provider/.dlt/config.toml @@ -0,0 +1 @@ +dlt_config_profile_name="prod" \ No newline at end of file diff --git a/docs/examples/custom_config_provider/__init__.py b/docs/examples/custom_config_provider/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/custom_config_provider/custom_config_provider.py b/docs/examples/custom_config_provider/custom_config_provider.py new file mode 100644 index 0000000000..e4b438cb4f --- /dev/null +++ b/docs/examples/custom_config_provider/custom_config_provider.py @@ -0,0 +1,94 @@ +""" +--- +title: Use custom yaml file for config and secrets +description: We show how to keep configuration in yaml file with switchable profiles and simple templates +keywords: [config, yaml config, profiles] +--- + +This example shows how to replace secrets/config toml files with a yaml file that contains several profiles (prod and dev) and jinja-like +placeholders that are replaced with corresponding env variables. +`dlt` resolves configuration by querying so called config providers (to ie. query env variables or content of a toml file). +Here we will instantiate a provider with a custom loader and register it to be queried. At the end we demonstrate (using mock github source) +that `dlt` uses it along other (standard) providers to resolve configuration. + +In this example you will learn to: + +* Implement custom configuration loader that parses yaml file, manipulates it and then returns final Python dict +* Instantiate custom provider (CustomLoaderDocProvider) from the loader +* Register provider instance to be queried + +""" + +import os +import re +import dlt +import yaml +import functools + +from dlt.common.configuration.providers import CustomLoaderDocProvider +from dlt.common.utils import map_nested_in_place + + +# config for all resources found in this file will be grouped in this source level config section +__source_name__ = "github_api" + + +def eval_placeholder(value): + """Replaces jinja placeholders {{ PLACEHOLDER }} with environment variables""" + if isinstance(value, str): + + def replacer(match): + return os.environ[match.group(1)] + + return re.sub(r"\{\{\s*(\w+)\s*\}\}", replacer, value) + return value + + +def loader(profile_name: str): + """Loads yaml file from profiles.yaml in current working folder, selects profile, replaces + placeholders with env variables and returns Python dict with final config + """ + path = os.path.abspath("profiles.yaml") + with open(path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + # get the requested environment + config = config.get(profile_name, None) + if config is None: + raise RuntimeError(f"Profile with name {profile_name} not found in {os.path.abspath(path)}") + # evaluate all placeholders + # NOTE: this method only works with placeholders wrapped as strings in yaml. use jinja lib for real templating + return map_nested_in_place(eval_placeholder, config) + + +@dlt.resource(standalone=True) +def github(url: str = dlt.config.value, api_key=dlt.secrets.value): + # just return the injected config and secret + yield url, api_key + + +if __name__ == "__main__": + # mock env variables to fill placeholders in profiles.yaml + os.environ["GITHUB_API_KEY"] = "secret_key" # mock expected var + + # dlt standard providers work at this point (we have profile name in config.toml) + profile_name = dlt.config["dlt_config_profile_name"] + + # instantiate custom provider using `prod` profile + # NOTE: all placeholders (ie. GITHUB_API_KEY) will be evaluated in next line! + provider = CustomLoaderDocProvider("profiles", functools.partial(loader, profile_name)) + # register provider, it will be added as the last one in chain + dlt.config.register_provider(provider) + + # your pipeline will now be able to use your yaml provider + # p = Pipeline(...) + # p.run(...) + + # show the final config + print(provider.to_yaml()) + # or if you like toml + print(provider.to_toml()) + + # inject && evaluate resource + config_vals = list(github()) + print(config_vals) + assert config_vals[0] == ("https://github.com/api", "secret_key") diff --git a/docs/examples/custom_config_provider/profiles.yaml b/docs/examples/custom_config_provider/profiles.yaml new file mode 100644 index 0000000000..f2ec783e3d --- /dev/null +++ b/docs/examples/custom_config_provider/profiles.yaml @@ -0,0 +1,12 @@ +prod: + sources: + github_api: # source level + github: # resource level + url: https://github.com/api + api_key: "{{GITHUB_API_KEY}}" + +dev: + sources: + github_api: + url: https://github.com/api + api_key: "" # no keys in dev env diff --git a/docs/examples/qdrant_zendesk/qdrant_zendesk.py b/docs/examples/qdrant_zendesk/qdrant_zendesk.py index 5416f2f2d0..9b6fbee150 100644 --- a/docs/examples/qdrant_zendesk/qdrant_zendesk.py +++ b/docs/examples/qdrant_zendesk/qdrant_zendesk.py @@ -165,14 +165,13 @@ def get_pages( dataset_name="zendesk_data", ) - # run the dlt pipeline and save info about the load process - load_info = pipeline.run( - # here we use a special function to tell Qdrant which fields to embed - qdrant_adapter( - zendesk_support(), # retrieve tickets data - embed=["subject", "description"], - ) - ) + # here we instantiate the source + source = zendesk_support() + # ...and apply special hints on the ticket resource to tell qdrant which fields to embed + qdrant_adapter(source.tickets_data, embed=["subject", "description"]) + + # run the dlt pipeline and print info about the load process + load_info = pipeline.run(source) print(load_info) @@ -189,7 +188,7 @@ def get_pages( # query Qdrant with prompt: getting tickets info close to "cancellation" response = qdrant_client.query( - "zendesk_data_content", # collection/dataset name with the 'content' suffix -> tickets content table + "zendesk_data_tickets_data", # tickets_data collection query_text="cancel subscription", # prompt to search limit=3, # limit the number of results to the nearest 3 embeddings ) diff --git a/docs/website/docs/_book-onboarding-call.md b/docs/website/docs/_book-onboarding-call.md index 5f6d5df81b..4725128bf0 100644 --- a/docs/website/docs/_book-onboarding-call.md +++ b/docs/website/docs/_book-onboarding-call.md @@ -1 +1 @@ -book a call with our support engineer Violetta \ No newline at end of file +book a call with a dltHub Solutions Engineer diff --git a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md index 4d92043fb5..51d124251a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md +++ b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md @@ -136,6 +136,58 @@ def streamed_resource(): streamed_resource.apply_hints(additional_table_hints={"x-insert-api": "streaming"}) ``` +### Use BigQuery schema autodetect for nested fields +You can let BigQuery to infer schemas and create destination tables instead of `dlt`. As a consequence, nested fields (ie. `RECORD`), which `dlt` does not support at +this moment (they are stored as JSON), may be created. You select certain resources with [BigQuery Adapter](#bigquery-adapter) or all of them with the following config option: +```toml +[destination.bigquery] +autodetect_schema=true +``` +We recommend to yield [arrow tables](../verified-sources/arrow-pandas.md) from your resources and `parquet` file format to load the data. In that case the schemas generated by `dlt` and BigQuery +will be identical. BigQuery will also preserve the column order from the generated parquet files. You can convert `json` data into arrow tables with [pyarrow or duckdb](../verified-sources/arrow-pandas.md#loading-json-documents). + +```py +import pyarrow.json as paj + +import dlt +from dlt.destinations.adapters import bigquery_adapter + +@dlt.resource(name="cve") +def load_cve(): + with open("cve.json", 'rb') as f: + # autodetect arrow schema and yields arrow table + yield paj.read_json(f) + +pipeline = dlt.pipeline("load_json_struct", destination="bigquery") +pipeline.run( + bigquery_adapter(load_cve(), autodetect_schema=True) +) +``` +Above, we use `pyarrow` library to convert `json` document into `arrow` table and use `biguery_adapter` to enable schema autodetect for **cve** resource. + +Yielding Python dicts/lists and loading them as `jsonl` works as well. In many cases, the resulting nested structure is simpler than those obtained via pyarrow/duckdb and parquet. However there are slight differences in inferred types from `dlt` (BigQuery coerces types more aggressively). BigQuery also does not try to preserve the column order in relation to the order of fields in JSON. + +```py +import dlt +from dlt.destinations.adapters import bigquery_adapter + +@dlt.resource(name="cve", max_table_nesting=1) +def load_cve(): + with open("cve.json", 'rb') as f: + yield json.load(f) + +pipeline = dlt.pipeline("load_json_struct", destination="bigquery") +pipeline.run( + bigquery_adapter(load_cve(), autodetect_schema=True) +) +``` +In the example below we represent `json` data as tables up until nesting level 1. Above this nesting level, we let BigQuery to create nested fields. + +:::caution +If you yield data as Python objects (dicts) and load this data as `parquet`, the nested fields will be converted into strings. This is one of the consequences of +`dlt` not being able to infer nested fields. +::: + ## Supported File Formats You can configure the following file formats to load data to BigQuery: @@ -148,7 +200,11 @@ When staging is enabled: * [jsonl](../file-formats/jsonl.md) is used by default. * [parquet](../file-formats/parquet.md) is supported. -> ❗ **Bigquery cannot load JSON columns from `parquet` files**. `dlt` will fail such jobs permanently. Switch to `jsonl` to load and parse JSON properly. +:::caution +**Bigquery cannot load JSON columns from `parquet` files**. `dlt` will fail such jobs permanently. Instead: +* Switch to `jsonl` to load and parse JSON properly. +* Use schema [autodetect and nested fields](#use-bigquery-schema-autodetect-for-nested-fields) +::: ## Supported Column Hints diff --git a/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md b/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md index b1dde5a328..bf8e2bce02 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md +++ b/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md @@ -37,7 +37,7 @@ or with `pip install "dlt[clickhouse]"`, which installs the `dlt` library and th ### 2. Setup ClickHouse database -To load data into ClickHouse, you need to create a ClickHouse database. While we recommend asking our GPT-4 assistant for details, we have provided a general outline of the process below: +To load data into ClickHouse, you need to create a ClickHouse database. While we recommend asking our GPT-4 assistant for details, we've provided a general outline of the process below: 1. You can use an existing ClickHouse database or create a new one. @@ -59,35 +59,52 @@ To load data into ClickHouse, you need to create a ClickHouse database. While we ```toml [destination.clickhouse.credentials] - database = "dlt" # The database name you created - username = "dlt" # ClickHouse username, default is usually "default" - password = "Dlt*12345789234567" # ClickHouse password if any - host = "localhost" # ClickHouse server host - port = 9000 # ClickHouse HTTP port, default is 9000 - http_port = 8443 # HTTP Port to connect to ClickHouse server's HTTP interface. Defaults to 8443. + database = "dlt" # The database name you created. + username = "dlt" # ClickHouse username, default is usually "default". + password = "Dlt*12345789234567" # ClickHouse password if any. + host = "localhost" # ClickHouse server host. + port = 9000 # ClickHouse native TCP protocol port, default is 9000. + http_port = 8443 # ClickHouse HTTP port, default is 9000. secure = 1 # Set to 1 if using HTTPS, else 0. - dataset_table_separator = "___" # Separator for dataset table names from dataset. ``` - :::info http_port - The `http_port` parameter specifies the port number to use when connecting to the ClickHouse server's HTTP interface. This is different from default port 9000, which is used for the native TCP - protocol. + :::info Network Ports + The `http_port` parameter specifies the port number to use when connecting to the ClickHouse server's HTTP interface. + The default non-secure HTTP port for ClickHouse is `8123`. + This is different from the default port `9000`, which is used for the native TCP protocol. - You must set `http_port` if you are not using external staging (i.e. you don't set the staging parameter in your pipeline). This is because dlt's built-in ClickHouse local storage staging uses the - [clickhouse-connect](https://github.com/ClickHouse/clickhouse-connect) library, which communicates with ClickHouse over HTTP. + You must set `http_port` if you are not using external staging (i.e. you don't set the `staging` parameter in your pipeline). This is because dlt's built-in ClickHouse local storage staging uses the [clickhouse-connect](https://github.com/ClickHouse/clickhouse-connect) library, which communicates with ClickHouse over HTTP. - Make sure your ClickHouse server is configured to accept HTTP connections on the port specified by `http_port`. For example, if you set `http_port = 8443`, then ClickHouse should be listening for - HTTP - requests on port 8443. If you are using external staging, you can omit the `http_port` parameter, since clickhouse-connect will not be used in this case. + Make sure your ClickHouse server is configured to accept HTTP connections on the port specified by `http_port`. For example: + + - If you set `http_port = 8123` (default non-secure HTTP port), then ClickHouse should be listening for HTTP requests on port 8123. + - If you set `http_port = 8443`, then ClickHouse should be listening for secure HTTPS requests on port 8443. + + If you're using external staging, you can omit the `http_port` parameter, since clickhouse-connect will not be used in this case. + + For local development and testing with ClickHouse running locally, it is recommended to use the default non-secure HTTP port `8123` by setting `http_port=8123` or omitting the parameter. + + Please see the [ClickHouse network port documentation](https://clickhouse.com/docs/en/guides/sre/network-ports) for further reference. ::: 2. You can pass a database connection string similar to the one used by the `clickhouse-driver` library. The credentials above will look like this: ```toml - # keep it at the top of your toml file, before any section starts. + # keep it at the top of your toml file before any section starts. destination.clickhouse.credentials="clickhouse://dlt:Dlt*12345789234567@localhost:9000/dlt?secure=1" ``` +### 3. Add configuration options + +You can set the following configuration options in the `.dlt/secrets.toml` file: + +```toml +[destination.clickhouse] +dataset_table_separator = "___" # The default separator for dataset table names from dataset. +table_engine_type = "merge_tree" # The default table engine to use. +dataset_sentinel_table_name = "dlt_sentinel_table" # The default name for sentinel tables. +``` + ## Write disposition All [write dispositions](../../general-usage/incremental-loading#choosing-a-write-disposition) are supported. @@ -104,7 +121,8 @@ Data is loaded into ClickHouse using the most efficient method depending on the `Clickhouse` does not support multiple datasets in one database, dlt relies on datasets to exist for multiple reasons. To make `clickhouse` work with `dlt`, tables generated by `dlt` in your `clickhouse` database will have their name prefixed with the dataset name separated by -the configurable `dataset_table_separator`. Additionally, a special sentinel table that does not contain any data will be created, so dlt knows which virtual datasets already exist in a +the configurable `dataset_table_separator`. +Additionally, a special sentinel table that doesn't contain any data will be created, so dlt knows which virtual datasets already exist in a clickhouse destination. @@ -115,14 +133,15 @@ destination. The `clickhouse` destination has a few specific deviations from the default sql destinations: -1. `Clickhouse` has an experimental `object` datatype, but we have found it to be a bit unpredictable, so the dlt clickhouse destination will load the complex datatype to a `text` column. If you need +1. `Clickhouse` has an experimental `object` datatype, but we've found it to be a bit unpredictable, so the dlt clickhouse destination will load the complex datatype to a `text` column. + If you need this feature, get in touch with our Slack community, and we will consider adding it. 2. `Clickhouse` does not support the `time` datatype. Time will be loaded to a `text` column. 3. `Clickhouse` does not support the `binary` datatype. Binary will be loaded to a `text` column. When loading from `jsonl`, this will be a base64 string, when loading from parquet this will be the `binary` object converted to `text`. -4. `Clickhouse` accepts adding columns to a populated table that are not null. -5. `Clickhouse` can produce rounding errors under certain conditions when using the float / double datatype. Make sure to use decimal if you cannot afford to have rounding errors. Loading the value - 12.7001 to a double column with the loader file format jsonl set will predictbly produce a rounding error for example. +4. `Clickhouse` accepts adding columns to a populated table that aren’t null. +5. `Clickhouse` can produce rounding errors under certain conditions when using the float / double datatype. Make sure to use decimal if you can’t afford to have rounding errors. Loading the value + 12.7001 to a double column with the loader file format jsonl set will predictably produce a rounding error, for example. ## Supported column hints @@ -130,31 +149,46 @@ ClickHouse supports the following [column hints](../../general-usage/schema#tabl - `primary_key` - marks the column as part of the primary key. Multiple columns can have this hint to create a composite primary key. -## Table Engine +## Choosing a Table Engine + +dlt defaults to `MergeTree` table engine. You can specify an alternate table engine in two ways: -By default, tables are created using the `ReplicatedMergeTree` table engine in ClickHouse. You can specify an alternate table engine using the `table_engine_type` with the clickhouse adapter: +### Setting a default table engine in the configuration + +You can set a default table engine for all resources and dlt tables by adding the `table_engine_type` parameter to your ClickHouse credentials in the `.dlt/secrets.toml` file: + +```toml +[destination.clickhouse] +# ... (other configuration options) +table_engine_type = "merge_tree" # The default table engine to use. +``` + +### Setting the table engine for specific resources + +You can also set the table engine for specific resources using the clickhouse_adapter, which will override the default engine set in `.dlt/secrets.toml`, for that resource: ```py from dlt.destinations.adapters import clickhouse_adapter - @dlt.resource() def my_resource(): - ... - + ... clickhouse_adapter(my_resource, table_engine_type="merge_tree") - ``` -Supported values are: +Supported values for `table_engine_type` are: + +- `merge_tree` (default) - creates tables using the `MergeTree` engine, suitable for most use cases. [Learn more about MergeTree](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree). +- `shared_merge_tree` - creates tables using the `SharedMergeTree` engine, optimized for cloud-native environments with shared storage. This table is **only** available on ClickHouse Cloud, and it the default selection if `merge_tree` is selected. [Learn more about SharedMergeTree](https://clickhouse.com/docs/en/cloud/reference/shared-merge-tree). +- `replicated_merge_tree` - creates tables using the `ReplicatedMergeTree` engine, which supports data replication across multiple nodes for high availability. [Learn more about ReplicatedMergeTree](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/replication). This defaults to `shared_merge_tree` on ClickHouse Cloud. +- Experimental support for the `Log` engine family with `stripe_log` and `tiny_log`. -- `merge_tree` - creates tables using the `MergeTree` engine -- `replicated_merge_tree` (default) - creates tables using the `ReplicatedMergeTree` engine +For local development and testing with ClickHouse running locally, the `MergeTree` engine is recommended. ## Staging support -ClickHouse supports Amazon S3, Google Cloud Storage and Azure Blob Storage as file staging destinations. +ClickHouse supports Amazon S3, Google Cloud Storage, and Azure Blob Storage as file staging destinations. `dlt` will upload Parquet or JSONL files to the staging location and use ClickHouse table functions to load the data directly from the staged files. @@ -214,7 +248,7 @@ dlt's staging mechanisms for ClickHouse. ### dbt support -Integration with [dbt](../transformations/dbt/dbt.md) is generally supported via dbt-clickhouse, but not tested by us. +Integration with [dbt](../transformations/dbt/dbt.md) is generally supported via dbt-clickhouse but not tested by us. ### Syncing of `dlt` state diff --git a/docs/website/docs/dlt-ecosystem/destinations/filesystem.md b/docs/website/docs/dlt-ecosystem/destinations/filesystem.md index e93ffb54d4..bbe21b7ea7 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/filesystem.md +++ b/docs/website/docs/dlt-ecosystem/destinations/filesystem.md @@ -265,6 +265,29 @@ The filesystem destination handles the write dispositions as follows: - `replace` - all files that belong to such tables are deleted from the dataset folder, and then the current set of files is added. - `merge` - falls back to `append` +### 🧪 `merge` with `delta` table format +The [`upsert`](../../general-usage/incremental-loading.md#upsert-strategy) merge strategy is supported when using the [`delta`](#delta-table-format) table format. + +:::caution +The `upsert` merge strategy for the `filesystem` destination with `delta` table format is considered experimental. +::: + +```py +@dlt.resource( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="my_primary_key", + table_format="delta" +) +def my_upsert_resource(): + ... +... +``` + +#### Known limitations +- `hard_delete` hint not supported +- deleting records from child tables not supported + - This means updates to complex columns that involve element removals are not propagated. For example, if you first load `{"key": 1, "complex": [1, 2]}` and then load `{"key": 1, "complex": [1]}`, then the record for element `2` will not be deleted from the child table. + ## File Compression The filesystem destination in the dlt library uses `gzip` compression by default for efficiency, which may result in the files being stored in a compressed format. This format may not be easily readable as plain text or JSON Lines (`jsonl`) files. If you encounter files that seem unreadable, they may be compressed. diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 2b8d73d4cb..181d024a2f 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -310,6 +310,27 @@ Above we set `csv` file without header, with **|** as a separator and we request You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) ::: +### Query Tagging +`dlt` [tags sessions](https://docs.snowflake.com/en/sql-reference/parameters#query-tag) that execute loading jobs with following job properties: +* **source** - name of the source (identical with the name of `dlt` schema) +* **resource** - name of the resource (if known, else empty string) +* **table** - name of the table loaded by the job +* **load_id** - load id of the job +* **pipeline_name** - name of the active pipeline (or empty string if not found) + +You can define query tag by defining a query tag placeholder in snowflake credentials: +```toml +[destination.snowflake] +query_tag='{{"source":"{source}", "resource":"{resource}", "table": "{table}", "load_id":"{load_id}", "pipeline_name":"{pipeline_name}"}}' +``` +which contains Python named formatters corresponding to tag names ie. `{source}` will assume the name of the dlt source. + +:::note +1. query tagging is off by default. `query_tag` configuration field is `None` by default and must be set to enable tagging. +2. only sessions associated with a job are tagged. sessions that migrate schemas remain untagged +3. jobs processing table chains (ie. sql merge jobs) will use top level table as **table** +::: + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-snowflake](https://github.com/dbt-labs/dbt-snowflake). Both password and key pair authentication are supported and shared with dbt runners. diff --git a/docs/website/docs/dlt-ecosystem/file-formats/_set_the_format.mdx b/docs/website/docs/dlt-ecosystem/file-formats/_set_the_format.mdx new file mode 100644 index 0000000000..e2cce374a2 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/file-formats/_set_the_format.mdx @@ -0,0 +1,31 @@ +import CodeBlock from '@theme/CodeBlock'; + +There are several ways of configuring dlt to use {props.file_type} file format for normalization step and to store your data at the destination: + +1. You can set the loader_file_format argument to {props.file_type} in the run command: + +
+info = pipeline.run(some_source(), loader_file_format="{props.file_type}")
+
+ + +2. You can set the `loader_file_format` in `config.toml` or `secrets.toml`: + +
+[normalize]{'\n'}
+loader_file_format="{props.file_type}"
+
+ +3. You can set the `loader_file_format` via ENV variable: + +
+export NORMALIZE__LOADER_FILE_FORMAT="{props.file_type}"
+
+ +4. You can set the file type directly in [the resource decorator](../../general-usage/resource#pick-loader-file-format-for-a-particular-resource). + +
+@dlt.resource(file_format="{props.file_type}"){'\n'}
+def generate_rows(nr):{'\n'}
+    pass
+
diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md index 02a7e81def..242a8282d1 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/csv.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -3,6 +3,7 @@ title: csv description: The csv file format keywords: [csv, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # CSV file format @@ -13,16 +14,13 @@ Internally we use two implementations: - **pyarrow** csv writer - very fast, multithreaded writer for the [arrow tables](../verified-sources/arrow-pandas.md) - **python stdlib writer** - a csv writer included in the Python standard library for Python objects - ## Supported Destinations -Supported by: **Postgres**, **Filesystem**, **snowflake** +The `csv` format is supported by the following destinations: **Postgres**, **Filesystem**, **Snowflake** -By setting the `loader_file_format` argument to `csv` in the run command, the pipeline will store your data in the csv format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="csv") -``` + ## Default Settings `dlt` attempts to make both writers to generate similarly looking files diff --git a/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md b/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md index 641be9a106..c6742c2584 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md @@ -3,6 +3,7 @@ title: INSERT description: The INSERT file format keywords: [insert values, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # SQL INSERT File Format @@ -21,10 +22,8 @@ This file format is [compressed](../../reference/performance.md#disabling-and-en This format is used by default by: **DuckDB**, **Postgres**, **Redshift**. -It is also supported by: **filesystem**. +It is also supported by: **Filesystem**. -By setting the `loader_file_format` argument to `insert_values` in the run command, the pipeline will store your data in the INSERT format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="insert_values") -``` + diff --git a/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md b/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md index 7467c6f639..72168b38f0 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md @@ -3,6 +3,7 @@ title: jsonl description: The jsonl file format keywords: [jsonl, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # jsonl - JSON Delimited @@ -22,11 +23,8 @@ This file format is ## Supported Destinations -This format is used by default by: **BigQuery**, **Snowflake**, **filesystem**. +This format is used by default by: **BigQuery**, **Snowflake**, **Filesystem**. -By setting the `loader_file_format` argument to `jsonl` in the run command, the pipeline will store -your data in the jsonl format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="jsonl") -``` + diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 414eaf2cb8..5d85b7a557 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -3,6 +3,7 @@ title: Parquet description: The parquet file format keywords: [parquet, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # Parquet file format @@ -16,13 +17,11 @@ pip install "dlt[parquet]" ## Supported Destinations -Supported by: **BigQuery**, **DuckDB**, **Snowflake**, **filesystem**, **Athena**, **Databricks**, **Synapse** +Supported by: **BigQuery**, **DuckDB**, **Snowflake**, **Filesystem**, **Athena**, **Databricks**, **Synapse** -By setting the `loader_file_format` argument to `parquet` in the run command, the pipeline will store your data in the parquet format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="parquet") -``` + ## Destination AutoConfig `dlt` uses [destination capabilities](../../walkthroughs/create-new-destination.md#3-set-the-destination-capabilities) to configure the parquet writer: diff --git a/docs/website/docs/dlt-ecosystem/transformations/sql.md b/docs/website/docs/dlt-ecosystem/transformations/sql.md index ad37c61bd8..b358e97b4c 100644 --- a/docs/website/docs/dlt-ecosystem/transformations/sql.md +++ b/docs/website/docs/dlt-ecosystem/transformations/sql.md @@ -16,7 +16,7 @@ connection. pipeline = dlt.pipeline(destination="bigquery", dataset_name="crm") try: with pipeline.sql_client() as client: - client.sql_client.execute_sql( + client.execute_sql( "INSERT INTO customers VALUES (%s, %s, %s)", 10, "Fred", diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md index f9ceb99a90..cb14db7ae7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md @@ -70,9 +70,12 @@ The output file format is chosen automatically based on the destination's capabi * snowflake * filesystem * athena +* databricks +* dremio +* synapse -## Normalize configuration +## Add `_dlt_load_id` and `_dlt_id` to your tables `dlt` does not add any data lineage columns by default when loading Arrow tables. This is to give the best performance and avoid unnecessary data copying. @@ -120,6 +123,21 @@ pipeline.run(orders) Look at the [Connector X + Arrow Example](../../examples/connector_x_arrow/) to see how to load data from production databases fast. ::: +## Loading `json` documents +If you want to skip default `dlt` JSON normalizer, you can use any available method to convert json documents into tabular data. +* **pandas** has `read_json` and `json_normalize` methods +* **pyarrow** can infer table schema and convert json files into tables with `read_json` +* **duckdb** can do the same with `read_json_auto` + +```py +import duckdb + +conn = duckdb.connect() +table = conn.execute(f"SELECT * FROM read_json_auto('{json_file_path}')").fetch_arrow_table() +``` + +Note that **duckdb** and **pyarrow** methods will generate [nested types](#loading-nested-types) for nested data, which are only partially supported by `dlt`. + ## Supported Arrow data types The Arrow data types are translated to dlt data types as follows: @@ -141,7 +159,7 @@ The Arrow data types are translated to dlt data types as follows: ## Loading nested types All struct types are represented as `complex` and will be loaded as JSON (if destination permits) or a string. Currently we do not support **struct** types, -even if they are present in the destination. +even if they are present in the destination (except **BigQuery** which can be [configured to handle them](../destinations/bigquery.md#use-bigquery-schema-autodetect-for-nested-fields)) If you want to represent nested data as separated tables, you must yield panda frames and arrow tables as records. In the examples above: ```py diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md b/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md index 357d50582f..83077270c7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md @@ -283,5 +283,25 @@ verified source. 1. This function loads data incrementally and tracks the `occurred_at.last_value` parameter from the previous pipeline run. Refer to our official documentation for more information on [incremental loading](../../general-usage/incremental-loading.md). +### Additional info +If you encounter the following error while processing your request: +:::warning ERROR +Your request to HubSpot is too long to process. Maximum allowed query length is 2000 symbols, ... while your list is +2125 symbols long. +::: + +Please note that by default, HubSpot requests all default properties and all custom properties (which are +user-created properties in HubSpot). Therefore, you need to request specific properties for each entity (contacts, +companies, tickets, etc.). + +Default properties are defined in `settings.py`, and you can change them. + +The custom properties could cause the error as there might be too many of them available in your HubSpot. +To change this, you can pass `include_custom_props=False` when initializing the source: + +```py +info = p.run(hubspot(include_custom_props=False)) +``` +Or, if you wish to include them, you can modify `settings.py`. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/index.md b/docs/website/docs/dlt-ecosystem/verified-sources/index.md index 7b5d9e2bcb..d105dccb9c 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/index.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/index.md @@ -8,7 +8,7 @@ import Link from '../../_book-onboarding-call.md'; Choose from our collection of verified sources, developed and maintained by the dlt team and community. Each source is rigorously tested against a real API and provided as Python code for easy customization. -Planning to use dlt in production and need a source that isn't listed? We're happy to build it for you: . +Planning to use dlt in production and need a source that isn't listed? We're happy to help you build it: . ### Popular sources diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md b/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md index 6d69f09cd3..a12c831137 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md @@ -17,6 +17,10 @@ Resources that can be loaded using this verified source are: | -------------------- | ----------------------------------------------- | | replication_resource | Load published messages from a replication slot | +:::info +The postgres replication source currently **does not** suppport the [scd2 merge strategy](../../general-usage/incremental-loading#scd2-strategy). +::: + ## Setup Guide ### Setup user @@ -268,4 +272,4 @@ If you wish to create your own pipelines, you can leverage source and resource m ) ``` - Similarly, to replicate changes from selected columns, you can use the `table_names` and `include_columns` arguments in the `replication_resource` function. \ No newline at end of file + Similarly, to replicate changes from selected columns, you can use the `table_names` and `include_columns` arguments in the `replication_resource` function. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index 8d43e471c8..9475dad578 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -386,7 +386,7 @@ from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator { "path": "posts", - "paginator": JSONResponsePaginator( + "paginator": JSONLinkPaginator( next_url_path="pagination.next" ), } @@ -400,16 +400,30 @@ These are the available paginators: | `type` | Paginator class | Description | | ------------ | -------------- | ----------- | -| `json_response` | [JSONResponsePaginator](../../general-usage/http/rest-client.md#jsonresponsepaginator) | The link to the next page is in the body (JSON) of the response.
*Parameters:*
  • `next_url_path` (str) - the JSONPath to the next page URL
| +| `json_link` | [JSONLinkPaginator](../../general-usage/http/rest-client.md#jsonresponsepaginator) | The link to the next page is in the body (JSON) of the response.
*Parameters:*
  • `next_url_path` (str) - the JSONPath to the next page URL
| | `header_link` | [HeaderLinkPaginator](../../general-usage/http/rest-client.md#headerlinkpaginator) | The links to the next page are in the response headers.
*Parameters:*
  • `link_header` (str) - the name of the header containing the links. Default is "next".
| | `offset` | [OffsetPaginator](../../general-usage/http/rest-client.md#offsetpaginator) | The pagination is based on an offset parameter. With total items count either in the response body or explicitly provided.
*Parameters:*
  • `limit` (int) - the maximum number of items to retrieve in each request
  • `offset` (int) - the initial offset for the first request. Defaults to `0`
  • `offset_param` (str) - the name of the query parameter used to specify the offset. Defaults to "offset"
  • `limit_param` (str) - the name of the query parameter used to specify the limit. Defaults to "limit"
  • `total_path` (str) - a JSONPath expression for the total number of items. If not provided, pagination is controlled by `maximum_offset`
  • `maximum_offset` (int) - optional maximum offset value. Limits pagination even without total count
| -| `page_number` | [PageNumberPaginator](../../general-usage/http/rest-client.md#pagenumberpaginator) | The pagination is based on a page number parameter. With total pages count either in the response body or explicitly provided.
*Parameters:*
  • `initial_page` (int) - the starting page number. Defaults to `0`
  • `page_param` (str) - the query parameter name for the page number. Defaults to "page"
  • `total_path` (str) - a JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page`
  • `maximum_page` (int) - optional maximum page number. Stops pagination once this page is reached
| +| `page_number` | [PageNumberPaginator](../../general-usage/http/rest-client.md#pagenumberpaginator) | The pagination is based on a page number parameter. With total pages count either in the response body or explicitly provided.
*Parameters:*
  • `base_page` (int) - the starting page number. Defaults to `0`
  • `page_param` (str) - the query parameter name for the page number. Defaults to "page"
  • `total_path` (str) - a JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page`
  • `maximum_page` (int) - optional maximum page number. Stops pagination once this page is reached
| | `cursor` | [JSONResponseCursorPaginator](../../general-usage/http/rest-client.md#jsonresponsecursorpaginator) | The pagination is based on a cursor parameter. The value of the cursor is in the response body (JSON).
*Parameters:*
  • `cursor_path` (str) - the JSONPath to the cursor value. Defaults to "cursors.next"
  • `cursor_param` (str) - the query parameter name for the cursor. Defaults to "after"
| | `single_page` | SinglePagePaginator | The response will be interpreted as a single-page response, ignoring possible pagination metadata. | | `auto` | `None` | Explicitly specify that the source should automatically detect the pagination method. | For more complex pagination methods, you can implement a [custom paginator](../../general-usage/http/rest-client.md#implementing-a-custom-paginator), instantiate it, and use it in the configuration. +Alternatively, you can use the dictionary configuration syntax also for custom paginators. For this, you need to register your custom paginator: + +```py +rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator) + +{ + # ... + "paginator": { + "type": "custom_paginator", + "next_url_path": "paging.nextLink", + } +} +``` + ### Data selection The `data_selector` field in the endpoint configuration allows you to specify a JSONPath to select the data from the response. By default, the source will try to detect locations of the data automatically. @@ -693,7 +707,7 @@ Let's break down the configuration. ### Incremental loading using the `incremental` field -The alternative method is to use the `incremental` field in the [endpoint configuration](#endpoint-configuration). This method is more flexible and allows you to specify the start and end conditions for the incremental loading. +The alternative method is to use the `incremental` field in the [endpoint configuration](#endpoint-configuration). This configuration is more powerful than the method shown above because it also allows you to specify not only the start parameter and value but also the end parameter and value for the incremental loading. Let's take the same example as above and configure it using the `incremental` field: @@ -721,6 +735,7 @@ The full available configuration for the `incremental` field is: "cursor_path": "", "initial_value": "", "end_value": "", + "convert": a_callable, } } ``` @@ -732,11 +747,44 @@ The fields are: - `cursor_path` (str): The JSONPath to the field within each item in the list. This is the field that will be used to track the incremental loading. In the example above, it's `"created_at"`. - `initial_value` (str): The initial value for the cursor. This is the value that will initialize the state of incremental loading. - `end_value` (str): The end value for the cursor to stop the incremental loading. This is optional and can be omitted if you only need to track the start condition. If you set this field, `initial_value` needs to be set as well. +- `convert` (callable): A callable that converts the cursor value into the format that the query parameter requires. For example, a UNIX timestamp can be converted into an ISO 8601 date or a date can be converted into `created_at+gt+{date}`. See the [incremental loading](../../general-usage/incremental-loading.md#incremental-loading-with-a-cursor-field) guide for more details. If you encounter issues with incremental loading, see the [troubleshooting section](../../general-usage/incremental-loading.md#troubleshooting) in the incremental loading guide. +### Convert the incremental value before calling the API + +If you need to transform the values in the cursor field before passing them to the API endpoint, you can specify a callable under the key `convert`. For example, the API might return UNIX epoch timestamps but expects to be queried with an ISO 8601 date. To achieve that, we can specify a function that converts from the date format returned by the API to the date format required for API requests. + +In the following examples, `1704067200` is returned from the API in the field `updated_at` but the API will be called with `?created_since=2024-01-01`. + +Incremental loading using the `params` field: +```py +{ + "created_since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "1704067200", + "convert": lambda epoch: pendulum.from_timestamp(int(epoch)).to_date_string(), + } +} +``` + +Incremental loading using the `incremental` field: +```py +{ + "path": "posts", + "data_selector": "results", + "incremental": { + "start_param": "created_since", + "cursor_path": "updated_at", + "initial_value": "1704067200", + "convert": lambda epoch: pendulum.from_timestamp(int(epoch)).to_date_string(), + }, +} +``` + ## Advanced configuration `rest_api_source()` function creates the [dlt source](../../general-usage/source.md) and lets you configure the following parameters: @@ -751,13 +799,24 @@ If you encounter issues with incremental loading, see the [troubleshooting secti ### Response actions -The `response_actions` field in the endpoint configuration allows you to specify how to handle specific responses from the API based on status codes or content substrings. This is useful for handling edge cases like ignoring responses on specific conditions. +The `response_actions` field in the endpoint configuration allows you to specify how to handle specific responses or all responses from the API. For example, responses with specific status codes or content substrings can be ignored. +Additionally, all responses or only responses with specific status codes or content substrings can be transformed with a custom callable, such as a function. This callable is passed on to the requests library as a [response hook](https://requests.readthedocs.io/en/latest/user/advanced/#event-hooks). The callable can modify the response object and has to return it for the modifications to take effect. :::caution Experimental Feature This is an experimental feature and may change in future releases. ::: -#### Example +**Fields:** + +- `status_code` (int, optional): The HTTP status code to match. +- `content` (str, optional): A substring to search for in the response content. +- `action` (str or Callable or List[Callable], optional): The action to take when the condition is met. Currently supported actions: + - `"ignore"`: Ignore the response. + - a callable accepting and returning the response object. + - a list of callables, each accepting and returning the response object. + + +#### Example A ```py { @@ -772,12 +831,78 @@ This is an experimental feature and may change in future releases. In this example, the source will ignore responses with a status code of 404, responses with the content "Not found", and responses with a status code of 200 _and_ content "some text". -**Fields:** +#### Example B + +```py +def set_encoding(response, *args, **kwargs): + # sets the encoding in case it's not correctly detected + response.encoding = 'windows-1252' + return response + + +def add_and_remove_fields(response: Response, *args, **kwargs) -> Response: + payload = response.json() + for record in payload["data"]: + record["custom_field"] = "foobar" + record.pop("email", None) + modified_content: bytes = json.dumps(payload).encode("utf-8") + response._content = modified_content + return response + + +source_config = { + "client": { + # ... + }, + "resources": [ + { + "name": "issues", + "endpoint": { + "path": "issues", + "response_actions": [ + set_encoding, + { + "status_code": 200, + "content": "some text", + "action": add_and_remove_fields, + }, + ], + }, + }, + ], +} +``` + +In this example, the resource will set the correct encoding for all responses first. Thereafter, for all responses that have the status code 200, we will add a field `custom_field` and remove the field `email`. + +#### Example C + +```py +def set_encoding(response, *args, **kwargs): + # sets the encoding in case it's not correctly detected + response.encoding = 'windows-1252' + return response + +source_config = { + "client": { + # ... + }, + "resources": [ + { + "name": "issues", + "endpoint": { + "path": "issues", + "response_actions": [ + set_encoding, + ], + }, + }, + ], +} +``` + +In this example, the resource will set the correct encoding for all responses. More callables can be added to the list of response_actions. -- `status_code` (int, optional): The HTTP status code to match. -- `content` (str, optional): A substring to search for in the response content. -- `action` (str): The action to take when the condition is met. Currently supported actions: - - `ignore`: Ignore the response. ## Troubleshooting @@ -866,4 +991,4 @@ If experiencing 401 (Unauthorized) errors, this could indicate: The `rest_api` source uses the [RESTClient](../../general-usage/http/rest-client.md) class for HTTP requests. Refer to the RESTClient [troubleshooting guide](../../general-usage/http/rest-client.md#troubleshooting) for debugging tips. -For further assistance, join our [Slack community](https://dlthub.com/community). We're here to help! \ No newline at end of file +For further assistance, join our [Slack community](https://dlthub.com/community). We're here to help! diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index a5869e99bd..eeb717515a 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -301,6 +301,15 @@ With dataset above and local postgres instance, connectorx is 2x faster than pya #### Postgres / MSSQL No issues found. Postgres is the only backend where we observed 2x speedup with connector x. On other db systems it performs same as `pyarrrow` backend or slower. +### Notes on data types + +#### JSON +JSON data type is represented as Python object for the **sqlalchemy** backend and as JSON string for the **pyarrow** backend. Currently it does not work correctly +with **pandas** and **connector-x** which cast Python objects to str generating invalid JSON strings that cannot be loaded into destination. + +#### UUID +UUIDs are represented as string by default. You can switch that behavior by using table adapter callback and modifying properties of the UUID type for a particular column. + ## Incremental Loading Efficient data management often requires loading only new or updated data from your SQL databases, rather than reprocessing the entire dataset. This is where incremental loading comes into play. diff --git a/docs/website/docs/general-usage/credentials/config_providers.md b/docs/website/docs/general-usage/credentials/config_providers.md index 626263c332..3dbe88893b 100644 --- a/docs/website/docs/general-usage/credentials/config_providers.md +++ b/docs/website/docs/general-usage/credentials/config_providers.md @@ -30,9 +30,13 @@ providers. configuration values and secrets. `secrets.toml` is dedicated to sensitive information, while `config.toml` contains non-sensitive configuration data. -4. **Default Argument Values**: These are the values specified in the function's signature. +4. Custom Providers added with `register_provider`: These are your own Provider implementation + you can use to connect to any backend. See [adding custom providers](#adding-custom-providers) for more information. + +5. **Default Argument Values**: These are the values specified in the function's signature. They have the lowest priority in the provider hierarchy. + ### Example ```py @@ -79,6 +83,32 @@ secrets (to reduce the number of requests done by `dlt` when searching sections) Context-aware providers will activate in the right environments i.e. on Airflow or AWS/GCP VMachines. ::: +### Adding Custom Providers + +You can use the `CustomLoaderDocProvider` classes to supply a custom dictionary obtained from any source to dlt for use +as a source of `config` and `secret` values. The code below demonstrates how to use a config stored in config.json. + +```py +import dlt + +from dlt.common.configuration.providers import CustomLoaderDocProvider + +# create a function that loads a dict +def load_config(): + with open("config.json", "rb") as f: + config_dict = json.load(f) + +# create the custom provider +provider = CustomLoaderDocProvider("my_json_provider",load_config) + +# register provider +dlt.config.register_provider(provider) +``` + +:::tip +Check our nice [example](../../examples/custom_config_provider) for a `yaml` based config provider that supports switchable profiles. +::: + ## Provider key formats ### TOML vs. Environment Variables diff --git a/docs/website/docs/general-usage/http/overview.md b/docs/website/docs/general-usage/http/overview.md index 2d193ceb2c..7358e577f4 100644 --- a/docs/website/docs/general-usage/http/overview.md +++ b/docs/website/docs/general-usage/http/overview.md @@ -58,12 +58,12 @@ Note that we do not explicitly specify the pagination parameters in the example. ```py import dlt from dlt.sources.helpers.rest_client import RESTClient -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator github_client = RESTClient( base_url="https://pokeapi.co/api/v2", - paginator=JSONResponsePaginator(next_url_path="next"), # (1) - data_selector="results", # (2) + paginator=JSONLinkPaginator(next_url_path="next"), # (1) + data_selector="results", # (2) ) @dlt.resource @@ -86,6 +86,6 @@ print(load_info) ``` In the example above: -1. We create a `RESTClient` instance with the base URL of the API: in this case, the [PokéAPI](https://pokeapi.co/). We also specify the paginator to use explicitly: `JSONResponsePaginator` with the `next_url_path` set to `"next"`. This tells the paginator to look for the next page URL in the `next` key of the JSON response. +1. We create a `RESTClient` instance with the base URL of the API: in this case, the [PokéAPI](https://pokeapi.co/). We also specify the paginator to use explicitly: `JSONLinkPaginator` with the `next_url_path` set to `"next"`. This tells the paginator to look for the next page URL in the `next` key of the JSON response. 2. In `data_selector` we specify the JSON path to extract the data from the response. This is used to extract the data from the response JSON. 3. By default the number of items per page is limited to 20. We override this by specifying the `limit` parameter in the API call. diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index d3a06a1d28..ddd66a233b 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -1,7 +1,7 @@ --- title: RESTClient description: Learn how to use the RESTClient class to interact with RESTful APIs -keywords: [api, http, rest, request, extract, restclient, client, pagination, json, response, data_selector, session, auth, paginator, jsonresponsepaginator, headerlinkpaginator, offsetpaginator, jsonresponsecursorpaginator, queryparampaginator, bearer, token, authentication] +keywords: [api, http, rest, request, extract, restclient, client, pagination, json, response, data_selector, session, auth, paginator, JSONLinkPaginator, headerlinkpaginator, offsetpaginator, jsonresponsecursorpaginator, queryparampaginator, bearer, token, authentication] --- The `RESTClient` class offers an interface for interacting with RESTful APIs, including features like: @@ -16,13 +16,13 @@ This guide shows how to use the `RESTClient` class to read data from APIs, focus ```py from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.auth import BearerTokenAuth -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator client = RESTClient( base_url="https://api.example.com", headers={"User-Agent": "MyApp/1.0"}, auth=BearerTokenAuth(token="your_access_token_here"), # type: ignore - paginator=JSONResponsePaginator(next_url_path="pagination.next"), + paginator=JSONLinkPaginator(next_url_path="pagination.next"), data_selector="data", session=MyCustomSession() ) @@ -111,7 +111,7 @@ Each `PageData` instance contains the data for a single page, along with context Paginators are used to handle paginated responses. The `RESTClient` class comes with built-in paginators for common pagination mechanisms: -- [JSONResponsePaginator](#jsonresponsepaginator) - link to the next page is included in the JSON response. +- [JSONLinkPaginator](#JSONLinkPaginator) - link to the next page is included in the JSON response. - [HeaderLinkPaginator](#headerlinkpaginator) - link to the next page is included in the response headers. - [OffsetPaginator](#offsetpaginator) - pagination based on offset and limit query parameters. - [PageNumberPaginator](#pagenumberpaginator) - pagination based on page numbers. @@ -119,9 +119,9 @@ Paginators are used to handle paginated responses. The `RESTClient` class comes If the API uses a non-standard pagination, you can [implement a custom paginator](#implementing-a-custom-paginator) by subclassing the `BasePaginator` class. -#### JSONResponsePaginator +#### JSONLinkPaginator -`JSONResponsePaginator` is designed for APIs where the next page URL is included in the response's JSON body. This paginator uses a JSONPath to locate the next page URL within the JSON response. +`JSONLinkPaginator` is designed for APIs where the next page URL is included in the response's JSON body. This paginator uses a JSONPath to locate the next page URL within the JSON response. **Parameters:** @@ -144,15 +144,15 @@ Suppose the API response for `https://api.example.com/posts` looks like this: } ``` -To paginate this response, you can use the `JSONResponsePaginator` with the `next_url_path` set to `"pagination.next"`: +To paginate this response, you can use the `JSONLinkPaginator` with the `next_url_path` set to `"pagination.next"`: ```py from dlt.sources.helpers.rest_client import RESTClient -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator client = RESTClient( base_url="https://api.example.com", - paginator=JSONResponsePaginator(next_url_path="pagination.next") + paginator=JSONLinkPaginator(next_url_path="pagination.next") ) @dlt.resource @@ -306,7 +306,7 @@ client = RESTClient( ### Implementing a custom paginator -When working with APIs that use non-standard pagination schemes, or when you need more control over the pagination process, you can implement a custom paginator by subclassing the `BasePaginator` class and implementing `init_request`, `update_state` and `update_request` methods: +When working with APIs that use non-standard pagination schemes, or when you need more control over the pagination process, you can implement a custom paginator by subclassing the `BasePaginator` class and implementing the methods `init_request`, `update_state` and `update_request`. - `init_request(request: Request) -> None`: This method is called before making the first API call in the `RESTClient.paginate` method. You can use this method to set up the initial request query parameters, headers, etc. For example, you can set the initial page number or cursor value. @@ -566,7 +566,7 @@ client = RESTClient( ## Advanced usage -`RESTClient.paginate()` allows to specify a custom hook function that can be used to modify the response objects. For example, to handle specific HTTP status codes gracefully: +`RESTClient.paginate()` allows to specify a [custom hook function](https://requests.readthedocs.io/en/latest/user/advanced/#event-hooks) that can be used to modify the response objects. For example, to handle specific HTTP status codes gracefully: ```py def custom_response_handler(response): @@ -590,6 +590,22 @@ for page in paginate("https://api.example.com/posts"): print(page) ``` + +## Retry + +You can customize how the RESTClient retries failed requests by editing your `config.toml`. +See more examples and explanations in our [documentation on retry rules](requests#retry-rules). + +Example: + +```toml +[runtime] +request_max_attempts = 10 # Stop after 10 retry attempts instead of 5 +request_backoff_factor = 1.5 # Multiplier applied to the exponential delays. Default is 1 +request_timeout = 120 # Timeout in seconds +request_max_retry_delay = 30 # Cap exponential delay to 30 seconds +``` + ## Troubleshooting ### `RESTClient.get()` and `RESTClient.post()` methods @@ -625,11 +641,11 @@ and [response](https://docs.python-requests.org/en/latest/api/#requests.Response ```py from dlt.sources.helpers.rest_client import RESTClient -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator client = RESTClient( base_url="https://api.example.com", - paginator=JSONResponsePaginator(next_url_path="pagination.next") + paginator=JSONLinkPaginator(next_url_path="pagination.next") ) for page in client.paginate("/posts"): diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index b21a5779bc..b130f7a4f5 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -41,7 +41,7 @@ user's profile Stateless data cannot change - for example, a recorded event, suc Because stateless data does not need to be updated, we can just append it. -For stateful data, comes a second question - Can I extract it incrementally from the source? If yes, you should use [slowly changing dimensions (Type-2)](#scd2-strategy), which allow you to maintain historical records of data changes over time. +For stateful data, comes a second question - Can I extract it incrementally from the source? If yes, you should use [slowly changing dimensions (Type-2)](#scd2-strategy), which allow you to maintain historical records of data changes over time. If not, then we need to replace the entire data set. If however we can request the data incrementally such as "all users added or modified since yesterday" then we can simply apply changes to our existing @@ -49,9 +49,10 @@ dataset with the merge write disposition. ## Merge incremental loading -The `merge` write disposition can be used with two different strategies: +The `merge` write disposition can be used with three different strategies: 1) `delete-insert` (default strategy) 2) `scd2` +3) `upsert` ### `delete-insert` strategy @@ -391,6 +392,44 @@ must be unique for a root table. We are working to allow `updated_at` style trac column in the root table to stamp changes in nested data. * `merge_key(s)` are (for now) ignored. +### `upsert` strategy + +:::caution +The `upsert` merge strategy is currently supported for these destinations: +- `athena` +- `bigquery` +- `databricks` +- `mssql` +- `postgres` +- `snowflake` +- 🧪 `filesytem` with `delta` table format (see limitations [here](../dlt-ecosystem/destinations/filesystem.md#known-limitations)) +::: + +The `upsert` merge strategy does primary-key based *upserts*: +- *update* record if key exists in target table +- *insert* record if key does not exist in target table + +You can [delete records](#delete-records) with the `hard_delete` hint. + +#### `upsert` versus `delete-insert` + +Unlike the default `delete-insert` merge strategy, the `upsert` strategy: +1. needs a `primary_key` +2. expects this `primary_key` to be unique (`dlt` does not deduplicate) +3. does not support `merge_key` +4. uses `MERGE` or `UPDATE` operations to process updates + +#### Example: `upsert` merge strategy +```py +@dlt.resource( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="my_primary_key" +) +def my_upsert_resource(): + ... +... +``` + ## Incremental loading with a cursor field In most of the REST APIs (and other data sources i.e. database tables) you can request new or updated @@ -460,15 +499,14 @@ We just yield all the events and `dlt` does the filtering (using `id` column dec Github returns events ordered from newest to oldest. So we declare the `rows_order` as **descending** to [stop requesting more pages once the incremental value is out of range](#declare-row-order-to-not-request-unnecessary-data). We stop requesting more data from the API after finding the first event with `created_at` earlier than `initial_value`. :::note -**Note on Incremental Cursor Behavior:** -When using incremental cursors for loading data, it's essential to understand how `dlt` handles records in relation to the cursor's -last value. By default, `dlt` will load only those records for which the incremental cursor value is higher than the last known value of the cursor. -This means that any records with a cursor value lower than or equal to the last recorded value will be ignored during the loading process. -This behavior ensures efficiency by avoiding the reprocessing of records that have already been loaded, but it can lead to confusion if -there are expectations of loading older records that fall below the current cursor threshold. If your use case requires the inclusion of -such records, you can consider adjusting your data extraction logic, using a full refresh strategy where appropriate or using `last_value_func` as discussed in the subsquent section. -::: +`dlt.sources.incremental` is implemented as a [filter function](resource.md#filter-transform-and-pivot-data) that is executed **after** all other transforms +you add with `add_map` / `add_filter`. This means that you can manipulate the data item before incremental filter sees it. For example: +* you can create surrogate primary key from other columns +* you can modify cursor value or create a new field composed from other fields +* dump Pydantic models to Python dicts to allow incremental to find custost values +[Data validation with Pydantic](schema-contracts.md#use-pydantic-models-for-data-validation) happens **before** incremental filtering. +::: ### max, min or custom `last_value_func` @@ -840,6 +878,59 @@ Consider the example below for reading incremental loading parameters from "conf ``` `id_after` incrementally stores the latest `cursor_path` value for future pipeline runs. +### Loading NULL values in the incremental cursor field + +When loading incrementally with a cursor field, each row is expected to contain a value at the cursor field that is not `None`. +For example, the following source data will raise an error: +```py +@dlt.resource +def some_data(updated_at=dlt.sources.incremental("updated_at")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 2, "updated_at": 2}, + {"id": 3, "created_at": 4, "updated_at": None}, + ] + +list(some_data()) +``` + +If you want to load data that includes `None` values you can transform the records before the incremental processing. +You can add steps to the pipeline that [filter, transform, or pivot your data](../general-usage/resource.md#filter-transform-and-pivot-data). + +:::caution +It is important to set the `insert_at` parameter of the `add_map` function to control the order of the execution and ensure that your custom steps are executed before the incremental processing starts. +In the following example, the step of data yielding is at `index = 0`, the custom transformation at `index = 1`, and the incremental processing at `index = 2`. +::: + +See below how you can modify rows before the incremental processing using `add_map()` and filter rows using `add_filter()`. + +```py +@dlt.resource +def some_data(updated_at=dlt.sources.incremental("updated_at")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 2, "updated_at": 2}, + {"id": 3, "created_at": 4, "updated_at": None}, + ] + +def set_default_updated_at(record): + if record.get("updated_at") is None: + record["updated_at"] = record.get("created_at") + return record + +# modifies records before the incremental processing +with_default_values = some_data().add_map(set_default_updated_at, insert_at=1) +result = list(with_default_values) +assert len(result) == 3 +assert result[2]["updated_at"] == 4 + +# removes records before the incremental processing +without_none = some_data().add_filter(lambda r: r.get("updated_at") is not None, insert_at=1) +result_filtered = list(without_none) +assert len(result_filtered) == 2 +``` + + ## Doing a full refresh You may force a full refresh of a `merge` and `append` pipelines: diff --git a/docs/website/docs/general-usage/resource.md b/docs/website/docs/general-usage/resource.md index 14f8d73b58..20149cfa0b 100644 --- a/docs/website/docs/general-usage/resource.md +++ b/docs/website/docs/general-usage/resource.md @@ -13,9 +13,9 @@ resource, we add the `@dlt.resource` decorator to that function. Commonly used arguments: -- `name` The name of the table generated by this resource. Defaults to decorated function name. -- `write_disposition` How should the data be loaded at destination? Currently, supported: `append`, - `replace` and `merge`. Defaults to `append.` +- `name` The name of the table generated by this resource. Defaults to the decorated function name. +- `write_disposition` How should the data be loaded at the destination? Currently supported: `append`, + `replace`, and `merge`. Defaults to `append.` Example: @@ -47,41 +47,42 @@ function. `dlt` will infer [schema](schema.md) for tables associated with resources from the resource's data. You can modify the generation process by using the table and column hints. Resource decorator -accepts following arguments: +accepts the following arguments: -1. `table_name` the name of the table, if different from resource name. -1. `primary_key` and `merge_key` define name of the columns (compound keys are allowed) that will +1. `table_name` the name of the table, if different from the resource name. +1. `primary_key` and `merge_key` define the name of the columns (compound keys are allowed) that will receive those hints. Used in [incremental loading](incremental-loading.md). -1. `columns` let's you define one or more columns, including the data types, nullability and other - hints. The column definition is a `TypedDict`: `TTableSchemaColumns`. In example below, we tell - `dlt` that column `tags` (containing a list of tags) in `user` table should have type `complex` - which means that it will be loaded as JSON/struct and not as child table. +1. `columns` let's you define one or more columns, including the data types, nullability, and other + hints. The column definition is a `TypedDict`: `TTableSchemaColumns`. In the example below, we tell + `dlt` that column `tags` (containing a list of tags) in the `user` table should have type `complex`, + which means that it will be loaded as JSON/struct and not as a child table. ```py @dlt.resource(name="user", columns={"tags": {"data_type": "complex"}}) def get_users(): ... - # the `table_schema` method gets table schema generated by a resource + # the `table_schema` method gets the table schema generated by a resource print(get_users().compute_table_schema()) ``` -> 💡 You can pass dynamic hints which are functions that take the data item as input and return a -> hint value. This let's you create table and column schemas depending on the data. See example in -> next section. +:::note +You can pass dynamic hints which are functions that take the data item as input and return a +hint value. This lets you create table and column schemas depending on the data. See an [example below](#adjust-schema-when-you-yield-data). +::: -> 💡 You can mark some resource arguments as [configuration and credentials](credentials) -> values so `dlt` can pass them automatically to your functions. +:::tip +You can mark some resource arguments as [configuration and credentials](credentials) values so `dlt` can pass them automatically to your functions. +::: -### Put a contract on a tables, columns and data -Use the `schema_contract` argument to tell dlt how to [deal with new tables, data types and bad data types](schema-contracts.md). For example if you set it to **freeze**, `dlt` will not allow for any new tables, columns or data types to be introduced to the schema - it will raise an exception. Learn more in on available contract modes [here](schema-contracts.md#setting-up-the-contract) +### Put a contract on tables, columns, and data +Use the `schema_contract` argument to tell dlt how to [deal with new tables, data types, and bad data types](schema-contracts.md). For example, if you set it to **freeze**, `dlt` will not allow for any new tables, columns, or data types to be introduced to the schema - it will raise an exception. Learn more on available contract modes [here](schema-contracts.md#setting-up-the-contract) ### Define a schema with Pydantic You can alternatively use a [Pydantic](https://pydantic-docs.helpmanual.io/) model to define the schema. For example: - ```py from pydantic import BaseModel @@ -106,7 +107,7 @@ def get_users(): ... ``` -The data types of the table columns are inferred from the types of the pydantic fields. These use the same type conversions +The data types of the table columns are inferred from the types of the Pydantic fields. These use the same type conversions as when the schema is automatically generated from the data. Pydantic models integrate well with [schema contracts](schema-contracts.md) as data validators. @@ -114,8 +115,8 @@ Pydantic models integrate well with [schema contracts](schema-contracts.md) as d Things to note: - Fields with an `Optional` type are marked as `nullable` -- Fields with a `Union` type are converted to the first (not `None`) type listed in the union. E.g. `status: Union[int, str]` results in a `bigint` column. -- `list`, `dict` and nested pydantic model fields will use the `complex` type which means they'll be stored as a JSON object in the database instead of creating child tables. +- Fields with a `Union` type are converted to the first (not `None`) type listed in the union. For example, `status: Union[int, str]` results in a `bigint` column. +- `list`, `dict`, and nested Pydantic model fields will use the `complex` type which means they'll be stored as a JSON object in the database instead of creating child tables. You can override this by configuring the Pydantic model @@ -132,14 +133,14 @@ def get_users(): ``` `"skip_complex_types"` omits any `dict`/`list`/`BaseModel` type fields from the schema, so dlt will fall back on the default -behaviour of creating child tables for these fields. +behavior of creating child tables for these fields. -We do not support `RootModel` that validate simple types. You can add such validator yourself, see [data filtering section](#filter-transform-and-pivot-data). +We do not support `RootModel` that validate simple types. You can add such a validator yourself, see [data filtering section](#filter-transform-and-pivot-data). ### Dispatch data to many tables You can load data to many tables from a single resource. The most common case is a stream of events -of different types, each with different data schema. To deal with this, you can use `table_name` +of different types, each with different data schema. To deal with this, you can use the `table_name` argument on `dlt.resource`. You could pass the table name as a function with the data item as an argument and the `table_name` string as a return value. @@ -152,7 +153,7 @@ and `comment` events to separate tables. The type of the event is in the "type" def repo_events() -> Iterator[TDataItems]: yield item -# the `table_schema` method gets table schema generated by a resource and takes optional +# the `table_schema` method gets the table schema generated by a resource and takes an optional # data item to evaluate dynamic hints print(repo_events().compute_table_schema({"type": "WatchEvent", id:...})) ``` @@ -163,7 +164,7 @@ resource function: ```py @dlt.resource def repo_events() -> Iterator[TDataItems]: - # mark the "item" to be sent to table with name item["type"] + # mark the "item" to be sent to the table with the name item["type"] yield dlt.mark.with_table_name(item, item["type"]) ``` @@ -213,7 +214,7 @@ def users_details(user_item): # dlt figures out dependencies for you. pipeline.run(user_details) ``` -In the example above, `user_details` will receive data from default instance of `users` resource (with `limit` set to `None`). You can also use +In the example above, `user_details` will receive data from the default instance of the `users` resource (with `limit` set to `None`). You can also use **pipe |** operator to bind resources dynamically ```py # you can be more explicit and use a pipe operator. @@ -242,9 +243,9 @@ print(list([1,2] | pokemon())) ::: ### Declare a standalone resource -A standalone resource is defined on a function that is top level in a module (not inner function) that accepts config and secrets values. Additionally -if `standalone` flag is specified, the decorated function signature and docstring will be preserved. `dlt.resource` will just wrap the -decorated function and user must call the wrapper to get the actual resource. Below we declare a `filesystem` resource that must be called before use. +A standalone resource is defined on a function that is top level in a module (not an inner function) that accepts config and secrets values. Additionally, +if the `standalone` flag is specified, the decorated function signature and docstring will be preserved. `dlt.resource` will just wrap the +decorated function, and the user must call the wrapper to get the actual resource. Below we declare a `filesystem` resource that must be called before use. ```py @dlt.resource(standalone=True) def filesystem(bucket_url=dlt.config.value): @@ -255,7 +256,7 @@ def filesystem(bucket_url=dlt.config.value): pipeline.run(filesystem("s3://my-bucket/reports"), table_name="reports") ``` -Standalone may have dynamic name that depends on the arguments passed to the decorated function. For example:: +Standalone may have a dynamic name that depends on the arguments passed to the decorated function. For example: ```py @dlt.resource(standalone=True, name=lambda args: args["stream_name"]) def kinesis(stream_name: str): @@ -301,15 +302,15 @@ Please find more details in [extract performance](../reference/performance.md#ex ### Filter, transform and pivot data -You can attach any number of transformations that are evaluated on item per item basis to your +You can attach any number of transformations that are evaluated on an item per item basis to your resource. The available transformation types: -- map - transform the data item (`resource.add_map`). -- filter - filter the data item (`resource.add_filter`). -- yield map - a map that returns iterator (so single row may generate many rows - +- **map** - transform the data item (`resource.add_map`). +- **filter** - filter the data item (`resource.add_filter`). +- **yield map** - a map that returns an iterator (so a single row may generate many rows - `resource.add_yield_map`). -Example: We have a resource that loads a list of users from an api endpoint. We want to customize it +Example: We have a resource that loads a list of users from an API endpoint. We want to customize it so: 1. We remove users with `user_id == "me"`. @@ -350,8 +351,8 @@ and generate child tables for all nested lists, without limit. :::note `max_table_nesting` is optional so you can skip it, in this case dlt will -use it from the source if it is specified there or fallback to default -value which has 1000 as maximum nesting level. +use it from the source if it is specified there or fallback to the default +value which has 1000 as the maximum nesting level. ::: ```py @@ -378,13 +379,13 @@ def my_resource(): } ``` -In the example above we want only 1 level of child tables to be generated (so there are no child +In the example above, we want only 1 level of child tables to be generated (so there are no child tables of child tables). Typical settings: - `max_table_nesting=0` will not generate child tables at all and all nested data will be - represented as json. -- `max_table_nesting=1` will generate child tables of top level tables and nothing more. All nested - data in child tables will be represented as json. + represented as JSON. +- `max_table_nesting=1` will generate child tables of top-level tables and nothing more. All nested + data in child tables will be represented as JSON. You can achieve the same effect after the resource instance is created: @@ -401,38 +402,43 @@ produces the clearest and human-readable schemas. ### Sample from large data -If your resource loads thousands of pages of data from a REST API or millions of rows from a db -table, you may want to just sample a fragment of it in order i.e. to quickly see the dataset with -example data and test your transformations etc. In order to do that, you limit how many items will -be yielded by a resource by calling `resource.add_limit` method. In the example below we load just -10 first items from and infinite counter - that would otherwise never end. +If your resource loads thousands of pages of data from a REST API or millions of rows from a db table, you may want to just sample a fragment of it in order to quickly see the dataset with example data and test your transformations, etc. In order to do that, you limit how many items will be yielded by a resource (or source) by calling the `add_limit` method. This method will close the generator which produces the data after the limit is reached. + +In the example below, we load just 10 first items from an infinite counter - that would otherwise never end. ```py r = dlt.resource(itertools.count(), name="infinity").add_limit(10) assert list(r) == list(range(10)) ``` -> 💡 We are not skipping any items. We are closing the iterator/generator that produces data after -> limit is reached. +:::note +Note that `add_limit` **does not limit the number of records** but rather the "number of yields". Depending on how your resource is set up, the number of extracted rows may vary. For example, consider this resource: + +```py +@dlt.resource +def my_resource(): + for i in range(100): + yield [{"record_id": j} for j in range(15)] + +dlt.pipeline(destination="duckdb").run(my_resource().add_limit(10)) +``` +The code above will extract `15*10=150` records. This is happening because in each iteration, 15 records are yielded, and we're limiting the number of iterations to 10. +::: -> 💡 You cannot limit transformers. They should process all the data they receive fully to avoid -> inconsistencies in generated datasets. +Some constraints of `add_limit` include: -> 💡 If you are paremetrizing the value of `add_limit` and sometimes need it to be disabled, you can set `None` or `-1` -> to disable the limiting. You can also set the limit to `0` for the resource to not yield any items. +1. `add_limit` does not skip any items. It closes the iterator/generator that produces data after the limit is reached. +1. You cannot limit transformers. They should process all the data they receive fully to avoid inconsistencies in generated datasets. +1. Async resources with a limit added may occasionally produce one item more than the limit on some runs. This behavior is not deterministic. -> 💡 For internal reasons, async resources with a limit added, occassionally produce one item more than the limit -> on some runs. This behavior is not deterministic. +:::tip +If you are parameterizing the value of `add_limit` and sometimes need it to be disabled, you can set `None` or `-1` to disable the limiting. +You can also set the limit to `0` for the resource to not yield any items. +::: ### Set table name and adjust schema -You can change the schema of a resource, be it standalone or as a part of a source. Look for method -named `apply_hints` which takes the same arguments as resource decorator. Obviously you should call -this method before data is extracted from the resource. Example below converts an `append` resource -loading the `users` table into [merge](incremental-loading.md#merge-incremental_loading) resource -that will keep just one updated record per `user_id`. It also adds -["last value" incremental loading](incremental-loading.md#incremental_loading-with-last-value) on -`created_at` column to prevent requesting again the already loaded records: +You can change the schema of a resource, be it standalone or as a part of a source. Look for a method named `apply_hints` which takes the same arguments as the resource decorator. Obviously, you should call this method before data is extracted from the resource. The example below converts an `append` resource loading the `users` table into a [merge](incremental-loading.md#merge-incremental_loading) resource that will keep just one updated record per `user_id`. It also adds ["last value" incremental loading](incremental-loading.md#incremental_loading-with-last-value) on the `created_at` column to prevent requesting again the already loaded records: ```py tables = sql_database() @@ -444,7 +450,7 @@ tables.users.apply_hints( pipeline.run(tables) ``` -To just change a name of a table to which resource will load data, do the following: +To just change the name of a table to which the resource will load data, do the following: ```py tables = sql_database() tables.users.table_name = "other_users" @@ -452,15 +458,12 @@ tables.users.table_name = "other_users" ### Adjust schema when you yield data -You can set or update the table name, columns and other schema elements when your resource is executed and you already yield data. Such changes will be merged -with the existing schema in the same way `apply_hints` method above works. There are many reason to adjust schema at runtime. For example when using Airflow, you -should avoid lengthy operations (ie. reflecting database tables) during creation of the DAG so it is better do do it when DAG executes. You may also emit partial -hints (ie. precision and scale for decimal types) for column to help `dlt` type inference. +You can set or update the table name, columns, and other schema elements when your resource is executed and you already yield data. Such changes will be merged with the existing schema in the same way the `apply_hints` method above works. There are many reasons to adjust the schema at runtime. For example, when using Airflow, you should avoid lengthy operations (i.e. reflecting database tables) during the creation of the DAG, so it is better to do it when the DAG executes. You may also emit partial hints (i.e. precision and scale for decimal types) for columns to help `dlt` type inference. ```py @dlt.resource def sql_table(credentials, schema, table): - # create sql alchemy engine + # create a SQL Alchemy engine engine = engine_from_credentials(credentials) engine.execution_options(stream_results=True) metadata = MetaData(schema=schema) @@ -469,7 +472,7 @@ def sql_table(credentials, schema, table): for idx, batch in enumerate(table_rows(engine, table_obj)): if idx == 0: - # emit first row with hints, table_to_columns and get_primary_key are helpers that extract dlt schema from + # emit the first row with hints, table_to_columns and get_primary_key are helpers that extract dlt schema from # SqlAlchemy model yield dlt.mark.with_hints( batch, @@ -481,16 +484,14 @@ def sql_table(credentials, schema, table): ``` -In the example above we use `dlt.mark.with_hints` and `dlt.mark.make_hints` to emit columns and primary key with the first extracted item. Table schema will -be adjusted after the `batch` is processed in the extract pipeline but before any schema contracts are applied and data is persisted in load package. +In the example above, we use `dlt.mark.with_hints` and `dlt.mark.make_hints` to emit columns and primary key with the first extracted item. The table schema will be adjusted after the `batch` is processed in the extract pipeline but before any schema contracts are applied and data is persisted in the load package. :::tip -You can emit columns as Pydantic model and use dynamic hints (ie. lambda for table name) as well. You should avoid redefining `Incremental` this way. +You can emit columns as a Pydantic model and use dynamic hints (i.e. lambda for table name) as well. You should avoid redefining `Incremental` this way. ::: ### Import external files -You can import external files ie. `csv`, `parquet` and `jsonl` by yielding items marked with `with_file_import`, optionally passing table schema corresponding -the the imported file. `dlt` will not read, parse and normalize any names (ie. `csv` or `arrow` headers) and will attempt to copy the file into the destination as is. +You can import external files i.e. `csv`, `parquet`, and `jsonl` by yielding items marked with `with_file_import`, optionally passing a table schema corresponding to the imported file. `dlt` will not read, parse, and normalize any names (i.e. `csv` or `arrow` headers) and will attempt to copy the file into the destination as is. ```py import os import dlt @@ -510,22 +511,22 @@ import_folder = "/tmp/import" @dlt.transformer(columns=columns) def orders(items: Iterator[FileItemDict]): for item in items: - # copy file locally + # copy the file locally dest_file = os.path.join(import_folder, item["file_name"]) - # download file + # download the file item.fsspec.download(item["file_url"], dest_file) # tell dlt to import the dest_file as `csv` yield dlt.mark.with_file_import(dest_file, "csv") -# use filesystem verified source to glob a bucket +# use the filesystem verified source to glob a bucket downloader = filesystem( bucket_url="s3://my_bucket/csv", file_glob="today/*.csv.gz") | orders info = pipeline.run(orders, destination="snowflake") ``` -In the example above, we glob all zipped csv files present on **my_bucket/csv/today** (using `filesystem` verified source) and send file descriptors to `orders` transformer. Transformer downloads and imports the files into extract package. At the end, `dlt` sends them to snowflake (the table will be created because we use `column` hints to define the schema). +In the example above, we glob all zipped csv files present on **my_bucket/csv/today** (using the `filesystem` verified source) and send file descriptors to the `orders` transformer. The transformer downloads and imports the files into the extract package. At the end, `dlt` sends them to Snowflake (the table will be created because we use `column` hints to define the schema). If imported `csv` files are not in `dlt` [default format](../dlt-ecosystem/file-formats/csv.md#default-settings), you may need to pass additional configuration. ```toml @@ -535,15 +536,15 @@ include_header=false on_error_continue=true ``` -You can sniff the schema from the data ie. using `duckdb` to infer the table schema from `csv` file. `dlt.mark.with_file_import` accepts additional arguments that you can use to pass hints at run time. +You can sniff the schema from the data i.e. using `duckdb` to infer the table schema from a `csv` file. `dlt.mark.with_file_import` accepts additional arguments that you can use to pass hints at runtime. :::note -* If you do not define any columns, the table will not be created in the destination. `dlt` will still attempt to load data into it, so you create a fitting table upfront, the load process will succeed. -* Files are imported using hard links if possible to avoid copying and duplicating storage space needed. +* If you do not define any columns, the table will not be created in the destination. `dlt` will still attempt to load data into it, so if you create a fitting table upfront, the load process will succeed. +* Files are imported using hard links if possible to avoid copying and duplicating the storage space needed. ::: ### Duplicate and rename resources -There are cases when you your resources are generic (ie. bucket filesystem) and you want to load several instances of it (ie. files from different folders) to separate tables. In example below we use `filesystem` source to load csvs from two different folders into separate tables: +There are cases when your resources are generic (i.e. bucket filesystem) and you want to load several instances of it (i.e. files from different folders) to separate tables. In the example below, we use the `filesystem` source to load csvs from two different folders into separate tables: ```py @dlt.resource(standalone=True) def filesystem(bucket_url): @@ -552,11 +553,11 @@ def filesystem(bucket_url): @dlt.transformer def csv_reader(file_item): - # load csv, parse and yield rows in file_item + # load csv, parse, and yield rows in file_item ... -# create two extract pipes that list files from the bucket and send to them to the reader. -# by default both pipes will load data to the same table (csv_reader) +# create two extract pipes that list files from the bucket and send them to the reader. +# by default, both pipes will load data to the same table (csv_reader) reports_pipe = filesystem("s3://my-bucket/reports") | load_csv() transactions_pipe = filesystem("s3://my-bucket/transactions") | load_csv() @@ -566,14 +567,11 @@ pipeline.run( ) ``` -`with_name` method returns a deep copy of the original resource, its data pipe and the data pipes of a parent resources. A renamed clone is fully separated from the original resource (and other clones) when loading: - it maintains a separate [resource state](state.md#read-and-write-pipeline-state-in-a-resource) and will load to a table +The `with_name` method returns a deep copy of the original resource, its data pipe, and the data pipes of a parent resource. A renamed clone is fully separated from the original resource (and other clones) when loading: it maintains a separate [resource state](state.md#read-and-write-pipeline-state-in-a-resource) and will load to a table ## Load resources -You can pass individual resources or list of resources to the `dlt.pipeline` object. The resources -loaded outside the source context, will be added to the [default schema](schema.md) of the -pipeline. +You can pass individual resources or a list of resources to the `dlt.pipeline` object. The resources loaded outside the source context will be added to the [default schema](schema.md) of the pipeline. ```py @dlt.resource(name='table_name', write_disposition='replace') @@ -586,7 +584,7 @@ pipeline = dlt.pipeline( destination="duckdb", dataset_name="rows_data" ) -# load individual resource +# load an individual resource pipeline.run(generate_rows(10)) # load a list of resources pipeline.run([generate_rows(10), generate_rows(20)]) @@ -600,15 +598,15 @@ def generate_rows(nr): for i in range(nr): yield {'id':i, 'example_string':'abc'} ``` -Resource above will be saved and loaded from a `parquet` file (if destination supports it). +The resource above will be saved and loaded from a `parquet` file (if the destination supports it). :::note -A special `file_format`: **preferred** will load resource using a format that is preferred by a destination. This settings supersedes the `loader_file_format` passed to `run` method. +A special `file_format`: **preferred** will load the resource using a format that is preferred by a destination. This setting supersedes the `loader_file_format` passed to the `run` method. ::: ### Do a full refresh -To do a full refresh of an `append` or `merge` resources you set the `refresh` argument on `run` method to `drop_data`. This will truncate the tables without dropping them. +To do a full refresh of an `append` or `merge` resource, you set the `refresh` argument on the `run` method to `drop_data`. This will truncate the tables without dropping them. ```py p.run(merge_source(), refresh="drop_data") diff --git a/docs/website/docs/general-usage/source.md b/docs/website/docs/general-usage/source.md index bcdd137dce..936a3160f0 100644 --- a/docs/website/docs/general-usage/source.md +++ b/docs/website/docs/general-usage/source.md @@ -100,7 +100,7 @@ Find more on transforms [here](resource.md#filter-transform-and-pivot-data). You can limit the number of items produced by each resource by calling a `add_limit` method on a source. This is useful for testing, debugging and generating sample datasets for experimentation. You can easily get your test dataset in a few minutes, when otherwise you'd need to wait hours for -the full loading to complete. Below we limit the `pipedrive` source to just get 10 pages of data +the full loading to complete. Below we limit the `pipedrive` source to just get **10 pages** of data from each endpoint. Mind that the transformers will be evaluated fully: ```py @@ -111,6 +111,10 @@ load_info = pipeline.run(pipedrive_source().add_limit(10)) print(load_info) ``` +:::note +Note that `add_limit` **does not limit the number of records** but rather the "number of yields". `dlt` will close the iterator/generator that produces data after the limit is reached. +::: + Find more on sampling data [here](resource.md#sample-from-large-data). ### Add more resources to existing source diff --git a/docs/website/docs/general-usage/state.md b/docs/website/docs/general-usage/state.md index 4a9e453ea4..b34d37c8b1 100644 --- a/docs/website/docs/general-usage/state.md +++ b/docs/website/docs/general-usage/state.md @@ -96,7 +96,7 @@ about the pipeline, pipeline run (that the state belongs to) and state blob. if you are not able to implement it with the standard incremental construct. - Store the custom fields dictionaries, dynamic configurations and other source-scoped state. -## When not to use pipeline state +## Do not use pipeline state if it can grow to millions of records Do not use dlt state when it may grow to millions of elements. Do you plan to store modification timestamps of all of your millions of user records? This is probably a bad idea! In that case you @@ -109,6 +109,39 @@ could: [sqlclient](../dlt-ecosystem/transformations/sql.md) and load the data of interest. In that case try at least to process your user records in batches. +### Access data in the destination instead of pipeline state +In the example below, we load recent comments made by given `user_id`. We access `user_comments` table to select +maximum comment id for a given user. +```py +import dlt + +@dlt.resource(name="user_comments") +def comments(user_id: str): + current_pipeline = dlt.current.pipeline() + # find last comment id for given user_id by looking in destination + max_id: int = 0 + # on first pipeline run, user_comments table does not yet exist so do not check at all + # alternatively catch DatabaseUndefinedRelation which is raised when unknown table is selected + if not current_pipeline.first_run: + with current_pipeline.sql_client() as client: + # we may get last user comment or None which we replace with 0 + max_id = ( + client.execute_sql( + "SELECT MAX(_id) FROM user_comments WHERE user_id=?", user_id + )[0][0] + or 0 + ) + # use max_id to filter our results (we simulate API query) + yield from [ + {"_id": i, "value": letter, "user_id": user_id} + for i, letter in zip([1, 2, 3], ["A", "B", "C"]) + if i > max_id + ] +``` +When pipeline is first run, the destination dataset and `user_comments` table do not yet exist. We skip the destination +query by using `first_run` property of the pipeline. We also handle a situation where there are no comments for a user_id +by replacing None with 0 as `max_id`. + ## Inspect the pipeline state You can inspect pipeline state with diff --git a/docs/website/docs/intro.md b/docs/website/docs/intro.md index 0374802b7d..c269c987b8 100644 --- a/docs/website/docs/intro.md +++ b/docs/website/docs/intro.md @@ -17,6 +17,10 @@ from various and often messy data sources into well-structured, live datasets. T ```sh pip install dlt ``` +:::tip +We recommend using a clean virtual environment for your experiments! Here are [detailed instructions](/reference/installation). +::: + Unlike other solutions, with dlt, there's no need to use any backends or containers. Simply import `dlt` in a Python file or a Jupyter Notebook cell, and create a pipeline to load data into any of the [supported destinations](dlt-ecosystem/destinations/). You can load data from any source that produces Python data structures, including APIs, files, databases, and more. `dlt` also supports building a [custom destination](dlt-ecosystem/destinations/destination.md), which you can use as reverse ETL. The library will create or update tables, infer data types, and handle nested data automatically. Here are a few example pipelines: diff --git a/docs/website/docs/reference/installation.md b/docs/website/docs/reference/installation.md index a23ce82c97..8fd80e52ff 100644 --- a/docs/website/docs/reference/installation.md +++ b/docs/website/docs/reference/installation.md @@ -6,21 +6,21 @@ keywords: [installation, environment, pip install] # Installation -## Set up environment +## Setting up your environment -### Make sure you are using **Python 3.8-3.12** and have `pip` installed +### 1. Make sure you are using **Python 3.8-3.12** and have `pip` installed ```sh python --version pip --version ``` -### If not, then please follow the instructions below to install it +If you have a different python version installed or are missing pip, follow the instructions below to update your python version and/or install `pip`. -You can install Python 3.10 with an `apt` command. +You can install Python 3.10 with `apt`. ```sh sudo apt update @@ -31,7 +31,7 @@ sudo apt install python3.10-venv -Once you have installed [Homebrew](https://brew.sh), you can install Python 3.10. +On MacOS you can use [Homebrew](https://brew.sh) to install Python 3.10. ```sh brew update @@ -41,8 +41,7 @@ brew install python@3.10 -You need to install [Python 3.10 (64-bit version) for Windows](https://www.python.org/downloads/windows/). -After this, you can then install `pip`. +After installing [Python 3.10 (64-bit version) for Windows](https://www.python.org/downloads/windows/), you can install `pip`. ```sh C:\> pip3 install -U pip @@ -51,13 +50,16 @@ C:\> pip3 install -U pip -### Once Python is installed, you should create virtual environment +### 2. Set up and activate a virtual environment for your python project + +We recommend working within a [virtual environment](https://docs.python.org/3/library/venv.html) when creating python projects. +This way all the dependencies for your current project will be isolated from packages in other projects. -Create a new virtual environment by making a `./env` directory to hold it. +Create a new virtual environment in your working folder. This will create an `./env` directory where your virtual environment will be stored: ```sh python -m venv ./env @@ -72,7 +74,7 @@ source ./env/bin/activate -Create a new virtual environment by making a `./env` directory to hold it. +Create a new virtual environment in your working folder. This will create an `./env` directory where your virtual environment will be stored: ```sh python -m venv ./env @@ -87,7 +89,7 @@ source ./env/bin/activate -Create a new virtual environment by making a `./env` directory to hold it. +Create a new virtual environment in your working folder. This will create an `./env` directory where your virtual environment will be stored: ```bat C:\> python -m venv ./env @@ -102,15 +104,24 @@ C:\> .\env\Scripts\activate -## Install `dlt` library +### 3. Install `dlt` library -You can install `dlt` in your virtual environment by running: +You can now install `dlt` in your virtual environment by running: ```sh +# install the newest dlt version or upgrade the exisint version to the newest one pip install -U dlt ``` -## Install dlt via Pixi and Conda +Other installation examples: +```sh +# install dlt with support for duckdb +pip install "dlt[duckdb]" +# install dlt version smaller than 0.5.0 +pip install "dlt<0.5.0" +``` + +### 3.1. Install dlt via Pixi or Conda Install dlt using `pixi`: @@ -123,3 +134,7 @@ Install dlt using `conda`: ```sh conda install -c conda-forge dlt ``` + +### 4. Done! + +You are now ready to [build your first pipeline](../getting-started) :) \ No newline at end of file diff --git a/docs/website/tools/preprocess_docs.js b/docs/website/tools/preprocess_docs.js index edc3d5c021..28b1d11474 100644 --- a/docs/website/tools/preprocess_docs.js +++ b/docs/website/tools/preprocess_docs.js @@ -14,6 +14,8 @@ const DOCS_EXTENSIONS = [".md", ".mdx"]; const SNIPPETS_FILE_SUFFIX = "-snippets.py" +const NUM_TUBA_LINKS = 10; + // examples settings const EXAMPLES_DESTINATION_DIR = `./${MD_TARGET_DIR}examples/`; const EXAMPLES_SOURCE_DIR = "../examples/"; @@ -175,11 +177,18 @@ function insertTubaLinks(lines) { for (let line of lines) { if (line.includes(TUBA_MARKER)) { const tubaTag = extractMarkerContent(TUBA_MARKER, line); - const links = tubaConfig.filter((link) => link.tags.includes(tubaTag)); + let links = tubaConfig.filter((link) => link.tags.includes(tubaTag)); if (links.length > 0) { result.push("## Additional Setup guides") - for (const link of links) { + // shuffle links + links = links.sort(() => 0.5 - Math.random()); + let count = 0; + for (const link of links) { result.push(`- [${link.title}](${link.public_url})`) + count += 1; + if (count >= NUM_TUBA_LINKS) { + break; + } } } else { // we could warn here, but it is a bit too verbose diff --git a/poetry.lock b/poetry.lock index a7d754f5a8..d54a73a2ef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -3725,6 +3725,106 @@ files = [ {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:222fc2ee0e40522de0b21ad3bc90ab8983be3bf3cec3d349c80d76c8bb1a4beb"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d4763b0b9195b72132a4e7de8e5a9bf1f05542f442a9115aa27cfc2a8004f581"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:209649da10c9d4a93d8a4d100ecbf9cc3b0252169426bec3e8b4ad7e57d600cf"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:68813aa333c1604a2df4a495b2a6ed065d7c8aebf26cc7e7abb5a6835d08353c"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:370a23ec775ad14e9d1e71474d56f381224dcf3e72b15d8ca7b4ad7dd9cd5853"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:14664a66a3ddf6bc9e56f401bf029db2d169982c53eff3f5876399104df0e9a6"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ea3722cc4932cbcebd553b69dce1b4a73572823cff4e6a244f1c855da21d511"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e14bb264c40fd7c627ef5678e295370cd6ba95ca71d835798b6e37502fc4c690"}, + {file = "google_re2-1.1-5-cp310-cp310-win32.whl", hash = "sha256:39512cd0151ea4b3969c992579c79b423018b464624ae955be685fc07d94556c"}, + {file = "google_re2-1.1-5-cp310-cp310-win_amd64.whl", hash = "sha256:ac66537aa3bc5504320d922b73156909e3c2b6da19739c866502f7827b3f9fdf"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b5ea68d54890c9edb1b930dcb2658819354e5d3f2201f811798bbc0a142c2b4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:33443511b6b83c35242370908efe2e8e1e7cae749c766b2b247bf30e8616066c"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:413d77bdd5ba0bfcada428b4c146e87707452ec50a4091ec8e8ba1413d7e0619"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:5171686e43304996a34baa2abcee6f28b169806d0e583c16d55e5656b092a414"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b284db130283771558e31a02d8eb8fb756156ab98ce80035ae2e9e3a5f307c4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:296e6aed0b169648dc4b870ff47bd34c702a32600adb9926154569ef51033f47"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:38d50e68ead374160b1e656bbb5d101f0b95fb4cc57f4a5c12100155001480c5"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a0416a35921e5041758948bcb882456916f22845f66a93bc25070ef7262b72a"}, + {file = "google_re2-1.1-5-cp311-cp311-win32.whl", hash = "sha256:a1d59568bbb5de5dd56dd6cdc79907db26cce63eb4429260300c65f43469e3e7"}, + {file = "google_re2-1.1-5-cp311-cp311-win_amd64.whl", hash = "sha256:72f5a2f179648b8358737b2b493549370debd7d389884a54d331619b285514e3"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cbc72c45937b1dc5acac3560eb1720007dccca7c9879138ff874c7f6baf96005"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5fadd1417fbef7235fa9453dba4eb102e6e7d94b1e4c99d5fa3dd4e288d0d2ae"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:040f85c63cc02696485b59b187a5ef044abe2f99b92b4fb399de40b7d2904ccc"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:64e3b975ee6d9bbb2420494e41f929c1a0de4bcc16d86619ab7a87f6ea80d6bd"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8ee370413e00f4d828eaed0e83b8af84d7a72e8ee4f4bd5d3078bc741dfc430a"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:5b89383001079323f693ba592d7aad789d7a02e75adb5d3368d92b300f5963fd"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63cb4fdfbbda16ae31b41a6388ea621510db82feb8217a74bf36552ecfcd50ad"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ebedd84ae8be10b7a71a16162376fd67a2386fe6361ef88c622dcf7fd679daf"}, + {file = "google_re2-1.1-5-cp312-cp312-win32.whl", hash = "sha256:c8e22d1692bc2c81173330c721aff53e47ffd3c4403ff0cd9d91adfd255dd150"}, + {file = "google_re2-1.1-5-cp312-cp312-win_amd64.whl", hash = "sha256:5197a6af438bb8c4abda0bbe9c4fbd6c27c159855b211098b29d51b73e4cbcf6"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b6727e0b98417e114b92688ad2aa256102ece51f29b743db3d831df53faf1ce3"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:711e2b6417eb579c61a4951029d844f6b95b9b373b213232efd413659889a363"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:71ae8b3df22c5c154c8af0f0e99d234a450ef1644393bc2d7f53fc8c0a1e111c"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:94a04e214bc521a3807c217d50cf099bbdd0c0a80d2d996c0741dbb995b5f49f"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:a770f75358508a9110c81a1257721f70c15d9bb592a2fb5c25ecbd13566e52a5"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:07c9133357f7e0b17c6694d5dcb82e0371f695d7c25faef2ff8117ef375343ff"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:204ca6b1cf2021548f4a9c29ac015e0a4ab0a7b6582bf2183d838132b60c8fda"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0b95857c2c654f419ca684ec38c9c3325c24e6ba7d11910a5110775a557bb18"}, + {file = "google_re2-1.1-5-cp38-cp38-win32.whl", hash = "sha256:347ac770e091a0364e822220f8d26ab53e6fdcdeaec635052000845c5a3fb869"}, + {file = "google_re2-1.1-5-cp38-cp38-win_amd64.whl", hash = "sha256:ec32bb6de7ffb112a07d210cf9f797b7600645c2d5910703fa07f456dd2150e0"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb5adf89060f81c5ff26c28e261e6b4997530a923a6093c9726b8dec02a9a326"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a22630c9dd9ceb41ca4316bccba2643a8b1d5c198f21c00ed5b50a94313aaf10"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:544dc17fcc2d43ec05f317366375796351dec44058e1164e03c3f7d050284d58"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:19710af5ea88751c7768575b23765ce0dfef7324d2539de576f75cdc319d6654"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:f82995a205e08ad896f4bd5ce4847c834fab877e1772a44e5f262a647d8a1dec"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:63533c4d58da9dc4bc040250f1f52b089911699f0368e0e6e15f996387a984ed"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79e00fcf0cb04ea35a22b9014712d448725ce4ddc9f08cc818322566176ca4b0"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bc41afcefee2da6c4ed883a93d7f527c4b960cd1d26bbb0020a7b8c2d341a60a"}, + {file = "google_re2-1.1-5-cp39-cp39-win32.whl", hash = "sha256:486730b5e1f1c31b0abc6d80abe174ce4f1188fe17d1b50698f2bf79dc6e44be"}, + {file = "google_re2-1.1-5-cp39-cp39-win_amd64.whl", hash = "sha256:4de637ca328f1d23209e80967d1b987d6b352cd01b3a52a84b4d742c69c3da6c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:621e9c199d1ff0fdb2a068ad450111a84b3bf14f96dfe5a8a7a0deae5f3f4cce"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:220acd31e7dde95373f97c3d1f3b3bd2532b38936af28b1917ee265d25bebbf4"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:db34e1098d164f76251a6ece30e8f0ddfd65bb658619f48613ce71acb3f9cbdb"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:5152bac41d8073977582f06257219541d0fc46ad99b0bbf30e8f60198a43b08c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:6191294799e373ee1735af91f55abd23b786bdfd270768a690d9d55af9ea1b0d"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:070cbafbb4fecbb02e98feb28a1eb292fb880f434d531f38cc33ee314b521f1f"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8437d078b405a59a576cbed544490fe041140f64411f2d91012e8ec05ab8bf86"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f00f9a9af8896040e37896d9b9fc409ad4979f1ddd85bb188694a7d95ddd1164"}, + {file = "google_re2-1.1-6-cp310-cp310-win32.whl", hash = "sha256:df26345f229a898b4fd3cafd5f82259869388cee6268fc35af16a8e2293dd4e5"}, + {file = "google_re2-1.1-6-cp310-cp310-win_amd64.whl", hash = "sha256:3665d08262c57c9b28a5bdeb88632ad792c4e5f417e5645901695ab2624f5059"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b26b869d8aa1d8fe67c42836bf3416bb72f444528ee2431cfb59c0d3e02c6ce3"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:41fd4486c57dea4f222a6bb7f1ff79accf76676a73bdb8da0fcbd5ba73f8da71"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:0ee378e2e74e25960070c338c28192377c4dd41e7f4608f2688064bd2badc41e"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:a00cdbf662693367b36d075b29feb649fd7ee1b617cf84f85f2deebeda25fc64"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4c09455014217a41499432b8c8f792f25f3df0ea2982203c3a8c8ca0e7895e69"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6501717909185327935c7945e23bb5aa8fc7b6f237b45fe3647fa36148662158"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3510b04790355f199e7861c29234081900e1e1cbf2d1484da48aa0ba6d7356ab"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8c0e64c187ca406764f9e9ad6e750d62e69ed8f75bf2e865d0bfbc03b642361c"}, + {file = "google_re2-1.1-6-cp311-cp311-win32.whl", hash = "sha256:2a199132350542b0de0f31acbb3ca87c3a90895d1d6e5235f7792bb0af02e523"}, + {file = "google_re2-1.1-6-cp311-cp311-win_amd64.whl", hash = "sha256:83bdac8ceaece8a6db082ea3a8ba6a99a2a1ee7e9f01a9d6d50f79c6f251a01d"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:81985ff894cd45ab5a73025922ac28c0707759db8171dd2f2cc7a0e856b6b5ad"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5635af26065e6b45456ccbea08674ae2ab62494008d9202df628df3b267bc095"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:813b6f04de79f4a8fdfe05e2cb33e0ccb40fe75d30ba441d519168f9d958bd54"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:5ec2f5332ad4fd232c3f2d6748c2c7845ccb66156a87df73abcc07f895d62ead"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5a687b3b32a6cbb731647393b7c4e3fde244aa557f647df124ff83fb9b93e170"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:39a62f9b3db5d3021a09a47f5b91708b64a0580193e5352751eb0c689e4ad3d7"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ca0f0b45d4a1709cbf5d21f355e5809ac238f1ee594625a1e5ffa9ff7a09eb2b"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64b3796a7a616c7861247bd061c9a836b5caf0d5963e5ea8022125601cf7b09"}, + {file = "google_re2-1.1-6-cp312-cp312-win32.whl", hash = "sha256:32783b9cb88469ba4cd9472d459fe4865280a6b1acdad4480a7b5081144c4eb7"}, + {file = "google_re2-1.1-6-cp312-cp312-win_amd64.whl", hash = "sha256:259ff3fd2d39035b9cbcbf375995f83fa5d9e6a0c5b94406ff1cc168ed41d6c6"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e4711bcffe190acd29104d8ecfea0c0e42b754837de3fb8aad96e6cc3c613cdc"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:4d081cce43f39c2e813fe5990e1e378cbdb579d3f66ded5bade96130269ffd75"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:4f123b54d48450d2d6b14d8fad38e930fb65b5b84f1b022c10f2913bd956f5b5"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:e1928b304a2b591a28eb3175f9db7f17c40c12cf2d4ec2a85fdf1cc9c073ff91"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:3a69f76146166aec1173003c1f547931bdf288c6b135fda0020468492ac4149f"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:fc08c388f4ebbbca345e84a0c56362180d33d11cbe9ccfae663e4db88e13751e"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b057adf38ce4e616486922f2f47fc7d19c827ba0a7f69d540a3664eba2269325"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4138c0b933ab099e96f5d8defce4486f7dfd480ecaf7f221f2409f28022ccbc5"}, + {file = "google_re2-1.1-6-cp38-cp38-win32.whl", hash = "sha256:9693e45b37b504634b1abbf1ee979471ac6a70a0035954592af616306ab05dd6"}, + {file = "google_re2-1.1-6-cp38-cp38-win_amd64.whl", hash = "sha256:5674d437baba0ea287a5a7f8f81f24265d6ae8f8c09384e2ef7b6f84b40a7826"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7783137cb2e04f458a530c6d0ee9ef114815c1d48b9102f023998c371a3b060e"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a49b7153935e7a303675f4deb5f5d02ab1305adefc436071348706d147c889e0"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a96a8bb309182090704593c60bdb369a2756b38fe358bbf0d40ddeb99c71769f"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:dff3d4be9f27ef8ec3705eed54f19ef4ab096f5876c15fe011628c69ba3b561c"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:40f818b0b39e26811fa677978112a8108269977fdab2ba0453ac4363c35d9e66"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:8a7e53538cdb40ef4296017acfbb05cab0c19998be7552db1cfb85ba40b171b9"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ee18e7569fb714e5bb8c42809bf8160738637a5e71ed5a4797757a1fb4dc4de"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cda4f6d1a7d5b43ea92bc395f23853fba0caf8b1e1efa6e8c48685f912fcb89"}, + {file = "google_re2-1.1-6-cp39-cp39-win32.whl", hash = "sha256:6a9cdbdc36a2bf24f897be6a6c85125876dc26fea9eb4247234aec0decbdccfd"}, + {file = "google_re2-1.1-6-cp39-cp39-win_amd64.whl", hash = "sha256:73f646cecfad7cc5b4330b4192c25f2e29730a3b8408e089ffd2078094208196"}, ] [[package]] @@ -4465,42 +4565,6 @@ completion = ["shtab"] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] -[[package]] -name = "lancedb" -version = "0.6.13" -description = "lancedb" -optional = false -python-versions = ">=3.8" -files = [ - {file = "lancedb-0.6.13-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:4667353ca7fa187e94cb0ca4c5f9577d65eb5160f6f3fe9e57902d86312c3869"}, - {file = "lancedb-0.6.13-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e22533fe6f6b2d7037dcdbbb4019a62402bbad4ce18395be68f4aa007bf8bc0"}, - {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837eaceafb87e3ae4c261eef45c4f73715f892a36165572c3da621dbdb45afcf"}, - {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:61af2d72b2a2f0ea419874c3f32760fe5e51530da3be2d65251a0e6ded74419b"}, - {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:31b24e57ee313f4ce6255e45d42e8bee19b90ddcd13a9e07030ac04f76e7dfde"}, - {file = "lancedb-0.6.13-cp38-abi3-win_amd64.whl", hash = "sha256:b851182d8492b1e5b57a441af64c95da65ca30b045d6618dc7d203c6d60d70fa"}, -] - -[package.dependencies] -attrs = ">=21.3.0" -cachetools = "*" -deprecation = "*" -overrides = ">=0.7" -pydantic = ">=1.10" -pylance = "0.10.12" -ratelimiter = ">=1.0,<2.0" -requests = ">=2.31.0" -retry = ">=0.9.2" -semver = "*" -tqdm = ">=4.27.0" - -[package.extras] -azure = ["adlfs (>=2024.2.0)"] -clip = ["open-clip", "pillow", "torch"] -dev = ["pre-commit", "ruff"] -docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] -embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] -tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] - [[package]] name = "lancedb" version = "0.9.0" @@ -6927,32 +6991,6 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] -[[package]] -name = "pylance" -version = "0.10.12" -description = "python wrapper for Lance columnar format" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pylance-0.10.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:30cbcca078edeb37e11ae86cf9287d81ce6c0c07ba77239284b369a4b361497b"}, - {file = "pylance-0.10.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e558163ff6035d518706cc66848497219ccc755e2972b8f3b1706a3e1fd800fd"}, - {file = "pylance-0.10.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75afb39f71d7f12429f9b4d380eb6cf6aed179ae5a1c5d16cc768373a1521f87"}, - {file = "pylance-0.10.12-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:3de391dfc3a99bdb245fd1e27ef242be769a94853f802ef57f246e9a21358d32"}, - {file = "pylance-0.10.12-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:34a5278b90f4cbcf21261353976127aa2ffbbd7d068810f0a2b0c1aa0334022a"}, - {file = "pylance-0.10.12-cp38-abi3-win_amd64.whl", hash = "sha256:6cef5975d513097fd2c22692296c9a5a138928f38d02cd34ab63a7369abc1463"}, -] - -[package.dependencies] -numpy = ">=1.22" -pyarrow = ">=12,<15.0.1" - -[package.extras] -benchmarks = ["pytest-benchmark"] -dev = ["ruff (==0.2.2)"] -ray = ["ray[data]"] -tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] -torch = ["torch"] - [[package]] name = "pylance" version = "0.13.0" @@ -9658,4 +9696,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "1205791c3a090cf55617833ef566f1d55e6fcfa7209079bca92277f217130549" +content-hash = "a64fdd2845d27c9abc344809be68cba08f46641aabdc07416c37c802450fe4f3" diff --git a/pyproject.toml b/pyproject.toml index f8c34a767e..45f6297b9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.5.1" +version = "0.5.2" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] @@ -79,7 +79,7 @@ qdrant-client = {version = ">=1.8", optional = true, extras = ["fastembed"]} databricks-sql-connector = {version = ">=2.9.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } -lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'" } +lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true } deltalake = { version = ">=0.17.4", optional = true } [tool.poetry.extras] @@ -220,7 +220,7 @@ pandas = ">2" alive-progress = ">=3.0.1" pyarrow = ">=14.0.0" psycopg2-binary = ">=2.9" -lancedb = ">=0.6.13" +lancedb = { version = ">=0.8.2", markers = "python_version >= '3.9'", allow-prereleases = true } openai = ">=1.35" [tool.black] # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file diff --git a/pytest.ini b/pytest.ini index 07de69d3e3..1d4e0df6dc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,4 +10,5 @@ python_functions = *_test test_* *_snippet filterwarnings= ignore::DeprecationWarning markers = essential: marks all essential tests - no_load: marks tests that do not load anything \ No newline at end of file + no_load: marks tests that do not load anything + needspyarrow17: marks tests that need pyarrow>=17.0.0 (deselected by default) \ No newline at end of file diff --git a/tests/cases.py b/tests/cases.py index fa346b8b49..aa2e8ed494 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -303,6 +303,7 @@ def arrow_table_all_data_types( include_date: bool = True, include_not_normalized_name: bool = True, include_name_clash: bool = False, + include_null: bool = True, num_rows: int = 3, tz="UTC", ) -> Tuple[Any, List[Dict[str, Any]], Dict[str, List[Any]]]: @@ -323,9 +324,11 @@ def arrow_table_all_data_types( "float_null": [round(random.uniform(0, 100), 4) for _ in range(num_rows - 1)] + [ None ], # decrease precision - "null": pd.Series([None for _ in range(num_rows)]), } + if include_null: + data["null"] = pd.Series([None for _ in range(num_rows)]) + if include_name_clash: data["pre Normalized Column"] = [random.choice(ascii_lowercase) for _ in range(num_rows)] include_not_normalized_name = True @@ -373,7 +376,7 @@ def arrow_table_all_data_types( "Pre Normalized Column": "pre_normalized_column", } ) - .drop(columns=["null"]) + .drop(columns=(["null"] if include_null else [])) .to_dict("records") ) if object_format == "object": diff --git a/tests/common/cases/configuration/config.yml b/tests/common/cases/configuration/config.yml new file mode 100644 index 0000000000..1c9c253a65 --- /dev/null +++ b/tests/common/cases/configuration/config.yml @@ -0,0 +1,28 @@ +destination: + postgres: + credentials: postgresql://dlt-loader:loader@localhost:5432/dlt_data + athena: + query_result_bucket: s3://dlt-ci-test-bucket + credentials: + aws_access_key_id: AK + aws_secret_access_key: b+secret + snowflake: + stage_name: PUBLIC.my_s3_stage + connection_timeout: 60.4 + csv_format: + delimiter: '|' + include_header: false + on_error_continue: true + credentials: + query_tag: '{{"source":"{source}", "resource":"{resource}", "table": "{table}", + "load_id":"{load_id}", "pipeline_name":"{pipeline_name}"}}' +sources: + zendesk: + credentials: + subdomain: subdomain + email: set me up + password: set me up + token: set me up + oauth_token: set me up +data_types: + datetime: 1979-05-27 07:32:00-08:00 \ No newline at end of file diff --git a/tests/common/configuration/test_accessors.py b/tests/common/configuration/test_accessors.py index dc8761110f..6a73636421 100644 --- a/tests/common/configuration/test_accessors.py +++ b/tests/common/configuration/test_accessors.py @@ -11,6 +11,7 @@ ConfigTomlProvider, SecretsTomlProvider, ) +from dlt.common.configuration.providers.toml import CustomLoaderDocProvider from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.specs import ( GcpServiceAccountCredentialsWithoutDefaults, @@ -184,7 +185,7 @@ def test_setter(toml_providers: ConfigProvidersContext, environment: Any) -> Non dlt.secrets["pipeline.new.credentials"] = {"api_key": "skjo87a7nnAAaa"} assert dlt.secrets["pipeline.new.credentials"] == {"api_key": "skjo87a7nnAAaa"} # check the toml directly - assert dlt.secrets.writable_provider._toml["pipeline"]["new"]["credentials"] == {"api_key": "skjo87a7nnAAaa"} # type: ignore[attr-defined] + assert dlt.secrets.writable_provider._config_doc["pipeline"]["new"]["credentials"] == {"api_key": "skjo87a7nnAAaa"} # type: ignore[attr-defined] # mod the config and use it to resolve the configuration dlt.config["pool"] = {"pool_type": "process", "workers": 21} @@ -224,3 +225,25 @@ def the_source( credentials=dlt.secrets["destination.credentials"], databricks_creds=dlt.secrets["databricks.credentials"], ) + + +def test_provider_registration(toml_providers: ConfigProvidersContext) -> None: + toml_providers.providers.clear() + + def loader(): + return {"api_url": "https://example.com/api"} + + @dlt.source + def test_source(api_url=dlt.config.value): + assert api_url == "https://example.com/api" + return dlt.resource([1, 2, 3], name="data") + + provider = CustomLoaderDocProvider("mock", loader, False) + assert provider.supports_secrets is False + + with pytest.raises(ConfigFieldMissingException): + test_source() + + # now register + dlt.config.register_provider(provider) + test_source() diff --git a/tests/common/configuration/test_credentials.py b/tests/common/configuration/test_credentials.py index d382a95a44..1c6319b551 100644 --- a/tests/common/configuration/test_credentials.py +++ b/tests/common/configuration/test_credentials.py @@ -236,11 +236,9 @@ def test_gcp_service_credentials_native_representation(environment) -> None: assert gcpc.private_key == "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n" assert gcpc.project_id == "chat-analytics" assert gcpc.client_email == "loader@iam.gserviceaccount.com" - # location is present but deprecated - assert gcpc.location == "US" # get native representation, it will also location _repr = gcpc.to_native_representation() - assert "location" in _repr + assert "project_id" in _repr # parse again gcpc_2 = GcpServiceAccountCredentials() gcpc_2.parse_native_representation(_repr) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 13d68b53e9..0dc7e53357 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -661,7 +661,7 @@ def postgres_direct(local_credentials: ConnectionStringCredentials = dlt.secrets environment.clear() # pass via toml - secrets_toml = toml_providers[SECRETS_TOML]._toml # type: ignore[attr-defined] + secrets_toml = toml_providers[SECRETS_TOML]._config_doc # type: ignore[attr-defined] secrets_toml["local_credentials"] = conn_str assert isinstance(postgres_direct(), ConnectionStringCredentials) assert isinstance(postgres_union(), ConnectionStringCredentials) @@ -669,7 +669,7 @@ def postgres_direct(local_credentials: ConnectionStringCredentials = dlt.secrets # make sure config is successfully deleted with pytest.raises(ConfigFieldMissingException): postgres_union() - # config_toml = toml_providers[CONFIG_TOML]._toml + # config_toml = toml_providers[CONFIG_TOML]._config_doc secrets_toml["local_credentials"] = {} for k, v in conn_dict.items(): secrets_toml["local_credentials"][k] = v diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 5271c68633..3b16a930e6 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -1,11 +1,11 @@ import os import pytest -import tomlkit -from typing import Any, Type +import yaml +from typing import Any, Dict, Type import datetime # noqa: I251 import dlt -from dlt.common import pendulum, Decimal +from dlt.common import pendulum, json from dlt.common.configuration import configspec, ConfigFieldMissingException, resolve from dlt.common.configuration.container import Container from dlt.common.configuration.inject import with_config @@ -14,7 +14,8 @@ from dlt.common.configuration.providers.toml import ( SECRETS_TOML, CONFIG_TOML, - BaseTomlProvider, + BaseDocProvider, + CustomLoaderDocProvider, SecretsTomlProvider, ConfigTomlProvider, StringTomlProvider, @@ -54,8 +55,8 @@ class EmbeddedWithGcpCredentials(BaseConfiguration): def test_secrets_from_toml_secrets(toml_providers: ConfigProvidersContext) -> None: # remove secret_value to trigger exception - del toml_providers["secrets.toml"]._toml["secret_value"] # type: ignore[attr-defined] - del toml_providers["secrets.toml"]._toml["credentials"] # type: ignore[attr-defined] + del toml_providers["secrets.toml"]._config_doc["secret_value"] # type: ignore[attr-defined] + del toml_providers["secrets.toml"]._config_doc["credentials"] # type: ignore[attr-defined] with pytest.raises(ConfigFieldMissingException) as py_ex: resolve.resolve_configuration(SecretConfiguration()) @@ -208,8 +209,8 @@ def test_secrets_toml_credentials_from_native_repr( environment: Any, toml_providers: ConfigProvidersContext ) -> None: # cfg = toml_providers["secrets.toml"] - # print(cfg._toml) - # print(cfg._toml["source"]["credentials"]) + # print(cfg._config_doc) + # print(cfg._config_doc["source"]["credentials"]) # resolve gcp_credentials by parsing initial value which is str holding json doc c = resolve.resolve_configuration( GcpServiceAccountCredentialsWithoutDefaults(), sections=("source",) @@ -263,7 +264,8 @@ def test_toml_global_config() -> None: # create instance with global toml enabled config = ConfigTomlProvider(add_global_config=True) assert config._add_global_config is True - assert isinstance(config._toml, tomlkit.TOMLDocument) + assert isinstance(config._config_doc, dict) + assert len(config._config_doc) > 0 # kept from global v, key = config.get_value("dlthub_telemetry", bool, None, "runtime") assert v is False @@ -278,15 +280,14 @@ def test_toml_global_config() -> None: assert v == "a" secrets = SecretsTomlProvider(add_global_config=True) - assert isinstance(secrets._toml, tomlkit.TOMLDocument) assert secrets._add_global_config is True # check if values from project exist secrets_project = SecretsTomlProvider(add_global_config=False) - assert secrets._toml == secrets_project._toml + assert secrets._config_doc == secrets_project._config_doc def test_write_value(toml_providers: ConfigProvidersContext) -> None: - provider: BaseTomlProvider + provider: BaseDocProvider for provider in toml_providers.providers: # type: ignore[assignment] if not provider.is_writable: continue @@ -298,7 +299,10 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: assert provider.get_value("_new_key_literal", TAny, None) == ("literal", "_new_key_literal") # this will create path of tables provider.set_value("deep_int", 2137, "deep_pipeline", "deep", "deep", "deep", "deep") - assert provider._toml["deep_pipeline"]["deep"]["deep"]["deep"]["deep"]["deep_int"] == 2137 # type: ignore[index] + assert ( + provider._config_doc["deep_pipeline"]["deep"]["deep"]["deep"]["deep"]["deep_int"] + == 2137 + ) assert provider.get_value( "deep_int", TAny, "deep_pipeline", "deep", "deep", "deep", "deep" ) == (2137, "deep_pipeline.deep.deep.deep.deep.deep_int") @@ -326,9 +330,6 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: [1, 2, 3, 4], "deep.deep.deep.deep_list", ) - # invalid type - with pytest.raises(ValueError): - provider.set_value("deep_decimal", Decimal("1.2"), None, "deep", "deep", "deep", "deep") # write new dict to a new key test_d1 = {"key": "top", "embed": {"inner": "bottom", "inner_2": True}} @@ -371,66 +372,96 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # write configuration pool = PoolRunnerConfiguration(pool_type="none", workers=10) provider.set_value("runner_config", dict(pool), "new_pipeline") - # print(provider._toml["new_pipeline"]["runner_config"].as_string()) + # print(provider._config_doc["new_pipeline"]["runner_config"].as_string()) expected_pool = dict(pool) # None is removed expected_pool.pop("start_method") - assert provider._toml["new_pipeline"]["runner_config"] == expected_pool # type: ignore[index] + assert provider._config_doc["new_pipeline"]["runner_config"] == expected_pool + +def test_set_spec_value(toml_providers: ConfigProvidersContext) -> None: + provider: BaseDocProvider + for provider in toml_providers.providers: # type: ignore[assignment] + if not provider.is_writable: + continue + provider._config_doc = {} # dict creates only shallow dict so embedded credentials will fail creds = WithCredentialsConfiguration() - creds.credentials = SecretCredentials(secret_value=TSecretValue("***** ***")) - with pytest.raises(ValueError): - provider.set_value("written_creds", dict(creds), None) + credentials = SecretCredentials(secret_value=TSecretValue("***** ***")) + creds.credentials = credentials + # use dataclass to dict to recursively convert base config to dict + import dataclasses -def test_write_toml_value(toml_providers: ConfigProvidersContext) -> None: - provider: BaseTomlProvider + provider.set_value("written_creds", dataclasses.asdict(creds), None) + # resolve config + resolved_config = resolve.resolve_configuration( + WithCredentialsConfiguration(), sections=("written_creds",) + ) + assert resolved_config.credentials.secret_value == "***** ***" + + +def test_set_fragment(toml_providers: ConfigProvidersContext) -> None: + provider: BaseDocProvider for provider in toml_providers.providers: # type: ignore[assignment] - if not provider.is_writable: + if not isinstance(provider, BaseDocProvider): continue - - new_doc = tomlkit.parse(""" -int_val=2232 + new_toml = """ +int_val = 2232 [table] -inner_int_val=2121 - """) +inner_int_val = 2121 +""" # key == None replaces the whole document - provider.set_value(None, new_doc, None) - assert provider._toml == new_doc + provider.set_fragment(None, new_toml, None) + print(provider.to_yaml()) + assert provider.to_toml().strip() == new_toml.strip() + val, _ = provider.get_value("table", dict, None) + assert val is not None # key != None merges documents - to_merge_doc = tomlkit.parse(""" -int_val=2137 + to_merge_yaml = """ +int_val: 2137 -[babble] -word1="do" -word2="you" - - """) - provider.set_value("", to_merge_doc, None) - merged_doc = tomlkit.parse(""" -int_val=2137 +babble: + word1: do + word2: you -[babble] -word1="do" -word2="you" +""" + provider.set_fragment("", to_merge_yaml, None) + merged_doc = """ +int_val = 2137 [table] -inner_int_val=2121 +inner_int_val = 2121 - """) - assert provider._toml == merged_doc +[babble] +word1 = "do" +word2 = "you" + +""" + assert provider.to_toml().strip() == merged_doc.strip() # currently we ignore the key when merging tomlkit - provider.set_value("level", to_merge_doc, None) - assert provider._toml == merged_doc + provider.set_fragment("level", to_merge_yaml, None) + assert provider.to_toml().strip() == merged_doc.strip() - # only toml accepted with empty key + # use JSON: empty key replaces dict + provider.set_fragment(None, json.dumps({"prop1": "A", "nested": {"propN": "N"}}), None) + assert provider._config_doc == {"prop1": "A", "nested": {"propN": "N"}} + # key cannot be empty for set_value with pytest.raises(ValueError): - provider.set_value(None, {}, None) + provider.set_value(None, "VAL", None) + # dict always merges from the top level doc, ignoring the key + provider.set_fragment( + "nested", json.dumps({"prop2": "B", "nested": {"prop3": "C"}, "prop1": ""}), None + ) + assert provider._config_doc == { + "prop2": "B", + "nested": {"propN": "N", "prop3": "C"}, + "prop1": "", + } def test_toml_string_provider() -> None: @@ -466,3 +497,33 @@ def test_toml_string_provider() -> None: [section2.subsection] key1 = \"other_value\" """ + + +def test_custom_loader(toml_providers: ConfigProvidersContext) -> None: + def loader() -> Dict[str, Any]: + with open("tests/common/cases/configuration/config.yml", "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + # remove all providers + toml_providers.providers.clear() + # create new provider + provider = CustomLoaderDocProvider("yaml", loader, True) + assert provider.name == "yaml" + assert provider.supports_secrets is True + assert provider.to_toml().startswith("[destination]") + assert provider.to_yaml().startswith("destination:") + value, _ = provider.get_value("datetime", datetime.datetime, None, "data_types") + assert value == pendulum.parse("1979-05-27 07:32:00-08:00") + + # add to context + toml_providers.add_provider(provider) + + # resolve one of configs + config = resolve.resolve_configuration( + ConnectionStringCredentials(), + sections=( + "destination", + "postgres", + ), + ) + assert config.username == "dlt-loader" diff --git a/tests/conftest.py b/tests/conftest.py index 7ed546dfea..6c0384ea8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ import os import dataclasses import logging -from typing import List +import sys +import pytest +from typing import List, Iterator +from importlib.metadata import version as pkg_version +from packaging.version import Version # patch which providers to enable from dlt.common.configuration.providers import ( @@ -115,4 +119,38 @@ def _create_pipeline_instance_id(self) -> str: # disable httpx request logging (too verbose when testing qdrant) logging.getLogger("httpx").setLevel("WARNING") - logging.getLogger("airflow.models.variable").setLevel("CRITICAL") + # reset and init airflow db + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + + try: + from airflow.utils import db + import contextlib + import io + + for log in [ + "airflow.models.crypto", + "airflow.models.variable", + "airflow", + "alembic", + "alembic.runtime.migration", + ]: + logging.getLogger(log).setLevel("ERROR") + + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): + db.resetdb() + + except Exception: + pass + + +@pytest.fixture(autouse=True) +def pyarrow17_check(request) -> Iterator[None]: + if "needspyarrow17" in request.keywords: + if "pyarrow" not in sys.modules or Version(pkg_version("pyarrow")) < Version("17.0.0"): + pytest.skip("test needs `pyarrow>=17.0.0`") + yield diff --git a/tests/destinations/test_utils.py b/tests/destinations/test_utils.py new file mode 100644 index 0000000000..32fc286830 --- /dev/null +++ b/tests/destinations/test_utils.py @@ -0,0 +1,43 @@ +import dlt +import pytest + +from dlt.destinations.utils import get_resource_for_adapter +from dlt.extract import DltResource + + +def test_get_resource_for_adapter() -> None: + # test on pure data + data = [1, 2, 3] + adapted_resource = get_resource_for_adapter(data) + assert isinstance(adapted_resource, DltResource) + assert list(adapted_resource) == [1, 2, 3] + assert adapted_resource.name == "content" + + # test on resource + @dlt.resource(table_name="my_table") + def some_resource(): + yield [1, 2, 3] + + adapted_resource = get_resource_for_adapter(some_resource) + assert adapted_resource == some_resource + assert adapted_resource.name == "some_resource" + + # test on source with one resource + @dlt.source + def source(): + return [some_resource] + + adapted_resource = get_resource_for_adapter(source()) + assert adapted_resource.table_name == "my_table" + + # test on source with multiple resources + @dlt.resource(table_name="my_table") + def other_resource(): + yield [1, 2, 3] + + @dlt.source + def other_source(): + return [some_resource, other_resource] + + with pytest.raises(ValueError): + get_resource_for_adapter(other_source()) diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index d285181c55..d40639a594 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -1,7 +1,7 @@ import os import asyncio import inspect -from typing import List, Sequence +from typing import ClassVar, List, Sequence import time import pytest @@ -236,6 +236,56 @@ def tx_minus(item, meta): assert [pi.item for pi in _l] == [4, 8, 12] +def test_append_transform_with_placement_affinity() -> None: + class FilterItemStart(FilterItem): + placement_affinity: ClassVar[float] = -1 + + class FilterItemEnd(FilterItem): + placement_affinity: ClassVar[float] = 1 + + assert FilterItemStart(lambda _: True).placement_affinity == -1 + assert FilterItemEnd(lambda _: True).placement_affinity == 1 + + data = [1, 2, 3] + # data_iter = iter(data) + pp = Pipe.from_data("data", data) + + pp.append_step(FilterItemEnd(lambda _: True)) + pp.append_step(FilterItemStart(lambda _: True)) + assert len(pp) == 3 + # gen must always be first + assert pp._steps[0] == data + assert isinstance(pp._steps[1], FilterItemStart) + assert isinstance(pp._steps[2], FilterItemEnd) + + def regular_lambda(item): + return True + + pp.append_step(regular_lambda) + assert pp._steps[-2].__name__ == "regular_lambda" # type: ignore[union-attr] + + # explicit insert works as before, ignores affinity + end_aff_2 = FilterItemEnd(lambda _: True) + start_aff_2 = FilterItemStart(lambda _: True) + pp.insert_step(end_aff_2, 1) + assert pp._steps[1] is end_aff_2 + pp.insert_step(start_aff_2, len(pp)) + assert pp._steps[-1] is start_aff_2 + + def tx(item): + yield item * 2 + + # create pipe with transformer + p = Pipe.from_data("tx", tx, parent=pp) + p.append_step(FilterItemEnd(lambda _: True)) + p.append_step(FilterItemStart(lambda _: True)) + assert len(p) == 3 + # note that in case of start affinity, tranform gets BEFORE transformer + assert isinstance(p._steps[0], FilterItemStart) + assert p._steps[1].__name__ == "tx" # type: ignore[union-attr] + assert isinstance(p._steps[2], FilterItemEnd) + + def test_pipe_propagate_meta() -> None: data = [1, 2, 3] _meta = ["M1", {"A": 1}, [1, 2, 3]] diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 26158177ff..f4082a7d86 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -25,6 +25,7 @@ from dlt.extract import DltSource from dlt.extract.exceptions import InvalidStepFunctionArguments +from dlt.extract.items import ValidateItem from dlt.extract.resource import DltResource from dlt.sources.helpers.transform import take_first from dlt.extract.incremental import IncrementalResourceWrapper, Incremental @@ -862,6 +863,9 @@ def child(item): state["mark"] = f"mark:{item['delta']}" yield item + print(parent_r._pipe._steps) + print(child._pipe._steps) + # also transformer will not receive new data info = p.run(child) assert len(info.loads_ids) == 0 @@ -2007,3 +2011,57 @@ def test_type_3(): r.incremental.allow_external_schedulers = True result = data_item_to_list(item_type, list(r)) assert len(result) == 3 + + +@pytest.mark.parametrize("yield_pydantic", (True, False)) +def test_pydantic_columns_validator(yield_pydantic: bool) -> None: + from pydantic import BaseModel, Field, ConfigDict + + # forbid extra fields so "id" in json is not a valid field BUT + # add alias for id_ that will serde "id" correctly + class TestRow(BaseModel): + model_config = ConfigDict(frozen=True, extra="forbid") + + id_: int = Field(alias="id") + example_string: str + ts: datetime + + @dlt.resource(name="table_name", columns=TestRow, primary_key="id", write_disposition="replace") + def generate_rows(): + for i in range(10): + item = {"id": i, "example_string": "abc", "ts": datetime.now()} + yield TestRow.model_validate(item) if yield_pydantic else item + + @dlt.resource(name="table_name", columns=TestRow, primary_key="id", write_disposition="replace") + def generate_rows_incremental( + ts: dlt.sources.incremental[datetime] = dlt.sources.incremental(cursor_path="ts"), + ): + for i in range(10): + item = {"id": i, "example_string": "abc", "ts": datetime.now()} + yield TestRow.model_validate(item) if yield_pydantic else item + if ts.end_out_of_range: + return + + @dlt.source + def test_source_incremental(): + return generate_rows_incremental + + @dlt.source + def test_source(): + return generate_rows + + pip_1_name = "test_pydantic_columns_validator_" + uniq_id() + pipeline = dlt.pipeline(pipeline_name=pip_1_name, destination="duckdb") + + info = pipeline.run(test_source()) + info.raise_on_failed_jobs() + + info = pipeline.run(test_source_incremental()) + info.raise_on_failed_jobs() + + # verify that right steps are at right place + steps = test_source().table_name._pipe._steps + assert isinstance(steps[-1], ValidateItem) + incremental_steps = test_source_incremental().table_name._pipe._steps + assert isinstance(incremental_steps[-2], ValidateItem) + assert isinstance(incremental_steps[-1], IncrementalResourceWrapper) diff --git a/tests/extract/test_validation.py b/tests/extract/test_validation.py index b9307ab97c..138589bb06 100644 --- a/tests/extract/test_validation.py +++ b/tests/extract/test_validation.py @@ -214,15 +214,15 @@ def some_data() -> t.Iterator[TDataItems]: items = list(r) assert len(items) == 3 # fully valid - assert items[0].a == 1 - assert items[0].b == "z" + assert items[0]["a"] == 1 + assert items[0]["b"] == "z" # data type not valid - assert items[1].a == "not_int" - assert items[1].b == "x" + assert items[1]["a"] == "not_int" + assert items[1]["b"] == "x" # extra attr and data invalid - assert items[2].a is None - assert items[2].b is None - assert items[2].c == "not_int" + assert items[2]["a"] is None + assert items[2]["b"] is None + assert items[2]["c"] == "not_int" # let it drop r = dlt.resource(some_data(), schema_contract="discard_row", columns=SimpleModel) @@ -232,8 +232,8 @@ def some_data() -> t.Iterator[TDataItems]: assert validator.model.__name__.endswith("ExtraForbid") items = list(r) assert len(items) == 1 - assert items[0].a == 1 - assert items[0].b == "z" + assert items[0]["a"] == 1 + assert items[0]["b"] == "z" # filter just offending values with pytest.raises(NotImplementedError): @@ -252,4 +252,4 @@ def some_data() -> t.Iterator[TDataItems]: items = list(r) assert len(items) == 3 # c is gone from the last model - assert not hasattr(items[2], "c") + assert "c" not in items[2] diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 61ccc4d5f4..7364ef7243 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -46,6 +46,8 @@ def expect_extracted_file( class AssertItems(ItemTransform[TDataItem]): + placement_affinity = 2.0 # even more sticky than incremental so gathers data after it + def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "object") -> None: self.expected_items = expected_items self.item_type = item_type diff --git a/tests/helpers/airflow_tests/test_airflow_provider.py b/tests/helpers/airflow_tests/test_airflow_provider.py index 68e426deb9..b31e78f986 100644 --- a/tests/helpers/airflow_tests/test_airflow_provider.py +++ b/tests/helpers/airflow_tests/test_airflow_provider.py @@ -10,7 +10,7 @@ from dlt.common import pendulum from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY +from dlt.common.configuration.providers.vault import SECRETS_TOML_KEY DEFAULT_DATE = pendulum.datetime(2023, 4, 18, tz="Europe/Berlin") # Test data @@ -212,7 +212,7 @@ def test_task(): provider for provider in providers if isinstance(provider, AirflowSecretsTomlProvider) ) return { - "airflow_secrets_toml": provider._toml.as_string(), + "airflow_secrets_toml": provider.to_toml(), } task = PythonOperator(task_id="test_task", python_callable=test_task, dag=dag) diff --git a/tests/helpers/airflow_tests/utils.py b/tests/helpers/airflow_tests/utils.py index 50aab77505..8a6b32191e 100644 --- a/tests/helpers/airflow_tests/utils.py +++ b/tests/helpers/airflow_tests/utils.py @@ -8,7 +8,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY +from dlt.common.configuration.providers.vault import SECRETS_TOML_KEY @pytest.fixture(scope="function", autouse=True) diff --git a/tests/helpers/providers/test_google_secrets_provider.py b/tests/helpers/providers/test_google_secrets_provider.py index 00c54b5705..4a3bf972b8 100644 --- a/tests/helpers/providers/test_google_secrets_provider.py +++ b/tests/helpers/providers/test_google_secrets_provider.py @@ -7,18 +7,18 @@ from dlt.common.configuration.specs.run_configuration import RunConfiguration from dlt.common.configuration.specs import GcpServiceAccountCredentials, known_sections from dlt.common.typing import AnyType -from dlt.common.utils import custom_environ from dlt.common.configuration.resolve import resolve_configuration DLT_SECRETS_TOML_CONTENT = """ -secret_value=2137 -api.secret_key="ABCD" +secret_value = 2137 +[api] +secret_key = "ABCD" [credentials] -secret_value="2138" -project_id="mock-credentials" +secret_value = "2138" +project_id = "mock-credentials" """ @@ -32,7 +32,7 @@ def test_regular_keys() -> None: # c = secrets.get("destination.credentials", GcpServiceAccountCredentials) # print(c) provider: GoogleSecretsProvider = _google_secrets_provider() # type: ignore[assignment] - assert provider._toml.as_string().strip() == DLT_SECRETS_TOML_CONTENT.strip() + assert provider.to_toml().strip() == DLT_SECRETS_TOML_CONTENT.strip() assert provider.get_value("secret_value", AnyType, "pipeline x !!") == ( None, "pipelinex-secret_value", @@ -83,11 +83,18 @@ def test_regular_keys() -> None: "pipeline-destination-filesystem-url", ) - # try a single secret value + # try a single secret value - not found until single values enabled assert provider.get_value("secret", TSecretValue, "pipeline") == (None, "pipeline-secret") # enable the single secrets provider.only_toml_fragments = False + assert provider.get_value("secret", TSecretValue, "pipeline") == ( + "THIS IS SECRET VALUE", + "pipeline-secret", + ) + del provider._config_doc["pipeline"]["secret"] + provider.clear_lookup_cache() + # but request as not secret value -> still not found assert provider.get_value("secret", str, "pipeline") == (None, "pipeline-secret") provider.only_secrets = False @@ -99,8 +106,8 @@ def test_regular_keys() -> None: # request json # print(provider._toml.as_string()) - assert provider.get_value("halo", str, None, "halo") == ({"halo": True}, "halo-halo") - assert provider.get_value("halo", str, None, "halo", "halo") == (True, "halo-halo-halo") + assert provider.get_value("halo", str, "halo") == ({"halo": True}, "halo-halo") + assert provider.get_value("halo", bool, "halo", "halo") == (True, "halo-halo-halo") # def test_special_sections() -> None: diff --git a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py index 691af8a9d1..744510afcf 100644 --- a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py +++ b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py @@ -13,7 +13,7 @@ import dlt -from streamlit.testing.v1 import AppTest # type: ignore[import-not-found] +from streamlit.testing.v1 import AppTest # type: ignore[import-not-found, unused-ignore] from dlt.helpers.streamlit_app.utils import render_with_pipeline from dlt.pipeline.exceptions import CannotRestorePipelineException @@ -102,8 +102,8 @@ def test_multiple_resources_pipeline(): assert streamlit_app.session_state["color_mode"] == "dark" # Check page links in sidebar - assert "Explore data" in streamlit_app.sidebar[2].label - assert "Load info" in streamlit_app.sidebar[3].label + assert "Explore data" in streamlit_app.sidebar[2].label # type: ignore[union-attr, unused-ignore] + assert "Load info" in streamlit_app.sidebar[3].label # type: ignore[union-attr, unused-ignore] # Check that at leas 4 content sections rendered assert len(streamlit_app.subheader) > 4 diff --git a/tests/libs/test_deltalake.py b/tests/libs/test_deltalake.py index d55f788fbe..a162ff427b 100644 --- a/tests/libs/test_deltalake.py +++ b/tests/libs/test_deltalake.py @@ -1,5 +1,5 @@ import os -from typing import Iterator, Tuple, cast +from typing import Iterator, Tuple, Union, cast import pytest from deltalake import DeltaTable @@ -76,7 +76,21 @@ def test_deltalake_storage_options() -> None: assert _deltalake_storage_options(config)["aws_access_key_id"] == "i_will_overwrite" -def test_write_delta_table(filesystem_client) -> None: +@pytest.mark.needspyarrow17 +@pytest.mark.parametrize("arrow_data_type", (pa.Table, pa.RecordBatchReader)) +def test_write_delta_table( + filesystem_client, + arrow_data_type: Union[pa.Table, pa.RecordBatchReader], +) -> None: + def arrow_data( # type: ignore[return] + arrow_table: pa.Table, + return_type: Union[pa.Table, pa.RecordBatchReader], + ) -> Union[pa.Table, pa.RecordBatchReader]: + if return_type == pa.Table: + return arrow_table + elif return_type == pa.RecordBatchReader: + return arrow_table.to_reader() + client, remote_dir = filesystem_client client = cast(FilesystemClient, client) storage_options = _deltalake_storage_options(client.config) @@ -102,7 +116,10 @@ def test_write_delta_table(filesystem_client) -> None: # first write should create Delta table with same shape as input Arrow table write_delta_table( - remote_dir, arrow_table, write_disposition="append", storage_options=storage_options + remote_dir, + arrow_data(arrow_table, arrow_data_type), + write_disposition="append", + storage_options=storage_options, ) dt = DeltaTable(remote_dir, storage_options=storage_options) assert dt.version() == 0 @@ -117,7 +134,10 @@ def test_write_delta_table(filesystem_client) -> None: # another `append` should create a new table version with twice the number of rows write_delta_table( - remote_dir, arrow_table, write_disposition="append", storage_options=storage_options + remote_dir, + arrow_data(arrow_table, arrow_data_type), + write_disposition="append", + storage_options=storage_options, ) dt = DeltaTable(remote_dir, storage_options=storage_options) assert dt.version() == 1 @@ -125,7 +145,10 @@ def test_write_delta_table(filesystem_client) -> None: # the `replace` write disposition should trigger a "logical delete" write_delta_table( - remote_dir, arrow_table, write_disposition="replace", storage_options=storage_options + remote_dir, + arrow_data(arrow_table, arrow_data_type), + write_disposition="replace", + storage_options=storage_options, ) dt = DeltaTable(remote_dir, storage_options=storage_options) assert dt.version() == 2 @@ -137,7 +160,10 @@ def test_write_delta_table(filesystem_client) -> None: # `merge` should resolve to `append` bevavior write_delta_table( - remote_dir, arrow_table, write_disposition="merge", storage_options=storage_options + remote_dir, + arrow_data(arrow_table, arrow_data_type), + write_disposition="merge", + storage_options=storage_options, ) dt = DeltaTable(remote_dir, storage_options=storage_options) assert dt.version() == 3 @@ -153,7 +179,10 @@ def test_write_delta_table(filesystem_client) -> None: # new column should be propagated to Delta table (schema evolution is supported) write_delta_table( - remote_dir, evolved_arrow_table, write_disposition="append", storage_options=storage_options + remote_dir, + arrow_data(evolved_arrow_table, arrow_data_type), + write_disposition="append", + storage_options=storage_options, ) dt = DeltaTable(remote_dir, storage_options=storage_options) assert dt.version() == 4 @@ -164,7 +193,10 @@ def test_write_delta_table(filesystem_client) -> None: # providing a subset of columns should lead to missing columns being null-filled write_delta_table( - remote_dir, arrow_table, write_disposition="append", storage_options=storage_options + remote_dir, + arrow_data(arrow_table, arrow_data_type), + write_disposition="append", + storage_options=storage_options, ) dt = DeltaTable(remote_dir, storage_options=storage_options) assert dt.version() == 5 @@ -176,7 +208,7 @@ def test_write_delta_table(filesystem_client) -> None: # unsupported value for `write_disposition` should raise ValueError write_delta_table( remote_dir, - arrow_table, + arrow_data(arrow_table, arrow_data_type), write_disposition="foo", # type:ignore[arg-type] storage_options=storage_options, ) diff --git a/tests/libs/test_pydantic.py b/tests/libs/test_pydantic.py index 951eabbde4..2222d13197 100644 --- a/tests/libs/test_pydantic.py +++ b/tests/libs/test_pydantic.py @@ -29,8 +29,8 @@ DltConfig, pydantic_to_table_schema_columns, apply_schema_contract_to_model, - validate_item, - validate_items, + validate_and_filter_item, + validate_and_filter_items, create_list_model, ) from pydantic import UUID4, BaseModel, Json, AnyHttpUrl, ConfigDict, ValidationError @@ -432,7 +432,7 @@ class ItemModel(BaseModel): discard_model = apply_schema_contract_to_model(ItemModel, "discard_row", "discard_row") discard_list_model = create_list_model(discard_model) # violate data type - items = validate_items( + items = validate_and_filter_items( "items", discard_list_model, [{"b": True}, {"b": 2, "opt": "not int", "extra": 1.2}, {"b": 3}, {"b": False}], @@ -445,7 +445,7 @@ class ItemModel(BaseModel): assert items[0].b is True assert items[1].b is False # violate extra field - items = validate_items( + items = validate_and_filter_items( "items", discard_list_model, [{"b": True}, {"b": 2}, {"b": 3}, {"b": False, "a": False}], @@ -460,7 +460,7 @@ class ItemModel(BaseModel): freeze_list_model = create_list_model(freeze_model) # violate data type with pytest.raises(DataValidationError) as val_ex: - validate_items( + validate_and_filter_items( "items", freeze_list_model, [{"b": True}, {"b": 2}, {"b": 3}, {"b": False}], @@ -476,7 +476,7 @@ class ItemModel(BaseModel): assert val_ex.value.data_item == {"b": 2} # extra type with pytest.raises(DataValidationError) as val_ex: - validate_items( + validate_and_filter_items( "items", freeze_list_model, [{"b": True}, {"a": 2, "b": False}, {"b": 3}, {"b": False}], @@ -495,7 +495,7 @@ class ItemModel(BaseModel): discard_value_model = apply_schema_contract_to_model(ItemModel, "discard_value", "freeze") discard_list_model = create_list_model(discard_value_model) # violate extra field - items = validate_items( + items = validate_and_filter_items( "items", discard_list_model, [{"b": True}, {"b": False, "a": False}], @@ -513,7 +513,7 @@ class ItemModel(BaseModel): evolve_model = apply_schema_contract_to_model(ItemModel, "evolve", "evolve") evolve_list_model = create_list_model(evolve_model) # for data types a lenient model will be created that accepts any type - items = validate_items( + items = validate_and_filter_items( "items", evolve_list_model, [{"b": True}, {"b": 2}, {"b": 3}, {"b": False}], @@ -524,7 +524,7 @@ class ItemModel(BaseModel): assert items[0].b is True assert items[1].b == 2 # extra fields allowed - items = validate_items( + items = validate_and_filter_items( "items", evolve_list_model, [{"b": True}, {"b": 2}, {"b": 3}, {"b": False, "a": False}], @@ -539,7 +539,7 @@ class ItemModel(BaseModel): mixed_model = apply_schema_contract_to_model(ItemModel, "discard_row", "evolve") mixed_list_model = create_list_model(mixed_model) # for data types a lenient model will be created that accepts any type - items = validate_items( + items = validate_and_filter_items( "items", mixed_list_model, [{"b": True}, {"b": 2}, {"b": 3}, {"b": False}], @@ -550,7 +550,7 @@ class ItemModel(BaseModel): assert items[0].b is True assert items[1].b == 2 # extra fields forbidden - full rows discarded - items = validate_items( + items = validate_and_filter_items( "items", mixed_list_model, [{"b": True}, {"b": 2}, {"b": 3}, {"b": False, "a": False}], @@ -568,10 +568,13 @@ class ItemModel(BaseModel): # non validating items removed from the list (both extra and declared) discard_model = apply_schema_contract_to_model(ItemModel, "discard_row", "discard_row") # violate data type - assert validate_item("items", discard_model, {"b": 2}, "discard_row", "discard_row") is None + assert ( + validate_and_filter_item("items", discard_model, {"b": 2}, "discard_row", "discard_row") + is None + ) # violate extra field assert ( - validate_item( + validate_and_filter_item( "items", discard_model, {"b": False, "a": False}, "discard_row", "discard_row" ) is None @@ -581,7 +584,7 @@ class ItemModel(BaseModel): freeze_model = apply_schema_contract_to_model(ItemModel, "freeze", "freeze") # violate data type with pytest.raises(DataValidationError) as val_ex: - validate_item("items", freeze_model, {"b": 2}, "freeze", "freeze") + validate_and_filter_item("items", freeze_model, {"b": 2}, "freeze", "freeze") assert val_ex.value.schema_name is None assert val_ex.value.table_name == "items" assert val_ex.value.column_name == str(("b",)) # pydantic location @@ -591,7 +594,7 @@ class ItemModel(BaseModel): assert val_ex.value.data_item == {"b": 2} # extra type with pytest.raises(DataValidationError) as val_ex: - validate_item("items", freeze_model, {"a": 2, "b": False}, "freeze", "freeze") + validate_and_filter_item("items", freeze_model, {"a": 2, "b": False}, "freeze", "freeze") assert val_ex.value.schema_name is None assert val_ex.value.table_name == "items" assert val_ex.value.column_name == str(("a",)) # pydantic location @@ -603,7 +606,7 @@ class ItemModel(BaseModel): # discard values discard_value_model = apply_schema_contract_to_model(ItemModel, "discard_value", "freeze") # violate extra field - item = validate_item( + item = validate_and_filter_item( "items", discard_value_model, {"b": False, "a": False}, "discard_value", "freeze" ) # "a" extra got removed @@ -612,21 +615,25 @@ class ItemModel(BaseModel): # evolve data types and extras evolve_model = apply_schema_contract_to_model(ItemModel, "evolve", "evolve") # for data types a lenient model will be created that accepts any type - item = validate_item("items", evolve_model, {"b": 2}, "evolve", "evolve") + item = validate_and_filter_item("items", evolve_model, {"b": 2}, "evolve", "evolve") assert item.b == 2 # extra fields allowed - item = validate_item("items", evolve_model, {"b": False, "a": False}, "evolve", "evolve") + item = validate_and_filter_item( + "items", evolve_model, {"b": False, "a": False}, "evolve", "evolve" + ) assert item.b is False assert item.a is False # type: ignore[attr-defined] # accept new types but discard new columns mixed_model = apply_schema_contract_to_model(ItemModel, "discard_row", "evolve") # for data types a lenient model will be created that accepts any type - item = validate_item("items", mixed_model, {"b": 3}, "discard_row", "evolve") + item = validate_and_filter_item("items", mixed_model, {"b": 3}, "discard_row", "evolve") assert item.b == 3 # extra fields forbidden - full rows discarded assert ( - validate_item("items", mixed_model, {"b": False, "a": False}, "discard_row", "evolve") + validate_and_filter_item( + "items", mixed_model, {"b": False, "a": False}, "discard_row", "evolve" + ) is None ) diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index e8b5dab8fd..a74ab11860 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -16,12 +16,14 @@ ) from dlt.common.configuration.specs import gcp_credentials from dlt.common.configuration.specs.exceptions import InvalidGoogleNativeCredentialsType +from dlt.common.schema.utils import new_table from dlt.common.storages import FileStorage from dlt.common.utils import digest128, uniq_id, custom_environ from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException +from dlt.destinations.impl.bigquery.bigquery_adapter import AUTODETECT_SCHEMA_HINT from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import json_case_path as common_json_case_path from tests.common.configuration.utils import environment @@ -217,15 +219,15 @@ def test_bigquery_configuration() -> None: assert config.fingerprint() == digest128("chat-analytics-rasa-ci") # credential location is deprecated - os.environ["CREDENTIALS__LOCATION"] = "EU" - config = resolve_configuration( - BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), - sections=("destination", "bigquery"), - ) - assert config.location == "US" - assert config.credentials.location == "EU" - # but if it is set, we propagate it to the config - assert config.get_location() == "EU" + # os.environ["CREDENTIALS__LOCATION"] = "EU" + # config = resolve_configuration( + # BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + # sections=("destination", "bigquery"), + # ) + # assert config.location == "US" + # assert config.credentials.location == "EU" + # # but if it is set, we propagate it to the config + # assert config.get_location() == "EU" os.environ["LOCATION"] = "ATLANTIS" config = resolve_configuration( BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), @@ -245,6 +247,27 @@ def test_bigquery_configuration() -> None: ) +def test_bigquery_autodetect_configuration(client: BigQueryClient) -> None: + # no schema autodetect + assert client._should_autodetect_schema("event_slot") is False + assert client._should_autodetect_schema("_dlt_loads") is False + # add parent table + child = new_table("event_slot__values", "event_slot") + client.schema.update_table(child) + assert client._should_autodetect_schema("event_slot__values") is False + # enable global config + client.config.autodetect_schema = True + assert client._should_autodetect_schema("event_slot") is True + assert client._should_autodetect_schema("_dlt_loads") is False + assert client._should_autodetect_schema("event_slot__values") is True + # enable hint per table + client.config.autodetect_schema = False + client.schema.get_table("event_slot")[AUTODETECT_SCHEMA_HINT] = True # type: ignore[typeddict-unknown-key] + assert client._should_autodetect_schema("event_slot") is True + assert client._should_autodetect_schema("_dlt_loads") is False + assert client._should_autodetect_schema("event_slot__values") is True + + def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: # non existing job with pytest.raises(LoadJobNotExistsException): @@ -290,7 +313,7 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) @pytest.mark.parametrize("location", ["US", "EU"]) def test_bigquery_location(location: str, file_storage: FileStorage, client) -> None: with cm_yield_client_with_storage( - "bigquery", default_config_values={"credentials": {"location": location}} + "bigquery", default_config_values={"location": location} ) as client: user_table_name = prepare_table(client) load_json = { diff --git a/tests/load/cases/loading/cve.json b/tests/load/cases/loading/cve.json new file mode 100644 index 0000000000..58796ed8c5 --- /dev/null +++ b/tests/load/cases/loading/cve.json @@ -0,0 +1,397 @@ +{ + "CVE_data_meta": { + "ASSIGNER": "security@apache.org", + "ID": "CVE-2021-44228", + "STATE": "PUBLIC", + "TITLE": "Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints" + }, + "affects": { + "vendor": { + "vendor_data": [ + { + "product": { + "product_data": [ + { + "product_name": "Apache Log4j2", + "version": { + "version_data": [ + { + "version_affected": ">=", + "version_name": "log4j-core", + "version_value": "2.0-beta9" + }, + { + "version_affected": "<", + "version_name": "log4j-core", + "version_value": "2.3.1" + }, + { + "version_affected": ">=", + "version_name": "log4j-core", + "version_value": "2.4" + }, + { + "version_affected": "<", + "version_name": "log4j-core" + }, + { + "version_affected": ">=", + "version_name": "log4j-core", + "version_value": "2.13.0" + }, + { + "version_affected": "<", + "version_name": "log4j-core", + "version_value": "2.15.0" + } + ] + } + } + ] + }, + "vendor_name": "Apache Software Foundation" + } + ] + } + }, + "credit": [ + { + "lang": "eng", + "value": "This issue was discovered by Chen Zhaojun of Alibaba Cloud Security Team." + } + ], + "data_format": "MITRE", + "data_type": "CVE", + "data_version": "4.0", + "description": { + "description_data": [ + { + "lang": "eng", + "value": "Apache Log4j2 2.0-beta9 through 2.15.0 (excluding security releases 2.12.2, 2.12.3, and 2.3.1) JNDI features used in configuration, log messages, and parameters do not protect against attacker controlled LDAP and other JNDI related endpoints. An attacker who can control log messages or log message parameters can execute arbitrary code loaded from LDAP servers when message lookup substitution is enabled. From log4j 2.15.0, this behavior has been disabled by default. From version 2.16.0 (along with 2.12.2, 2.12.3, and 2.3.1), this functionality has been completely removed. Note that this vulnerability is specific to log4j-core and does not affect log4net, log4cxx, or other Apache Logging Services projects." + } + ] + }, + "generator": { + "engine": "Vulnogram 0.0.9" + }, + "impact": [ + { + "other": "critical" + } + ], + "problemtype": { + "problemtype_data": [ + { + "description": [ + { + "lang": "eng", + "value": "CWE-502 Deserialization of Untrusted Data" + } + ] + }, + { + "description": [ + { + "lang": "eng", + "value": "CWE-400 Uncontrolled Resource Consumption" + } + ] + }, + { + "description": [ + { + "lang": "eng", + "value": "CWE-20 Improper Input Validation" + } + ] + } + ] + }, + "references": { + "reference_data": [ + { + "refsource": "MISC", + "url": "https://logging.apache.org/log4j/2.x/security.html", + "name": "https://logging.apache.org/log4j/2.x/security.html" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211210 CVE-2021-44228: Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints", + "url": "http://www.openwall.com/lists/oss-security/2021/12/10/1" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211210 Re: CVE-2021-44228: Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints", + "url": "http://www.openwall.com/lists/oss-security/2021/12/10/2" + }, + { + "refsource": "CISCO", + "name": "20211210 Vulnerability in Apache Log4j Library Affecting Cisco Products: December 2021", + "url": "https://tools.cisco.com/security/center/content/CiscoSecurityAdvisory/cisco-sa-apache-log4j-qRuKNEbd" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211210 Re: CVE-2021-44228: Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints", + "url": "http://www.openwall.com/lists/oss-security/2021/12/10/3" + }, + { + "refsource": "CONFIRM", + "name": "https://security.netapp.com/advisory/ntap-20211210-0007/", + "url": "https://security.netapp.com/advisory/ntap-20211210-0007/" + }, + { + "refsource": "CONFIRM", + "name": "https://security.netapp.com/advisory/ntap-20211210-0007/", + "url": "https://security.netapp.com/advisory/ntap-20211210-0007/" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html" + }, + { + "refsource": "CONFIRM", + "name": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032", + "url": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032" + }, + { + "refsource": "CONFIRM", + "name": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032", + "url": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032" + }, + { + "refsource": "CONFIRM", + "name": "https://www.oracle.com/security-alerts/alert-cve-2021-44228.html", + "url": "https://www.oracle.com/security-alerts/alert-cve-2021-44228.html" + }, + { + "refsource": "DEBIAN", + "name": "DSA-5020", + "url": "https://www.debian.org/security/2021/dsa-5020" + }, + { + "refsource": "MLIST", + "name": "[debian-lts-announce] 20211212 [SECURITY] [DLA 2842-1] apache-log4j2 security update", + "url": "https://lists.debian.org/debian-lts-announce/2021/12/msg00007.html" + }, + { + "refsource": "FEDORA", + "name": "FEDORA-2021-f0f501d01f", + "url": "https://lists.fedoraproject.org/archives/list/package-announce@lists.fedoraproject.org/message/VU57UJDCFIASIO35GC55JMKSRXJMCDFM/" + }, + { + "refsource": "MS", + "name": "Microsoft\u2019s Response to CVE-2021-44228 Apache Log4j 2", + "url": "https://msrc-blog.microsoft.com/2021/12/11/microsofts-response-to-cve-2021-44228-apache-log4j2/" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211213 Re: CVE-2021-4104: Deserialization of untrusted data in JMSAppender in Apache Log4j 1.2", + "url": "http://www.openwall.com/lists/oss-security/2021/12/13/2" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211213 CVE-2021-4104: Deserialization of untrusted data in JMSAppender in Apache Log4j 1.2", + "url": "http://www.openwall.com/lists/oss-security/2021/12/13/1" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211214 CVE-2021-45046: Apache Log4j2 Thread Context Message Pattern and Context Lookup Pattern vulnerable to a denial of service attack", + "url": "http://www.openwall.com/lists/oss-security/2021/12/14/4" + }, + { + "refsource": "CISCO", + "name": "20211210 A Vulnerability in Apache Log4j Library Affecting Cisco Products: December 2021", + "url": "https://tools.cisco.com/security/center/content/CiscoSecurityAdvisory/cisco-sa-apache-log4j-qRuKNEbd" + }, + { + "refsource": "CERT-VN", + "name": "VU#930724", + "url": "https://www.kb.cert.org/vuls/id/930724" + }, + { + "refsource": "MISC", + "name": "https://twitter.com/kurtseifried/status/1469345530182455296", + "url": "https://twitter.com/kurtseifried/status/1469345530182455296" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-661247.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-661247.pdf" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165260/VMware-Security-Advisory-2021-0028.html", + "url": "http://packetstormsecurity.com/files/165260/VMware-Security-Advisory-2021-0028.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165270/Apache-Log4j2-2.14.1-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165270/Apache-Log4j2-2.14.1-Remote-Code-Execution.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165261/Apache-Log4j2-2.14.1-Information-Disclosure.html", + "url": "http://packetstormsecurity.com/files/165261/Apache-Log4j2-2.14.1-Information-Disclosure.html" + }, + { + "refsource": "CONFIRM", + "name": "https://www.intel.com/content/www/us/en/security-center/advisory/intel-sa-00646.html", + "url": "https://www.intel.com/content/www/us/en/security-center/advisory/intel-sa-00646.html" + }, + { + "refsource": "CISCO", + "name": "20211210 Vulnerabilities in Apache Log4j Library Affecting Cisco Products: December 2021", + "url": "https://tools.cisco.com/security/center/content/CiscoSecurityAdvisory/cisco-sa-apache-log4j-qRuKNEbd" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211215 Re: CVE-2021-45046: Apache Log4j2 Thread Context Message Pattern and Context Lookup Pattern vulnerable to a denial of service attack", + "url": "http://www.openwall.com/lists/oss-security/2021/12/15/3" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165282/Log4j-Payload-Generator.html", + "url": "http://packetstormsecurity.com/files/165282/Log4j-Payload-Generator.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165281/Log4j2-Log4Shell-Regexes.html", + "url": "http://packetstormsecurity.com/files/165281/Log4j2-Log4Shell-Regexes.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165307/Log4j-Remote-Code-Execution-Word-Bypassing.html", + "url": "http://packetstormsecurity.com/files/165307/Log4j-Remote-Code-Execution-Word-Bypassing.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165311/log4j-scan-Extensive-Scanner.html", + "url": "http://packetstormsecurity.com/files/165311/log4j-scan-Extensive-Scanner.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165306/L4sh-Log4j-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165306/L4sh-Log4j-Remote-Code-Execution.html" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-714170.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-714170.pdf" + }, + { + "refsource": "FEDORA", + "name": "FEDORA-2021-66d6c484f3", + "url": "https://lists.fedoraproject.org/archives/list/package-announce@lists.fedoraproject.org/message/M5CSVUNV4HWZZXGOKNSK6L7RPM7BOKIB/" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165371/VMware-Security-Advisory-2021-0028.4.html", + "url": "http://packetstormsecurity.com/files/165371/VMware-Security-Advisory-2021-0028.4.html" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-397453.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-397453.pdf" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-479842.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-479842.pdf" + }, + { + "url": "https://www.oracle.com/security-alerts/cpujan2022.html", + "refsource": "MISC", + "name": "https://www.oracle.com/security-alerts/cpujan2022.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165532/Log4Shell-HTTP-Header-Injection.html", + "url": "http://packetstormsecurity.com/files/165532/Log4Shell-HTTP-Header-Injection.html" + }, + { + "refsource": "MISC", + "name": "https://github.com/cisagov/log4j-affected-db/blob/develop/SOFTWARE-LIST.md", + "url": "https://github.com/cisagov/log4j-affected-db/blob/develop/SOFTWARE-LIST.md" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165642/VMware-vCenter-Server-Unauthenticated-Log4Shell-JNDI-Injection-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165642/VMware-vCenter-Server-Unauthenticated-Log4Shell-JNDI-Injection-Remote-Code-Execution.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165673/UniFi-Network-Application-Unauthenticated-Log4Shell-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165673/UniFi-Network-Application-Unauthenticated-Log4Shell-Remote-Code-Execution.html" + }, + { + "refsource": "FULLDISC", + "name": "20220314 APPLE-SA-2022-03-14-7 Xcode 13.3", + "url": "http://seclists.org/fulldisclosure/2022/Mar/23" + }, + { + "refsource": "MISC", + "name": "https://www.bentley.com/en/common-vulnerability-exposure/be-2022-0001", + "url": "https://www.bentley.com/en/common-vulnerability-exposure/be-2022-0001" + }, + { + "refsource": "MISC", + "name": "https://github.com/cisagov/log4j-affected-db", + "url": "https://github.com/cisagov/log4j-affected-db" + }, + { + "refsource": "CONFIRM", + "name": "https://support.apple.com/kb/HT213189", + "url": "https://support.apple.com/kb/HT213189" + }, + { + "url": "https://www.oracle.com/security-alerts/cpuapr2022.html", + "refsource": "MISC", + "name": "https://www.oracle.com/security-alerts/cpuapr2022.html" + }, + { + "refsource": "MISC", + "name": "https://github.com/nu11secur1ty/CVE-mitre/tree/main/CVE-2021-44228", + "url": "https://github.com/nu11secur1ty/CVE-mitre/tree/main/CVE-2021-44228" + }, + { + "refsource": "MISC", + "name": "https://www.nu11secur1ty.com/2021/12/cve-2021-44228.html", + "url": "https://www.nu11secur1ty.com/2021/12/cve-2021-44228.html" + }, + { + "refsource": "FULLDISC", + "name": "20220721 Open-Xchange Security Advisory 2022-07-21", + "url": "http://seclists.org/fulldisclosure/2022/Jul/11" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/167794/Open-Xchange-App-Suite-7.10.x-Cross-Site-Scripting-Command-Injection.html", + "url": "http://packetstormsecurity.com/files/167794/Open-Xchange-App-Suite-7.10.x-Cross-Site-Scripting-Command-Injection.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/167917/MobileIron-Log4Shell-Remote-Command-Execution.html", + "url": "http://packetstormsecurity.com/files/167917/MobileIron-Log4Shell-Remote-Command-Execution.html" + }, + { + "refsource": "FULLDISC", + "name": "20221208 Intel Data Center Manager <= 5.1 Local Privileges Escalation", + "url": "http://seclists.org/fulldisclosure/2022/Dec/2" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/171626/AD-Manager-Plus-7122-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/171626/AD-Manager-Plus-7122-Remote-Code-Execution.html" + } + ] + }, + "source": { + "discovery": "UNKNOWN" + } +} \ No newline at end of file diff --git a/tests/load/clickhouse/clickhouse-compose.yml b/tests/load/clickhouse/clickhouse-compose.yml new file mode 100644 index 0000000000..b6415b120a --- /dev/null +++ b/tests/load/clickhouse/clickhouse-compose.yml @@ -0,0 +1,26 @@ +--- +services: + clickhouse: + image: clickhouse/clickhouse-server + ports: + - "9000:9000" + - "8123:8123" + environment: + - CLICKHOUSE_DB=dlt_data + - CLICKHOUSE_USER=loader + - CLICKHOUSE_PASSWORD=loader + - CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT=1 + volumes: + - clickhouse_data:/var/lib/clickhouse/ + - clickhouse_logs:/var/log/clickhouse-server/ + restart: unless-stopped + healthcheck: + test: [ "CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8123/ping" ] + interval: 3s + timeout: 5s + retries: 5 + + +volumes: + clickhouse_data: + clickhouse_logs: diff --git a/tests/load/clickhouse/test_clickhouse_adapter.py b/tests/load/clickhouse/test_clickhouse_adapter.py index ea3116c25b..e8e2b327c0 100644 --- a/tests/load/clickhouse/test_clickhouse_adapter.py +++ b/tests/load/clickhouse/test_clickhouse_adapter.py @@ -1,61 +1,112 @@ +from typing import Generator, Dict, cast + import dlt +from dlt.common.utils import custom_environ from dlt.destinations.adapters import clickhouse_adapter +from dlt.destinations.impl.clickhouse.sql_client import ClickHouseSqlClient +from dlt.destinations.impl.clickhouse.typing import TDeployment +from tests.load.clickhouse.utils import get_deployment_type from tests.pipeline.utils import assert_load_info def test_clickhouse_adapter() -> None: @dlt.resource - def merge_tree_resource(): + def merge_tree_resource() -> Generator[Dict[str, int], None, None]: yield {"field1": 1, "field2": 2} + # `ReplicatedMergeTree` has been supplanted by `ReplacingMergeTree` on CH Cloud, + # which is automatically selected even if `MergeTree` is selected. + # See https://clickhouse.com/docs/en/cloud/reference/shared-merge-tree. + + # The `Log` Family of engines are only supported in self-managed deployments. + # So can't test in CH Cloud CI. + @dlt.resource - def replicated_merge_tree_resource(): + def replicated_merge_tree_resource() -> Generator[Dict[str, int], None, None]: yield {"field1": 1, "field2": 2} @dlt.resource - def not_annotated_resource(): + def not_annotated_resource() -> Generator[Dict[str, int], None, None]: + """Non annotated resource will default to `SharedMergeTree` for CH cloud + and `MergeTree` for self-managed installation.""" yield {"field1": 1, "field2": 2} clickhouse_adapter(merge_tree_resource, table_engine_type="merge_tree") clickhouse_adapter(replicated_merge_tree_resource, table_engine_type="replicated_merge_tree") pipe = dlt.pipeline(pipeline_name="adapter_test", destination="clickhouse", dev_mode=True) - pack = pipe.run([merge_tree_resource, replicated_merge_tree_resource, not_annotated_resource]) + + with pipe.sql_client() as client: + deployment_type: TDeployment = get_deployment_type(cast(ClickHouseSqlClient, client)) + + if deployment_type == "ClickHouseCloud": + pack = pipe.run( + [ + merge_tree_resource, + replicated_merge_tree_resource, + not_annotated_resource, + ] + ) + else: + # `ReplicatedMergeTree` not supported if only a single node. + pack = pipe.run([merge_tree_resource, not_annotated_resource]) assert_load_info(pack) with pipe.sql_client() as client: - # get map of table names to full table names + # Get a map of table names to full table names. tables = {} for table in client._list_tables(): if "resource" in table: tables[table.split("___")[1]] = table - assert (len(tables.keys())) == 3 + if deployment_type == "ClickHouseCloud": + assert (len(tables.keys())) == 3 + else: + assert (len(tables.keys())) == 2 - # check content + # Check the table content. for full_table_name in tables.values(): with client.execute_query(f"SELECT * FROM {full_table_name};") as cursor: res = cursor.fetchall() assert tuple(res[0])[:2] == (1, 2) - # check table format - # fails now, because we do not have a cluster (I think), it will fall back to SharedMergeTree - for full_table_name in tables.values(): + # Check the table engine. + for table_name, full_table_name in tables.items(): with client.execute_query( - "SELECT database, name, engine, engine_full FROM system.tables WHERE name =" - f" '{full_table_name}';" + "SELECT database, name, engine, engine_full FROM system.tables " + f"WHERE name = '{full_table_name}';" ) as cursor: res = cursor.fetchall() - # this should test that two tables should be replicatedmergetree tables - assert tuple(res[0])[2] == "SharedMergeTree" + if table_name in ( + "merge_tree_resource", + "replicated_merge_tree_resource", + ): + if deployment_type == "ClickHouseCloud": + assert tuple(res[0])[2] in ( + "MergeTree", + "SharedMergeTree", + "ReplicatedMergeTree", + ) + else: + assert tuple(res[0])[2] in ("MergeTree",) + else: + # Non annotated resource needs to default to detected installation + # type, i.e. cloud or self-managed. + # CI runs on CH cloud, so will be `SharedMergeTree`. + if deployment_type == "ClickHouseCloud": + assert tuple(res[0])[2] == "SharedMergeTree" + else: + assert tuple(res[0])[2] == "MergeTree" - # we can check the gen table sql though + # We can check the generated table's SQL, though. with pipe.destination_client() as dest_client: - for table in tables.keys(): + for table in tables: sql = dest_client._get_table_update_sql( # type: ignore[attr-defined] - table, pipe.default_schema.tables[table]["columns"].values(), generate_alter=False + table, + pipe.default_schema.tables[table]["columns"].values(), + generate_alter=False, ) - if table == "merge_tree_resource": - assert "ENGINE = MergeTree" in sql[0] - else: + if table == "replicated_merge_tree_resource": assert "ENGINE = ReplicatedMergeTree" in sql[0] + else: + assert "ENGINE = MergeTree" or "ENGINE = SharedMergeTree" in sql[0] diff --git a/tests/load/clickhouse/test_clickhouse_configuration.py b/tests/load/clickhouse/test_clickhouse_configuration.py index eb02155406..a4e8abc8dd 100644 --- a/tests/load/clickhouse/test_clickhouse_configuration.py +++ b/tests/load/clickhouse/test_clickhouse_configuration.py @@ -1,8 +1,7 @@ -from typing import Any, Iterator +from typing import Iterator import pytest -import dlt from dlt.common.configuration.resolve import resolve_configuration from dlt.common.libs.sql_alchemy import make_url from dlt.common.utils import digest128 @@ -11,11 +10,6 @@ ClickHouseCredentials, ClickHouseClientConfiguration, ) -from dlt.destinations.impl.snowflake.configuration import ( - SnowflakeClientConfiguration, - SnowflakeCredentials, -) -from tests.common.configuration.utils import environment from tests.load.utils import yield_client_with_storage @@ -27,8 +21,8 @@ def client() -> Iterator[ClickHouseClient]: def test_clickhouse_connection_string_with_all_params() -> None: url = ( "clickhouse://user1:pass1@host1:9000/testdb?allow_experimental_lightweight_delete=1&" - "allow_experimental_object_type=1&connect_timeout=230&enable_http_compression=1&secure=0" - "&send_receive_timeout=1000" + "allow_experimental_object_type=1&connect_timeout=230&date_time_input_format=best_effort&" + "enable_http_compression=1&secure=0&send_receive_timeout=1000" ) creds = ClickHouseCredentials() @@ -53,15 +47,15 @@ def test_clickhouse_configuration() -> None: # def empty fingerprint assert ClickHouseClientConfiguration().fingerprint() == "" # based on host - c = resolve_configuration( - SnowflakeCredentials(), + config = resolve_configuration( + ClickHouseCredentials(), explicit_value="clickhouse://user1:pass1@host1:9000/db1", ) - assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128("host1") + assert ClickHouseClientConfiguration(credentials=config).fingerprint() == digest128("host1") def test_clickhouse_connection_settings(client: ClickHouseClient) -> None: - """Test experimental settings are set correctly for session.""" + """Test experimental settings are set correctly for the session.""" conn = client.sql_client.open_connection() cursor1 = conn.cursor() cursor2 = conn.cursor() @@ -74,3 +68,4 @@ def test_clickhouse_connection_settings(client: ClickHouseClient) -> None: assert ("allow_experimental_lightweight_delete", "1") in res assert ("enable_http_compression", "1") in res + assert ("date_time_input_format", "best_effort") in res diff --git a/tests/load/clickhouse/test_clickhouse_table_builder.py b/tests/load/clickhouse/test_clickhouse_table_builder.py index 867102dde9..433383b631 100644 --- a/tests/load/clickhouse/test_clickhouse_table_builder.py +++ b/tests/load/clickhouse/test_clickhouse_table_builder.py @@ -6,7 +6,6 @@ from dlt.common.schema import Schema from dlt.common.utils import custom_environ, digest128 from dlt.common.utils import uniq_id - from dlt.destinations import clickhouse from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient from dlt.destinations.impl.clickhouse.configuration import ( @@ -140,7 +139,9 @@ def test_clickhouse_alter_table(clickhouse_client: ClickHouseClient) -> None: @pytest.mark.usefixtures("empty_schema") -def test_clickhouse_create_table_with_primary_keys(clickhouse_client: ClickHouseClient) -> None: +def test_clickhouse_create_table_with_primary_keys( + clickhouse_client: ClickHouseClient, +) -> None: mod_update = deepcopy(TABLE_UPDATE) mod_update[1]["primary_key"] = True @@ -172,3 +173,28 @@ def test_clickhouse_create_table_with_hints(client: ClickHouseClient) -> None: # No hints. assert "`col3` boolean NOT NULL" in sql assert "`col4` timestamp with time zone NOT NULL" in sql + + +def test_clickhouse_table_engine_configuration() -> None: + with custom_environ( + { + "DESTINATION__CLICKHOUSE__CREDENTIALS__HOST": "localhost", + "DESTINATION__CLICKHOUSE__DATASET_NAME": f"test_{uniq_id()}", + } + ): + config = resolve_configuration( + ClickHouseClientConfiguration(), sections=("destination", "clickhouse") + ) + assert config.table_engine_type == "merge_tree" + + with custom_environ( + { + "DESTINATION__CLICKHOUSE__CREDENTIALS__HOST": "localhost", + "DESTINATION__CLICKHOUSE__TABLE_ENGINE_TYPE": "replicated_merge_tree", + "DESTINATION__CLICKHOUSE__DATASET_NAME": f"test_{uniq_id()}", + } + ): + config = resolve_configuration( + ClickHouseClientConfiguration(), sections=("destination", "clickhouse") + ) + assert config.table_engine_type == "replicated_merge_tree" diff --git a/tests/load/clickhouse/utils.py b/tests/load/clickhouse/utils.py new file mode 100644 index 0000000000..5c34d52148 --- /dev/null +++ b/tests/load/clickhouse/utils.py @@ -0,0 +1,9 @@ +from dlt.destinations.impl.clickhouse.sql_client import ClickHouseSqlClient +from dlt.destinations.impl.clickhouse.typing import TDeployment + + +def get_deployment_type(client: ClickHouseSqlClient) -> TDeployment: + cloud_mode = int(client.execute_sql(""" + SELECT value FROM system.settings WHERE name = 'cloud_mode' + """)[0][0]) + return "ClickHouseCloud" if cloud_mode else "ClickHouseOSS" diff --git a/tests/load/filesystem/test_object_store_rs_credentials.py b/tests/load/filesystem/test_object_store_rs_credentials.py index 524cd4425d..90530218d9 100644 --- a/tests/load/filesystem/test_object_store_rs_credentials.py +++ b/tests/load/filesystem/test_object_store_rs_credentials.py @@ -18,12 +18,19 @@ GcpServiceAccountCredentialsWithoutDefaults, GcpOAuthCredentialsWithoutDefaults, ) +from dlt.common.configuration.specs.exceptions import ObjectStoreRsCredentialsException + +from tests.load.utils import ( + AZ_BUCKET, + AWS_BUCKET, + GCS_BUCKET, + R2_BUCKET_CONFIG, + ALL_FILESYSTEM_DRIVERS, +) -from tests.load.utils import AZ_BUCKET, AWS_BUCKET, GCS_BUCKET, ALL_FILESYSTEM_DRIVERS - -if all(driver not in ALL_FILESYSTEM_DRIVERS for driver in ("az", "s3", "gs")): +if all(driver not in ALL_FILESYSTEM_DRIVERS for driver in ("az", "s3", "gs", "r2")): pytest.skip( - "Requires at least one of `az`, `s3`, `gs` in `ALL_FILESYSTEM_DRIVERS`.", + "Requires at least one of `az`, `s3`, `gs`, `r2` in `ALL_FILESYSTEM_DRIVERS`.", allow_module_level=True, ) @@ -53,10 +60,10 @@ def can_connect(bucket_url: str, object_store_rs_credentials: Dict[str, str]) -> return False -@pytest.mark.skipif( - "az" not in ALL_FILESYSTEM_DRIVERS, reason="`az` not in `ALL_FILESYSTEM_DRIVERS`" +@pytest.mark.parametrize( + "driver", [driver for driver in ALL_FILESYSTEM_DRIVERS if driver in ("az")] ) -def test_azure_object_store_rs_credentials() -> None: +def test_azure_object_store_rs_credentials(driver: str) -> None: creds: AnyAzureCredentials creds = AzureServicePrincipalCredentialsWithoutDefaults( @@ -78,55 +85,80 @@ def test_azure_object_store_rs_credentials() -> None: assert can_connect(AZ_BUCKET, creds.to_object_store_rs_credentials()) -@pytest.mark.skipif( - "s3" not in ALL_FILESYSTEM_DRIVERS, reason="`s3` not in `ALL_FILESYSTEM_DRIVERS`" +@pytest.mark.parametrize( + "driver", [driver for driver in ALL_FILESYSTEM_DRIVERS if driver in ("s3", "r2")] ) -def test_aws_object_store_rs_credentials() -> None: +def test_aws_object_store_rs_credentials(driver: str) -> None: creds: AwsCredentialsWithoutDefaults + fs_creds = FS_CREDS + if driver == "r2": + fs_creds = R2_BUCKET_CONFIG["credentials"] # type: ignore[assignment] + + # AwsCredentialsWithoutDefaults: no user-provided session token + creds = AwsCredentialsWithoutDefaults( + aws_access_key_id=fs_creds["aws_access_key_id"], + aws_secret_access_key=fs_creds["aws_secret_access_key"], + region_name=fs_creds.get("region_name"), + endpoint_url=fs_creds.get("endpoint_url"), + ) + assert creds.aws_session_token is None + object_store_rs_creds = creds.to_object_store_rs_credentials() + assert "aws_session_token" not in object_store_rs_creds # no auto-generated token + assert can_connect(AWS_BUCKET, object_store_rs_creds) + # AwsCredentials: no user-provided session token creds = AwsCredentials( - aws_access_key_id=FS_CREDS["aws_access_key_id"], - aws_secret_access_key=FS_CREDS["aws_secret_access_key"], - # region_name must be configured in order for data lake to work - region_name=FS_CREDS["region_name"], + aws_access_key_id=fs_creds["aws_access_key_id"], + aws_secret_access_key=fs_creds["aws_secret_access_key"], + region_name=fs_creds.get("region_name"), + endpoint_url=fs_creds.get("endpoint_url"), ) assert creds.aws_session_token is None object_store_rs_creds = creds.to_object_store_rs_credentials() - assert object_store_rs_creds["aws_session_token"] is not None # auto-generated token + assert "aws_session_token" not in object_store_rs_creds # no auto-generated token assert can_connect(AWS_BUCKET, object_store_rs_creds) + # exception should be raised if both `endpoint_url` and `region_name` are + # not provided + with pytest.raises(ObjectStoreRsCredentialsException): + AwsCredentials( + aws_access_key_id=fs_creds["aws_access_key_id"], + aws_secret_access_key=fs_creds["aws_secret_access_key"], + ).to_object_store_rs_credentials() + + if "endpoint_url" in object_store_rs_creds: + # TODO: make sure this case is tested on GitHub CI, e.g. by adding + # a local MinIO bucket to the set of tested buckets + if object_store_rs_creds["endpoint_url"].startswith("http://"): + assert object_store_rs_creds["aws_allow_http"] == "true" + + # remainder of tests use session tokens + # we don't run them on S3 compatible storage because session tokens + # may not be available + return + # AwsCredentials: user-provided session token # use previous credentials to create session token for new credentials + assert isinstance(creds, AwsCredentials) sess_creds = creds.to_session_credentials() creds = AwsCredentials( aws_access_key_id=sess_creds["aws_access_key_id"], aws_secret_access_key=cast(TSecretStrValue, sess_creds["aws_secret_access_key"]), aws_session_token=cast(TSecretStrValue, sess_creds["aws_session_token"]), - region_name=FS_CREDS["region_name"], + region_name=fs_creds["region_name"], ) assert creds.aws_session_token is not None object_store_rs_creds = creds.to_object_store_rs_credentials() assert object_store_rs_creds["aws_session_token"] is not None assert can_connect(AWS_BUCKET, object_store_rs_creds) - # AwsCredentialsWithoutDefaults: no user-provided session token - creds = AwsCredentialsWithoutDefaults( - aws_access_key_id=FS_CREDS["aws_access_key_id"], - aws_secret_access_key=FS_CREDS["aws_secret_access_key"], - region_name=FS_CREDS["region_name"], - ) - assert creds.aws_session_token is None - object_store_rs_creds = creds.to_object_store_rs_credentials() - assert "aws_session_token" not in object_store_rs_creds # no auto-generated token - assert can_connect(AWS_BUCKET, object_store_rs_creds) - # AwsCredentialsWithoutDefaults: user-provided session token creds = AwsCredentialsWithoutDefaults( aws_access_key_id=sess_creds["aws_access_key_id"], aws_secret_access_key=cast(TSecretStrValue, sess_creds["aws_secret_access_key"]), aws_session_token=cast(TSecretStrValue, sess_creds["aws_session_token"]), - region_name=FS_CREDS["region_name"], + region_name=fs_creds["region_name"], ) assert creds.aws_session_token is not None object_store_rs_creds = creds.to_object_store_rs_credentials() @@ -134,10 +166,10 @@ def test_aws_object_store_rs_credentials() -> None: assert can_connect(AWS_BUCKET, object_store_rs_creds) -@pytest.mark.skipif( - "gs" not in ALL_FILESYSTEM_DRIVERS, reason="`gs` not in `ALL_FILESYSTEM_DRIVERS`" +@pytest.mark.parametrize( + "driver", [driver for driver in ALL_FILESYSTEM_DRIVERS if driver in ("gs")] ) -def test_gcp_object_store_rs_credentials() -> None: +def test_gcp_object_store_rs_credentials(driver) -> None: creds = GcpServiceAccountCredentialsWithoutDefaults( project_id=FS_CREDS["project_id"], private_key=FS_CREDS["private_key"], diff --git a/tests/load/pipeline/test_bigquery.py b/tests/load/pipeline/test_bigquery.py index f4fdef8665..fd0a55e273 100644 --- a/tests/load/pipeline/test_bigquery.py +++ b/tests/load/pipeline/test_bigquery.py @@ -1,6 +1,9 @@ import pytest +import io -from dlt.common import Decimal +import dlt +from dlt.common import Decimal, json +from dlt.common.typing import TLoaderFileFormat from tests.pipeline.utils import assert_load_info from tests.load.utils import destinations_configs, DestinationTestConfiguration @@ -38,3 +41,107 @@ def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration row = q.fetchone() assert row[0] == data[0]["col_big_numeric"] assert row[1] == data[0]["col_numeric"] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["bigquery"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("file_format", ("parquet", "jsonl")) +def test_bigquery_autodetect_schema( + destination_config: DestinationTestConfiguration, file_format: TLoaderFileFormat +) -> None: + from dlt.destinations.adapters import bigquery_adapter + from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient + + @dlt.resource(name="cve", max_table_nesting=0, file_format=file_format) + def load_cve(stage: int): + with open("tests/load/cases/loading/cve.json", "rb") as f: + cve = json.load(f) + if stage == 0: + # remove a whole struct field + del cve["references"] + if stage == 1: + # remove a field from struct + for item in cve["references"]["reference_data"]: + del item["refsource"] + if file_format == "jsonl": + yield cve + else: + import pyarrow.json as paj + + table = paj.read_json(io.BytesIO(json.dumpb(cve))) + yield table + + pipeline = destination_config.setup_pipeline("test_bigquery_autodetect_schema", dev_mode=True) + # run without one nested field + cve = bigquery_adapter(load_cve(0), autodetect_schema=True) + info = pipeline.run(cve) + assert_load_info(info) + client: BigQuerySqlClient + with pipeline.sql_client() as client: # type: ignore[assignment] + table = client.native_connection.get_table( + client.make_qualified_table_name("cve", escape=False) + ) + field = next(field for field in table.schema if field.name == "source") + # not repeatable + assert field.field_type == "RECORD" + assert field.mode == "NULLABLE" + field = next(field for field in table.schema if field.name == "credit") + if file_format == "parquet": + # parquet wraps struct into repeatable list + field = field.fields[0] + assert field.name == "list" + assert field.field_type == "RECORD" + assert field.mode == "REPEATED" + # no references + field = next((field for field in table.schema if field.name == "references"), None) + assert field is None + + # evolve schema - add nested field + cve = bigquery_adapter(load_cve(1), autodetect_schema=True) + info = pipeline.run(cve) + assert_load_info(info) + with pipeline.sql_client() as client: # type: ignore[assignment] + table = client.native_connection.get_table( + client.make_qualified_table_name("cve", escape=False) + ) + field = next(field for field in table.schema if field.name == "references") + field = field.fields[0] + assert field.name == "reference_data" + if file_format == "parquet": + # parquet wraps struct into repeatable list + field = field.fields[0] + assert field.name == "list" + assert field.mode == "REPEATED" + # and enclosed in another type 🤷 + field = field.fields[0] + else: + assert field.mode == "REPEATED" + # make sure url is there + nested_field = next(f for f in field.fields if f.name == "url") + assert nested_field.field_type == "STRING" + # refsource not there + nested_field = next((f for f in field.fields if f.name == "refsource"), None) + assert nested_field is None + + # evolve schema - add field to a nested struct + cve = bigquery_adapter(load_cve(2), autodetect_schema=True) + info = pipeline.run(cve) + assert_load_info(info) + with pipeline.sql_client() as client: # type: ignore[assignment] + table = client.native_connection.get_table( + client.make_qualified_table_name("cve", escape=False) + ) + field = next(field for field in table.schema if field.name == "references") + field = field.fields[0] + if file_format == "parquet": + # parquet wraps struct into repeatable list + field = field.fields[0] + assert field.name == "list" + assert field.mode == "REPEATED" + # and enclosed in another type 🤷 + field = field.fields[0] + # it looks like BigQuery can evolve structs and the field is added + nested_field = next(f for f in field.fields if f.name == "refsource") diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 3f0352cab7..7ad571f2aa 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -3,6 +3,8 @@ import posixpath from pathlib import Path from typing import Any, Callable, List, Dict, cast +from importlib.metadata import version as pkg_version +from packaging.version import Version from pytest_mock import MockerFixture import dlt @@ -12,6 +14,7 @@ from dlt.common import pendulum from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.utils import uniq_id +from dlt.common.exceptions import DependencyVersionException from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders @@ -24,6 +27,8 @@ from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, + MEMORY_BUCKET, + FILE_BUCKET, ) from tests.pipeline.utils import load_table_counts, assert_load_info, load_tables_to_dicts @@ -32,12 +37,6 @@ skip_if_not_active("filesystem") -@pytest.fixture -def local_filesystem_pipeline() -> dlt.Pipeline: - os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "_storage" - return dlt.pipeline(pipeline_name="fs_pipe", destination="filesystem", dev_mode=True) - - def test_pipeline_merge_write_disposition(default_buckets_env: str) -> None: """Run pipeline twice with merge write disposition Regardless wether primary key is set or not, filesystem appends @@ -223,28 +222,50 @@ def some_source(): assert table.column("value").to_pylist() == [1, 2, 3, 4, 5] +def test_delta_table_pyarrow_version_check() -> None: + """Tests pyarrow version checking for `delta` table format. + + DependencyVersionException should be raised if pyarrow<17.0.0. + """ + # test intentionally does not use destination_configs(), because that + # function automatically marks `delta` table format configs as + # `needspyarrow17`, which should not happen for this test to run in an + # environment where pyarrow<17.0.0 + + assert Version(pkg_version("pyarrow")) < Version("17.0.0"), "test assumes `pyarrow<17.0.0`" + + @dlt.resource(table_format="delta") + def foo(): + yield {"foo": 1, "bar": 2} + + pipeline = dlt.pipeline(destination=filesystem(FILE_BUCKET)) + + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(foo()) + assert isinstance(pip_ex.value.__context__, DependencyVersionException) + + @pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_exclude=(MEMORY_BUCKET), + ), + ids=lambda x: x.name, +) def test_delta_table_core( - default_buckets_env: str, - local_filesystem_pipeline: dlt.Pipeline, + destination_config: DestinationTestConfiguration, ) -> None: """Tests core functionality for `delta` table format. - Tests all data types, all filesystems, all write dispositions. + Tests all data types, all filesystems. + Tests `append` and `replace` write dispositions (`merge` is tested elsewhere). """ from tests.pipeline.utils import _get_delta_table - if default_buckets_env.startswith("memory://"): - pytest.skip( - "`deltalake` library does not support `memory` protocol (write works, read doesn't)" - ) - if default_buckets_env.startswith("s3://"): - # https://delta-io.github.io/delta-rs/usage/writing/writing-to-s3-with-locking-provider/ - os.environ["DESTINATION__FILESYSTEM__DELTALAKE_STORAGE_OPTIONS"] = ( - '{"AWS_S3_ALLOW_UNSAFE_RENAME": "true"}' - ) - # create resource that yields rows with all data types column_schemas, row = table_update_and_row() @@ -253,8 +274,10 @@ def data_types(): nonlocal row yield [row] * 10 + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + # run pipeline, this should create Delta table - info = local_filesystem_pipeline.run(data_types()) + info = pipeline.run(data_types()) assert_load_info(info) # `delta` table format should use `parquet` file format @@ -266,42 +289,37 @@ def data_types(): # 10 rows should be loaded to the Delta table and the content of the first # row should match expected values - rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[ - "data_types" - ] + rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"] assert len(rows) == 10 assert_all_data_types_row(rows[0], schema=column_schemas) # another run should append rows to the table - info = local_filesystem_pipeline.run(data_types()) + info = pipeline.run(data_types()) assert_load_info(info) - rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[ - "data_types" - ] + rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"] assert len(rows) == 20 # ensure "replace" write disposition is handled # should do logical replace, increasing the table version - info = local_filesystem_pipeline.run(data_types(), write_disposition="replace") + info = pipeline.run(data_types(), write_disposition="replace") assert_load_info(info) - client = cast(FilesystemClient, local_filesystem_pipeline.destination_client()) + client = cast(FilesystemClient, pipeline.destination_client()) assert _get_delta_table(client, "data_types").version() == 2 - rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[ - "data_types" - ] + rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"] assert len(rows) == 10 - # `merge` resolves to `append` behavior - info = local_filesystem_pipeline.run(data_types(), write_disposition="merge") - assert_load_info(info) - rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[ - "data_types" - ] - assert len(rows) == 20 - +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) def test_delta_table_multiple_files( - local_filesystem_pipeline: dlt.Pipeline, + destination_config: DestinationTestConfiguration, ) -> None: """Tests loading multiple files into a Delta table. @@ -316,7 +334,9 @@ def test_delta_table_multiple_files( def delta_table(): yield [{"foo": True}] * 10 - info = local_filesystem_pipeline.run(delta_table()) + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + info = pipeline.run(delta_table()) assert_load_info(info) # multiple Parquet files should have been created @@ -330,16 +350,23 @@ def delta_table(): assert len(delta_table_parquet_jobs) == 5 # 10 records, max 2 per file # all 10 records should have been loaded into a Delta table in a single commit - client = cast(FilesystemClient, local_filesystem_pipeline.destination_client()) + client = cast(FilesystemClient, pipeline.destination_client()) assert _get_delta_table(client, "delta_table").version() == 0 - rows = load_tables_to_dicts(local_filesystem_pipeline, "delta_table", exclude_system_cols=True)[ - "delta_table" - ] + rows = load_tables_to_dicts(pipeline, "delta_table", exclude_system_cols=True)["delta_table"] assert len(rows) == 10 +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) def test_delta_table_child_tables( - local_filesystem_pipeline: dlt.Pipeline, + destination_config: DestinationTestConfiguration, ) -> None: """Tests child table handling for `delta` table format.""" @@ -358,10 +385,12 @@ def complex_table(): }, ] - info = local_filesystem_pipeline.run(complex_table()) + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + info = pipeline.run(complex_table()) assert_load_info(info) rows_dict = load_tables_to_dicts( - local_filesystem_pipeline, + pipeline, "complex_table", "complex_table__child", "complex_table__child__grandchild", @@ -377,10 +406,10 @@ def complex_table(): assert rows_dict["complex_table__child__grandchild"][0].keys() == {"value"} # test write disposition handling with child tables - info = local_filesystem_pipeline.run(complex_table()) + info = pipeline.run(complex_table()) assert_load_info(info) rows_dict = load_tables_to_dicts( - local_filesystem_pipeline, + pipeline, "complex_table", "complex_table__child", "complex_table__child__grandchild", @@ -390,10 +419,10 @@ def complex_table(): assert len(rows_dict["complex_table__child"]) == 3 * 2 assert len(rows_dict["complex_table__child__grandchild"]) == 5 * 2 - info = local_filesystem_pipeline.run(complex_table(), write_disposition="replace") + info = pipeline.run(complex_table(), write_disposition="replace") assert_load_info(info) rows_dict = load_tables_to_dicts( - local_filesystem_pipeline, + pipeline, "complex_table", "complex_table__child", "complex_table__child__grandchild", @@ -404,8 +433,108 @@ def complex_table(): assert len(rows_dict["complex_table__child__grandchild"]) == 5 +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) +def test_delta_table_empty_source( + destination_config: DestinationTestConfiguration, +) -> None: + """Tests empty source handling for `delta` table format. + + Tests both empty Arrow table and `dlt.mark.materialize_table_schema()`. + """ + from dlt.common.libs.pyarrow import pyarrow as pa + from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_data + from tests.pipeline.utils import _get_delta_table, users_materialize_table_schema + + @dlt.resource(table_format="delta") + def delta_table(data): + yield data + + # create empty Arrow table with schema + arrow_table = arrow_table_all_data_types( + "arrow-table", + include_decimal_default_precision=False, + include_decimal_arrow_max_precision=True, + include_not_normalized_name=False, + include_null=False, + num_rows=2, + )[0] + empty_arrow_table = arrow_table.schema.empty_table() + assert empty_arrow_table.num_rows == 0 # it's empty + assert empty_arrow_table.schema.equals(arrow_table.schema) # it has a schema + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + # run 1: empty Arrow table with schema + # this should create empty Delta table with same schema as Arrow table + info = pipeline.run(delta_table(empty_arrow_table)) + assert_load_info(info) + client = cast(FilesystemClient, pipeline.destination_client()) + dt = _get_delta_table(client, "delta_table") + assert dt.version() == 0 + dt_arrow_table = dt.to_pyarrow_table() + assert dt_arrow_table.shape == (0, empty_arrow_table.num_columns) + assert dt_arrow_table.schema.equals( + ensure_delta_compatible_arrow_data(empty_arrow_table).schema + ) + + # run 2: non-empty Arrow table with same schema as run 1 + # this should load records into Delta table + info = pipeline.run(delta_table(arrow_table)) + assert_load_info(info) + dt = _get_delta_table(client, "delta_table") + assert dt.version() == 1 + dt_arrow_table = dt.to_pyarrow_table() + assert dt_arrow_table.shape == (2, empty_arrow_table.num_columns) + assert dt_arrow_table.schema.equals( + ensure_delta_compatible_arrow_data(empty_arrow_table).schema + ) + + # run 3: empty Arrow table with different schema + # this should not alter the Delta table + empty_arrow_table_2 = pa.schema( + [pa.field("foo", pa.int64()), pa.field("bar", pa.string())] + ).empty_table() + + info = pipeline.run(delta_table(empty_arrow_table_2)) + assert_load_info(info) + dt = _get_delta_table(client, "delta_table") + assert dt.version() == 1 # still 1, no new commit was done + dt_arrow_table = dt.to_pyarrow_table() + assert dt_arrow_table.shape == (2, empty_arrow_table.num_columns) # shape did not change + assert dt_arrow_table.schema.equals( # schema did not change + ensure_delta_compatible_arrow_data(empty_arrow_table).schema + ) + + # test `dlt.mark.materialize_table_schema()` + users_materialize_table_schema.apply_hints(table_format="delta") + info = pipeline.run(users_materialize_table_schema()) + assert_load_info(info) + dt = _get_delta_table(client, "users") + assert dt.version() == 0 + dt_arrow_table = dt.to_pyarrow_table() + assert dt_arrow_table.num_rows == 0 + assert "id", "name" == dt_arrow_table.schema.names[:2] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) def test_delta_table_mixed_source( - local_filesystem_pipeline: dlt.Pipeline, + destination_config: DestinationTestConfiguration, ) -> None: """Tests file format handling in mixed source. @@ -424,9 +553,9 @@ def non_delta_table(): def s(): return [delta_table(), non_delta_table()] - info = local_filesystem_pipeline.run( - s(), loader_file_format="jsonl" - ) # set file format at pipeline level + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + info = pipeline.run(s(), loader_file_format="jsonl") # set file format at pipeline level assert_load_info(info) completed_jobs = info.load_packages[0].jobs["completed_jobs"] @@ -444,8 +573,17 @@ def s(): assert non_delta_table_job.file_path.endswith(".jsonl") +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) def test_delta_table_dynamic_dispatch( - local_filesystem_pipeline: dlt.Pipeline, + destination_config: DestinationTestConfiguration, ) -> None: @dlt.resource(primary_key="id", table_name=lambda i: i["type"], table_format="delta") def github_events(): @@ -454,7 +592,9 @@ def github_events(): ) as f: yield json.load(f) - info = local_filesystem_pipeline.run(github_events()) + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + info = pipeline.run(github_events()) assert_load_info(info) completed_jobs = info.load_packages[0].jobs["completed_jobs"] # 20 event types, two jobs per table (.parquet and .reference), 1 job for _dlt_pipeline_state diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 2c1d1346f1..63188d4f5e 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -11,18 +11,29 @@ from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext from dlt.common.schema.utils import has_table_seen_data -from dlt.common.schema.exceptions import SchemaException +from dlt.common.schema.exceptions import SchemaCorruptedException +from dlt.common.schema.typing import TLoaderMergeStrategy from dlt.common.typing import StrAny from dlt.common.utils import digest128 +from dlt.common.destination import TDestination +from dlt.common.destination.exceptions import DestinationCapabilitiesException from dlt.extract import DltResource from dlt.sources.helpers.transform import skip_first, take_first from dlt.pipeline.exceptions import PipelineStepFailed -from tests.pipeline.utils import assert_load_info, load_table_counts, select_data +from tests.pipeline.utils import ( + assert_load_info, + load_table_counts, + select_data, + load_tables_to_dicts, + assert_records_as_set, +) from tests.load.utils import ( normalize_storage_table_cols, destinations_configs, DestinationTestConfiguration, + FILE_BUCKET, + AZ_BUCKET, ) # uncomment add motherduck tests @@ -30,13 +41,38 @@ # ACTIVE_DESTINATIONS += ["motherduck"] +def skip_if_not_supported( + merge_strategy: TLoaderMergeStrategy, + destination: TDestination, +) -> None: + if merge_strategy not in destination.capabilities().supported_merge_strategies: + pytest.skip( + f"`{merge_strategy}` merge strategy not supported for `{destination.destination_name}`" + " destination." + ) + + @pytest.mark.essential @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs( + default_sql_configs=True, + all_buckets_filesystem_configs=True, + table_format_filesystem_configs=True, + supports_merge=True, + bucket_subset=(FILE_BUCKET, AZ_BUCKET), # test one local, one remote + ), + ids=lambda x: x.name, ) -def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_merge_on_keys_in_schema( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: p = destination_config.setup_pipeline("eth_2", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) + with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) @@ -45,18 +81,22 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio del schema.tables["blocks__uncles"]["x-normalizer"] assert not has_table_seen_data(schema.tables["blocks__uncles"]) - with open( - "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", - "r", - encoding="utf-8", - ) as f: - data = json.load(f) + @dlt.resource( + table_name="blocks", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + table_format=destination_config.table_format, + ) + def data(slice_: slice = None): + with open( + "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", + "r", + encoding="utf-8", + ) as f: + yield json.load(f) if slice_ is None else json.load(f)[slice_] # take only the first block. the first block does not have uncles so this table should not be created and merged info = p.run( - data[:1], - table_name="blocks", - write_disposition="merge", + data(slice(1)), schema=schema, loader_file_format=destination_config.file_format, ) @@ -73,8 +113,6 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # if the table would be created before the whole load would fail because new columns have hints info = p.run( data, - table_name="blocks", - write_disposition="merge", schema=schema, loader_file_format=destination_config.file_format, ) @@ -84,8 +122,6 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # make sure we have same record after merging full dataset again info = p.run( data, - table_name="blocks", - write_disposition="merge", schema=schema, loader_file_format=destination_config.file_format, ) @@ -97,22 +133,155 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio assert eth_2_counts == eth_3_counts +@pytest.mark.essential @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs( + default_sql_configs=True, + local_filesystem_configs=True, + table_format_filesystem_configs=True, + supports_merge=True, + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_merge_record_updates( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: + p = destination_config.setup_pipeline("test_merge_record_updates", dev_mode=True) + + skip_if_not_supported(merge_strategy, p.destination) + + @dlt.resource( + table_name="parent", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + primary_key="id", + table_format=destination_config.table_format, + ) + def r(data): + yield data + + # initial load + run_1 = [ + {"id": 1, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, + {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, + ] + info = p.run(r(run_1)) + assert_load_info(info) + assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { + "parent": 2, + "parent__child": 2, + "parent__child__grandchild": 2, + } + tables = load_tables_to_dicts(p, "parent", exclude_system_cols=True) + assert_records_as_set( + tables["parent"], + [ + {"id": 1, "foo": 1}, + {"id": 2, "foo": 1}, + ], + ) + + # update record — change at parent level + run_2 = [ + {"id": 1, "foo": 2, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, + {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, + ] + info = p.run(r(run_2)) + assert_load_info(info) + assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { + "parent": 2, + "parent__child": 2, + "parent__child__grandchild": 2, + } + tables = load_tables_to_dicts(p, "parent", exclude_system_cols=True) + assert_records_as_set( + tables["parent"], + [ + {"id": 1, "foo": 2}, + {"id": 2, "foo": 1}, + ], + ) + + # update record — change at child level + run_3 = [ + {"id": 1, "foo": 2, "child": [{"bar": 2, "grandchild": [{"baz": 1}]}]}, + {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, + ] + info = p.run(r(run_3)) + assert_load_info(info) + assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { + "parent": 2, + "parent__child": 2, + "parent__child__grandchild": 2, + } + tables = load_tables_to_dicts(p, "parent", "parent__child", exclude_system_cols=True) + assert_records_as_set( + tables["parent__child"], + [ + {"bar": 2}, + {"bar": 1}, + ], + ) + + # update record — change at grandchild level + run_3 = [ + {"id": 1, "foo": 2, "child": [{"bar": 2, "grandchild": [{"baz": 2}]}]}, + {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, + ] + info = p.run(r(run_3)) + assert_load_info(info) + assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { + "parent": 2, + "parent__child": 2, + "parent__child__grandchild": 2, + } + tables = load_tables_to_dicts(p, "parent__child__grandchild", exclude_system_cols=True) + assert_records_as_set( + tables["parent__child__grandchild"], + [ + {"baz": 2}, + {"baz": 1}, + ], + ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, + local_filesystem_configs=True, + table_format_filesystem_configs=True, + supports_merge=True, + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, ) -def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_merge_on_ad_hoc_primary_key( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: p = destination_config.setup_pipeline("github_1", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) - with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" - ) as f: - data = json.load(f) - # note: NodeId will be normalized to "node_id" which exists in the schema - info = p.run( - data[:17], + @dlt.resource( table_name="issues", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, primary_key="NodeId", + table_format=destination_config.table_format, + ) + def data(slice_: slice = None): + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: + yield json.load(f) if slice_ is None else json.load(f)[slice_] + + # note: NodeId will be normalized to "node_id" which exists in the schema + info = p.run( + data(slice(0, 17)), loader_file_format=destination_config.file_format, ) assert_load_info(info) @@ -125,10 +294,7 @@ def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfigur assert p.default_schema.tables["issues"]["columns"]["node_id"]["nullable"] is False info = p.run( - data[5:], - table_name="issues", - write_disposition="merge", - primary_key="node_id", + data(slice(5, None)), loader_file_format=destination_config.file_format, ) assert_load_info(info) @@ -572,30 +738,57 @@ def duplicates_no_child(): @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, supports_merge=True), + destinations_configs( + default_sql_configs=True, + local_filesystem_configs=True, + table_format_filesystem_configs=True, + supports_merge=True, + bucket_subset=(FILE_BUCKET), + ), ids=lambda x: x.name, ) -def test_complex_column_missing(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_complex_column_missing( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: + if destination_config.table_format == "delta": + pytest.skip( + "Record updates that involve removing elements from a complex" + " column is not supported for `delta` table format." + ) + table_name = "test_complex_column_missing" - @dlt.resource(name=table_name, write_disposition="merge", primary_key="id") + @dlt.resource( + name=table_name, + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + primary_key="id", + table_format=destination_config.table_format, + ) def r(data): yield data p = destination_config.setup_pipeline("abstract", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) - data = [{"id": 1, "simple": "foo", "complex": [1, 2, 3]}] + data = [ + {"id": 1, "simple": "foo", "complex": [1, 2, 3]}, + {"id": 2, "simple": "foo", "complex": [1, 2]}, + ] info = p.run(r(data), loader_file_format=destination_config.file_format) assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 1 - assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 3 + assert load_table_counts(p, table_name)[table_name] == 2 + assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 5 # complex column is missing, previously inserted records should be deleted from child table - data = [{"id": 1, "simple": "bar"}] + data = [ + {"id": 1, "simple": "bar"}, + ] info = p.run(r(data), loader_file_format=destination_config.file_format) assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 1 - assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 0 + assert load_table_counts(p, table_name)[table_name] == 2 + assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 2 @pytest.mark.parametrize( @@ -604,14 +797,21 @@ def r(data): ids=lambda x: x.name, ) @pytest.mark.parametrize("key_type", ["primary_key", "merge_key", "no_key"]) -def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_hard_delete_hint( + destination_config: DestinationTestConfiguration, + key_type: str, + merge_strategy: TLoaderMergeStrategy, +) -> None: + if merge_strategy == "upsert" and key_type != "primary_key": + pytest.skip("`upsert` merge strategy requires `primary_key`") # no_key setting will have the effect that hard deletes have no effect, since hard delete records # can not be matched table_name = "test_hard_delete_hint" @dlt.resource( name=table_name, - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, columns={"deleted": {"hard_delete": True}}, ) def data_resource(data): @@ -626,6 +826,7 @@ def data_resource(data): pass p = destination_config.setup_pipeline(f"abstract_{key_type}", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) # insert two records data = [ @@ -667,6 +868,8 @@ def data_resource(data): {"id": 3, "val": "foo", "deleted": False}, {"id": 3, "val": "bar", "deleted": False}, ] + if merge_strategy == "upsert": + del data[0] # `upsert` requires unique `primary_key` info = p.run(data_resource(data), loader_file_format=destination_config.file_format) assert_load_info(info) counts = load_table_counts(p, table_name)[table_name] @@ -759,12 +962,16 @@ def data_resource(data): destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name, ) -def test_hard_delete_hint_config(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_hard_delete_hint_config( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: table_name = "test_hard_delete_hint_non_bool" @dlt.resource( name=table_name, - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, primary_key="id", columns={ "deleted_timestamp": {"data_type": "timestamp", "nullable": True, "hard_delete": True} @@ -774,6 +981,7 @@ def data_resource(data): yield data p = destination_config.setup_pipeline("abstract", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) # insert two records data = [ @@ -982,17 +1190,55 @@ def r(): info = p.run(r(), loader_file_format=destination_config.file_format) +def test_merge_strategy_config() -> None: + # merge strategy invalid + with pytest.raises(ValueError): + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "foo"}) # type: ignore[call-overload] + def invalid_resource(): + yield {"foo": "bar"} + + p = dlt.pipeline( + pipeline_name="dummy_pipeline", + destination="dummy", + full_refresh=True, + ) + + # merge strategy not supported by destination + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "scd2"}) + def r(): + yield {"foo": "bar"} + + assert "scd2" not in p.destination.capabilities().supported_merge_strategies + with pytest.raises(DestinationCapabilitiesException): + p.run(r()) + + @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=["duckdb"]), + destinations_configs( + default_sql_configs=True, + table_format_filesystem_configs=True, + supports_merge=True, + subset=["postgres", "filesystem"], # test one SQL and one non-SQL destination + ), ids=lambda x: x.name, ) -def test_invalid_merge_strategy(destination_config: DestinationTestConfiguration) -> None: - @dlt.resource(write_disposition={"disposition": "merge", "strategy": "foo"}) # type: ignore[call-overload] +def test_upsert_merge_strategy_config(destination_config: DestinationTestConfiguration) -> None: + if destination_config.destination == "filesystem": + # TODO: implement validation and remove this test exception + pytest.skip( + "`upsert` merge strategy configuration validation has not yet been" + " implemented for `fileystem` destination." + ) + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "upsert"}) def r(): yield {"foo": "bar"} - p = destination_config.setup_pipeline("abstract", dev_mode=True) + # `upsert` merge strategy without `primary_key` should error + p = destination_config.setup_pipeline("upsert_pipeline", dev_mode=True) + assert "primary_key" not in r._hints with pytest.raises(PipelineStepFailed) as pip_ex: p.run(r()) - assert isinstance(pip_ex.value.__context__, SchemaException) + assert isinstance(pip_ex.value.__context__, SchemaCorruptedException) diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index d263f165b7..c3968e2e74 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -11,6 +11,7 @@ from dlt.common.schema.utils import normalize_table_identifiers from dlt.common.utils import uniq_id from dlt.common.destination.exceptions import DestinationUndefinedEntity +from dlt.common.destination.reference import WithStateSync from dlt.load import Load from dlt.pipeline.exceptions import SqlClientNotAvailable @@ -453,8 +454,19 @@ def test_ignore_state_unfinished_load(destination_config: DestinationTestConfigu @dlt.resource def some_data(param: str) -> Any: dlt.current.source_state()[param] = param - yield param + yield {"col1": param, param: 1} + + job_client: WithStateSync + # Load some complete load packages with state to the destination + p.run(some_data("state1"), loader_file_format=destination_config.file_format) + p.run(some_data("state2"), loader_file_format=destination_config.file_format) + p.run(some_data("state3"), loader_file_format=destination_config.file_format) + + with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] + state = load_pipeline_state_from_destination(pipeline_name, job_client) + assert state and state["_state_version"] == 3 + # Simulate a load package that stores state but is not completed (no entry in loads table) def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = False): # complete in local storage but skip call to the database self.load_storage.complete_load_package(load_id, aborted) @@ -463,11 +475,18 @@ def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = Fa p.run(some_data("fix_1"), loader_file_format=destination_config.file_format) # assert complete_package.called - job_client: SqlJobClientBase with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] # state without completed load id is not visible state = load_pipeline_state_from_destination(pipeline_name, job_client) - assert state is None + # Restored state version has not changed + assert state and state["_state_version"] == 3 + + newest_schema_hash = p.default_schema.version_hash + p._wipe_working_folder() + p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p.sync_destination() + + assert p.default_schema.version_hash == newest_schema_hash @pytest.mark.parametrize( diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index b33c5a2590..8b41c354b2 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -21,7 +21,12 @@ destinations_configs, DestinationTestConfiguration, ) -from tests.pipeline.utils import load_tables_to_dicts, assert_load_info, load_table_counts +from tests.pipeline.utils import ( + load_tables_to_dicts, + assert_load_info, + load_table_counts, + assert_records_as_set, +) from tests.utils import TPythonTableFormat @@ -64,13 +69,6 @@ def strip_timezone(ts: datetime) -> datetime: ) -def assert_records_as_set(actual: List[Dict[str, Any]], expected: List[Dict[str, Any]]) -> None: - """Compares two lists of dicts regardless of order""" - actual_set = set(frozenset(dict_.items()) for dict_ in actual) - expected_set = set(frozenset(dict_.items()) for dict_ in expected) - assert actual_set == expected_set - - @pytest.mark.essential @pytest.mark.parametrize( "destination_config,simple,validity_column_names,active_record_timestamp", @@ -103,6 +101,10 @@ def test_core_functionality( validity_column_names: List[str], active_record_timestamp: Optional[pendulum.DateTime], ) -> None: + # somehow destination_config comes through as ParameterSet instead of + # DestinationTestConfiguration + destination_config = destination_config.values[0] # type: ignore[attr-defined] + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( diff --git a/tests/load/pipeline/test_snowflake_pipeline.py b/tests/load/pipeline/test_snowflake_pipeline.py index 3cfa9e8b21..0203a39147 100644 --- a/tests/load/pipeline/test_snowflake_pipeline.py +++ b/tests/load/pipeline/test_snowflake_pipeline.py @@ -1,10 +1,13 @@ +import os import pytest +from pytest_mock import MockerFixture import dlt -from dlt.common import Decimal from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DatabaseUndefinedRelation + +from tests.load.snowflake.test_snowflake_client import QUERY_TAG from tests.pipeline.utils import assert_load_info from tests.load.utils import destinations_configs, DestinationTestConfiguration @@ -18,9 +21,13 @@ ids=lambda x: x.name, ) def test_snowflake_case_sensitive_identifiers( - destination_config: DestinationTestConfiguration, + destination_config: DestinationTestConfiguration, mocker: MockerFixture ) -> None: + from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient + snow_ = dlt.destinations.snowflake(naming_convention="sql_cs_v1") + # we make sure that session was not tagged (lack of query tag in config) + tag_query_spy = mocker.spy(SnowflakeSqlClient, "_tag_session") dataset_name = "CaseSensitive_Dataset_" + uniq_id() pipeline = destination_config.setup_pipeline( @@ -36,6 +43,7 @@ def test_snowflake_case_sensitive_identifiers( # load some case sensitive data info = pipeline.run([{"Id": 1, "Capital": 0.0}], table_name="Expenses") assert_load_info(info) + tag_query_spy.assert_not_called() with pipeline.sql_client() as client: assert client.has_dataset() # use the same case sensitive dataset @@ -53,3 +61,21 @@ def test_snowflake_case_sensitive_identifiers( print(rows) with pytest.raises(DatabaseUndefinedRelation): client.execute_sql('SELECT "Id", "Capital" FROM Expenses') + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) +def test_snowflake_query_tagging( + destination_config: DestinationTestConfiguration, mocker: MockerFixture +): + from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient + + os.environ["DESTINATION__SNOWFLAKE__QUERY_TAG"] = QUERY_TAG + tag_query_spy = mocker.spy(SnowflakeSqlClient, "_tag_session") + pipeline = destination_config.setup_pipeline("test_snowflake_case_sensitive_identifiers") + info = pipeline.run([1, 2, 3], table_name="digits") + assert_load_info(info) + assert tag_query_spy.call_count == 2 diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index a33ecd2a8d..73f53221ed 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -10,7 +10,7 @@ from dlt.destinations.adapters import qdrant_adapter from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter, VECTORIZE_HINT -from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient +from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient from tests.pipeline.utils import assert_load_info from tests.load.qdrant.utils import drop_active_pipeline_data, assert_collection from tests.load.utils import sequence_generator diff --git a/tests/load/qdrant/test_restore_state.py b/tests/load/qdrant/test_restore_state.py new file mode 100644 index 0000000000..31bc725d24 --- /dev/null +++ b/tests/load/qdrant/test_restore_state.py @@ -0,0 +1,70 @@ +from typing import TYPE_CHECKING +import pytest +from qdrant_client import models + +import dlt +from tests.load.utils import destinations_configs, DestinationTestConfiguration + +from dlt.common.destination.reference import JobClientBase, WithStateSync +from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_vector_configs=True, subset=["qdrant"]), + ids=lambda x: x.name, +) +def test_uncommitted_state(destination_config: DestinationTestConfiguration): + """Load uncommitted state into qdrant, meaning that data is written to the state + table but load is not completed (nothing is added to loads table) + + Ensure that state restoration does not include such state + """ + # Type hint of JobClientBase with WithStateSync mixin + + pipeline = destination_config.setup_pipeline("uncommitted_state", dev_mode=True) + + state_val = 0 + + @dlt.resource + def dummy_table(): + dlt.current.resource_state("dummy_table")["val"] = state_val + yield [1, 2, 3] + + # Create > 10 load packages to be above pagination size when restoring state + for _ in range(12): + state_val += 1 + pipeline.extract(dummy_table) + + pipeline.normalize() + info = pipeline.load(raise_on_failed_jobs=True) + + client: QdrantClient + with pipeline.destination_client() as client: # type: ignore[assignment] + state = client.get_stored_state(pipeline.pipeline_name) + + assert state and state.version == state_val + + # Delete last 10 _dlt_loads entries so pagination is triggered when restoring state + with pipeline.destination_client() as client: # type: ignore[assignment] + table_name = client._make_qualified_collection_name( + pipeline.default_schema.loads_table_name + ) + p_load_id = pipeline.default_schema.naming.normalize_identifier("load_id") + + client.db_client.delete( + table_name, + points_selector=models.Filter( + must=[ + models.FieldCondition( + key=p_load_id, match=models.MatchAny(any=info.loads_ids[2:]) + ) + ] + ), + ) + + with pipeline.destination_client() as client: # type: ignore[assignment] + state = client.get_stored_state(pipeline.pipeline_name) + + # Latest committed state is restored + assert state and state.version == 2 diff --git a/tests/load/qdrant/utils.py b/tests/load/qdrant/utils.py index 3b12d15f86..e96e06be87 100644 --- a/tests/load/qdrant/utils.py +++ b/tests/load/qdrant/utils.py @@ -5,7 +5,7 @@ from dlt.common.pipeline import PipelineContext from dlt.common.configuration.container import Container -from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient +from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient def assert_unordered_list_equal(list1: List[Any], list2: List[Any]) -> None: diff --git a/tests/load/snowflake/test_snowflake_client.py b/tests/load/snowflake/test_snowflake_client.py new file mode 100644 index 0000000000..aebf514b56 --- /dev/null +++ b/tests/load/snowflake/test_snowflake_client.py @@ -0,0 +1,61 @@ +import os +from typing import Iterator +from pytest_mock import MockerFixture +import pytest + +from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.destinations.job_client_impl import SqlJobClientBase + +from dlt.destinations.sql_client import TJobQueryTags + +from tests.load.utils import yield_client_with_storage + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + +QUERY_TAG = ( + '{{"source":"{source}", "resource":"{resource}", "table": "{table}", "load_id":"{load_id}",' + ' "pipeline_name":"{pipeline_name}"}}' +) +QUERY_TAGS_DICT: TJobQueryTags = { + "source": "test_source", + "resource": "test_resource", + "table": "test_table", + "load_id": "1109291083091", + "pipeline_name": "test_pipeline", +} + + +@pytest.fixture(scope="function") +def client() -> Iterator[SqlJobClientBase]: + os.environ["QUERY_TAG"] = QUERY_TAG + yield from yield_client_with_storage("snowflake") + + +def test_query_tag(client: SnowflakeClient, mocker: MockerFixture): + assert client.config.query_tag == QUERY_TAG + # make sure we generate proper query + execute_sql_spy = mocker.spy(client.sql_client, "execute_sql") + # reset the query if tags are not set + client.sql_client.set_query_tags(None) + execute_sql_spy.assert_called_once_with(sql="ALTER SESSION UNSET QUERY_TAG") + execute_sql_spy.reset_mock() + client.sql_client.set_query_tags({}) # type: ignore[typeddict-item] + execute_sql_spy.assert_called_once_with(sql="ALTER SESSION UNSET QUERY_TAG") + execute_sql_spy.reset_mock() + # set query tags + client.sql_client.set_query_tags(QUERY_TAGS_DICT) + execute_sql_spy.assert_called_once_with( + sql=( + 'ALTER SESSION SET QUERY_TAG = \'{"source":"test_source", "resource":"test_resource",' + ' "table": "test_table", "load_id":"1109291083091", "pipeline_name":"test_pipeline"}\'' + ) + ) + # remove query tag from config + client.sql_client.query_tag = None + execute_sql_spy.reset_mock() + client.sql_client.set_query_tags(QUERY_TAGS_DICT) + execute_sql_spy.assert_not_called + execute_sql_spy.reset_mock() + client.sql_client.set_query_tags(None) + execute_sql_spy.assert_not_called diff --git a/tests/load/utils.py b/tests/load/utils.py index 95083b7d31..4b6c01c916 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -2,7 +2,7 @@ import contextlib import codecs import os -from typing import Any, Iterator, List, Sequence, IO, Tuple, Optional, Dict, Union, Generator +from typing import Any, Iterator, List, Sequence, IO, Tuple, Optional, Dict, Union, Generator, cast import shutil from pathlib import Path from urllib.parse import urlparse @@ -27,6 +27,7 @@ from dlt.common.data_writers import DataWriter from dlt.common.pipeline import PipelineContext from dlt.common.schema import TTableSchemaColumns, Schema +from dlt.common.schema.typing import TTableFormat from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration from dlt.common.schema.utils import new_table, normalize_table_identifiers from dlt.common.storages import ParsedLoadJobFileName, LoadStorage, PackageStorage @@ -120,6 +121,7 @@ class DestinationTestConfiguration: destination: str staging: Optional[TDestinationReferenceArg] = None file_format: Optional[TLoaderFileFormat] = None + table_format: Optional[TTableFormat] = None bucket_url: Optional[str] = None stage_name: Optional[str] = None staging_iam_role: Optional[str] = None @@ -131,12 +133,15 @@ class DestinationTestConfiguration: disable_compression: bool = False dev_mode: bool = False credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any]]] = None + env_vars: Optional[Dict[str, str]] = None @property def name(self) -> str: name: str = self.destination if self.file_format: name += f"-{self.file_format}" + if self.table_format: + name += f"-{self.table_format}" if not self.staging: name += "-no-staging" else: @@ -172,6 +177,10 @@ def setup(self) -> None: for key, value in dict(self.credentials).items(): os.environ[f"DESTINATION__CREDENTIALS__{key.upper()}"] = str(value) + if self.env_vars is not None: + for k, v in self.env_vars.items(): + os.environ[k] = v + def setup_pipeline( self, pipeline_name: str, dataset_name: str = None, dev_mode: bool = False, **kwargs ) -> dlt.Pipeline: @@ -202,9 +211,13 @@ def destinations_configs( all_staging_configs: bool = False, local_filesystem_configs: bool = False, all_buckets_filesystem_configs: bool = False, + table_format_filesystem_configs: bool = False, subset: Sequence[str] = (), + bucket_subset: Sequence[str] = (), exclude: Sequence[str] = (), + bucket_exclude: Sequence[str] = (), file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None, + table_format: Union[TTableFormat, Sequence[TTableFormat]] = None, supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, force_iceberg: Optional[bool] = None, @@ -460,17 +473,26 @@ def destinations_configs( if local_filesystem_configs: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", bucket_url=FILE_BUCKET, file_format="insert_values" + destination="filesystem", + bucket_url=FILE_BUCKET, + file_format="insert_values", + supports_merge=False, ) ] destination_configs += [ DestinationTestConfiguration( - destination="filesystem", bucket_url=FILE_BUCKET, file_format="parquet" + destination="filesystem", + bucket_url=FILE_BUCKET, + file_format="parquet", + supports_merge=False, ) ] destination_configs += [ DestinationTestConfiguration( - destination="filesystem", bucket_url=FILE_BUCKET, file_format="jsonl" + destination="filesystem", + bucket_url=FILE_BUCKET, + file_format="jsonl", + supports_merge=False, ) ] @@ -478,7 +500,31 @@ def destinations_configs( for bucket in DEFAULT_BUCKETS: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", bucket_url=bucket, extra_info=bucket + destination="filesystem", + bucket_url=bucket, + extra_info=bucket, + supports_merge=False, + ) + ] + + if table_format_filesystem_configs: + for bucket in DEFAULT_BUCKETS: + destination_configs += [ + DestinationTestConfiguration( + destination="filesystem", + bucket_url=bucket, + extra_info=bucket, + table_format="delta", + supports_merge=True, + env_vars=( + { + "DESTINATION__FILESYSTEM__DELTALAKE_STORAGE_OPTIONS": ( + '{"AWS_S3_ALLOW_UNSAFE_RENAME": "true"}' + ) + } + if bucket == AWS_BUCKET + else None + ), ) ] @@ -490,10 +536,22 @@ def destinations_configs( # filter out destinations not in subset if subset: destination_configs = [conf for conf in destination_configs if conf.destination in subset] + if bucket_subset: + destination_configs = [ + conf + for conf in destination_configs + if conf.destination != "filesystem" or conf.bucket_url in bucket_subset + ] if exclude: destination_configs = [ conf for conf in destination_configs if conf.destination not in exclude ] + if bucket_exclude: + destination_configs = [ + conf + for conf in destination_configs + if conf.destination != "filesystem" or conf.bucket_url not in bucket_exclude + ] if file_format: if not isinstance(file_format, Sequence): file_format = [file_format] @@ -502,6 +560,14 @@ def destinations_configs( for conf in destination_configs if conf.file_format and conf.file_format in file_format ] + if table_format: + if not isinstance(table_format, Sequence): + table_format = [table_format] + destination_configs = [ + conf + for conf in destination_configs + if conf.table_format and conf.table_format in table_format + ] if supports_merge is not None: destination_configs = [ conf for conf in destination_configs if conf.supports_merge == supports_merge @@ -521,6 +587,18 @@ def destinations_configs( conf for conf in destination_configs if conf.force_iceberg is force_iceberg ] + # add marks + destination_configs = [ + cast( + DestinationTestConfiguration, + pytest.param( + conf, + marks=pytest.mark.needspyarrow17 if conf.table_format == "delta" else [], + ), + ) + for conf in destination_configs + ] + return destination_configs diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index a267d3106d..7c7dac8e71 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -16,9 +16,6 @@ from dlt.common import json, pendulum from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ConfigFieldMissingException, InvalidNativeValue -from dlt.common.configuration.specs.aws_credentials import AwsCredentials -from dlt.common.configuration.specs.exceptions import NativeValueError -from dlt.common.configuration.specs.gcp_credentials import GcpOAuthCredentials from dlt.common.data_writers.exceptions import FileImportNotFound, SpecLookupFailed from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import WithStateSync @@ -90,6 +87,36 @@ def test_default_pipeline() -> None: assert p.default_schema_name in ["dlt_pytest", "dlt"] +def test_default_pipeline_dataset_layout(environment) -> None: + # Set dataset_name_layout to "bobby_%s" + dataset_name_layout = "bobby_%s" + environment["DATASET_NAME_LAYOUT"] = dataset_name_layout + + p = dlt.pipeline() + # this is a name of executing test harness or blank pipeline on windows + possible_names = ["dlt_pytest", "dlt_pipeline"] + possible_dataset_names = [ + dataset_name_layout % "dlt_pytest_dataset", + dataset_name_layout % "dlt_pipeline_dataset", + ] + assert p.pipeline_name in possible_names + assert p.pipelines_dir == os.path.abspath(os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines")) + assert p.runtime_config.pipeline_name == p.pipeline_name + # dataset that will be used to load data is the pipeline name + assert p.dataset_name in possible_dataset_names + assert p.destination is None + assert p.default_schema_name is None + + # this is the same pipeline + p2 = dlt.pipeline() + assert p is p2 + + # this will create default schema + p.extract(["a", "b", "c"], table_name="data") + # `_pipeline` is removed from default schema name + assert p.default_schema_name in ["dlt_pytest", "dlt"] + + def test_default_pipeline_dataset() -> None: # dummy does not need a dataset p = dlt.pipeline(destination="dummy") @@ -101,6 +128,40 @@ def test_default_pipeline_dataset() -> None: assert p.dataset_name in possible_dataset_names +def test_default_pipeline_dataset_name(environment) -> None: + environment["DATASET_NAME"] = "dataset" + environment["DATASET_NAME_LAYOUT"] = "prefix_%s" + + p = dlt.pipeline(destination="filesystem") + assert p.dataset_name == "prefix_dataset" + + +def test_default_pipeline_dataset_layout_exception(environment) -> None: + # Set dataset_name_layout without placeholder %s + environment["DATASET_NAME_LAYOUT"] = "bobby_" + + with pytest.raises(ValueError): + dlt.pipeline(destination="filesystem") + + +def test_default_pipeline_dataset_layout_placeholder(environment) -> None: + # Set dataset_name_layout only with placeholder + environment["DATASET_NAME_LAYOUT"] = "%s" + + possible_dataset_names = ["dlt_pytest_dataset", "dlt_pipeline_dataset"] + p = dlt.pipeline(destination="filesystem") + assert p.dataset_name in possible_dataset_names + + +def test_default_pipeline_dataset_layout_empty(environment) -> None: + # Set dataset_name_layout empty + environment["DATASET_NAME_LAYOUT"] = "" + + possible_dataset_names = ["dlt_pytest_dataset", "dlt_pipeline_dataset"] + p = dlt.pipeline(destination="filesystem") + assert p.dataset_name in possible_dataset_names + + def test_run_dev_mode_default_dataset() -> None: p = dlt.pipeline(dev_mode=True, destination="filesystem") assert p.dataset_name.endswith(p._pipeline_instance_id) @@ -119,6 +180,39 @@ def test_run_dev_mode_default_dataset() -> None: assert p.dataset_name and p.dataset_name.endswith(p._pipeline_instance_id) +def test_run_dev_mode_default_dataset_layout(environment) -> None: + # Set dataset_name_layout to "bobby_%s" + dataset_name_layout = "bobby_%s" + environment["DATASET_NAME_LAYOUT"] = dataset_name_layout + + p = dlt.pipeline(dev_mode=True, destination="filesystem") + assert p.dataset_name in [ + dataset_name_layout % f"dlt_pytest_dataset{p._pipeline_instance_id}", + dataset_name_layout % f"dlt_pipeline_dataset{p._pipeline_instance_id}", + ] + # restore this pipeline + r_p = dlt.attach(dev_mode=False) + assert r_p.dataset_name in [ + dataset_name_layout % f"dlt_pytest_dataset{p._pipeline_instance_id}", + dataset_name_layout % f"dlt_pipeline_dataset{p._pipeline_instance_id}", + ] + + # dummy does not need dataset + p = dlt.pipeline(dev_mode=True, destination="dummy") + assert p.dataset_name is None + + # simulate set new dataset + p._set_destinations("filesystem") + assert p.dataset_name is None + p._set_dataset_name(None) + + # full refresh is still observed + assert p.dataset_name in [ + dataset_name_layout % f"dlt_pytest_dataset{p._pipeline_instance_id}", + dataset_name_layout % f"dlt_pipeline_dataset{p._pipeline_instance_id}", + ] + + def test_run_dev_mode_underscored_dataset() -> None: p = dlt.pipeline(dev_mode=True, dataset_name="_main_") assert p.dataset_name.endswith(p._pipeline_instance_id) @@ -178,6 +272,16 @@ def test_invalid_dataset_name() -> None: assert p.dataset_name == "!" +def test_invalid_dataset_layout(environment) -> None: + # Set dataset_name_prefix to "bobby" + dataset_name_layout = "bobby_%s" + environment["DATASET_NAME_LAYOUT"] = dataset_name_layout + + # this is invalid dataset name but it will be normalized within a destination + p = dlt.pipeline(dataset_name="!") + assert p.dataset_name == dataset_name_layout % "!" + + def test_pipeline_context_deferred_activation() -> None: ctx = Container()[PipelineContext] assert ctx.is_active() is False @@ -2468,6 +2572,56 @@ def test_static_staging_dataset() -> None: assert_data_table_counts(pipeline_2, {"letters": 4}) +def test_underscore_tables_and_columns() -> None: + pipeline = dlt.pipeline("test_underscore_tables_and_columns", destination="duckdb") + + @dlt.resource + def ids(_id=dlt.sources.incremental("_id", initial_value=2)): + yield from [{"_id": i, "value": l} for i, l in zip([1, 2, 3], ["A", "B", "C"])] + + info = pipeline.run(ids, table_name="_ids") + assert_load_info(info) + print(pipeline.default_schema.to_pretty_yaml()) + assert pipeline.last_trace.last_normalize_info.row_counts["_ids"] == 2 + + +def test_access_pipeline_in_resource() -> None: + pipeline = dlt.pipeline("test_access_pipeline_in_resource", destination="duckdb") + + @dlt.resource(name="user_comments") + def comments(user_id: str): + current_pipeline = dlt.current.pipeline() + # find last comment id for given user_id by looking in destination + max_id: int = 0 + # on first pipeline run, user_comments table does not yet exist so do not check at all + # alternatively catch DatabaseUndefinedRelation which is raised when unknown table is selected + if not current_pipeline.first_run: + with current_pipeline.sql_client() as client: + # we may get last user comment or None which we replace with 0 + max_id = ( + client.execute_sql( + "SELECT MAX(_id) FROM user_comments WHERE user_id=?", user_id + )[0][0] + or 0 + ) + # use max_id to filter our results + yield from [ + {"_id": i, "value": l, "user_id": user_id} + for i, l in zip([1, 2, 3], ["A", "B", "C"]) + if i > max_id + ] + + info = pipeline.run(comments("USER_A")) + assert_load_info(info) + assert pipeline.last_trace.last_normalize_info.row_counts["user_comments"] == 3 + info = pipeline.run(comments("USER_A")) + # no more data for USER_A + assert_load_info(info, 0) + info = pipeline.run(comments("USER_B")) + assert_load_info(info) + assert pipeline.last_trace.last_normalize_info.row_counts["user_comments"] == 3 + + def assert_imported_file( pipeline: Pipeline, table_name: str, diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 308cdcd91d..d3e44198b4 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -3,6 +3,8 @@ from typing import Any, ClassVar, Dict, Iterator, List, Optional import pytest +from dlt.pipeline.exceptions import PipelineStepFailed + try: from pydantic import BaseModel from dlt.common.libs.pydantic import DltConfig @@ -237,6 +239,35 @@ def with_mark(): assert table["columns"]["name"]["data_type"] == "text" +def test_dump_trace_freeze_exception() -> None: + class TestRow(BaseModel): + id_: int + example_string: str + + # yield model in resource so incremental fails when looking for "id" + # TODO: support pydantic models in incremental + + @dlt.resource(name="table_name", primary_key="id", write_disposition="replace") + def generate_rows_incremental( + ts: dlt.sources.incremental[int] = dlt.sources.incremental(cursor_path="id"), + ): + for i in range(10): + yield TestRow(id_=i, example_string="abc") + if ts.end_out_of_range: + return + + pipeline = dlt.pipeline(pipeline_name="test_dump_trace_freeze_exception", destination="duckdb") + + with pytest.raises(PipelineStepFailed): + # must raise because incremental failed + pipeline.run(generate_rows_incremental()) + + # force to reload trace from storage + pipeline._last_trace = None + # trace file not present because we tried to pickle TestRow which is a local object + assert pipeline.last_trace is None + + @pytest.mark.parametrize("file_format", ("parquet", "insert_values", "jsonl")) def test_columns_hint_with_file_formats(file_format: TLoaderFileFormat) -> None: @dlt.resource(write_disposition="replace", columns=[{"name": "text", "data_type": "text"}]) @@ -437,28 +468,20 @@ def pandas_incremental(numbers=dlt.sources.incremental("Numbers")): def test_empty_parquet(test_storage: FileStorage) -> None: from dlt.destinations import filesystem + from tests.pipeline.utils import users_materialize_table_schema local = filesystem(os.path.abspath(TEST_STORAGE_ROOT)) # we have two options to materialize columns: add columns hint or use dlt.mark to emit schema # at runtime. below we use the second option - @dlt.resource - def users(): - yield dlt.mark.with_hints( - # this is a special empty item which will materialize table schema - dlt.mark.materialize_table_schema(), - # emit table schema with the item - dlt.mark.make_hints( - columns=[ - {"name": "id", "data_type": "bigint", "precision": 4, "nullable": False}, - {"name": "name", "data_type": "text", "nullable": False}, - ] - ), - ) - # write parquet file to storage - info = dlt.run(users, destination=local, loader_file_format="parquet", dataset_name="user_data") + info = dlt.run( + users_materialize_table_schema, + destination=local, + loader_file_format="parquet", + dataset_name="user_data", + ) assert_load_info(info) assert set(info.pipeline.default_schema.tables["users"]["columns"].keys()) == {"id", "name", "_dlt_load_id", "_dlt_id"} # type: ignore # find parquet file diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index 609897f161..bdb3e3eb22 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -1,3 +1,4 @@ +from copy import deepcopy import io import os import asyncio @@ -267,7 +268,12 @@ def test_save_load_trace() -> None: loaded_trace = load_trace(pipeline.working_dir) print(loaded_trace.asstr(2)) assert len(trace.steps) == 4 - assert loaded_trace.asdict() == trace.asdict() + loaded_trace_dict = deepcopy(loaded_trace.asdict()) + trace_dict = deepcopy(trace.asdict()) + assert loaded_trace_dict == trace_dict + # do it again to check if we are not popping + assert loaded_trace_dict == loaded_trace.asdict() + assert trace_dict == trace.asdict() # exception also saves trace @dlt.resource @@ -294,6 +300,29 @@ def data(): assert pipeline.last_trace.last_normalize_info is None +def test_save_load_empty_trace() -> None: + os.environ["COMPLETED_PROB"] = "1.0" + os.environ["RESTORE_FROM_DESTINATION"] = "false" + pipeline = dlt.pipeline() + pipeline.run([], table_name="data", destination="dummy") + trace = pipeline.last_trace + assert_trace_printable(trace) + assert len(trace.steps) == 4 + + pipeline.activate() + + # load trace and check if all elements are present + loaded_trace = load_trace(pipeline.working_dir) + print(loaded_trace.asstr(2)) + assert len(trace.steps) == 4 + loaded_trace_dict = deepcopy(loaded_trace.asdict()) + trace_dict = deepcopy(trace.asdict()) + assert loaded_trace_dict == trace_dict + # do it again to check if we are not popping + assert loaded_trace_dict == loaded_trace.asdict() + assert trace_dict == trace.asdict() + + def test_disable_trace(environment: DictStrStr) -> None: environment["ENABLE_RUNTIME_TRACE"] = "false" environment["COMPLETED_PROB"] = "1.0" @@ -499,7 +528,9 @@ def assert_trace_printable(trace: PipelineTrace) -> None: str(trace) trace.asstr(0) trace.asstr(1) - trace.asdict() + trace_dict = deepcopy(trace.asdict()) + # check if we do not pop + assert trace_dict == trace.asdict() with io.BytesIO() as b: json.typed_dump(trace, b, pretty=True) b.getvalue() diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index f2e0058891..bd62f76dc1 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -76,6 +76,21 @@ def many_delayed(many, iters): yield dlt.resource(run_deferred(iters), name="resource_" + str(n)) +@dlt.resource(table_name="users") +def users_materialize_table_schema(): + yield dlt.mark.with_hints( + # this is a special empty item which will materialize table schema + dlt.mark.materialize_table_schema(), + # emit table schema with the item + dlt.mark.make_hints( + columns=[ + {"name": "id", "data_type": "bigint", "precision": 4, "nullable": False}, + {"name": "name", "data_type": "text", "nullable": False}, + ] + ), + ) + + # # Utils for accessing data in pipelines # @@ -240,6 +255,13 @@ def _sort_list_of_dicts(list_: List[Dict[str, Any]], sortkey: str) -> List[Dict[ return result +def assert_records_as_set(actual: List[Dict[str, Any]], expected: List[Dict[str, Any]]) -> None: + """Compares two lists of dicts regardless of order""" + actual_set = set(frozenset(dict_.items()) for dict_ in actual) + expected_set = set(frozenset(dict_.items()) for dict_ in expected) + assert actual_set == expected_set + + def assert_only_table_columns( p: dlt.Pipeline, table_name: str, expected_columns: Sequence[str], schema_name: str = None ) -> None: diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index ed227cd3cd..f5de1ec5da 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,7 +1,7 @@ import os from base64 import b64encode from typing import Any, Dict, cast -from unittest.mock import patch +from unittest.mock import patch, ANY import pytest from requests import PreparedRequest, Request, Response @@ -22,7 +22,7 @@ ) from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator, BaseReferencePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator, BaseReferencePaginator from .conftest import DEFAULT_PAGE_SIZE, DEFAULT_TOTAL_PAGES, assert_pagination @@ -82,7 +82,7 @@ def test_get_single_resource(self, rest_client): def test_pagination(self, rest_client: RESTClient): pages_iter = rest_client.paginate( "/posts", - paginator=JSONResponsePaginator(next_url_path="next_page"), + paginator=JSONLinkPaginator(next_url_path="next_page"), ) pages = list(pages_iter) @@ -92,7 +92,7 @@ def test_pagination(self, rest_client: RESTClient): def test_page_context(self, rest_client: RESTClient) -> None: for page in rest_client.paginate( "/posts", - paginator=JSONResponsePaginator(next_url_path="next_page"), + paginator=JSONLinkPaginator(next_url_path="next_page"), ): # response that produced data assert isinstance(page.response, Response) @@ -100,7 +100,7 @@ def test_page_context(self, rest_client: RESTClient) -> None: assert isinstance(page.request, Request) # make request url should be same as next link in paginator if page.paginator.has_next_page: - paginator = cast(JSONResponsePaginator, page.paginator) + paginator = cast(JSONLinkPaginator, page.paginator) assert paginator._next_reference == page.request.url def test_default_paginator(self, rest_client: RESTClient): @@ -112,7 +112,7 @@ def test_default_paginator(self, rest_client: RESTClient): def test_excplicit_paginator(self, rest_client: RESTClient): pages_iter = rest_client.paginate( - "/posts", paginator=JSONResponsePaginator(next_url_path="next_page") + "/posts", paginator=JSONLinkPaginator(next_url_path="next_page") ) pages = list(pages_iter) @@ -121,7 +121,7 @@ def test_excplicit_paginator(self, rest_client: RESTClient): def test_excplicit_paginator_relative_next_url(self, rest_client: RESTClient): pages_iter = rest_client.paginate( "/posts_relative_next_url", - paginator=JSONResponsePaginator(next_url_path="next_page"), + paginator=JSONLinkPaginator(next_url_path="next_page"), ) pages = list(pages_iter) @@ -138,7 +138,7 @@ def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: pages_iter = rest_client.paginate( "/posts", - paginator=JSONResponsePaginator(next_url_path="next_page"), + paginator=JSONLinkPaginator(next_url_path="next_page"), hooks=hooks, ) @@ -148,7 +148,7 @@ def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: pages_iter = rest_client.paginate( "/posts/1/some_details_404", - paginator=JSONResponsePaginator(), + paginator=JSONLinkPaginator(), hooks=hooks, ) @@ -430,3 +430,92 @@ def test_post_json_body_without_params(self, rest_client) -> None: assert len(returned_posts) == DEFAULT_PAGE_SIZE # only one page is returned for i in range(DEFAULT_PAGE_SIZE): assert returned_posts[i] == {"id": posts_skip + i, "title": f"Post {posts_skip + i}"} + + def test_configurable_timeout(self, mocker) -> None: + cfg = { + "RUNTIME__REQUEST_TIMEOUT": 42, + } + os.environ.update({key: str(value) for key, value in cfg.items()}) + + rest_client = RESTClient( + base_url="https://api.example.com", + session=Client().session, + ) + + import requests + + original_send = requests.Session.send + requests.Session.send = mocker.Mock() # type: ignore[method-assign] + rest_client.get("/posts/1") + assert requests.Session.send.call_args[1] == { # type: ignore[attr-defined] + "timeout": 42, + "proxies": ANY, + "stream": ANY, + "verify": ANY, + "cert": ANY, + } + # restore, otherwise side-effect on subsequent tests + requests.Session.send = original_send # type: ignore[method-assign] + + def test_request_kwargs(self, mocker) -> None: + def send_spy(*args, **kwargs): + return original_send(*args, **kwargs) + + rest_client = RESTClient( + base_url="https://api.example.com", + session=Client().session, + ) + original_send = rest_client.session.send + mocked_send = mocker.patch.object(rest_client.session, "send", side_effect=send_spy) + + rest_client.get( + path="/posts/1", + proxies={ + "http": "http://10.10.1.10:1111", + "https": "http://10.10.1.10:2222", + }, + stream=True, + verify=False, + cert=("/path/client.cert", "/path/client.key"), + timeout=321, + allow_redirects=False, + ) + + assert mocked_send.call_args[1] == { + "proxies": { + "http": "http://10.10.1.10:1111", + "https": "http://10.10.1.10:2222", + }, + "stream": True, + "verify": False, + "cert": ("/path/client.cert", "/path/client.key"), + "timeout": 321, + "allow_redirects": False, + } + + next( + rest_client.paginate( + path="posts", + proxies={ + "http": "http://10.10.1.10:1234", + "https": "http://10.10.1.10:4321", + }, + stream=True, + verify=False, + cert=("/path/client_2.cert", "/path/client_2.key"), + timeout=432, + allow_redirects=False, + ) + ) + + assert mocked_send.call_args[1] == { + "proxies": { + "http": "http://10.10.1.10:1234", + "https": "http://10.10.1.10:4321", + }, + "stream": True, + "verify": False, + "cert": ("/path/client_2.cert", "/path/client_2.key"), + "timeout": 432, + "allow_redirects": False, + } diff --git a/tests/sources/helpers/rest_client/test_detector.py b/tests/sources/helpers/rest_client/test_detector.py index 6511b472fb..93efe34662 100644 --- a/tests/sources/helpers/rest_client/test_detector.py +++ b/tests/sources/helpers/rest_client/test_detector.py @@ -11,7 +11,7 @@ from dlt.sources.helpers.rest_client.paginators import ( OffsetPaginator, PageNumberPaginator, - JSONResponsePaginator, + JSONLinkPaginator, HeaderLinkPaginator, SinglePagePaginator, JSONResponseCursorPaginator, @@ -106,7 +106,7 @@ "results": [{"id": 1, "name": "Account 1"}, {"id": 2, "name": "Account 2"}], }, "expected": { - "type": JSONResponsePaginator, + "type": JSONLinkPaginator, "records_path": "results", "next_path": ("next",), }, @@ -123,7 +123,7 @@ "page": {"size": 2, "totalElements": 100, "totalPages": 50, "number": 1}, }, "expected": { - "type": JSONResponsePaginator, + "type": JSONLinkPaginator, "records_path": "_embedded.items", "next_path": ("_links", "next", "href"), }, @@ -145,7 +145,7 @@ }, }, "expected": { - "type": JSONResponsePaginator, + "type": JSONLinkPaginator, "records_path": "items", "next_path": ("links", "nextPage"), }, @@ -197,7 +197,7 @@ }, }, "expected": { - "type": JSONResponsePaginator, + "type": JSONLinkPaginator, "records_path": "data", "next_path": ("links", "next"), }, @@ -395,7 +395,7 @@ def test_find_paginator(test_case) -> None: assert type(paginator) is expected_paginator if isinstance(paginator, PageNumberPaginator): assert str(paginator.total_path) == ".".join(test_case["expected"]["total_path"]) - if isinstance(paginator, JSONResponsePaginator): + if isinstance(paginator, JSONLinkPaginator): assert str(paginator.next_url_path) == ".".join(test_case["expected"]["next_path"]) if isinstance(paginator, JSONResponseCursorPaginator): assert str(paginator.cursor_path) == ".".join(test_case["expected"]["next_path"]) diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index a5f9d888a2..8a3c136e09 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -10,7 +10,7 @@ OffsetPaginator, PageNumberPaginator, HeaderLinkPaginator, - JSONResponsePaginator, + JSONLinkPaginator, JSONResponseCursorPaginator, ) @@ -46,7 +46,7 @@ def test_client_pagination(self, rest_client): @pytest.mark.usefixtures("mock_api_server") -class TestJSONResponsePaginator: +class TestJSONLinkPaginator: @pytest.mark.parametrize( "test_case", [ @@ -98,9 +98,9 @@ def test_update_state(self, test_case): next_url_path = test_case["next_url_path"] if next_url_path is None: - paginator = JSONResponsePaginator() + paginator = JSONLinkPaginator() else: - paginator = JSONResponsePaginator(next_url_path=next_url_path) + paginator = JSONLinkPaginator(next_url_path=next_url_path) response = Mock(Response, json=lambda: test_case["response_json"]) paginator.update_state(response) assert paginator._next_reference == test_case["expected"]["next_reference"] @@ -167,14 +167,14 @@ def test_update_state(self, test_case): ], ) def test_update_request(self, test_case): - paginator = JSONResponsePaginator() + paginator = JSONLinkPaginator() paginator._next_reference = test_case["next_reference"] request = Mock(Request, url=test_case["request_url"]) paginator.update_request(request) assert request.url == test_case["expected"] def test_no_duplicate_params_on_update_request(self): - paginator = JSONResponsePaginator() + paginator = JSONLinkPaginator() request = Request( method="GET", @@ -200,7 +200,7 @@ def test_no_duplicate_params_on_update_request(self): def test_client_pagination(self, rest_client): pages_iter = rest_client.paginate( "/posts", - paginator=JSONResponsePaginator( + paginator=JSONLinkPaginator( next_url_path="next_page", ), ) diff --git a/tests/sources/helpers/rest_client/test_requests_paginate.py b/tests/sources/helpers/rest_client/test_requests_paginate.py index 43b2a412db..b40f857553 100644 --- a/tests/sources/helpers/rest_client/test_requests_paginate.py +++ b/tests/sources/helpers/rest_client/test_requests_paginate.py @@ -1,7 +1,7 @@ import pytest from dlt.sources.helpers.rest_client import paginate -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator from .conftest import assert_pagination @@ -9,7 +9,7 @@ def test_requests_paginate(): pages_iter = paginate( "https://api.example.com/posts", - paginator=JSONResponsePaginator(next_url_path="next_page"), + paginator=JSONLinkPaginator(next_url_path="next_page"), ) pages = list(pages_iter)