From fb9dabf9e55c77ca28878e9a0f6bcbfcc80d4103 Mon Sep 17 00:00:00 2001 From: David Scharf Date: Sat, 15 Jul 2023 21:45:12 +0200 Subject: [PATCH] 375 filesystem storage staging (#451) * add creation of reference followup jobs * copy job * add parquet format * make staging destination none by default * add automatic resolving of correct file format * renaming staging destination to staging * refactor redshift job and inject fs config info * small cleanup and full parquet file loading test * add redshift test * fix resuming existing jobs * linter fixes and something else i forgot * move reference job follow up creation into job * add bigquery staging with gcs * add jsonl loading for bigquery staging * better supported file format resolution * move to existing bigquery load job * change pipeline args order * add staging run arg * some more pipeline fixes * configure staging via config * enhance staging load tests * fix merge disposition on redshift * add comprehensive staging tests * fix redshift jsonl loading * add doc page (not in hierarchy for now) * move redshift credentials testing to redshift loading location * change rotation test * change timing test * implement snowflake file staging * switch to staging instead of integration * add s3 stage (which does not currently work) * filter out certain combinations for tests * update docs for snowflake staging * forward staging config to supported destination configuration types * move boto out of staging credentials * authentication to s3 with iam role for redshift * verify support for snowflake s3 stage and update docs for this * adds named put stage to snowflake, improves exception messages, fixes and re-enables staging tests --------- Co-authored-by: Marcin Rudolf --- dlt/common/configuration/specs/__init__.py | 2 +- .../configuration/specs/aws_credentials.py | 47 +++--- .../specs/config_providers_context.py | 4 +- dlt/common/data_writers/writers.py | 4 +- dlt/common/destination/capabilities.py | 19 ++- dlt/common/destination/reference.py | 27 +++- dlt/common/exceptions.py | 35 ++++- dlt/common/libs/pyarrow.py | 22 ++- dlt/common/pipeline.py | 11 +- dlt/common/schema/utils.py | 5 + dlt/destinations/bigquery/__init__.py | 5 + dlt/destinations/bigquery/bigquery.py | 58 ++++++-- dlt/destinations/bigquery/sql_client.py | 3 +- dlt/destinations/duckdb/__init__.py | 5 + dlt/destinations/duckdb/duck.py | 24 ++- dlt/destinations/dummy/__init__.py | 3 +- dlt/destinations/exceptions.py | 4 +- dlt/destinations/filesystem/configuration.py | 7 +- dlt/destinations/filesystem/filesystem.py | 30 +++- dlt/destinations/insert_job_client.py | 4 +- dlt/destinations/job_client_impl.py | 31 +++- dlt/destinations/job_impl.py | 22 ++- dlt/destinations/motherduck/__init__.py | 3 + dlt/destinations/postgres/__init__.py | 6 + dlt/destinations/postgres/postgres.py | 17 ++- dlt/destinations/redshift/__init__.py | 5 + dlt/destinations/redshift/configuration.py | 3 +- dlt/destinations/redshift/redshift.py | 92 ++++++++++-- dlt/destinations/snowflake/__init__.py | 5 + dlt/destinations/snowflake/configuration.py | 2 +- dlt/destinations/snowflake/snowflake.py | 100 ++++++++----- dlt/destinations/snowflake/sql_client.py | 7 +- dlt/load/load.py | 75 +++++++--- dlt/pipeline/__init__.py | 12 +- dlt/pipeline/configuration.py | 2 +- dlt/pipeline/pipeline.py | 112 +++++++++++--- .../dlt-ecosystem/destinations/bigquery.md | 30 +++- .../dlt-ecosystem/destinations/redshift.md | 42 ++++++ .../dlt-ecosystem/destinations/snowflake.md | 76 +++++++++- docs/website/docs/dlt-ecosystem/staging.md | 13 ++ docs/website/sidebars.js | 3 +- tests/common/utils.py | 2 + tests/load/bigquery/test_bigquery_parquet.py | 38 ----- tests/load/conftest.py | 5 +- tests/load/pipeline/conftest.py | 3 + tests/load/pipeline/test_dbt_helper.py | 5 +- tests/load/pipeline/test_drop.py | 2 - .../load/pipeline/test_filesystem_pipeline.py | 3 +- tests/load/pipeline/test_merge_disposition.py | 7 +- tests/load/pipeline/test_pipelines.py | 79 ++++++++-- tests/load/pipeline/test_restore_state.py | 40 +++-- tests/load/pipeline/test_stage_loading.py | 137 ++++++++++++++++++ tests/load/pipeline/utils.py | 31 +++- tests/load/test_job_client.py | 23 +-- tests/load/utils.py | 122 ++++++++++++++-- tests/pipeline/conftest.py | 2 + tests/pipeline/test_pipeline.py | 8 +- .../test_pipeline_file_format_resolver.py | 62 ++++++++ tests/pipeline/test_pipeline_state.py | 22 ++- tests/pipeline/test_pipeline_trace.py | 3 +- tests/utils.py | 5 +- 61 files changed, 1259 insertions(+), 317 deletions(-) create mode 100644 docs/website/docs/dlt-ecosystem/staging.md delete mode 100644 tests/load/bigquery/test_bigquery_parquet.py create mode 100644 tests/load/pipeline/conftest.py create mode 100644 tests/load/pipeline/test_stage_loading.py create mode 100644 tests/pipeline/conftest.py create mode 100644 tests/pipeline/test_pipeline_file_format_resolver.py diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py index 01e8a67751..675b0a0bec 100644 --- a/dlt/common/configuration/specs/__init__.py +++ b/dlt/common/configuration/specs/__init__.py @@ -5,7 +5,7 @@ from .gcp_credentials import GcpServiceAccountCredentialsWithoutDefaults, GcpServiceAccountCredentials, GcpOAuthCredentialsWithoutDefaults, GcpOAuthCredentials, GcpCredentials # noqa: F401 from .connection_string_credentials import ConnectionStringCredentials # noqa: F401 from .api_credentials import OAuth2Credentials # noqa: F401 -from .aws_credentials import AwsCredentials # noqa: F401 +from .aws_credentials import AwsCredentials, AwsCredentialsWithoutDefaults # noqa: F401 # backward compatibility for service account credentials diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index 8e4e510f1f..62c14a9558 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -1,22 +1,38 @@ -from typing import Optional, TYPE_CHECKING, Dict +from typing import Optional, TYPE_CHECKING, Dict, Any from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs import CredentialsConfiguration, CredentialsWithDefault, configspec from dlt import version -if TYPE_CHECKING: - from botocore.credentials import Credentials - from boto3 import Session - @configspec -class AwsCredentials(CredentialsConfiguration, CredentialsWithDefault): +class AwsCredentialsWithoutDefaults(CredentialsConfiguration): + # credentials without boto implementation aws_access_key_id: str = None aws_secret_access_key: TSecretStrValue = None aws_session_token: Optional[TSecretStrValue] = None aws_profile: Optional[str] = None + def to_s3fs_credentials(self) -> Dict[str, Optional[str]]: + """Dict of keyword arguments that can be passed to s3fs""" + return dict( + key=self.aws_access_key_id, + secret=self.aws_secret_access_key, + token=self.aws_session_token, + profile=self.aws_profile + ) + + def to_native_representation(self) -> Dict[str, Optional[str]]: + """Return a dict that can be passed as kwargs to boto3 session""" + d = dict(self) + d['profile_name'] = d.pop('aws_profile') # boto3 argument doesn't match env var name + return d + + +@configspec +class AwsCredentials(AwsCredentialsWithoutDefaults, CredentialsWithDefault): + def on_partial(self) -> None: # Try get default credentials session = self._to_session() @@ -30,27 +46,12 @@ def on_partial(self) -> None: if not self.is_partial(): self.resolve() - def _to_session(self) -> "Session": + def _to_session(self) -> Any: try: import boto3 except ImportError: raise MissingDependencyException(self.__class__.__name__, [f"{version.DLT_PKG_NAME}[s3]"]) return boto3.Session(**self.to_native_representation()) - def to_native_credentials(self) -> Optional["Credentials"]: + def to_native_credentials(self) -> Optional[Any]: return self._to_session().get_credentials() - - def to_s3fs_credentials(self) -> Dict[str, Optional[str]]: - """Dict of keyword arguments that can be passed to s3fs""" - return dict( - key=self.aws_access_key_id, - secret=self.aws_secret_access_key, - token=self.aws_session_token, - profile=self.aws_profile - ) - - def to_native_representation(self) -> Dict[str, Optional[str]]: - """Return a dict that can be passed as kwargs to boto3 session""" - d = dict(self) - d['profile_name'] = d.pop('aws_profile') # boto3 argument doesn't match env var name - return d diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 4d60e2e3c7..b575ea9756 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -1,8 +1,6 @@ - - import contextlib import io -from typing import Any, List +from typing import List from dlt.common.configuration.exceptions import DuplicateConfigProviderException from dlt.common.configuration.providers import ConfigProvider, EnvironProvider, ContextProvider, SecretsTomlProvider, ConfigTomlProvider, GoogleSecretsProvider from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 13259f57bc..9cb01f1eb5 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -211,7 +211,9 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: from dlt.common.libs.pyarrow import pyarrow, get_py_arrow_datatype # build schema - self.schema = pyarrow.schema([pyarrow.field(name, get_py_arrow_datatype(schema_item["data_type"]), nullable=schema_item["nullable"]) for name, schema_item in columns_schema.items()]) + self.schema = pyarrow.schema( + [pyarrow.field(name, get_py_arrow_datatype(schema_item["data_type"], self._caps), nullable=schema_item["nullable"]) for name, schema_item in columns_schema.items()] + ) # find row items that are of the complex type (could be abstracted out for use in other writers?) self.complex_indices = [i for i, field in columns_schema.items() if field["data_type"] == "complex"] self.writer = pyarrow.parquet.ParquetWriter(self._f, self.schema, flavor=self.parquet_flavor, version=self.parquet_version, data_page_size=self.parquet_data_page_size) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 7c2665216c..8c35796fea 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -1,17 +1,24 @@ -from typing import Any, Callable, ClassVar, List, Literal +from typing import Any, Callable, ClassVar, List, Literal, Optional, Tuple, Set, get_args from dlt.common.configuration.utils import serialize_value from dlt.common.configuration import configspec from dlt.common.configuration.specs import ContainerInjectableContext from dlt.common.utils import identity +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + +from dlt.common.wei import EVM_DECIMAL_PRECISION # known loader file formats # jsonl - new line separated json documents # puae-jsonl - internal extract -> normalize format bases on jsonl # insert_values - insert SQL statements # sql - any sql statement -TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values", "sql", "parquet"] +TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values", "sql", "parquet", "reference"] +# file formats used internally by dlt +INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"puae-jsonl", "sql", "reference"} +# file formats that may be chosen by the user +EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS @configspec(init=True) @@ -19,8 +26,12 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): """Injectable destination capabilities required for many Pipeline stages ie. normalize""" preferred_loader_file_format: TLoaderFileFormat supported_loader_file_formats: List[TLoaderFileFormat] + preferred_staging_file_format: Optional[TLoaderFileFormat] + supported_staging_file_formats: List[TLoaderFileFormat] escape_identifier: Callable[[str], str] escape_literal: Callable[[Any], Any] + decimal_precision: Tuple[int, int] + wei_precision: Tuple[int, int] max_identifier_length: int max_column_identifier_length: int max_query_length: int @@ -39,8 +50,12 @@ def generic_capabilities(preferred_loader_file_format: TLoaderFileFormat = None) caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = preferred_loader_file_format caps.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] caps.escape_identifier = identity caps.escape_literal = serialize_value + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (EVM_DECIMAL_PRECISION, 0) caps.max_identifier_length = 65536 caps.max_column_identifier_length = 65536 caps.max_query_length = 32 * 1024 * 1024 diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index f282d546d9..369cc51550 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from importlib import import_module from types import TracebackType, ModuleType -from typing import ClassVar, Final, Optional, Literal, Sequence, Iterable, Type, Protocol, Union, TYPE_CHECKING, cast +from typing import ClassVar, Final, Optional, Literal, Sequence, Iterable, Type, Protocol, Union, TYPE_CHECKING, cast, List from dlt.common import logger from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule @@ -15,6 +15,7 @@ from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import get_module_name +from dlt.common.configuration.specs import GcpCredentials, AwsCredentialsWithoutDefaults @configspec(init=True) @@ -27,7 +28,8 @@ def __str__(self) -> str: return str(self.credentials) if TYPE_CHECKING: - def __init__(self, destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None) -> None: + def __init__(self, destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None +) -> None: ... @@ -38,6 +40,7 @@ class DestinationClientDwhConfiguration(DestinationClientConfiguration): """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" default_schema_name: Optional[str] = None """name of default schema to be used to name effective dataset to load data to""" + staging_credentials: Optional[CredentialsConfiguration] = None if TYPE_CHECKING: def __init__( @@ -45,10 +48,25 @@ def __init__( destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None, dataset_name: str = None, - default_schema_name: Optional[str] = None + default_schema_name: Optional[str] = None, + staging_credentials: Optional[CredentialsConfiguration] = None ) -> None: ... +@configspec(init=True) +class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): + as_staging: bool = False + + if TYPE_CHECKING: + def __init__( + self, + destination_name: str = None, + credentials: Union[AwsCredentialsWithoutDefaults, GcpCredentials] = None, + dataset_name: str = None, + default_schema_name: Optional[str] = None, + as_staging: bool = False, + ) -> None: + ... TLoadJobState = Literal["running", "failed", "retry", "completed"] @@ -106,7 +124,8 @@ def new_file_path(self) -> str: class FollowupJob: """Adds a trait that allows to create a followup job""" - pass + def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: + return [] class JobClientBase(ABC): diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index 4dfc94da08..fbf3a49dfc 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, AnyStr, Sequence, Optional +from typing import Any, AnyStr, List, Sequence, Optional, Iterable class DltException(Exception): @@ -119,6 +119,35 @@ class DestinationTransientException(DestinationException, TransientException): pass +class DestinationLoadingViaStagingNotSupported(DestinationTerminalException): + def __init__(self, destination: str) -> None: + self.destination = destination + super().__init__(f"Destination {destination} does not support loading via staging.") + + +class DestinationNoStagingMode(DestinationTerminalException): + def __init__(self, destination: str) -> None: + self.destination = destination + super().__init__(f"Destination {destination} cannot be used as a staging") + + +class DestinationIncompatibleLoaderFileFormatException(DestinationTerminalException): + def __init__(self, destination: str, staging: str, file_format: str, supported_formats: Iterable[str]) -> None: + self.destination = destination + self.staging = staging + self.file_format = file_format + self.supported_formats = supported_formats + supported_formats_str = ", ".join(supported_formats) + if self.staging: + if not supported_formats: + msg = f"Staging {staging} cannot be used with destination {destination} because they have no file formats in common." + else: + msg = f"Unsupported file format {file_format} for destination {destination} in combination with staging destination {staging}. Supported formats: {supported_formats_str}" + else: + msg = f"Unsupported file format {file_format} destination {destination}. Supported formats: {supported_formats_str}. Check the staging option in the dlt.pipeline for additional formats." + super().__init__(msg) + + class IdentifierTooLongException(DestinationTerminalException): def __init__(self, destination_name: str, identifier_type: str, identifier_name: str, max_identifier_length: int) -> None: self.destination_name = destination_name @@ -129,13 +158,13 @@ def __init__(self, destination_name: str, identifier_type: str, identifier_name: class DestinationHasFailedJobs(DestinationTerminalException): - def __init__(self, destination_name: str, load_id: str) -> None: + def __init__(self, destination_name: str, load_id: str, failed_jobs: List[Any]) -> None: self.destination_name = destination_name self.load_id = load_id + self.failed_jobs = failed_jobs super().__init__(f"Destination {destination_name} has failed jobs in load package {load_id}") - class PipelineException(DltException): def __init__(self, pipeline_name: str, msg: str) -> None: """Base class for all pipeline exceptions. Should not be raised.""" diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 933a605f2c..3dbd338beb 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,5 +1,7 @@ from dlt.common.exceptions import MissingDependencyException -from typing import Any +from typing import Any, Tuple + +from dlt.common.destination.capabilities import DestinationCapabilitiesContext try: import pyarrow @@ -8,7 +10,7 @@ raise MissingDependencyException("DLT parquet Helpers", ["parquet"], "DLT Helpers for for parquet.") -def get_py_arrow_datatype(column_type: str) -> Any: +def get_py_arrow_datatype(column_type: str, caps: DestinationCapabilitiesContext) -> Any: if column_type == "text": return pyarrow.string() elif column_type == "double": @@ -22,12 +24,22 @@ def get_py_arrow_datatype(column_type: str) -> Any: elif column_type == "binary": return pyarrow.binary() elif column_type == "complex": + # return pyarrow.struct([pyarrow.field('json', pyarrow.string())]) return pyarrow.string() elif column_type == "decimal": - return pyarrow.decimal128(38, 18) + return get_py_arrow_numeric(caps.decimal_precision) elif column_type == "wei": - return pyarrow.decimal128(38, 0) + return get_py_arrow_numeric(caps.wei_precision) elif column_type == "date": return pyarrow.date32() else: - raise ValueError(column_type) \ No newline at end of file + raise ValueError(column_type) + + +def get_py_arrow_numeric(precision: Tuple[int, int]) -> Any: + if precision[0] <= 38: + return pyarrow.decimal128(*precision) + if precision[0] <= 76: + return pyarrow.decimal256(*precision) + # for higher precision use max precision and trim scale to leave the most significant part + return pyarrow.decimal256(76, max(0, 76 - (precision[0] - precision[1]))) diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 181abf1406..2df1590ae1 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -54,6 +54,8 @@ class LoadInfo(NamedTuple): pipeline: "SupportsPipeline" destination_name: str destination_displayable_credentials: str + staging_name: str + staging_displayable_credentials: str dataset_name: str loads_ids: List[str] """ids of the loaded packages""" @@ -79,6 +81,9 @@ def asstr(self, verbosity: int = 0) -> str: else: msg += "---" msg += f"\n{len(self.loads_ids)} load package(s) were loaded to destination {self.destination_name} and into dataset {self.dataset_name}\n" + if self.staging_name: + msg += f"The {self.staging_name} staging destination used {self.staging_displayable_credentials} location to stage data\n" + msg += f"The {self.destination_name} destination used {self.destination_displayable_credentials} location to store data" for load_package in self.load_packages: cstr = load_package.state.upper() if load_package.completed_at else "NOT COMPLETED" @@ -106,8 +111,9 @@ def has_failed_jobs(self) -> bool: def raise_on_failed_jobs(self) -> None: """Raises `DestinationHasFailedJobs` exception if any of the load packages has a failed job.""" for load_package in self.load_packages: - if len(load_package.jobs["failed_jobs"]): - raise DestinationHasFailedJobs(self.destination_name, load_package.load_id) + failed_jobs = load_package.jobs["failed_jobs"] + if len(failed_jobs): + raise DestinationHasFailedJobs(self.destination_name, load_package.load_id, failed_jobs) def __str__(self) -> str: return self.asstr(verbosity=1) @@ -128,6 +134,7 @@ class TPipelineState(TypedDict, total=False): schema_names: Optional[List[str]] """All the schemas present within the pipeline working directory""" destination: Optional[str] + staging: Optional[str] # properties starting with _ are not automatically applied to pipeline object when state is restored _state_version: int diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 636a0660cb..d9ec22d1b0 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -418,6 +418,11 @@ def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDispo raise ValueError(f"No write disposition found in the chain of tables for '{table_name}'.") +def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: + """Checks if `table` schema contains column with type _typ""" + return any(c["data_type"] == _typ for c in table["columns"].values()) + + def get_top_level_table(tables: TSchemaTables, table_name: str) -> TTableSchema: """Finds top level (without parent) of a `table_name` following the ancestry hierarchy.""" table = tables[table_name] diff --git a/dlt/destinations/bigquery/__init__.py b/dlt/destinations/bigquery/__init__.py index 0f5d4da04e..d7743c1b4b 100644 --- a/dlt/destinations/bigquery/__init__.py +++ b/dlt/destinations/bigquery/__init__.py @@ -6,6 +6,7 @@ from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration @@ -19,8 +20,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "jsonl" caps.supported_loader_file_formats = ["jsonl", "sql", "parquet"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] caps.escape_identifier = escape_bigquery_identifier caps.escape_literal = None + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (76, 38) caps.max_identifier_length = 1024 caps.max_column_identifier_length = 300 caps.max_query_length = 1024 * 1024 diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index 67ed83ef1f..099229778f 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -1,12 +1,11 @@ +import os from pathlib import Path -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, cast +from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, cast, Type import google.cloud.bigquery as bigquery # noqa: I250 from google.cloud import exceptions as gcp_exceptions from google.api_core import exceptions as api_core_exceptions from dlt.common import json, logger -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob from dlt.common.data_types import TDataType @@ -21,7 +20,9 @@ from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS from dlt.destinations.sql_merge_job import SqlMergeJob +from dlt.destinations.job_impl import NewReferenceJob +from dlt.common.schema.utils import table_schema_has_type SCT_TO_BQT: Dict[TDataType, str] = { "complex": "JSON", @@ -32,7 +33,7 @@ "timestamp": "TIMESTAMP", "bigint": "INTEGER", "binary": "BYTES", - "decimal": f"NUMERIC({DEFAULT_NUMERIC_PRECISION},{DEFAULT_NUMERIC_SCALE})", + "decimal": "NUMERIC(%i,%i)", "wei": "BIGNUMERIC" # non parametrized should hold wei values } @@ -157,18 +158,19 @@ def restore_file_load(self, file_path: str) -> LoadJob: if reason == "notFound": raise LoadJobNotExistsException(file_path) elif reason in BQ_TERMINAL_REASONS: - raise LoadJobTerminalException(file_path) + raise LoadJobTerminalException(file_path, f"The server reason was: {reason}") else: raise DestinationTransientException(gace) return job def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: job = super().start_file_load(table, file_path, load_id) + if not job: try: job = BigQueryLoadJob( FileStorage.get_file_name_from_file_path(file_path), - self._create_load_job(table["name"], table["write_disposition"], file_path), + self._create_load_job(table, file_path), self.config.http_timeout, self.config.retry_deadline ) @@ -182,7 +184,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> return self.restore_file_load(file_path) elif reason in BQ_TERMINAL_REASONS: # google.api_core.exceptions.BadRequest - will not be processed ie bad job name - raise LoadJobTerminalException(file_path) + raise LoadJobTerminalException(file_path, f"The server reason was: {reason}") else: raise DestinationTransientException(gace) return job @@ -235,14 +237,29 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] except gcp_exceptions.NotFound: return False, schema_table - def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition, file_path: str) -> bigquery.LoadJob: + def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.LoadJob: + table_name = table["name"] + write_disposition = table["write_disposition"] # append to table for merge loads (append to stage) and regular appends bq_wd = bigquery.WriteDisposition.WRITE_TRUNCATE if write_disposition == "replace" else bigquery.WriteDisposition.WRITE_APPEND + # determine wether we load from local or uri + bucket_path = None + ext: str = os.path.splitext(file_path)[1][1:] + if NewReferenceJob.is_reference_job(file_path): + bucket_path = NewReferenceJob.resolve_reference(file_path) + ext = os.path.splitext(bucket_path)[1][1:] + # choose correct source format source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON - if file_path.endswith("parquet"): + decimal_target_types: List[str] = None + if ext == "parquet": + # if table contains complex types, we cannot load with parquet + if table_schema_has_type(table, "complex"): + raise LoadJobTerminalException(file_path, "Bigquery cannot load into JSON data type from parquet. Use jsonl instead.") source_format = bigquery.SourceFormat.PARQUET + # parquet needs NUMERIC type autodetection + decimal_target_types = ["NUMERIC", "BIGNUMERIC"] # if merge then load to staging with self.sql_client.with_staging_dataset(write_disposition == "merge"): @@ -252,8 +269,19 @@ def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition write_disposition=bq_wd, create_disposition=bigquery.CreateDisposition.CREATE_NEVER, source_format=source_format, + decimal_target_types=decimal_target_types, ignore_unknown_values=False, max_bad_records=0) + + if bucket_path: + return self.sql_client.native_connection.load_table_from_uri( + bucket_path, + self.sql_client.make_qualified_table_name(table_name, escape=False), + job_id=job_id, + job_config=job_config, + timeout=self.config.file_upload_timeout + ) + with open(file_path, "rb") as f: return self.sql_client.native_connection.load_table_from_file( f, @@ -267,13 +295,17 @@ def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: job_id = BigQueryLoadJob.get_job_id_from_file_path(file_path) return cast(bigquery.LoadJob, self.sql_client.native_connection.get_job(job_id)) - @staticmethod - def _to_db_type(sc_t: TDataType) -> str: + @classmethod + def _to_db_type(cls, sc_t: TDataType) -> str: + if sc_t == "decimal": + return SCT_TO_BQT["decimal"] % cls.capabilities.decimal_precision return SCT_TO_BQT[sc_t] - @staticmethod - def _from_db_type(bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + @classmethod + def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: if bq_t == "BIGNUMERIC": if precision is None: # biggest numeric possible return "wei" return BQT_TO_SCT.get(bq_t, "text") + + diff --git a/dlt/destinations/bigquery/sql_client.py b/dlt/destinations/bigquery/sql_client.py index 4851240d8e..79e27d61d4 100644 --- a/dlt/destinations/bigquery/sql_client.py +++ b/dlt/destinations/bigquery/sql_client.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence +from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence, Type import google.cloud.bigquery as bigquery # noqa: I250 from google.cloud.bigquery import dbapi as bq_dbapi @@ -65,6 +65,7 @@ def __init__( self._default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name(escape=False)) self._session_query: bigquery.QueryJobConfig = None + @raise_open_connection_error def open_connection(self) -> bigquery.Client: self._client = bigquery.Client( diff --git a/dlt/destinations/duckdb/__init__.py b/dlt/destinations/duckdb/__init__.py index c4b831e5f2..7c4e582323 100644 --- a/dlt/destinations/duckdb/__init__.py +++ b/dlt/destinations/duckdb/__init__.py @@ -6,6 +6,7 @@ from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration @@ -19,8 +20,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "insert_values" caps.supported_loader_file_formats = ["insert_values", "parquet", "sql"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] caps.escape_identifier = escape_postgres_identifier caps.escape_literal = escape_duckdb_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) caps.max_identifier_length = 65536 caps.max_column_identifier_length = 65536 caps.naming_convention = "duck_case" diff --git a/dlt/destinations/duckdb/duck.py b/dlt/destinations/duckdb/duck.py index ff6afd7986..2bcb0e8900 100644 --- a/dlt/destinations/duckdb/duck.py +++ b/dlt/destinations/duckdb/duck.py @@ -1,9 +1,11 @@ from typing import ClassVar, Dict, Optional -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema +from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.storages.file_storage import FileStorage from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -11,12 +13,6 @@ from dlt.destinations.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState - -from dlt.common.schema.typing import TTableSchema, TWriteDisposition - -from dlt.common.storages.file_storage import FileStorage - SCT_TO_PGT: Dict[TDataType, str] = { "complex": "JSON", @@ -27,7 +23,7 @@ "timestamp": "TIMESTAMP WITH TIME ZONE", "bigint": "BIGINT", "binary": "BLOB", - "decimal": f"DECIMAL({DEFAULT_NUMERIC_PRECISION},{DEFAULT_NUMERIC_SCALE})" + "decimal": "DECIMAL(%i,%i)" } PGT_TO_SCT: Dict[str, TDataType] = { @@ -89,14 +85,16 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" - @staticmethod - def _to_db_type(sc_t: TDataType) -> str: + @classmethod + def _to_db_type(cls, sc_t: TDataType) -> str: if sc_t == "wei": - return "DECIMAL(38,0)" + return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision + if sc_t == "decimal": + return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision return SCT_TO_PGT[sc_t] - @staticmethod - def _from_db_type(pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + @classmethod + def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: # duckdb provides the types with scale and precision pq_t = pq_t.split("(")[0].upper() if pq_t == "DECIMAL": diff --git a/dlt/destinations/dummy/__init__.py b/dlt/destinations/dummy/__init__.py index d55f713178..7131f0109a 100644 --- a/dlt/destinations/dummy/__init__.py +++ b/dlt/destinations/dummy/__init__.py @@ -19,7 +19,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = config.loader_file_format caps.supported_loader_file_formats = [config.loader_file_format] - + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] caps.max_identifier_length = 127 caps.max_column_identifier_length = 127 caps.max_query_length = 8 * 1024 * 1024 diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index fdbfc0f896..8a81120a17 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -51,8 +51,8 @@ def __init__(self, job_id: str) -> None: class LoadJobTerminalException(DestinationTerminalException): - def __init__(self, file_path: str) -> None: - super().__init__(f"Job with id/file name {file_path} encountered unrecoverable problem") + def __init__(self, file_path: str, message: str) -> None: + super().__init__(f"Job with id/file name {file_path} encountered unrecoverable problem: {message}") class LoadJobUnknownTableException(DestinationTerminalException): diff --git a/dlt/destinations/filesystem/configuration.py b/dlt/destinations/filesystem/configuration.py index 8629288c72..8cb8a0d417 100644 --- a/dlt/destinations/filesystem/configuration.py +++ b/dlt/destinations/filesystem/configuration.py @@ -3,7 +3,7 @@ from typing import Final, Type, Optional, Union from dlt.common.configuration import configspec, resolve_type -from dlt.common.destination.reference import CredentialsConfiguration, DestinationClientDwhConfiguration +from dlt.common.destination.reference import CredentialsConfiguration, DestinationClientStagingConfiguration from dlt.common.configuration.specs import GcpCredentials, GcpServiceAccountCredentials, AwsCredentials, GcpOAuthCredentials from dlt.common.configuration.exceptions import ConfigurationValueError @@ -18,10 +18,9 @@ @configspec(init=True) -class FilesystemClientConfiguration(DestinationClientDwhConfiguration): - credentials: Union[AwsCredentials, GcpCredentials] - +class FilesystemClientConfiguration(DestinationClientStagingConfiguration): destination_name: Final[str] = "filesystem" # type: ignore + credentials: Union[AwsCredentials, GcpCredentials] bucket_url: str @property diff --git a/dlt/destinations/filesystem/filesystem.py b/dlt/destinations/filesystem/filesystem.py index 3096ab0b2d..5f98f4b040 100644 --- a/dlt/destinations/filesystem/filesystem.py +++ b/dlt/destinations/filesystem/filesystem.py @@ -1,19 +1,22 @@ import posixpath import threading +import os from types import TracebackType -from typing import ClassVar, List, Sequence, Type, Iterable +from typing import ClassVar, List, Sequence, Type, Iterable, cast from fsspec import AbstractFileSystem from dlt.common.schema import Schema, TTableSchema from dlt.common.schema.typing import TWriteDisposition, LOADS_TABLE_NAME from dlt.common.storages import FileStorage from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase +from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.filesystem import capabilities from dlt.destinations.filesystem.configuration import FilesystemClientConfiguration from dlt.destinations.filesystem.filesystem_client import client_from_config from dlt.common.storages import LoadStorage +from dlt.destinations.job_impl import NewLoadJobImpl +from dlt.destinations.job_impl import NewReferenceJob class LoadFilesystemJob(LoadJob): @@ -29,6 +32,9 @@ def __init__( load_id: str ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) + self.config = config + self.dataset_path = dataset_path + super().__init__(file_name) fs_client, _ = client_from_config(config) @@ -49,14 +55,17 @@ def __init__( for item in items: fs_client.rm_file(item) - destination_file_name = LoadFilesystemJob.make_destination_filename(file_name, schema_name, load_id) - fs_client.put_file(local_path, posixpath.join(dataset_path, destination_file_name)) + self.destination_file_name = LoadFilesystemJob.make_destination_filename(file_name, schema_name, load_id) + fs_client.put_file(local_path, self.make_remote_path()) @staticmethod def make_destination_filename(file_name: str, schema_name: str, load_id: str) -> str: job_info = LoadStorage.parse_job_file_name(file_name) return f"{schema_name}.{job_info.table_name}.{load_id}.{job_info.file_id}.{job_info.file_format}" + def make_remote_path(self) -> str: + return f"{self.config.protocol}://{posixpath.join(self.dataset_path, self.destination_file_name)}" + def state(self) -> TLoadJobState: return "completed" @@ -64,6 +73,15 @@ def exception(self) -> str: raise NotImplementedError() +class FollowupFilesystemJob(FollowupJob, LoadFilesystemJob): + def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: + jobs = super().create_followup_jobs(next_state) + if next_state == "completed": + ref_job = NewReferenceJob(file_name=self.file_name(), status="running", remote_path=self.make_remote_path()) + jobs.append(ref_job) + return jobs + + class FilesystemClient(JobClientBase): """filesystem client storing jobs in memory""" @@ -87,8 +105,9 @@ def is_storage_initialized(self, staging: bool = False) -> bool: return self.fs_client.isdir(self.dataset_path) # type: ignore[no-any-return] def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + cls = FollowupFilesystemJob if self.config.as_staging else LoadFilesystemJob has_merge_keys = any(col['merge_key'] or col['primary_key'] for col in table['columns'].values()) - return LoadFilesystemJob( + return cls( file_path, self.dataset_path, config=self.config, @@ -104,6 +123,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: def create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return None + def complete_load(self, load_id: str) -> None: schema_name = self.schema.name table_name = LOADS_TABLE_NAME diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index e0400b39e1..10ceb32675 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -1,5 +1,6 @@ import os -from typing import Any, Iterator, List +import abc +from typing import Any, Iterator, List, Type from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState from dlt.common.schema.typing import TTableSchema, TWriteDisposition @@ -102,3 +103,4 @@ def _get_in_table_constraints_sql(self, t: TTableSchema) -> str: def _get_out_table_constrains_sql(self, t: TTableSchema) -> str: # set non unique indexes pass + diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index eed1f04224..781bf39500 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -15,14 +15,15 @@ from dlt.common.schema.utils import add_missing_hints from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables -from dlt.common.destination.reference import DestinationClientConfiguration, DestinationClientDwhConfiguration, NewLoadJob, TLoadJobState, LoadJob, JobClientBase +from dlt.common.destination.reference import DestinationClientConfiguration, DestinationClientDwhConfiguration, NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob, DestinationClientStagingConfiguration, CredentialsConfiguration from dlt.common.utils import concat_strings_with_limit from dlt.destinations.exceptions import DatabaseUndefinedRelation, DestinationSchemaWillNotUpdate -from dlt.destinations.job_impl import EmptyLoadJobWithoutFollowup +from dlt.destinations.job_impl import EmptyLoadJobWithoutFollowup, NewReferenceJob from dlt.destinations.sql_merge_job import SqlMergeJob from dlt.destinations.typing import TNativeConn from dlt.destinations.sql_client import SqlClientBase +from dlt.common.configuration import with_config, known_sections class StorageSchemaInfo(NamedTuple): @@ -58,6 +59,23 @@ def is_sql_job(file_path: str) -> bool: return os.path.splitext(file_path)[1][1:] == "sql" +class CopyRemoteFileLoadJob(LoadJob, FollowupJob): + def __init__(self, table: TTableSchema, file_path: str, sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None) -> None: + super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + self._sql_client = sql_client + self._staging_credentials = staging_credentials + + self.execute(table, NewReferenceJob.resolve_reference(file_path)) + + def execute(self, table: TTableSchema, bucket_path: str) -> None: + # implement in child implementations + raise NotImplementedError() + + def state(self) -> TLoadJobState: + # this job is always done + return "completed" + + class SqlJobClientBase(JobClientBase): VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[str] = "version_hash, schema_name, version, engine_version, inserted_at, schema" @@ -186,14 +204,14 @@ def _null_to_bool(v: str) -> bool: schema_table[c[0]] = add_missing_hints(schema_c) return True, schema_table - @staticmethod + @classmethod @abstractmethod - def _to_db_type(schema_type: TDataType) -> str: + def _to_db_type(cls, schema_type: TDataType) -> str: pass - @staticmethod + @classmethod @abstractmethod - def _from_db_type(db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType: pass def get_newest_schema_from_storage(self) -> StorageSchemaInfo: @@ -342,3 +360,4 @@ def _update_schema_in_storage(self, schema: Schema) -> None: self.sql_client.execute_sql( f"INSERT INTO {name}({self.VERSION_TABLE_SCHEMA_COLUMNS}) VALUES (%s, %s, %s, %s, %s, %s);", schema.stored_version_hash, schema.name, schema.version, schema.ENGINE_VERSION, now_ts, schema_str ) + diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 81090b7db4..fb3ba48b6d 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -4,7 +4,7 @@ from dlt.common.storages import FileStorage from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob - +from dlt.common.storages.load_storage import ParsedLoadJobFileName class EmptyLoadJobWithoutFollowup(LoadJob): def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: @@ -36,4 +36,22 @@ def _save_text_file(self, data: str) -> None: def new_file_path(self) -> str: """Path to a newly created temporary job file""" - return self._new_file_path \ No newline at end of file + return self._new_file_path + +class NewReferenceJob(NewLoadJobImpl): + + def __init__(self, file_name: str, status: TLoadJobState, exception: str = None, remote_path: str = None) -> None: + file_name = os.path.splitext(file_name)[0] + ".reference" + super().__init__(file_name, status, exception) + self._remote_path = remote_path + self._save_text_file(remote_path) + + @staticmethod + def is_reference_job(file_path: str) -> bool: + return os.path.splitext(file_path)[1][1:] == "reference" + + @staticmethod + def resolve_reference(file_path: str) -> str: + with open(file_path, "r+", encoding="utf-8") as f: + # Reading from a file + return f.read() diff --git a/dlt/destinations/motherduck/__init__.py b/dlt/destinations/motherduck/__init__.py index 5647f41023..c3e43b0c0f 100644 --- a/dlt/destinations/motherduck/__init__.py +++ b/dlt/destinations/motherduck/__init__.py @@ -6,6 +6,7 @@ from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.motherduck.configuration import MotherDuckClientConfiguration @@ -21,6 +22,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["parquet", "insert_values", "sql"] caps.escape_identifier = escape_postgres_identifier caps.escape_literal = escape_duckdb_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) caps.max_identifier_length = 65536 caps.max_column_identifier_length = 65536 caps.naming_convention = "duck_case" diff --git a/dlt/destinations/postgres/__init__.py b/dlt/destinations/postgres/__init__.py index 492ada6256..346e016cee 100644 --- a/dlt/destinations/postgres/__init__.py +++ b/dlt/destinations/postgres/__init__.py @@ -6,6 +6,8 @@ from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.destinations.postgres.configuration import PostgresClientConfiguration @@ -20,8 +22,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "insert_values" caps.supported_loader_file_formats = ["insert_values", "sql"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] caps.escape_identifier = escape_postgres_identifier caps.escape_literal = escape_postgres_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (2*EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) caps.max_identifier_length = 63 caps.max_column_identifier_length = 63 caps.max_query_length = 32 * 1024 * 1024 diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index 910d767ab5..b12a13fb7e 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -32,7 +32,7 @@ "date": "date", "bigint": "bigint", "binary": "bytea", - "decimal": f"numeric({DEFAULT_NUMERIC_PRECISION},{DEFAULT_NUMERIC_SCALE})" + "decimal": "numeric(%i,%i)" } PGT_TO_SCT: Dict[str, TDataType] = { @@ -70,15 +70,20 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" - @staticmethod - def _to_db_type(sc_t: TDataType) -> str: + @classmethod + def _to_db_type(cls, sc_t: TDataType) -> str: + if sc_t == "wei": + return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision + if sc_t == "decimal": + return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision + if sc_t == "wei": return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})" return SCT_TO_PGT[sc_t] - @staticmethod - def _from_db_type(pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + @classmethod + def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: if pq_t == "numeric": - if precision == 2*EVM_DECIMAL_PRECISION and scale == EVM_DECIMAL_PRECISION: + if (precision, scale) == cls.capabilities.wei_precision: return "wei" return PGT_TO_SCT.get(pq_t, "text") diff --git a/dlt/destinations/redshift/__init__.py b/dlt/destinations/redshift/__init__.py index bbe0a84430..6ca3423035 100644 --- a/dlt/destinations/redshift/__init__.py +++ b/dlt/destinations/redshift/__init__.py @@ -6,6 +6,7 @@ from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.redshift.configuration import RedshiftClientConfiguration @@ -19,8 +20,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "insert_values" caps.supported_loader_file_formats = ["insert_values", "sql"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet"] caps.escape_identifier = escape_redshift_identifier caps.escape_literal = escape_redshift_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) caps.max_identifier_length = 127 caps.max_column_identifier_length = 127 caps.max_query_length = 16 * 1024 * 1024 diff --git a/dlt/destinations/redshift/configuration.py b/dlt/destinations/redshift/configuration.py index d41d04d51a..b35a1a0922 100644 --- a/dlt/destinations/redshift/configuration.py +++ b/dlt/destinations/redshift/configuration.py @@ -1,4 +1,4 @@ -from typing import Final +from typing import Final, Optional from dlt.common.typing import TSecretValue from dlt.common.configuration import configspec @@ -18,3 +18,4 @@ class RedshiftCredentials(PostgresCredentials): class RedshiftClientConfiguration(PostgresClientConfiguration): destination_name: Final[str] = "redshift" # type: ignore credentials: RedshiftCredentials + staging_iam_role: Optional[str] = None diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index d39215a96d..ad74445849 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -1,6 +1,9 @@ import platform +import os from dlt.destinations.postgres.sql_client import Psycopg2SqlClient + +from dlt.common.schema.utils import table_schema_has_type if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 # from psycopg2cffi.sql import SQL, Composed @@ -8,21 +11,24 @@ import psycopg2 # from psycopg2.sql import SQL, Composed -from typing import ClassVar, Dict, List, Optional, Sequence +from typing import ClassVar, Dict, List, Optional, Sequence, Any -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob +from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema +from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.sql_merge_job import SqlMergeJob -from dlt.destinations.exceptions import DatabaseTerminalException +from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException +from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob from dlt.destinations.redshift import capabilities from dlt.destinations.redshift.configuration import RedshiftClientConfiguration +from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.sql_client import SqlClientBase @@ -35,7 +41,7 @@ "timestamp": "timestamp with time zone", "bigint": "bigint", "binary": "varbinary", - "decimal": f"numeric({DEFAULT_NUMERIC_PRECISION},{DEFAULT_NUMERIC_SCALE})" + "decimal": "numeric(%i,%i)" } PGT_TO_SCT: Dict[str, TDataType] = { @@ -73,6 +79,61 @@ def _maybe_make_terminal_exception_from_data_error(pg_ex: psycopg2.DataError) -> return DatabaseTerminalException(pg_ex) return None +class RedshiftCopyFileLoadJob(CopyRemoteFileLoadJob): + + def __init__(self, table: TTableSchema, file_path: str, sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None, staging_iam_role: str = None) -> None: + self._staging_iam_role = staging_iam_role + super().__init__(table, file_path, sql_client, staging_credentials) + + def execute(self, table: TTableSchema, bucket_path: str) -> None: + + # we assume s3 credentials where provided for the staging + credentials = "" + if self._staging_iam_role: + credentials = f"IAM_ROLE '{self._staging_iam_role}'" + elif self._staging_credentials and isinstance(self._staging_credentials, AwsCredentialsWithoutDefaults): + aws_access_key = self._staging_credentials.aws_access_key_id + aws_secret_key = self._staging_credentials.aws_secret_access_key + credentials = f"CREDENTIALS 'aws_access_key_id={aws_access_key};aws_secret_access_key={aws_secret_key}'" + table_name = table["name"] + + # get format + ext = os.path.splitext(bucket_path)[1][1:] + file_type = "" + dateformat = "" + compression = "" + if ext == "jsonl": + if table_schema_has_type(table, "binary"): + raise LoadJobTerminalException(self.file_name(), "Redshift cannot load VARBYTE columns from json files. Switch to parquet to load binaries.") + file_type = "FORMAT AS JSON 'auto'" + dateformat = "dateformat 'auto' timeformat 'auto'" + compression = "GZIP" + elif ext == "parquet": + file_type = "PARQUET" + # if table contains complex types then SUPER field will be used. + # https://docs.aws.amazon.com/redshift/latest/dg/ingest-super.html + if table_schema_has_type(table, "complex"): + file_type += " SERIALIZETOJSON" + else: + raise ValueError(f"Unsupported file type {ext} for Redshift.") + + with self._sql_client.with_staging_dataset(table["write_disposition"]=="merge"): + with self._sql_client.begin_transaction(): + if table["write_disposition"]=="replace": + self._sql_client.execute_sql(f"""TRUNCATE TABLE {table_name}""") + dataset_name = self._sql_client.dataset_name + # TODO: if we ever support csv here remember to add column names to COPY + self._sql_client.execute_sql(f""" + COPY {dataset_name}.{table_name} + FROM '{bucket_path}' + {file_type} + {dateformat} + {compression} + {credentials} MAXERROR 0;""") + + def exception(self) -> str: + # this part of code should be never reached + raise NotImplementedError() class RedshiftMergeJob(SqlMergeJob): @@ -108,15 +169,24 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" - @staticmethod - def _to_db_type(sc_t: TDataType) -> str: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" + if NewReferenceJob.is_reference_job(file_path): + return RedshiftCopyFileLoadJob(table, file_path, self.sql_client, staging_credentials=self.config.staging_credentials, staging_iam_role=self.config.staging_iam_role) + return super().start_file_load(table, file_path, load_id) + + @classmethod + def _to_db_type(cls, sc_t: TDataType) -> str: if sc_t == "wei": - return f"numeric({DEFAULT_NUMERIC_PRECISION},0)" + return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision + if sc_t == "decimal": + return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision return SCT_TO_PGT[sc_t] - @staticmethod - def _from_db_type(pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + @classmethod + def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: if pq_t == "numeric": - if precision == DEFAULT_NUMERIC_PRECISION and scale == 0: + if (precision, scale) == cls.capabilities.wei_precision: return "wei" return PGT_TO_SCT.get(pq_t, "text") + diff --git a/dlt/destinations/snowflake/__init__.py b/dlt/destinations/snowflake/__init__.py index 226f18274f..d901f89550 100644 --- a/dlt/destinations/snowflake/__init__.py +++ b/dlt/destinations/snowflake/__init__.py @@ -7,6 +7,7 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration from dlt.common.data_writers.escape import escape_snowflake_identifier +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration @@ -20,7 +21,11 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "jsonl" caps.supported_loader_file_formats = ["jsonl", "parquet", "sql"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet", "sql"] caps.escape_identifier = escape_snowflake_identifier + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) caps.max_identifier_length = 255 caps.max_column_identifier_length = 255 caps.max_query_length = 2 * 1024 * 1024 diff --git a/dlt/destinations/snowflake/configuration.py b/dlt/destinations/snowflake/configuration.py index 63fd66b389..6041e59c70 100644 --- a/dlt/destinations/snowflake/configuration.py +++ b/dlt/destinations/snowflake/configuration.py @@ -90,4 +90,4 @@ class SnowflakeClientConfiguration(DestinationClientDwhConfiguration): stage_name: Optional[str] = None """Use an existing named stage instead of the default. Default uses the implicit table stage per table""" keep_staged_files: bool = True - """Whether to keep or delete the staged files after COPY INTO succeeds""" + """Whether to keep or delete the staged files after COPY INTO succeeds""" \ No newline at end of file diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index 5f8d69a4c2..cf430cdfbe 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -1,32 +1,31 @@ -from pathlib import Path -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, cast, Iterable +from typing import ClassVar, Dict, Optional, Sequence, Tuple, List +from urllib.parse import urlparse -from dlt.common import json, logger from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob +from dlt.common.destination.reference import FollowupJob, TLoadJobState, LoadJob, CredentialsConfiguration +from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables +from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TWriteDisposition -from dlt.common.wei import EVM_DECIMAL_PRECISION + from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, LoadJobUnknownTableException +from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.snowflake import capabilities from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_merge_job import SqlMergeJob from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient +from dlt.destinations.job_impl import NewReferenceJob BIGINT_PRECISION = 19 MAX_NUMERIC_PRECISION = 38 -SCT_TO_BQT: Dict[TDataType, str] = { +SCT_TO_SNOW: Dict[TDataType, str] = { "complex": "VARIANT", "text": "VARCHAR", "double": "FLOAT", @@ -35,10 +34,10 @@ "timestamp": "TIMESTAMP_TZ", "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types "binary": "BINARY", - "decimal": f"NUMBER({DEFAULT_NUMERIC_PRECISION},{DEFAULT_NUMERIC_SCALE})", + "decimal": "NUMBER(%i,%i)", } -BQT_TO_SCT: Dict[str, TDataType] = { +SNOW_TO_SCT: Dict[str, TDataType] = { "VARCHAR": "text", "FLOAT": "double", "BOOLEAN": "bool", @@ -48,10 +47,11 @@ "VARIANT": "complex" } + class SnowflakeLoadJob(LoadJob, FollowupJob): def __init__( self, file_path: str, table_name: str, write_disposition: TWriteDisposition, load_id: str, client: SnowflakeSqlClient, - stage_name: Optional[str] = None, keep_staged_files: bool = True + stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) @@ -59,33 +59,54 @@ def __init__( with client.with_staging_dataset(write_disposition == "merge"): qualified_table_name = client.make_qualified_table_name(table_name) - if stage_name: - # Concat "SCHEMA_NAME".stage_name - stage_name = client.make_qualified_table_name(stage_name) - # Create the stage if it doesn't exist - client.execute_sql(f"CREATE STAGE IF NOT EXISTS {stage_name}") + # extract and prepare some vars + bucket_path = NewReferenceJob.resolve_reference(file_path) if NewReferenceJob.is_reference_job(file_path) else "" + file_name = FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + from_clause = "" + credentials_clause = "" + files_clause = "" + stage_file_path = "" + + if bucket_path: + # s3 credentials case + if bucket_path.startswith("s3://") and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" + from_clause = f"FROM '{bucket_path}'" + else: + if not stage_name: + # when loading from bucket stage must be given + raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for instructions on setting up the `stage_name`") + from_clause = f"FROM @{stage_name}/" + files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" else: - # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name('%'+table_name) - + # this means we have a local file + if not stage_name: + # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" + stage_name = client.make_qualified_table_name('%'+table_name) + stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' + from_clause = f"FROM {stage_file_path}" + + # decide on source format, stage_file_path will either be a local file or a bucket path source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" - if file_path.endswith("parquet"): - source_format = "(TYPE = 'PARQUET')" + if file_name.endswith("parquet"): + source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE)" - stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' with client.begin_transaction(): - # PUT and copy files in one transaction - client.execute_sql(f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE, AUTO_COMPRESS = FALSE') if write_disposition == "replace": client.execute_sql(f"TRUNCATE TABLE IF EXISTS {qualified_table_name}") + # PUT and COPY in one tx if local file, otherwise only copy + if not bucket_path: + client.execute_sql(f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE, AUTO_COMPRESS = FALSE') client.execute_sql( f"""COPY INTO {qualified_table_name} - FROM {stage_file_path} + {from_clause} + {files_clause} + {credentials_clause} FILE_FORMAT = {source_format} MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' """ ) - if not keep_staged_files: + if stage_file_path and not keep_staged_files: client.execute_sql(f'REMOVE {stage_file_path}') @@ -115,7 +136,8 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if not job: job = SnowflakeLoadJob( file_path, table['name'], table['write_disposition'], load_id, self.sql_client, - stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files + stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, + staging_credentials=self.config.staging_credentials ) return job @@ -136,21 +158,23 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql - @staticmethod - def _to_db_type(sc_t: TDataType) -> str: - if sc_t == 'wei': - return "NUMBER(38,0)" - return SCT_TO_BQT[sc_t] + @classmethod + def _to_db_type(cls, sc_t: TDataType) -> str: + if sc_t == "wei": + return SCT_TO_SNOW["decimal"] % cls.capabilities.wei_precision + if sc_t == "decimal": + return SCT_TO_SNOW["decimal"] % cls.capabilities.decimal_precision + return SCT_TO_SNOW[sc_t] - @staticmethod - def _from_db_type(bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + @classmethod + def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: if bq_t == "NUMBER": if precision == BIGINT_PRECISION and scale == 0: return 'bigint' - elif precision == MAX_NUMERIC_PRECISION and scale == 0: + elif (precision, scale) == cls.capabilities.wei_precision: return 'wei' return 'decimal' - return BQT_TO_SCT.get(bq_t, "text") + return SNOW_TO_SCT.get(bq_t, "text") def _get_column_def_sql(self, c: TColumnSchema) -> str: name = self.capabilities.escape_identifier(c["name"]) diff --git a/dlt/destinations/snowflake/sql_client.py b/dlt/destinations/snowflake/sql_client.py index d9950b934d..6a4b3b5577 100644 --- a/dlt/destinations/snowflake/sql_client.py +++ b/dlt/destinations/snowflake/sql_client.py @@ -30,9 +30,14 @@ def __init__(self, dataset_name: str, credentials: SnowflakeCredentials) -> None self.credentials = credentials def open_connection(self) -> snowflake_lib.SnowflakeConnection: + conn_params = self.credentials.to_connector_params() + # set the timezone to UTC so when loading from file formats that do not have timezones + # we get dlt expected UTC + if "timezone" not in conn_params: + conn_params["timezone"] = "UTC" self._conn = snowflake_lib.connect( schema=self.fully_qualified_dataset_name(), - **self.credentials.to_connector_params() + **conn_params ) return self._conn diff --git a/dlt/load/load.py b/dlt/load/load.py index 5a464708bd..fa70e15e2d 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -2,6 +2,7 @@ import datetime # noqa: 251 from typing import Dict, List, Optional, Tuple from multiprocessing.pool import ThreadPool +import os from dlt.common import sleep, logger from dlt.common.configuration import with_config, known_sections @@ -15,9 +16,10 @@ from dlt.common.runtime.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema -from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.schema.typing import VERSION_TABLE_NAME, TTableSchema, TWriteDisposition from dlt.common.storages import LoadStorage -from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration +from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration, DestinationClientStagingConfiguration +from dlt.destinations.filesystem.filesystem import LoadFilesystemJob from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import DestinationTerminalException, DestinationTransientException, LoadJobUnknownTableException @@ -32,26 +34,33 @@ class Load(Runnable[ThreadPool]): def __init__( self, destination: DestinationReference, + staging: DestinationReference = None, collector: Collector = NULL_COLLECTOR, is_storage_owner: bool = False, config: LoaderConfiguration = config.value, - initial_client_config: DestinationClientConfiguration = config.value + initial_client_config: DestinationClientConfiguration = config.value, + initial_staging_client_config: DestinationClientConfiguration = config.value ) -> None: self.config = config self.collector = collector self.initial_client_config = initial_client_config + self.initial_staging_client_config = initial_staging_client_config self.destination = destination self.capabilities = destination.capabilities() + self.staging = staging self.pool: ThreadPool = None self.load_storage: LoadStorage = self.create_storage(is_storage_owner) self._processed_load_ids: Dict[str, int] = {} def create_storage(self, is_storage_owner: bool) -> LoadStorage: + supported_file_formats = self.capabilities.supported_loader_file_formats + if self.staging: + supported_file_formats = self.staging.capabilities().supported_loader_file_formats + ["reference", "sql"] load_storage = LoadStorage( is_storage_owner, self.capabilities.preferred_loader_file_format, - self.capabilities.supported_loader_file_formats, + supported_file_formats, config=self.config._load_storage_config ) return load_storage @@ -73,9 +82,11 @@ def get_load_table(schema: Schema, file_name: str) -> TTableSchema: def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> Optional[LoadJob]: job: LoadJob = None try: - with self.destination.client(schema, self.initial_client_config) as client: + # if we have a staging destination and the file is not a reference, send to staging + client = self.get_staging_client(schema) if self.is_staging_job(file_path) else self.get_destination_client(schema) + with client as client: job_info = self.load_storage.parse_job_file_name(file_path) - if job_info.file_format not in self.capabilities.supported_loader_file_formats: + if job_info.file_format not in self.load_storage.supported_file_formats: raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") table = self.get_load_table(schema, file_path) @@ -112,11 +123,15 @@ def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJo # remove None jobs and check the rest return file_count, [job for job in jobs if job is not None] - def retrieve_jobs(self, client: JobClientBase, load_id: str) -> Tuple[int, List[LoadJob]]: + def is_staging_job(self, file_path: str) -> bool: + return self.staging is not None and os.path.splitext(file_path)[1][1:] in self.staging.capabilities().supported_loader_file_formats + + def retrieve_jobs(self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None) -> Tuple[int, List[LoadJob]]: jobs: List[LoadJob] = [] # list all files that were started but not yet completed started_jobs = self.load_storage.list_started_jobs(load_id) + logger.info(f"Found {len(started_jobs)} that are already started and should be continued") if len(started_jobs) == 0: return 0, jobs @@ -124,6 +139,7 @@ def retrieve_jobs(self, client: JobClientBase, load_id: str) -> Tuple[int, List[ for file_path in started_jobs: try: logger.info(f"Will retrieve {file_path}") + client = staging_client if self.is_staging_job(file_path) else client job = client.restore_file_load(file_path) except DestinationTerminalException: logger.exception(f"Job retrieval for {file_path} failed, job will be terminated") @@ -144,7 +160,8 @@ def get_new_jobs_info(self, load_id: str, schema: Schema, disposition: TWriteDis jobs_info.append(LoadStorage.parse_job_file_name(job_file)) return jobs_info - def create_merge_job(self, load_id: str, schema: Schema, top_merged_table: TTableSchema, starting_job: LoadJob) -> NewLoadJob: + def get_completed_table_chain(self, load_id: str, schema: Schema, top_merged_table: TTableSchema, starting_job_id: str) -> List[TTableSchema]: + """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed""" # returns ordered list of tables from parent to child leaf tables table_chain: List[TTableSchema] = [] # make sure all the jobs for the table chain is completed @@ -155,23 +172,26 @@ def create_merge_job(self, load_id: str, schema: Schema, top_merged_table: TTabl if not table_jobs: continue # all jobs must be completed in order for merge to be created - if any(job.state not in ("failed_jobs", "completed_jobs") and job.job_file_info.job_id() != starting_job.job_file_info().job_id() for job in table_jobs): + if any(job.state not in ("failed_jobs", "completed_jobs") and job.job_file_info.job_id() != starting_job_id for job in table_jobs): return None table_chain.append(table) # there must be at least 1 job assert len(table_chain) > 0 - # all tables completed, create merge sql job - return self.destination.client(schema, self.initial_client_config).create_merge_job(table_chain) + return table_chain def create_followup_jobs(self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema) -> List[NewLoadJob]: jobs: List[NewLoadJob] = [] if isinstance(starting_job, FollowupJob): - if state == "completed": - top_merged_table = get_top_level_table(schema.tables, self.get_load_table(schema, starting_job.file_name())["name"]) - if top_merged_table["write_disposition"] == "merge": - job = self.create_merge_job(load_id, schema, top_merged_table, starting_job) - if job: - jobs.append(job) + # check for merge jobs only for non-staging jobs. we may move that logic to the interface + starting_job_file_name = starting_job.file_name() + if state == "completed" and not self.is_staging_job(starting_job_file_name): + top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, starting_job_file_name)["name"]) + if top_job_table["write_disposition"] == "merge": + # if all tables completed, create merge sql job on destination client + if table_chain := self.get_completed_table_chain(load_id, schema, top_job_table, starting_job.job_file_info().job_id()): + if job := self.destination.client(schema, self.initial_client_config).create_merge_job(table_chain): + jobs.append(job) + jobs = jobs + starting_job.create_followup_jobs(state) return jobs def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: @@ -220,10 +240,13 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li return remaining_jobs + def get_destination_client(self, schema: Schema) -> JobClientBase: + return self.destination.client(schema, self.initial_client_config) + def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) -> None: # do not commit load id for aborted packages if not aborted: - with self.destination.client(schema, self.initial_client_config) as job_client: + with self.get_destination_client(schema) as job_client: job_client.complete_load(load_id) self.load_storage.complete_load_package(load_id, aborted) logger.info(f"All jobs completed, archiving package {load_id} with aborted set to {aborted}") @@ -232,7 +255,7 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) def load_single_package(self, load_id: str, schema: Schema) -> None: # initialize analytical storage ie. create dataset required by passed schema job_client: JobClientBase - with self.destination.client(schema, self.initial_client_config) as job_client: + with self.get_destination_client(schema) as job_client: expected_update = self.load_storage.begin_schema_update(load_id) if expected_update is not None: # update the default dataset @@ -251,12 +274,16 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: job_client.initialize_storage(staging=True) logger.info(f"Client for {job_client.config.destination_name} will UPDATE STAGING SCHEMA to package schema") merge_tables = set(job.table_name for job in merge_jobs) - job_client.update_storage_schema(staging=True, only_tables=merge_tables | dlt_tables, expected_update=expected_update) + job_client.update_storage_schema(staging=True, only_tables=merge_tables | {VERSION_TABLE_NAME}, expected_update=expected_update) logger.info(f"Client for {job_client.config.destination_name} will TRUNCATE STAGING TABLES: {merge_tables}") job_client.initialize_storage(staging=True, truncate_tables=merge_tables) self.load_storage.commit_schema_update(load_id, applied_update) # spool or retrieve unfinished jobs - jobs_count, jobs = self.retrieve_jobs(job_client, load_id) + if self.staging: + with self.get_staging_client(schema) as staging_client: + jobs_count, jobs = self.retrieve_jobs(job_client, load_id, staging_client) + else: + jobs_count, jobs = self.retrieve_jobs(job_client, load_id) if not jobs: # jobs count is a total number of jobs including those that could not be initialized @@ -301,6 +328,9 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: self.complete_package(load_id, schema, True) raise + def get_staging_client(self, schema: Schema) -> JobClientBase: + return self.staging.client(schema, self.initial_staging_client_config) + def run(self, pool: ThreadPool) -> TRunMetrics: # store pool self.pool = pool @@ -323,6 +353,7 @@ def run(self, pool: ThreadPool) -> TRunMetrics: self._processed_load_ids[load_id] = None with self.collector(f"Load {schema.name} in {load_id}"): self.load_single_package(load_id, schema) + return TRunMetrics(False, len(self.load_storage.list_packages())) def get_load_info(self, pipeline: SupportsPipeline, started_at: datetime.datetime = None) -> LoadInfo: @@ -340,6 +371,8 @@ def get_load_info(self, pipeline: SupportsPipeline, started_at: datetime.datetim pipeline, self.initial_client_config.destination_name, str(self.initial_client_config), + self.initial_staging_client_config.destination_name if self.initial_staging_client_config else None, + str(self.initial_staging_client_config) if self.initial_staging_client_config else None, dataset_name, list(load_ids), load_packages, diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index c107e608a2..3c30dd43e0 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -22,12 +22,13 @@ def pipeline( pipelines_dir: str = None, pipeline_salt: TSecretValue = None, destination: TDestinationReferenceArg = None, + staging: TDestinationReferenceArg = None, dataset_name: str = None, import_schema_path: str = None, export_schema_path: str = None, full_refresh: bool = False, credentials: Any = None, - progress: TCollectorArg = _NULL_COLLECTOR + progress: TCollectorArg = _NULL_COLLECTOR, ) -> Pipeline: """Creates a new instance of `dlt` pipeline, which moves the data from the source ie. a REST API to a destination ie. database or a data lake. @@ -52,6 +53,9 @@ def pipeline( destination (str | DestinationReference, optional): A name of the destination to which dlt will load the data, or a destination module imported from `dlt.destination`. May also be provided to `run` method of the `pipeline`. + staging (str | DestinationReference, optional): A name of the destination where dlt will stage the data before final loading, or a destination module imported from `dlt.destination`. + May also be provided to `run` method of the `pipeline`. + dataset_name (str, optional): A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. May also be provided later to the `run` or `load` methods of the `Pipeline`. If not provided at all then defaults to the `pipeline_name` @@ -86,6 +90,7 @@ def pipeline( pipelines_dir: str = None, pipeline_salt: TSecretValue = None, destination: TDestinationReferenceArg = None, + staging: TDestinationReferenceArg = None, dataset_name: str = None, import_schema_path: str = None, export_schema_path: str = None, @@ -113,6 +118,8 @@ def pipeline( pipelines_dir = get_dlt_pipelines_dir() destination = DestinationReference.from_name(destination or kwargs["destination_name"]) + staging = DestinationReference.from_name(staging or kwargs.get("staging_name", None)) if staging is not None else None + progress = collector_from_name(progress) # create new pipeline instance p = Pipeline( @@ -120,6 +127,7 @@ def pipeline( pipelines_dir, pipeline_salt, destination, + staging, dataset_name, credentials, import_schema_path, @@ -149,7 +157,7 @@ def attach( pipelines_dir = get_dlt_pipelines_dir() progress = collector_from_name(progress) # create new pipeline instance - p = Pipeline(pipeline_name, pipelines_dir, pipeline_salt, None, None, None, None, None, full_refresh, progress, True, last_config(**kwargs), kwargs["runtime"]) + p = Pipeline(pipeline_name, pipelines_dir, pipeline_salt, None, None, None, None, None, None, full_refresh, progress, True, last_config(**kwargs), kwargs["runtime"]) # set it as current pipeline p.activate() return p diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index da98f88a4e..3d0c70f4b1 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -7,12 +7,12 @@ from dlt.common.data_writers import TLoaderFileFormat - @configspec class PipelineConfiguration(BaseConfiguration): pipeline_name: Optional[str] = None pipelines_dir: Optional[str] = None destination_name: Optional[str] = None + staging_name: Optional[str] = None loader_file_format: Optional[TLoaderFileFormat] = None dataset_name: Optional[str] = None pipeline_salt: Optional[TSecretValue] = None diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index a53c0529dd..ce82e9be3b 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -13,7 +13,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ConfigFieldMissingException, ContextDefaultCannotBeCreated from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.exceptions import MissingDependencyException +from dlt.common.exceptions import DestinationLoadingViaStagingNotSupported, DestinationNoStagingMode, MissingDependencyException, DestinationIncompatibleLoaderFileFormatException from dlt.common.normalizers import default_normalizers, import_normalizers from dlt.common.runtime import signals, initialize_runtime from dlt.common.schema.exceptions import InvalidDatasetName @@ -23,7 +23,7 @@ from dlt.common.runners import pool_runner as runner from dlt.common.storages import LiveSchemaStorage, NormalizeStorage, LoadStorage, SchemaStorage, FileStorage, NormalizeStorageConfiguration, SchemaStorageConfiguration, LoadStorageConfiguration from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import DestinationReference, JobClientBase, DestinationClientConfiguration, DestinationClientDwhConfiguration, TDestinationReferenceArg +from dlt.common.destination.reference import DestinationReference, JobClientBase, DestinationClientConfiguration, DestinationClientDwhConfiguration, TDestinationReferenceArg, DestinationClientStagingConfiguration, DestinationClientDwhConfiguration from dlt.common.pipeline import ExtractInfo, LoadInfo, NormalizeInfo, PipelineContext, SupportsPipeline, TPipelineLocalState, TPipelineState, StateInjectableContext from dlt.common.schema import Schema from dlt.common.utils import is_interactive @@ -48,6 +48,8 @@ from dlt.pipeline.typing import TPipelineStep from dlt.pipeline.state_sync import STATE_ENGINE_VERSION, load_state_from_destination, merge_state_if_changed, migrate_state, state_resource, json_encode_state, json_decode_state +from dlt.common.destination.capabilities import INTERNAL_LOADER_FILE_FORMATS + def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: @@ -178,6 +180,7 @@ def __init__( pipelines_dir: str, pipeline_salt: TSecretValue, destination: DestinationReference, + staging: DestinationReference, dataset_name: str, credentials: Any, import_schema_path: str, @@ -186,15 +189,16 @@ def __init__( progress: _Collector, must_attach_to_local_pipeline: bool, config: PipelineConfiguration, - runtime: RunConfiguration + runtime: RunConfiguration, ) -> None: """Initializes the Pipeline class which implements `dlt` pipeline. Please use `pipeline` function in `dlt` module to create a new Pipeline instance.""" - self.pipeline_salt = pipeline_salt self.config = config self.runtime_config = runtime self.full_refresh = full_refresh self.collector = progress or _NULL_COLLECTOR + self.destination = None + self.staging = None self._container = Container() self._pipeline_instance_id = self._create_pipeline_instance_id() @@ -214,8 +218,12 @@ def __init__( with self.managed_state() as state: # set the pipeline properties from state self._state_to_props(state) + # we overwrite the state with the values from init + if staging: + self._set_staging(staging) self._set_destination(destination) # changing the destination could be dangerous if pipeline has not loaded items + self._set_dataset_name(dataset_name) self.credentials = credentials self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) @@ -230,6 +238,7 @@ def drop(self) -> "Pipeline": self.pipelines_dir, self.pipeline_salt, self.destination, + self.staging, self.dataset_name, self.credentials, self._schema_storage.config.import_schema_path, @@ -288,6 +297,8 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No """Normalizes the data prepared with `extract` method, infers the schema and creates load packages for the `load` method. Requires `destination` to be known.""" if is_interactive() and workers > 1: raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") + if loader_file_format and loader_file_format in INTERNAL_LOADER_FILE_FORMATS: + raise ValueError(f"{loader_file_format} is one of internal dlt file formats.") # check if any schema is present, if not then no data was extracted if not self.default_schema_name: return None @@ -338,6 +349,12 @@ def load( # make sure that destination is set and client is importable and can be instantiated client = self._get_destination_client(self.default_schema) + staging_client = None + if self.staging: + staging_client = self._get_staging_client(self.default_schema) + # inject staging config into destination config, TODO: Not super clean I think? + if isinstance(client.config, DestinationClientDwhConfiguration) and not client.config.staging_credentials: + client.config.staging_credentials = staging_client.config.credentials # create default loader config and the loader load_config = LoaderConfiguration( @@ -345,7 +362,7 @@ def load( raise_on_failed_jobs=raise_on_failed_jobs, _load_storage_config=self._load_storage_config ) - load = Load(self.destination, collector=self.collector, is_storage_owner=False, config=load_config, initial_client_config=client.config) + load = Load(self.destination, staging=self.staging, collector=self.collector, is_storage_owner=False, config=load_config, initial_client_config=client.config, initial_staging_client_config=staging_client.config if staging_client else None) try: with signals.delayed_signals(): runner.run_pool(load.config, load) @@ -362,6 +379,7 @@ def run( data: Any = None, *, destination: TDestinationReferenceArg = None, + staging: TDestinationReferenceArg = None, dataset_name: str = None, credentials: Any = None, table_name: str = None, @@ -417,7 +435,7 @@ def run( schema (Schema, optional): An explicit `Schema` object in which all table schemas will be grouped. By default `dlt` takes the schema from the source (if passed in `data` argument) or creates a default one itself. - loader_file_format (Literal["jsonl", "puae-jsonl", "insert_values", "sql", "parquet"], optional). The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. + loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional). The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. ### Raises: PipelineStepFailed when a problem happened during `extract`, `normalize` or `load` steps. @@ -426,6 +444,7 @@ def run( """ signals.raise_if_signalled() self._set_destination(destination) + self._set_staging(staging) self._set_dataset_name(dataset_name) # sync state with destination @@ -844,8 +863,9 @@ def _extract_source(self, storage: ExtractorStorage, source: DltSource, max_para return extract_id - def _get_destination_client_initial_config(self, credentials: Any = None) -> DestinationClientConfiguration: - if not self.destination: + def _get_destination_client_initial_config(self, destination: DestinationReference = None, credentials: Any = None, as_staging: bool = False) -> DestinationClientConfiguration: + destination = destination or self.destination + if not destination: raise PipelineConfigMissing( self.pipeline_name, "destination", @@ -853,25 +873,27 @@ def _get_destination_client_initial_config(self, credentials: Any = None) -> Des "Please provide `destination` argument to `pipeline`, `run` or `load` method directly or via .dlt config.toml file or environment variable." ) # create initial destination client config - client_spec = self.destination.spec() + client_spec = destination.spec() # initialize explicit credentials credentials = credentials or self.credentials if credentials is not None and not isinstance(credentials, CredentialsConfiguration): # use passed credentials as initial value. initial value may resolve credentials credentials = client_spec.get_resolvable_fields()["credentials"](credentials) # this client support schemas and datasets - if issubclass(client_spec, DestinationClientDwhConfiguration): + default_schema_name = None if self.config.use_single_dataset else self.default_schema_name + + if issubclass(client_spec, DestinationClientStagingConfiguration): + return client_spec(dataset_name=self.dataset_name, default_schema_name=default_schema_name, credentials=credentials, as_staging=as_staging) + elif issubclass(client_spec, DestinationClientDwhConfiguration): # set default schema name to load all incoming data to a single dataset, no matter what is the current schema name - default_schema_name = None if self.config.use_single_dataset else self.default_schema_name return client_spec(dataset_name=self.dataset_name, default_schema_name=default_schema_name, credentials=credentials) - else: - return client_spec(credentials=credentials) + return client_spec(credentials=credentials) def _get_destination_client(self, schema: Schema, initial_config: DestinationClientConfiguration = None) -> JobClientBase: try: # config is not provided then get it with injected credentials if not initial_config: - initial_config = self._get_destination_client_initial_config() + initial_config = self._get_destination_client_initial_config(self.destination) return self.destination.client(schema, initial_config) except ImportError: client_spec = self.destination.spec() @@ -881,6 +903,20 @@ def _get_destination_client(self, schema: Schema, initial_config: DestinationCli "Dependencies for specific destinations are available as extras of dlt" ) + def _get_staging_client(self, schema: Schema, initial_config: DestinationClientConfiguration = None) -> JobClientBase: + try: + # config is not provided then get it with injected credentials + if not initial_config: + initial_config = self._get_destination_client_initial_config(self.staging, as_staging=True) + return self.staging.client(schema, initial_config) # type: ignore + except ImportError: + client_spec = self.destination.spec() + raise MissingDependencyException( + f"{client_spec.destination_name} destination", + [f"{version.DLT_PKG_NAME}[{client_spec.destination_name}]"], + "Dependencies for specific destinations are available as extras of dlt" + ) + def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: if not self.destination: raise PipelineConfigMissing( @@ -891,6 +927,9 @@ def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: ) return self.destination.capabilities() + def _get_staging_capabilities(self) -> DestinationCapabilitiesContext: + return self.staging.capabilities() if self.staging is not None else None # type: ignore + def _validate_pipeline_name(self) -> None: try: FileStorage.validate_file_name_component(self.pipeline_name) @@ -929,21 +968,56 @@ def _set_destination(self, destination: TDestinationReferenceArg) -> None: # default normalizers must match the destination self._set_default_normalizers() + def _set_staging(self, staging: TDestinationReferenceArg) -> None: + staging_module = DestinationReference.from_name(staging) + if staging_module and not issubclass(staging_module.spec(), DestinationClientStagingConfiguration): + raise DestinationNoStagingMode(staging_module.__name__) + self.staging = staging_module or self.staging + @contextmanager def _maybe_destination_capabilities(self, loader_file_format: TLoaderFileFormat = None) -> Iterator[DestinationCapabilitiesContext]: try: caps: DestinationCapabilitiesContext = None injected_caps: ContextManager[DestinationCapabilitiesContext] = None if self.destination: - injected_caps = self._container.injectable_context(self._get_destination_capabilities()) + destination_caps = self._get_destination_capabilities() + stage_caps = self._get_staging_capabilities() + injected_caps = self._container.injectable_context(destination_caps) caps = injected_caps.__enter__() - if loader_file_format: - caps.preferred_loader_file_format = loader_file_format + + caps.preferred_loader_file_format = self._resolve_loader_file_format( + DestinationReference.to_name(self.destination), + DestinationReference.to_name(self.staging) if self.staging else None, + destination_caps, stage_caps, loader_file_format) yield caps finally: if injected_caps: injected_caps.__exit__(None, None, None) + @staticmethod + def _resolve_loader_file_format( + destination: str, + staging: str, + dest_caps: DestinationCapabilitiesContext, + stage_caps: DestinationCapabilitiesContext, + file_format: TLoaderFileFormat) -> TLoaderFileFormat: + + possible_file_formats = dest_caps.supported_loader_file_formats + if stage_caps: + if not dest_caps.supported_staging_file_formats: + raise DestinationLoadingViaStagingNotSupported(destination) + possible_file_formats = [f for f in dest_caps.supported_staging_file_formats if f in stage_caps.supported_loader_file_formats] + if not file_format: + if not stage_caps: + file_format = dest_caps.preferred_loader_file_format + elif stage_caps and dest_caps.preferred_staging_file_format in possible_file_formats: + file_format = dest_caps.preferred_staging_file_format + else: + file_format = possible_file_formats[0] if len(possible_file_formats) > 0 else None + if file_format not in possible_file_formats: + raise DestinationIncompatibleLoaderFileFormatException(destination, staging, file_format, set(possible_file_formats) - INTERNAL_LOADER_FILE_FORMATS) + return file_format + def _set_default_normalizers(self) -> None: self._default_naming, _ = import_normalizers(default_normalizers()) @@ -1124,6 +1198,8 @@ def _state_to_props(self, state: TPipelineState) -> None: for prop in Pipeline.LOCAL_STATE_PROPS: if prop in state["_local"] and not prop.startswith("_"): setattr(self, prop, state["_local"][prop]) # type: ignore + if "staging" in state: + self._set_staging(DestinationReference.from_name(self.staging)) if "destination" in state: self._set_destination(DestinationReference.from_name(self.destination)) @@ -1137,6 +1213,8 @@ def _props_to_state(self, state: TPipelineState) -> None: state["_local"][prop] = getattr(self, prop) # type: ignore if self.destination: state["destination"] = self.destination.__name__ + if self.staging: + state["staging"] = self.staging.__name__ state["schema_names"] = self._schema_storage.list_schemas() def _save_state(self, state: TPipelineState) -> None: diff --git a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md index 405eeb5df9..868034f400 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md +++ b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md @@ -102,11 +102,39 @@ You can configure the following file formats to load data to BigQuery * [jsonl](../file-formats/jsonl.md) is used by default * [parquet](../file-formats/parquet.md) is supported +When staging is enabled: +* [jsonl](../file-formats/jsonl.md) is used by default +* [parquet](../file-formats/parquet.md) is supported + +> ❗ **Bigquery cannot load JSON columns from `parquet` files**. `dlt` will fail such jobs permanently. Switch to `jsonl` to load and parse JSON properly. + ## Supported column hints BigQuery supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): * `partition` - creates a partition with a day granularity on decorated column (`PARTITION BY DATE`). May be used with `datetime`, `date` data types and `bigint` and `double` if they contain valid UNIX timestamps. Only one column per table is supported and only when a new table is created. * `cluster` - creates a cluster column(s). Many column per table are supported and only when a new table is created. +## Staging Support + +BigQuery supports gcs as a file staging destination. DLT will upload files in the parquet format to gcs and ask BigQuery to copy their data directly into the db. Please refer to the [Google Storage filesystem documentation](./filesystem.md#google-storage) to learn how to set up your gcs bucket with the bucket_url and credentials. If you use the same service account for gcs and your redshift deployment, you do not need to provide additional authentication for BigQuery to be able to read from your bucket. +```toml +``` + +Alternatively to parquet files, you can also specify jsonl as the staging file format. For this set the `loader_file_format` argument of the `run` command of the pipeline to `jsonl`. + +### BigQuery/GCS staging Example Code + +```python +# Create a dlt pipeline that will load +# chess player data to the BigQuery destination +# via a gcs bucket. +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='biquery', + staging='filesystem', # add this to activate the staging location + dataset_name='player_data' +) +``` + ## Additional destination options You can configure the data location and various timeouts as shown below. This information is not a secret so can be placed in `config.toml` as well. ```toml @@ -125,4 +153,4 @@ retry_deadline=60.0 This destination [integrates with dbt](../transformations/dbt.md) via [dbt-bigquery](https://github.com/dbt-labs/dbt-bigquery). Credentials, if explicitly defined, are shared with `dbt` along with other settings like **location** and retries and timeouts. In case of implicit credentials (ie. available in cloud function), `dlt` shares the `project_id` and delegates obtaining credentials to `dbt` adapter. ### Syncing of `dlt` state -This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination) \ No newline at end of file +This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination) diff --git a/docs/website/docs/dlt-ecosystem/destinations/redshift.md b/docs/website/docs/dlt-ecosystem/destinations/redshift.md index 84d6da9666..ee8d9cdca6 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/redshift.md +++ b/docs/website/docs/dlt-ecosystem/destinations/redshift.md @@ -62,6 +62,17 @@ All [write dispositions](../../general-usage/incremental-loading#choosing-a-writ ## Supported file formats [SQL Insert](../file-formats/insert-format) is used by default. +When staging is enabled: +* [jsonl](../file-formats/jsonl.md) is used by default +* [parquet](../file-formats/parquet.md) is supported + +> ❗ **Redshift cannot load VARBYTE columns from `json` files**. `dlt` will fail such jobs permanently. Switch to `parquet` to load binaries. + +> ❗ **Redshift cannot detect compression type from `json` files**. `dlt` assumes that `jsonl` files are gzip compressed which is the default. + +> ❗ **Redshift loads `complex` types as strings into SUPER with `parquet`**. Use `jsonl` format to store JSON in SUPER natively or transform your SUPER columns with `PARSE_JSON``. + + ## Supported column hints Amazon Redshift supports the following column hints: @@ -69,6 +80,33 @@ Amazon Redshift supports the following column hints: - `cluster` - hint is a Redshift term for table distribution. Applying it to a column makes it the "DISTKEY," affecting query and join performance. Check the following [documentation](https://docs.aws.amazon.com/redshift/latest/dg/c_best-practices-best-dist-key.html) for more info. - `sort` - creates SORTKEY to order rows on disk physically. It is used to improve a query and join speed in Redshift, please read the [sort key docs](https://docs.aws.amazon.com/redshift/latest/dg/c_best-practices-sort-key.html) to learn more. +## Staging support + +Redshift supports s3 as a file staging destination. DLT will upload files in the parquet format to s3 and ask redshift to copy their data directly into the db. Please refere to the [S3 documentation](./filesystem.md#aws-s3) to learn how to set up your s3 bucket with the bucket_url and credentials. The `dlt`` Redshift loader will use the aws credentials provided for s3 to access the s3 bucket if not specified otherwise (see config options below). Alternatively to parquet files, you can also specify jsonl as the staging file format. For this set the `loader_file_format` argument of the `run` command of the pipeline to `jsonl`. + +### Authentication iam Role + +If you would like to load from s3 without forwarding the aws staging credentials but authorize with an iam role connected to Redshift, follow the [Redshift documentation](https://docs.aws.amazon.com/redshift/latest/mgmt/authorizing-redshift-service.html) to create a role with access to s3 linked to your redshift cluster and change your destination settings to use the iam role: + +```toml +[destination] +staging_iam_role="arn:aws:iam::..." +``` + +### Redshift/S3 staging example code + +```python +# Create a dlt pipeline that will load +# chess player data to the redshift destination +# via staging on s3 +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='redshift', + staging='filesystem', # add this to activate the staging location + dataset_name='player_data' +) +``` + ## Additional destination options ### dbt support @@ -76,3 +114,7 @@ Amazon Redshift supports the following column hints: ### Syncing of `dlt` state - This destination fully supports [dlt state sync.](../../general-usage/state#syncing-state-with-destination) + +## Supported loader file formats + +Supported loader file formats for Redshift are `sql` and `insert_values` (default). When using a staging location, Redshift supports `parquet` and `jsonl`. diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 9e7687ad40..7ce2345769 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -107,6 +107,13 @@ The data is loaded using internal Snowflake stage. We use `PUT` command and per- ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default * [parquet](../file-formats/parquet.md) is supported +* [jsonl](../file-formats/jsonl.md) is supported + +When staging is enabled: +* [jsonl](../file-formats/jsonl.md) is used by default +* [parquet](../file-formats/parquet.md) is supported + +> ❗ When loading from `parquet`, Snowflake will store `complex` types (JSON) in `VARIANT` as string. Use `jsonl` format instead or use `PARSE_JSON` to update the `VARIANT`` field after loading. ## Supported column hints Snowflake supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): @@ -117,6 +124,73 @@ Snowflake makes all unquoted identifiers uppercase and then resolves them case-i Names of tables and columns in [schemas](../../general-usage/schema.md) are kept in lower case like for all other destinations. This is the pattern we observed in other tools ie. `dbt`. In case of `dlt` it is however trivial to define your own uppercase [naming convention](../../general-usage/schema.md#naming-convention) +## Staging support + +Snowflake supports s3 and gcs as a file staging destinations. DLT will upload files in the parquet format to the bucket provider and will ask snowflake to copy their data directly into the db. + +Alternavitely to parquet files, you can also specify jsonl as the staging file format. For this set the `loader_file_format` argument of the `run` command of the pipeline to `jsonl`. + +### Snowflake and Amazon S3 + +Please refer to the [S3 documentation](./filesystem.md#aws-s3) to learn how to set up your bucket with the bucket_url and credentials. For s3 The dlt Redshift loader will use the aws credentials provided for s3 to access the s3 bucket if not specified otherwise (see config options below). Alternatively you can create a stage for your S3 Bucket by following the instructions provided in the [Snowflake S3 documentation](https://docs.snowflake.com/en/user-guide/data-load-s3-config-storage-integration). +The basic steps are as follows: + +* Create a storage integration linked to GCS and the right bucket +* Grant access to this storage integration to the snowflake role you are using to load the data into snowflake. +* Create a stage from this storage integration in the PUBLIC namespace, or the namespace of the schema of your data. +* Also grant access to this stage for the role you are using to load data into snowflake. +* Provide the name of your stage (including the namespace) to dlt like so: + +To prevent dlt from forwarding the s3 bucket credentials on every command, and set your s3 stage, change these settings: + +```toml +[destination] +stage_name=PUBLIC.my_s3_stage +``` + +To run Snowflake with s3 as staging destination: + +```python +# Create a dlt pipeline that will load +# chess player data to the snowflake destination +# via staging on s3 +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='snowflake', + staging='filesystem', # add this to activate the staging location + dataset_name='player_data' +) +``` + +### Snowflake and Google Cloud Storage + +Please refer to the [Google Storage filesystem documentation](./filesystem.md#google-storage) to learn how to set up your bucket with the bucket_url and credentials. For gcs you can define a stage in Snowflake and provide the stage identifier in the configuration (see config options below.) Please consult the snowflake Documentation on [how to create a stage for your GCS Bucket](https://docs.snowflake.com/en/user-guide/data-load-gcs-config). The basic steps are as follows: + +* Create a storage integration linked to GCS and the right bucket +* Grant access to this storage integration to the snowflake role you are using to load the data into snowflake. +* Create a stage from this storage integration in the PUBLIC namespace, or the namespace of the schema of your data. +* Also grant access to this stage for the role you are using to load data into snowflake. +* Provide the name of your stage (including the namespace) to dlt like so: + +```toml +[destination] +stage_name=PUBLIC.my_gcs_stage +``` + +To run Snowflake with gcs as staging destination: + +```python +# Create a dlt pipeline that will load +# chess player data to the snowflake destination +# via staging on gcs +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='snowflake', + staging='filesystem', # add this to activate the staging location + dataset_name='player_data' +) +``` + ## Additional destination options You can define your own stage to PUT files and disable removing of the staged files after loading. ```toml @@ -131,4 +205,4 @@ keep_staged_files=true This destination [integrates with dbt](../transformations/dbt.md) via [dbt-snowflake](https://github.com/dbt-labs/dbt-snowflake). Both password and key pair authentication is supported and shared with dbt runners. ### Syncing of `dlt` state -This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination) \ No newline at end of file +This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination) diff --git a/docs/website/docs/dlt-ecosystem/staging.md b/docs/website/docs/dlt-ecosystem/staging.md new file mode 100644 index 0000000000..f8ef739060 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/staging.md @@ -0,0 +1,13 @@ +--- +title: Staging +description: Configure an s3 or gcs bucket for staging before copying into the destination +keywords: [staging, destination] +--- + +# Staging + +dlt supports a staging location for some destinations. Currently it is possible to copy files from a s3 bucket into redshift, from a gcs bucket into bigquery and from gcs and s3 into snowflake. dlt will automatically select an appropriate loader file format for the staging files. For this to work you have to set the `staging` argument of the pipeline to `filesystem` and provide both the credentials for the staging and the destination. You may also define an alternative staging file format in the `run` command of the pipeline, DLT will check wether the format is compatible with the final destination. Please refer to the documentation of each destination to learn how to use staging in the respective environment. + +# Why staging? + +By staging the data, you can leverage parallel processing capabilities of many modern cloud-based storage solutions. This can greatly reduce the total time it takes to load your data compared to uploading via the SQL interface. If you wish you can also retain a history of all imported data files in your bucket for auditing and trouble shooting purposes. \ No newline at end of file diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 802b97470b..6492b108c6 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -76,7 +76,7 @@ const sidebars = { 'dlt-ecosystem/verified-sources/strapi', 'dlt-ecosystem/verified-sources/stripe', 'dlt-ecosystem/verified-sources/workable', - 'dlt-ecosystem/verified-sources/zendesk', + 'dlt-ecosystem/verified-sources/zendesk' ] }, { @@ -112,6 +112,7 @@ const sidebars = { 'dlt-ecosystem/destinations/motherduck', ] }, + 'dlt-ecosystem/staging', { type: 'category', label: 'Deployments', diff --git a/tests/common/utils.py b/tests/common/utils.py index f2274b1428..155d31767f 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -4,7 +4,9 @@ from git import Repo, Commit from pathlib import Path from typing import Mapping, Tuple, cast, Any +import datetime # noqa: 251 +from dlt.common.arithmetics import Decimal from dlt.common import json from dlt.common.typing import StrAny from dlt.common.schema import utils diff --git a/tests/load/bigquery/test_bigquery_parquet.py b/tests/load/bigquery/test_bigquery_parquet.py deleted file mode 100644 index 46bae82938..0000000000 --- a/tests/load/bigquery/test_bigquery_parquet.py +++ /dev/null @@ -1,38 +0,0 @@ - - - -import dlt - -from dlt.common.utils import uniq_id -from dlt.destinations.bigquery.bigquery import BigQueryClient - -def test_pipeline_parquet_bigquery_destination() -> None: - """Run pipeline twice with merge write disposition - Resource with primary key falls back to append. Resource without keys falls back to replace. - """ - pipeline = dlt.pipeline(pipeline_name='parquet_test_' + uniq_id(), destination="bigquery", dataset_name='parquet_test_' + uniq_id()) - - @dlt.resource(primary_key='id') - def some_data(): # type: ignore[no-untyped-def] - yield [{'id': 1}, {'id': 2}, {'id': 3}] - - @dlt.resource - def other_data(): # type: ignore[no-untyped-def] - yield [1, 2, 3, 4, 5] - - @dlt.source - def some_source(): # type: ignore[no-untyped-def] - return [some_data(), other_data()] - - info = pipeline.run(some_source()) - package_info = pipeline.get_load_package_info(info.loads_ids[0]) - assert package_info.state == "loaded" - # all three jobs succeeded - assert len(package_info.jobs["failed_jobs"]) == 0 - assert len(package_info.jobs["completed_jobs"]) == 3 - - client: BigQueryClient = pipeline._destination_client() # type: ignore[assignment] - with client.sql_client as sql_client: - assert [row[0] for row in sql_client.execute_sql("SELECT * FROM other_data")] == [1, 2, 3, 4, 5] - assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data")] == [1, 2, 3] - diff --git a/tests/load/conftest.py b/tests/load/conftest.py index 27d3a91ac3..cd04a1c09b 100644 --- a/tests/load/conftest.py +++ b/tests/load/conftest.py @@ -1,12 +1,9 @@ from typing import Iterator, Tuple import os - import pytest -import dlt -from tests.utils import preserve_environ -from tests.utils import ALL_DESTINATIONS +from tests.utils import ALL_DESTINATIONS, preserve_environ from tests.load.utils import ALL_BUCKETS diff --git a/tests/load/pipeline/conftest.py b/tests/load/pipeline/conftest.py new file mode 100644 index 0000000000..97a4d72c04 --- /dev/null +++ b/tests/load/pipeline/conftest.py @@ -0,0 +1,3 @@ +from tests.utils import patch_home_dir, preserve_environ, autouse_test_storage +from tests.pipeline.utils import drop_dataset_from_env +from tests.load.pipeline.utils import drop_pipeline diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 47e6a8ee41..58684a071f 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -10,9 +10,8 @@ from dlt.helpers.dbt import create_venv from dlt.helpers.dbt.exceptions import DBTProcessingError, PrerequisitesException -from tests.pipeline.utils import drop_dataset_from_env -from tests.load.pipeline.utils import select_data, drop_pipeline -from tests.utils import ALL_DESTINATIONS, autouse_test_storage, preserve_environ, patch_home_dir +from tests.load.pipeline.utils import select_data +from tests.utils import ALL_DESTINATIONS # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index 8e6b8f7338..057b2546ad 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -1,5 +1,4 @@ from typing import Any, Iterator, Dict, Any, List -import secrets from unittest import mock from itertools import chain @@ -13,7 +12,6 @@ from dlt.pipeline.exceptions import PipelineStepFailed from dlt.destinations.job_client_impl import SqlJobClientBase -from tests.load.pipeline.utils import drop_pipeline from tests.utils import ALL_DESTINATIONS diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index c3bf506433..fcfc21a757 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -9,8 +9,7 @@ import pyarrow.parquet as pq -from tests.utils import autouse_test_storage, init_test_logging, preserve_environ, patch_home_dir -from tests.load.pipeline.utils import drop_pipeline +from tests.utils import init_test_logging diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 2b44e87e97..04152e4853 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -15,10 +15,9 @@ from dlt.extract.source import DltResource from dlt.sources.helpers.transform import skip_first, take_first -from tests.utils import ALL_DESTINATIONS, patch_home_dir, preserve_environ, autouse_test_storage -from tests.pipeline.utils import drop_dataset_from_env, assert_load_info -from tests.load.utils import delete_dataset -from tests.load.pipeline.utils import drop_pipeline, load_table_counts, select_data +from tests.utils import ALL_DESTINATIONS +from tests.pipeline.utils import assert_load_info +from tests.load.pipeline.utils import load_table_counts, select_data # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 1d0f3af48a..c6ac87a863 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -16,12 +16,12 @@ from dlt.extract.source import DltSource from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineConfigMissing, PipelineStepFailed from dlt.common.schema.exceptions import CannotCoerceColumnException +from dlt.common.exceptions import DestinationHasFailedJobs -from tests.utils import ALL_DESTINATIONS, patch_home_dir, preserve_environ, autouse_test_storage, TEST_STORAGE_ROOT -# from tests.common.configuration.utils import environment -from tests.pipeline.utils import drop_dataset_from_env, assert_load_info -from tests.load.utils import delete_dataset -from tests.load.pipeline.utils import drop_active_pipeline_data, drop_pipeline, assert_query_data, assert_table, load_table_counts, select_data +from tests.utils import ALL_DESTINATIONS, TEST_STORAGE_ROOT +from tests.pipeline.utils import assert_load_info +from tests.load.utils import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row, delete_dataset +from tests.load.pipeline.utils import drop_active_pipeline_data, assert_query_data, assert_table, load_table_counts, select_data @pytest.mark.parametrize('use_single_dataset', [True, False]) @@ -532,10 +532,21 @@ def gen2(): @pytest.mark.parametrize('destination_name', ["snowflake"]) def test_snowflake_custom_stage(destination_name: str) -> None: """Using custom stage name instead of the table stage""" - os.environ['DESTINATION__SNOWFLAKE__STAGE_NAME'] = 'my_custom_stage' - + os.environ['DESTINATION__SNOWFLAKE__STAGE_NAME'] = 'my_non_existing_stage' pipeline, data = simple_nested_pipeline(destination_name, f"custom_stage_{uniq_id()}", False) + info = pipeline.run(data()) + with pytest.raises(DestinationHasFailedJobs) as f_jobs: + info.raise_on_failed_jobs() + assert "MY_NON_EXISTING_STAGE" in f_jobs.value.failed_jobs[0].failed_message + + drop_active_pipeline_data() + # NOTE: this stage must be created in DLT_DATA database for this test to pass! + # CREATE STAGE MY_CUSTOM_LOCAL_STAGE; + # GRANT READ, WRITE ON STAGE DLT_DATA.PUBLIC.MY_CUSTOM_LOCAL_STAGE TO ROLE DLT_LOADER_ROLE; + stage_name = 'PUBLIC.MY_CUSTOM_LOCAL_STAGE' + os.environ['DESTINATION__SNOWFLAKE__STAGE_NAME'] = stage_name + pipeline, data = simple_nested_pipeline(destination_name, f"custom_stage_{uniq_id()}", False) info = pipeline.run(data()) assert_load_info(info) @@ -543,7 +554,6 @@ def test_snowflake_custom_stage(destination_name: str) -> None: # Get a list of the staged files and verify correct number of files in the "load_id" dir with pipeline.sql_client() as client: - stage_name = client.make_qualified_table_name('my_custom_stage') staged_files = client.execute_sql(f'LIST @{stage_name}/"{load_id}"') assert len(staged_files) == 3 # check data of one table to ensure copy was done successfully @@ -575,6 +585,57 @@ def test_snowflake_delete_file_after_copy(destination_name: str) -> None: assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ['a', None, None]) +# do not remove - it allows us to filter tests by destination +@pytest.mark.parametrize('destination_name', ["bigquery", "snowflake", "duckdb"]) +def test_parquet_loading(destination_name: str) -> None: + """Run pipeline twice with merge write disposition + Resource with primary key falls back to append. Resource without keys falls back to replace. + """ + pipeline = dlt.pipeline(pipeline_name='parquet_test_' + uniq_id(), destination=destination_name, dataset_name='parquet_test_' + uniq_id()) + + @dlt.resource(primary_key='id') + def some_data(): # type: ignore[no-untyped-def] + yield [{'id': 1}, {'id': 2}, {'id': 3}] + + @dlt.resource(write_disposition="replace") + def other_data(): # type: ignore[no-untyped-def] + yield [1, 2, 3, 4, 5] + + data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + # parquet on bigquery does not support JSON but we still want to run the test + if destination_name == "bigquery": + column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" + + # apply the exact columns definitions so we process complex and wei types correctly! + @dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas) + def my_resource(): + nonlocal data_types + yield [data_types]*10 + + @dlt.source(max_table_nesting=0) + def some_source(): # type: ignore[no-untyped-def] + return [some_data(), other_data(), my_resource()] + + info = pipeline.run(some_source(), loader_file_format="parquet") + package_info = pipeline.get_load_package_info(info.loads_ids[0]) + assert package_info.state == "loaded" + # all three jobs succeeded + assert len(package_info.jobs["failed_jobs"]) == 0 + assert len(package_info.jobs["completed_jobs"]) == 5 # 3 tables + 1 state + 1 sql merge job + + client = pipeline._destination_client() # type: ignore[assignment] + with client.sql_client as sql_client: + assert [row[0] for row in sql_client.execute_sql("SELECT * FROM other_data")] == [1, 2, 3, 4, 5] + assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data")] == [1, 2, 3] + db_rows = sql_client.execute_sql("SELECT * FROM data_types") + assert len(db_rows) == 10 + db_row = list(db_rows[0]) + # "snowflake" and "bigquery" do not parse JSON form parquet string so double parse + assert_all_data_types_row(db_row[:-2], parse_complex_strings=destination_name in ["snowflake", "bigquery"]) + + def simple_nested_pipeline(destination_name: str, dataset_name: str, full_refresh: bool) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: data = ["a", ["a", "b", "c"], ["a", "b", "c"]] @@ -585,6 +646,6 @@ def d(): def _data(): return dlt.resource(d(), name="lists", write_disposition="append") - p = dlt.pipeline(full_refresh=full_refresh, destination=destination_name, dataset_name=dataset_name) + p = dlt.pipeline(pipeline_name=f"pipeline_{dataset_name}", full_refresh=full_refresh, destination=destination_name, dataset_name=dataset_name) return p, _data diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 0bc8eac1c1..0e5bbf14a3 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -1,7 +1,7 @@ import itertools import os import shutil -from typing import Any +from typing import Any, Dict import pytest import dlt @@ -13,17 +13,26 @@ from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import STATE_TABLE_COLUMNS, STATE_TABLE_NAME, load_state_from_destination, state_resource -from tests.utils import ALL_DESTINATIONS, patch_home_dir, preserve_environ, autouse_test_storage, TEST_STORAGE_ROOT +from tests.utils import ALL_DESTINATIONS, TEST_STORAGE_ROOT +from tests.cases import JSON_TYPED_DICT from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V5, yml_case_path as common_yml_case_path from tests.common.configuration.utils import environment -from tests.pipeline.utils import drop_dataset_from_env -from tests.load.pipeline.utils import assert_query_data, drop_pipeline -from tests.cases import JSON_TYPED_DICT +from tests.load.pipeline.utils import assert_query_data, drop_active_pipeline_data, STAGING_AND_NON_STAGING_COMBINATIONS, STAGING_COMBINATION_FIELDS -@pytest.mark.parametrize('destination_name', ALL_DESTINATIONS) -def test_restore_state_utils(destination_name: str) -> None: - p = dlt.pipeline(pipeline_name="pipe_" + uniq_id(), destination=destination_name, dataset_name="state_test_" + uniq_id()) +@pytest.mark.parametrize(STAGING_COMBINATION_FIELDS, STAGING_AND_NON_STAGING_COMBINATIONS) +def test_restore_state_utils(destination: str, staging: str, file_format: str, bucket: str, settings: Dict[str, Any]) -> None: + + # snowflake requires gcs prefix instead of gs in bucket path + if destination == "snowflake" and bucket: + bucket = bucket.replace("gs://", "gcs://") + + # set env vars + os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = bucket + os.environ['DESTINATION__STAGE_NAME'] = settings.get("stage_name", "") + os.environ["RAISE_ON_FAILED_JOBS"] = "true" + + p = dlt.pipeline(pipeline_name="pipe_" + uniq_id(), destination=destination, staging=staging, dataset_name="state_test_" + uniq_id()) schema = Schema("state") # inject schema into pipeline, don't do it in production p._inject_schema(schema) @@ -59,7 +68,7 @@ def test_restore_state_utils(destination_name: str) -> None: with p.managed_state(extract_state=True): pass # just run the existing extract - p.normalize() + p.normalize(loader_file_format=file_format) p.load() stored_state = load_state_from_destination(p.pipeline_name, job_client.sql_client) local_state = p._get_state() @@ -69,7 +78,7 @@ def test_restore_state_utils(destination_name: str) -> None: with p.managed_state(extract_state=True) as managed_state: # this will be saved managed_state["sources"] = {"source": dict(JSON_TYPED_DICT)} - p.normalize() + p.normalize(loader_file_format=file_format) p.load() stored_state = load_state_from_destination(p.pipeline_name, job_client.sql_client) assert stored_state["sources"] == {"source": JSON_TYPED_DICT} @@ -83,7 +92,7 @@ def test_restore_state_utils(destination_name: str) -> None: new_local_state = p._get_state() new_local_state.pop("_local") assert local_state == new_local_state - p.normalize() + p.normalize(loader_file_format=file_format) info = p.load() assert len(info.loads_ids) == 0 new_stored_state = load_state_from_destination(p.pipeline_name, job_client.sql_client) @@ -112,7 +121,7 @@ def test_restore_state_utils(destination_name: str) -> None: assert "_last_extracted_at" in new_local_state_2_local # but the version didn't change assert new_local_state["_state_version"] == new_local_state_2["_state_version"] - p.normalize() + p.normalize(loader_file_format=file_format) info = p.load() assert len(info.loads_ids) == 1 new_stored_state_2 = load_state_from_destination(p.pipeline_name, job_client.sql_client) @@ -279,7 +288,8 @@ def some_data(): p = dlt.pipeline(pipeline_name=pipeline_name, destination=destination_name, dataset_name=dataset_name, full_refresh=True) p.run() assert p.default_schema_name is None - p._wipe_working_folder() + drop_active_pipeline_data() + # create pipeline without restore os.environ["RESTORE_FROM_DESTINATION"] = "False" p = dlt.pipeline(pipeline_name=pipeline_name, destination=destination_name, dataset_name=dataset_name) @@ -438,11 +448,11 @@ def some_data(param: str) -> Any: production_p.run(data3) assert production_p.state["_state_version"] > prod_state["_state_version"] # and will be detected locally - print(p.default_schema) + # print(p.default_schema) p.sync_destination() # existing schema got overwritten assert "state1_data2" in p._schema_storage.load_schema(p.default_schema_name).tables - print(p.default_schema) + # print(p.default_schema) assert "state1_data2" in p.default_schema.tables # change state locally diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py new file mode 100644 index 0000000000..524c96c4ae --- /dev/null +++ b/tests/load/pipeline/test_stage_loading.py @@ -0,0 +1,137 @@ +import pytest +from typing import Dict, Any + +import dlt, os +from dlt.common import json, sleep +from copy import deepcopy + +from tests.load.pipeline.test_merge_disposition import github +from tests.load.pipeline.utils import load_table_counts +from tests.pipeline.utils import assert_load_info +from tests.load.utils import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row +from tests.load.pipeline.utils import ALL_STAGING_COMBINATIONS, STAGING_COMBINATION_FIELDS + + +@dlt.resource(table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url")) +def load_modified_issues(): + with open("tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8") as f: + issues = json.load(f) + + # change 2 issues + issue = next(filter(lambda i: i["id"] == 1232152492, issues)) + issue["number"] = 105 + + issue = next(filter(lambda i: i["id"] == 1142699354, issues)) + issue["number"] = 300 + + yield from issues + + +@pytest.mark.parametrize(STAGING_COMBINATION_FIELDS, ALL_STAGING_COMBINATIONS) +def test_staging_load(destination: str, staging: str, file_format: str, bucket: str, settings: Dict[str, Any]) -> None: + + # snowflake requires gcs prefix instead of gs in bucket path + if destination == "snowflake": + bucket = bucket.replace("gs://", "gcs://") + + # set env vars + os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = bucket + os.environ['DESTINATION__STAGE_NAME'] = settings.get("stage_name", "") + os.environ['DESTINATION__STAGING_IAM_ROLE'] = settings.get("staging_iam_role", "") + + pipeline = dlt.pipeline(pipeline_name='test_stage_loading_5', destination=destination, staging=staging, dataset_name='staging_test', full_refresh=True) + + info = pipeline.run(github(), loader_file_format=file_format) + assert_load_info(info) + package_info = pipeline.get_load_package_info(info.loads_ids[0]) + assert package_info.state == "loaded" + + assert len(package_info.jobs["failed_jobs"]) == 0 + # we have 4 parquet and 4 reference jobs plus one merge job + assert len(package_info.jobs["completed_jobs"]) == 9 + assert len([x for x in package_info.jobs["completed_jobs"] if x.job_file_info.file_format == "reference"]) == 4 + assert len([x for x in package_info.jobs["completed_jobs"] if x.job_file_info.file_format == file_format]) == 4 + assert len([x for x in package_info.jobs["completed_jobs"] if x.job_file_info.file_format == "sql"]) == 1 + + initial_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + assert initial_counts["issues"] == 100 + + # check item of first row in db + with pipeline._get_destination_client(pipeline.default_schema) as client: + rows = client.sql_client.execute_sql("SELECT url FROM issues WHERE id = 388089021 LIMIT 1") + assert rows[0][0] == "https://api.github.com/repos/duckdb/duckdb/issues/71" + + # test merging in some changed values + info = pipeline.run(load_modified_issues, loader_file_format=file_format) + assert_load_info(info) + assert pipeline.default_schema.tables["issues"]["write_disposition"] == "merge" + merge_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + assert merge_counts == initial_counts + + # check changes where merged in + with pipeline._get_destination_client(pipeline.default_schema) as client: + rows = client.sql_client.execute_sql("SELECT number FROM issues WHERE id = 1232152492 LIMIT 1") + assert rows[0][0] == 105 + rows = client.sql_client.execute_sql("SELECT number FROM issues WHERE id = 1142699354 LIMIT 1") + assert rows[0][0] == 300 + + # test append + info = pipeline.run(github().load_issues, write_disposition="append", loader_file_format=file_format) + assert_load_info(info) + assert pipeline.default_schema.tables["issues"]["write_disposition"] == "append" + # the counts of all tables must be double + append_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + assert {k:v*2 for k, v in initial_counts.items()} == append_counts + + # test replace + info = pipeline.run(github().load_issues, write_disposition="replace", loader_file_format=file_format) + assert_load_info(info) + assert pipeline.default_schema.tables["issues"]["write_disposition"] == "replace" + # the counts of all tables must be double + replace_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + assert replace_counts == initial_counts + + +@pytest.mark.parametrize(STAGING_COMBINATION_FIELDS, ALL_STAGING_COMBINATIONS) +def test_all_data_types(destination: str, staging: str, file_format: str, bucket: str, settings: Dict[str, Any]) -> None: + # set env vars + os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = bucket + os.environ['DESTINATION__STAGE_NAME'] = settings.get("stage_name", "") + pipeline = dlt.pipeline(pipeline_name='test_stage_loading', destination=destination, dataset_name='staging_test', full_refresh=True, staging=staging) + + data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + # bigquery cannot load into JSON fields from parquet + if file_format == "parquet": + if destination == "bigquery": + # change datatype to text and then allow for it in the assert (parse_complex_strings) + column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" + # redshift cannot load from json into VARBYTE + if file_format == "jsonl": + if destination == "redshift": + # change the datatype to text which will result in inserting base64 (allow_base64_binary) + column_schemas["col7_null"]["data_type"] = column_schemas["col7"]["data_type"] = "text" + + # apply the exact columns definitions so we process complex and wei types correctly! + @dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas) + def my_resource(): + nonlocal data_types + yield [data_types]*10 + + @dlt.source(max_table_nesting=0) + def my_source(): + return my_resource + + info = pipeline.run(my_source(), loader_file_format=file_format) + assert_load_info(info) + + with pipeline.sql_client() as sql_client: + db_rows = sql_client.execute_sql("SELECT * FROM data_types") + assert len(db_rows) == 10 + db_row = list(db_rows[0]) + # parquet is not really good at inserting json, best we get are strings in JSON columns + parse_complex_strings = file_format == "parquet" and destination in ["redshift", "bigquery", "snowflake"] + allow_base64_binary = file_format == "jsonl" and destination in ["redshift"] + # content must equal + assert_all_data_types_row(db_row[:-2], parse_complex_strings=parse_complex_strings, allow_base64_binary=allow_base64_binary) diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 3e0a88ec09..45743fe9e3 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -11,6 +11,35 @@ if TYPE_CHECKING: from dlt.destinations.filesystem.filesystem import FilesystemClient +from tests.load.utils import ALL_DESTINATIONS, AWS_BUCKET, GCS_BUCKET + + +# destination configs including staging +STAGING_COMBINATION_FIELDS = "destination,staging,file_format,bucket,settings" + +ALL_DEFAULT_FILETYPE_STAGING_COMBINATIONS = [ + # redshift with iam role + ("redshift","filesystem","parquet",AWS_BUCKET,{"staging_iam_role": "arn:aws:iam::267388281016:role/redshift_s3_read"}), + ("bigquery","filesystem","parquet",GCS_BUCKET, {}), + ("snowflake","filesystem","jsonl",GCS_BUCKET, {"stage_name": "PUBLIC.dlt_gcs_stage"}), + ("snowflake","filesystem","jsonl",AWS_BUCKET, {"stage_name":"PUBLIC.dlt_s3_stage"}) + ] +# filter out destinations not set for this run +ALL_DEFAULT_FILETYPE_STAGING_COMBINATIONS = [item for item in ALL_DEFAULT_FILETYPE_STAGING_COMBINATIONS if item[0] in ALL_DESTINATIONS] + +ALL_STAGING_COMBINATIONS = ALL_DEFAULT_FILETYPE_STAGING_COMBINATIONS + [ + ("redshift","filesystem","parquet",AWS_BUCKET,{}), # redshift with credential forwarding + ("snowflake","filesystem","parquet",AWS_BUCKET, {}), # snowflake with credential forwarding + ("redshift","filesystem","jsonl",AWS_BUCKET, {}), + ("bigquery","filesystem","jsonl",GCS_BUCKET, {}) +] +# filter out destinations not set for this run +ALL_STAGING_COMBINATIONS = [item for item in ALL_STAGING_COMBINATIONS if item[0] in ALL_DESTINATIONS] + +STAGING_AND_NON_STAGING_COMBINATIONS = ALL_DEFAULT_FILETYPE_STAGING_COMBINATIONS + [ + (destination, None, None, "", {}) for destination in ALL_DESTINATIONS +] + @pytest.fixture(autouse=True) def drop_pipeline() -> Iterator[None]: @@ -19,7 +48,7 @@ def drop_pipeline() -> Iterator[None]: def drop_active_pipeline_data() -> None: - """Drops all the datasets for currently active pipeline and then deactivated it. Does not drop the working dir - see new_test_storage""" + """Drops all the datasets for currently active pipeline, wipes the working folder and then deactivated it.""" if Container()[PipelineContext].is_active(): # take existing pipeline p = dlt.pipeline() diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index f31fb7c84a..25ba2fdcac 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -20,7 +20,7 @@ from tests.utils import TEST_STORAGE_ROOT, ALL_DESTINATIONS, autouse_test_storage from tests.common.utils import load_json_case -from tests.load.utils import (ALL_CLIENTS_SUBSET, TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW, expect_load_file, load_table, yield_client_with_storage, +from tests.load.utils import (ALL_CLIENTS_SUBSET, TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES, assert_all_data_types_row , expect_load_file, load_table, yield_client_with_storage, cm_yield_client_with_storage, write_dataset, prepare_table, ALL_CLIENTS) @@ -431,27 +431,12 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: str, f canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row with io.BytesIO() as f: - write_dataset(client, f, [TABLE_ROW], TABLE_UPDATE_COLUMNS_SCHEMA) + write_dataset(client, f, [TABLE_ROW_ALL_DATA_TYPES], TABLE_UPDATE_COLUMNS_SCHEMA) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) db_row = list(client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0]) # content must equal - db_row[3] = str(pendulum.instance(db_row[3])) # serialize date - if isinstance(db_row[6], str): - db_row[6] = bytes.fromhex(db_row[6]) # redshift returns binary as hex string - else: - db_row[6] = bytes(db_row[6]) - # redshift and bigquery return strings from structured fields - if isinstance(db_row[8], str): - # then it must be json - db_row[8] = json.loads(db_row[8]) - - db_row[9] = db_row[9].isoformat() - - expected_rows = list(TABLE_ROW.values()) - # expected_rows[8] = COL_9_DICT - - assert db_row == expected_rows + assert_all_data_types_row(db_row) @pytest.mark.parametrize('write_disposition', ["append", "replace", "merge"]) @@ -477,7 +462,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, fi for idx in range(2): for t in [table_name, child_table]: # write row, use col1 (INT) as row number - table_row = deepcopy(TABLE_ROW) + table_row = deepcopy(TABLE_ROW_ALL_DATA_TYPES ) table_row["col1"] = idx with io.BytesIO() as f: write_dataset(client, f, [table_row], TABLE_UPDATE_COLUMNS_SCHEMA) diff --git a/tests/load/utils.py b/tests/load/utils.py index df134b9080..b11e89f1e4 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,3 +1,4 @@ +import base64 import contextlib from importlib import import_module import codecs @@ -7,7 +8,7 @@ from pathlib import Path import dlt -from dlt.common import json, Decimal, sleep +from dlt.common import json, Decimal, sleep, pendulum from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext @@ -26,13 +27,12 @@ from tests.utils import ALL_DESTINATIONS - -# env variables for URLs for all test buckets, e.g. "gcs://bucket_name", "s3://bucket_name", "file://bucket_name" -bucket_env_vars = [ - "tests.bucket_url_gcs", "tests.bucket_url_aws", "tests.bucket_url_file", "tests.memory" # , "tests.gdrive_url" -] - -ALL_BUCKETS = [b for b in (dlt.config.get(var, str) for var in bucket_env_vars) if b] +# bucket urls +AWS_BUCKET = dlt.config.get("tests.bucket_url_aws", str) +GCS_BUCKET = dlt.config.get("tests.bucket_url_gcs", str) +FILE_BUCKET = dlt.config.get("tests.bucket_url_file", str) +MEMORY_BUCKET = dlt.config.get("tests.memory", str) +ALL_BUCKETS = [GCS_BUCKET, AWS_BUCKET, FILE_BUCKET, MEMORY_BUCKET] ALL_CLIENTS = [f"{name}_client" for name in ALL_DESTINATIONS] @@ -66,7 +66,7 @@ def ALL_CLIENTS_SUBSET(subset: Sequence[str]) -> List[str]: { "name": "col5", "data_type": "text", - "nullable": True + "nullable": False }, { "name": "col6", @@ -76,12 +76,12 @@ def ALL_CLIENTS_SUBSET(subset: Sequence[str]) -> List[str]: { "name": "col7", "data_type": "binary", - "nullable": True + "nullable": False }, { "name": "col8", "data_type": "wei", - "nullable": True + "nullable": False }, { "name": "col9", @@ -93,11 +93,62 @@ def ALL_CLIENTS_SUBSET(subset: Sequence[str]) -> List[str]: "name": "col10", "data_type": "date", "nullable": False + }, + { + "name": "col1_null", + "data_type": "bigint", + "nullable": True + }, + { + "name": "col2_null", + "data_type": "double", + "nullable": True + }, + { + "name": "col3_null", + "data_type": "bool", + "nullable": True + }, + { + "name": "col4_null", + "data_type": "timestamp", + "nullable": True + }, + { + "name": "col5_null", + "data_type": "text", + "nullable": True + }, + { + "name": "col6_null", + "data_type": "decimal", + "nullable": True + }, + { + "name": "col7_null", + "data_type": "binary", + "nullable": True + }, + { + "name": "col8_null", + "data_type": "wei", + "nullable": True + }, + { + "name": "col9_null", + "data_type": "complex", + "nullable": True, + "variant": True + }, + { + "name": "col10_null", + "data_type": "date", + "nullable": True } ] TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]:t for t in TABLE_UPDATE} -TABLE_ROW = { +TABLE_ROW_ALL_DATA_TYPES = { "col1": 989127831, "col2": 898912.821982, "col3": True, @@ -107,9 +158,49 @@ def ALL_CLIENTS_SUBSET(subset: Sequence[str]) -> List[str]: "col7": b'binary data \n \r \x8e', "col8": 2**56 + 92093890840, "col9": {"complex":[1,2,3,"a"], "link": "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\012 \6 \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085"}, - "col10": "2023-02-27" + "col10": "2023-02-27", + "col1_null": None, + "col2_null": None, + "col3_null": None, + "col4_null": None, + "col5_null": None, + "col6_null": None, + "col7_null": None, + "col8_null": None, + "col9_null": None, + "col10_null": None } + +def assert_all_data_types_row(db_row: List[Any], parse_complex_strings: bool = False, allow_base64_binary: bool = False) -> None: + print(db_row) + # content must equal + db_row[3] = str(pendulum.instance(db_row[3])) # serialize date + if isinstance(db_row[6], str): + try: + db_row[6] = bytes.fromhex(db_row[6]) # redshift returns binary as hex string + except ValueError: + if not allow_base64_binary: + raise + db_row[6] = base64.b64decode(db_row[6], validate=True) + else: + db_row[6] = bytes(db_row[6]) + # redshift and bigquery return strings from structured fields + if isinstance(db_row[8], str): + # then it must be json + db_row[8] = json.loads(db_row[8]) + # parse again + if parse_complex_strings and isinstance(db_row[8], str): + # then it must be json + db_row[8] = json.loads(db_row[8]) + + db_row[9] = db_row[9].isoformat() + + expected_rows = list(TABLE_ROW_ALL_DATA_TYPES.values()) + print(expected_rows) + assert db_row == expected_rows + + def load_table(name: str) -> TTableSchemaColumns: with open(f"./tests/load/cases/{name}.json", "rb") as f: return cast(TTableSchemaColumns, json.load(f)) @@ -220,12 +311,15 @@ def cm_yield_client_with_storage( return yield_client_with_storage(destination_name, default_config_values, schema_name) -def write_dataset(client: JobClientBase, f: IO[bytes], rows: Sequence[StrAny], columns_schema: TTableSchemaColumns) -> None: +def write_dataset(client: JobClientBase, f: IO[bytes], rows: List[StrAny], columns_schema: TTableSchemaColumns) -> None: data_format = DataWriter.data_format_from_file_format(client.capabilities.preferred_loader_file_format) # adapt bytes stream to text file format if not data_format.is_binary_format and isinstance(f.read(0), bytes): f = codecs.getwriter("utf-8")(f) writer = DataWriter.from_destination_capabilities(client.capabilities, f) + # remove None values + for idx, row in enumerate(rows): + rows[idx] = {k:v for k, v in row.items() if v is not None} writer.write_all(columns_schema, rows) diff --git a/tests/pipeline/conftest.py b/tests/pipeline/conftest.py new file mode 100644 index 0000000000..144155fee7 --- /dev/null +++ b/tests/pipeline/conftest.py @@ -0,0 +1,2 @@ +from tests.utils import preserve_environ, autouse_test_storage, patch_home_dir +from tests.pipeline.utils import drop_dataset_from_env, drop_pipeline \ No newline at end of file diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index b31fad290b..8e4484d3d5 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -10,13 +10,11 @@ from dlt.common import json, sleep from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.exceptions import DestinationHasFailedJobs, DestinationTerminalException, PipelineStateNotAvailable, TerminalException, UnknownDestinationModule +from dlt.common.exceptions import DestinationHasFailedJobs, DestinationTerminalException, PipelineStateNotAvailable, UnknownDestinationModule from dlt.common.pipeline import PipelineContext from dlt.common.runtime.collector import AliveCollector, EnlightenCollector, LogCollector, TqdmCollector from dlt.common.schema.exceptions import InvalidDatasetName -from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id -from dlt.destinations.redshift.configuration import RedshiftCredentials from dlt.extract.exceptions import SourceExhausted from dlt.extract.extract import ExtractorStorage from dlt.extract.source import DltResource, DltSource @@ -26,10 +24,10 @@ from dlt.pipeline.state_sync import STATE_TABLE_NAME from tests.common.utils import TEST_SENTRY_DSN -from tests.utils import ALL_DESTINATIONS, TEST_STORAGE_ROOT, preserve_environ, autouse_test_storage, patch_home_dir +from tests.utils import ALL_DESTINATIONS, TEST_STORAGE_ROOT from tests.common.configuration.utils import environment from tests.extract.utils import expect_extracted_file -from tests.pipeline.utils import assert_load_info, drop_dataset_from_env, drop_pipeline +from tests.pipeline.utils import assert_load_info def test_default_pipeline() -> None: diff --git a/tests/pipeline/test_pipeline_file_format_resolver.py b/tests/pipeline/test_pipeline_file_format_resolver.py new file mode 100644 index 0000000000..0a9ecacd2a --- /dev/null +++ b/tests/pipeline/test_pipeline_file_format_resolver.py @@ -0,0 +1,62 @@ + +import dlt +import pytest + +from dlt.common.exceptions import DestinationIncompatibleLoaderFileFormatException, DestinationLoadingViaStagingNotSupported, DestinationNoStagingMode + +def test_file_format_resolution() -> None: + # raise on destinations that does not support staging + with pytest.raises(DestinationLoadingViaStagingNotSupported): + p = dlt.pipeline(pipeline_name="managed_state_pipeline", destination="postgres", staging="filesystem") + + # raise on staging that does not support staging interface + with pytest.raises(DestinationNoStagingMode): + p = dlt.pipeline(pipeline_name="managed_state_pipeline", staging="postgres") + + p = dlt.pipeline(pipeline_name="managed_state_pipeline") + + class cp(): + def __init__(self) -> None: + self.preferred_loader_file_format = None + self.supported_loader_file_formats = [] + self.preferred_staging_file_format = None + self.supported_staging_file_formats = [] + + destcp = cp() + stagecp = cp() + + # check regular resolution + destcp.preferred_loader_file_format = "jsonl" + destcp.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] + assert p._resolve_loader_file_format("some", "some", destcp, None, None) == "jsonl" + + # check resolution with input + assert p._resolve_loader_file_format("some", "some", destcp, None, "parquet") == "parquet" + + # check invalid input + with pytest.raises(DestinationIncompatibleLoaderFileFormatException): + assert p._resolve_loader_file_format("some", "some", destcp, None, "csv") + + # check staging resolution with clear preference + destcp.supported_staging_file_formats = ["jsonl", "insert_values", "parquet"] + destcp.preferred_staging_file_format = "insert_values" + stagecp.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] + assert p._resolve_loader_file_format("some", "some", destcp, stagecp, None) == "insert_values" + + # check invalid input + with pytest.raises(DestinationIncompatibleLoaderFileFormatException): + p._resolve_loader_file_format("some", "some", destcp, stagecp, "csv") + + # check staging resolution where preference does not match + destcp.supported_staging_file_formats = ["insert_values", "parquet"] + destcp.preferred_staging_file_format = "csv" + stagecp.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] + assert p._resolve_loader_file_format("some", "some", destcp, stagecp, None) == "insert_values" + assert p._resolve_loader_file_format("some", "some", destcp, stagecp, "parquet") == "parquet" + + # check incompatible staging + destcp.supported_staging_file_formats = ["insert_values", "csv"] + destcp.preferred_staging_file_format = "csv" + stagecp.supported_loader_file_formats = ["jsonl", "parquet"] + with pytest.raises(DestinationIncompatibleLoaderFileFormatException): + p._resolve_loader_file_format("some", "some", destcp, stagecp, None) \ No newline at end of file diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index 228c2ef8f4..71e8d90406 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -15,8 +15,8 @@ from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import migrate_state, STATE_ENGINE_VERSION -from tests.utils import autouse_test_storage, test_storage, patch_home_dir, preserve_environ -from tests.pipeline.utils import drop_dataset_from_env, json_case_path, load_json_case, drop_pipeline +from tests.utils import test_storage +from tests.pipeline.utils import json_case_path, load_json_case @dlt.resource() @@ -33,6 +33,24 @@ def some_data_resource_state(): dlt.current.resource_state()["last_value"] = last_value + 1 +def test_restore_state_props() -> None: + p = dlt.pipeline(pipeline_name="restore_state_props", destination="redshift", staging="filesystem", dataset_name="the_dataset") + p.extract(some_data()) + state = p.state + assert state["dataset_name"] == "the_dataset" + assert state["destination"].endswith("redshift") + assert state["staging"].endswith("filesystem") + + p = dlt.pipeline(pipeline_name="restore_state_props") + state = p.state + assert state["dataset_name"] == "the_dataset" + assert state["destination"].endswith("redshift") + assert state["staging"].endswith("filesystem") + # also instances are restored + assert p.destination.__name__.endswith("redshift") + assert p.staging.__name__.endswith("filesystem") + + def test_managed_state() -> None: p = dlt.pipeline(pipeline_name="managed_state_pipeline") p.extract(some_data()) diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index 76cd0d34c5..174785d308 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -21,9 +21,8 @@ from dlt.pipeline.trace import PipelineTrace, SerializableResolvedValueTrace, load_trace from dlt.pipeline.track import slack_notify_load_success -from tests.utils import preserve_environ, patch_home_dir, start_test_telemetry +from tests.utils import start_test_telemetry from tests.common.configuration.utils import toml_providers, environment -from tests.pipeline.utils import drop_dataset_from_env, drop_pipeline def test_create_trace(toml_providers: ConfigProvidersContext) -> None: diff --git a/tests/utils.py b/tests/utils.py index 8fb8caa4c8..cf0926ddfa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,12 +21,11 @@ from dlt.common.typing import StrAny from dlt.common.utils import uniq_id - TEST_STORAGE_ROOT = "_storage" + +# destination configs ALL_DESTINATIONS = dlt.config.get("ALL_DESTINATIONS", list) or ["bigquery", "redshift", "postgres", "duckdb", "snowflake"] ALL_LOCAL_DESTINATIONS = set(ALL_DESTINATIONS).intersection("postgres", "duckdb") -# ALL_DESTINATIONS = ["duckdb", "postgres"] - def TEST_DICT_CONFIG_PROVIDER(): # add test dictionary provider