diff --git a/Makefile b/Makefile index 1059cfdf0a..d9d92ec799 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,7 @@ dev: has-poetry lint: ./check-package.sh - poetry run black ./ --diff --exclude=".*syntax_error.py|\.venv.*" + poetry run black ./ --diff --exclude=".*syntax_error.py|\.venv.*|_storage/.*" # poetry run isort ./ --diff poetry run mypy --config-file mypy.ini dlt tests poetry run flake8 --max-line-length=200 dlt @@ -56,7 +56,7 @@ lint: # $(MAKE) lint-security format: - poetry run black ./ --exclude=".*syntax_error.py|\.venv.*" + poetry run black ./ --exclude=".*syntax_error.py|\.venv.*|_storage/.*" # poetry run isort ./ test-and-lint-snippets: diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index 4ca58c20db..78cca1fbad 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -29,6 +29,8 @@ class RunConfiguration(BaseConfiguration): request_max_retry_delay: float = 300 """Maximum delay between http request retries""" config_files_storage_path: str = "/run/config/" + """Platform connection""" + dlthub_dsn: Optional[TSecretStrValue] = None __section__ = "runtime" diff --git a/dlt/common/managed_thread_pool.py b/dlt/common/managed_thread_pool.py new file mode 100644 index 0000000000..ea2a0e6b47 --- /dev/null +++ b/dlt/common/managed_thread_pool.py @@ -0,0 +1,27 @@ +from typing import Optional + +import atexit +from concurrent.futures import ThreadPoolExecutor + + +class ManagedThreadPool: + def __init__(self, max_workers: int = 1) -> None: + self._max_workers = max_workers + self._thread_pool: Optional[ThreadPoolExecutor] = None + + def _create_thread_pool(self) -> None: + assert not self._thread_pool, "Thread pool already created" + self._thread_pool = ThreadPoolExecutor(self._max_workers) + # flush pool on exit + atexit.register(self.stop) + + @property + def thread_pool(self) -> ThreadPoolExecutor: + if not self._thread_pool: + self._create_thread_pool() + return self._thread_pool + + def stop(self, wait: bool = True) -> None: + if self._thread_pool: + self._thread_pool.shutdown(wait=wait) + self._thread_pool = None diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index fc632003c1..6a20488293 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -15,6 +15,7 @@ TYPE_CHECKING, Tuple, TypedDict, + Mapping, ) from typing_extensions import NotRequired @@ -72,7 +73,7 @@ def asdict(self) -> DictStrAny: """A dictionary representation of NormalizeInfo that can be loaded with `dlt`""" d = self._asdict() # list representation creates a nice table - d["row_counts"] = [(k, v) for k, v in self.row_counts.items()] + d["row_counts"] = [{"table_name": k, "count": v} for k, v in self.row_counts.items()] return d def asstr(self, verbosity: int = 0) -> str: @@ -227,6 +228,10 @@ class SupportsPipeline(Protocol): def state(self) -> TPipelineState: """Returns dictionary with pipeline state""" + @property + def schemas(self) -> Mapping[str, Schema]: + """Mapping of all pipeline schemas""" + def set_local_state_val(self, key: str, value: Any) -> None: """Sets value in local state. Local state is not synchronized with destination.""" diff --git a/dlt/common/runtime/exec_info.py b/dlt/common/runtime/exec_info.py index d365156ad2..3aa19c83ab 100644 --- a/dlt/common/runtime/exec_info.py +++ b/dlt/common/runtime/exec_info.py @@ -1,23 +1,16 @@ import io import os import contextlib +import sys +import multiprocessing +import platform +from dlt.common.runtime.typing import TExecutionContext, TVersion, TExecInfoNames from dlt.common.typing import StrStr, StrAny, Literal, List from dlt.common.utils import filter_env_vars -from dlt.version import __version__ - - -TExecInfoNames = Literal[ - "kubernetes", - "docker", - "codespaces", - "github_actions", - "airflow", - "notebook", - "colab", - "aws_lambda", - "gcp_cloud_function", -] +from dlt.version import __version__, DLT_PKG_NAME + + # if one of these environment variables is set, we assume to be running in CI env CI_ENVIRONMENT_TELL = [ "bamboo.buildKey", @@ -174,3 +167,15 @@ def is_aws_lambda() -> bool: def is_gcp_cloud_function() -> bool: "Return True if the process is running in the serverless platform GCP Cloud Functions" return os.environ.get("FUNCTION_NAME") is not None + + +def get_execution_context() -> TExecutionContext: + "Get execution context information" + return TExecutionContext( + ci_run=in_continuous_integration(), + python=sys.version.split(" ")[0], + cpu=multiprocessing.cpu_count(), + exec_info=exec_info_names(), + os=TVersion(name=platform.system(), version=platform.release()), + library=TVersion(name=DLT_PKG_NAME, version=__version__), + ) diff --git a/dlt/common/runtime/segment.py b/dlt/common/runtime/segment.py index d06ef80607..e302767fcc 100644 --- a/dlt/common/runtime/segment.py +++ b/dlt/common/runtime/segment.py @@ -2,32 +2,31 @@ # several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py import os -import sys -import multiprocessing + import atexit import base64 import requests -import platform from concurrent.futures import ThreadPoolExecutor from typing import Literal, Optional from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.runtime import logger +from dlt.common.managed_thread_pool import ManagedThreadPool from dlt.common.configuration.specs import RunConfiguration -from dlt.common.runtime.exec_info import exec_info_names, in_continuous_integration +from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id -from dlt.version import __version__, DLT_PKG_NAME +from dlt.version import __version__ TEventCategory = Literal["pipeline", "command", "helper"] -_THREAD_POOL: ThreadPoolExecutor = None +_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1) _SESSION: requests.Session = None _WRITE_KEY: str = None _SEGMENT_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts _SEGMENT_ENDPOINT = "https://api.segment.io/v1/track" -_SEGMENT_CONTEXT: DictStrAny = None +_SEGMENT_CONTEXT: TExecutionContext = None def init_segment(config: RunConfiguration) -> None: @@ -36,9 +35,8 @@ def init_segment(config: RunConfiguration) -> None: ), "dlthub_telemetry_segment_write_key not present in RunConfiguration" # create thread pool to send telemetry to segment - global _THREAD_POOL, _WRITE_KEY, _SESSION - if not _THREAD_POOL: - _THREAD_POOL = ThreadPoolExecutor(1) + global _WRITE_KEY, _SESSION + if not _SESSION: _SESSION = requests.Session() # flush pool on exit atexit.register(_at_exit_cleanup) @@ -81,10 +79,9 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]: def _at_exit_cleanup() -> None: - global _THREAD_POOL, _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT - if _THREAD_POOL: - _THREAD_POOL.shutdown(wait=True) - _THREAD_POOL = None + global _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT + if _SESSION: + _THREAD_POOL.stop(True) _SESSION.close() _SESSION = None _WRITE_KEY = None @@ -141,7 +138,7 @@ def _segment_request_payload(event_name: str, properties: StrAny, context: StrAn } -def _default_context_fields() -> DictStrAny: +def _default_context_fields() -> TExecutionContext: """Return a dictionary that contains the default context values. Return: @@ -152,14 +149,7 @@ def _default_context_fields() -> DictStrAny: if not _SEGMENT_CONTEXT: # Make sure to update the example in docs/docs/telemetry/telemetry.mdx # if you change / add context - _SEGMENT_CONTEXT = { - "os": {"name": platform.system(), "version": platform.release()}, - "ci_run": in_continuous_integration(), - "python": sys.version.split(" ")[0], - "library": {"name": DLT_PKG_NAME, "version": __version__}, - "cpu": multiprocessing.cpu_count(), - "exec_info": exec_info_names(), - } + _SEGMENT_CONTEXT = get_execution_context() # avoid returning the cached dict --> caller could modify the dictionary... # usually we would use `lru_cache`, but that doesn't return a dict copy and @@ -207,4 +197,4 @@ def _future_send() -> None: if not data.get("success"): logger.debug(f"Segment telemetry request returned a failure. Response: {data}") - _THREAD_POOL.submit(_future_send) + _THREAD_POOL.thread_pool.submit(_future_send) diff --git a/dlt/common/runtime/typing.py b/dlt/common/runtime/typing.py new file mode 100644 index 0000000000..88707d387e --- /dev/null +++ b/dlt/common/runtime/typing.py @@ -0,0 +1,46 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Type, + TypedDict, + NewType, + Union, + get_args, +) + + +TExecInfoNames = Literal[ + "kubernetes", + "docker", + "codespaces", + "github_actions", + "airflow", + "notebook", + "colab", + "aws_lambda", + "gcp_cloud_function", +] + + +class TVersion(TypedDict): + """TypeDict representing a library version""" + + name: str + version: str + + +class TExecutionContext(TypedDict): + """TypeDict representing the runtime context info""" + + ci_run: bool + python: str + cpu: int + exec_info: List[TExecInfoNames] + library: TVersion + os: TVersion diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 63573b9f18..0518b07232 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -29,6 +29,7 @@ from dlt.common.configuration.accessors import config from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaTables, TTableSchemaColumns +from dlt.common.schema.typing import TStoredSchema from dlt.common.storages.configuration import LoadStorageConfiguration from dlt.common.storages.versioned_storage import VersionedStorage from dlt.common.storages.data_item_storage import DataItemStorage @@ -112,6 +113,7 @@ class LoadPackageInfo(NamedTuple): package_path: str state: TLoadPackageState schema_name: str + schema_hash: str schema_update: TSchemaTables completed_at: datetime.datetime jobs: Dict[TJobState, List[LoadJobInfo]] @@ -135,6 +137,7 @@ def asdict(self) -> DictStrAny: table["columns"] = columns d.pop("schema_update") d["tables"] = tables + d["schema_hash"] = self.schema_hash return d def asstr(self, verbosity: int = 0) -> str: @@ -374,6 +377,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: self.storage.make_full_path(package_path), package_state, schema.name, + schema.version_hash, applied_update, package_created_at, all_jobs, diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index ddb7d6d489..1a9503c287 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -266,9 +266,9 @@ def run( # plug default tracking module -from dlt.pipeline import trace, track +from dlt.pipeline import trace, track, platform -trace.TRACKING_MODULE = track +trace.TRACKING_MODULES = [track, platform] # setup default pipeline in the container Container()[PipelineContext] = PipelineContext(pipeline) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 4c45f0e486..836442f5bb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -16,6 +16,7 @@ cast, get_type_hints, ContextManager, + Mapping, ) from dlt import version @@ -167,51 +168,54 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore -def with_runtime_trace(f: TFun) -> TFun: - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - trace: PipelineTrace = self._trace - trace_step: PipelineStepTrace = None - step_info: Any = None - is_new_trace = self._trace is None and self.config.enable_runtime_trace - - # create a new trace if we enter a traced function and there's no current trace - if is_new_trace: - self._trace = trace = start_trace(cast(TPipelineStep, f.__name__), self) +def with_runtime_trace(send_state: bool = False) -> Callable[[TFun], TFun]: + def decorator(f: TFun) -> TFun: + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + trace: PipelineTrace = self._trace + trace_step: PipelineStepTrace = None + step_info: Any = None + is_new_trace = self._trace is None and self.config.enable_runtime_trace - try: - # start a trace step for wrapped function - if trace: - trace_step = start_trace_step(trace, cast(TPipelineStep, f.__name__), self) + # create a new trace if we enter a traced function and there's no current trace + if is_new_trace: + self._trace = trace = start_trace(cast(TPipelineStep, f.__name__), self) - step_info = f(self, *args, **kwargs) - return step_info - except Exception as ex: - step_info = ex # step info is an exception - raise - finally: try: - if trace_step: - # if there was a step, finish it - end_trace_step(self._trace, trace_step, self, step_info) - if is_new_trace: - assert ( - trace is self._trace - ), f"Messed up trace reference {id(self._trace)} vs {id(trace)}" - end_trace(trace, self, self._pipeline_storage.storage_path) + # start a trace step for wrapped function + if trace: + trace_step = start_trace_step(trace, cast(TPipelineStep, f.__name__), self) + + step_info = f(self, *args, **kwargs) + return step_info + except Exception as ex: + step_info = ex # step info is an exception + raise finally: - # always end trace - if is_new_trace: - assert ( - self._trace == trace - ), f"Messed up trace reference {id(self._trace)} vs {id(trace)}" - # if we end new trace that had only 1 step, add it to previous trace - # this way we combine several separate calls to extract, normalize, load as single trace - # the trace of "run" has many steps and will not be merged - self._last_trace = merge_traces(self._last_trace, trace) - self._trace = None + try: + if trace_step: + # if there was a step, finish it + end_trace_step(self._trace, trace_step, self, step_info, send_state) + if is_new_trace: + assert ( + trace is self._trace + ), f"Messed up trace reference {id(self._trace)} vs {id(trace)}" + end_trace(trace, self, self._pipeline_storage.storage_path, send_state) + finally: + # always end trace + if is_new_trace: + assert ( + self._trace == trace + ), f"Messed up trace reference {id(self._trace)} vs {id(trace)}" + # if we end new trace that had only 1 step, add it to previous trace + # this way we combine several separate calls to extract, normalize, load as single trace + # the trace of "run" has many steps and will not be merged + self._last_trace = merge_traces(self._last_trace, trace) + self._trace = None - return _wrap # type: ignore + return _wrap # type: ignore + + return decorator def with_config_section(sections: Tuple[str, ...]) -> Callable[[TFun], TFun]: @@ -335,7 +339,7 @@ def drop(self) -> "Pipeline": self.runtime_config, ) - @with_runtime_trace + @with_runtime_trace() @with_schemas_sync # this must precede with_state_sync @with_state_sync(may_extract_state=True) @with_config_section((known_sections.EXTRACT,)) @@ -388,7 +392,7 @@ def extract( self, "extract", exc, ExtractInfo(describe_extract_data(data)) ) from exc - @with_runtime_trace + @with_runtime_trace() @with_schemas_sync @with_config_section((known_sections.NORMALIZE,)) def normalize( @@ -429,7 +433,7 @@ def normalize( self, "normalize", n_ex, normalize.get_normalize_info() ) from n_ex - @with_runtime_trace + @with_runtime_trace(send_state=True) @with_schemas_sync @with_state_sync() @with_config_section((known_sections.LOAD,)) @@ -482,7 +486,7 @@ def load( except Exception as l_ex: raise PipelineStepFailed(self, "load", l_ex, self._get_load_info(load)) from l_ex - @with_runtime_trace + @with_runtime_trace() @with_config_section(("run",)) def run( self, diff --git a/dlt/pipeline/platform.py b/dlt/pipeline/platform.py new file mode 100644 index 0000000000..c8014d5ae7 --- /dev/null +++ b/dlt/pipeline/platform.py @@ -0,0 +1,118 @@ +"""Implements SupportsTracking""" +from typing import Any, cast, TypedDict, List +import requests +from dlt.common.managed_thread_pool import ManagedThreadPool +from urllib.parse import urljoin + +from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, TPipelineStep, SupportsPipeline +from dlt.common import json +from dlt.common.runtime import logger +from dlt.common.pipeline import LoadInfo +from dlt.common.schema.typing import TStoredSchema + +_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1) +TRACE_URL_SUFFIX = "/trace" +STATE_URL_SUFFIX = "/state" + + +class TPipelineSyncPayload(TypedDict): + pipeline_name: str + destination_name: str + destination_displayable_credentials: str + destination_fingerprint: str + dataset_name: str + schemas: List[TStoredSchema] + + +def _send_trace_to_platform(trace: PipelineTrace, pipeline: SupportsPipeline) -> None: + """ + Send the full trace after a run operation to the platform + TODO: Migrate this to open telemetry in the next iteration + """ + if not pipeline.runtime_config.dlthub_dsn: + return + + def _future_send() -> None: + try: + trace_dump = json.dumps(trace.asdict()) + url = pipeline.runtime_config.dlthub_dsn + TRACE_URL_SUFFIX + response = requests.put(url, data=trace_dump) + if response.status_code != 200: + logger.debug( + f"Failed to send trace to platform, response code: {response.status_code}" + ) + except Exception as e: + logger.debug(f"Exception while sending trace to platform: {e}") + + _THREAD_POOL.thread_pool.submit(_future_send) + + # trace_dump = json.dumps(trace.asdict(), pretty=True) + # with open(f"trace.json", "w") as f: + # f.write(trace_dump) + + +def _sync_schemas_to_platform(trace: PipelineTrace, pipeline: SupportsPipeline) -> None: + if not pipeline.runtime_config.dlthub_dsn: + return + + # sync only if load step was processed + load_info: LoadInfo = None + for step in trace.steps: + if step.step == "load": + load_info = cast(LoadInfo, step.step_info) + + if not load_info: + return + + payload = TPipelineSyncPayload( + pipeline_name=pipeline.pipeline_name, + destination_name=load_info.destination_name, + destination_displayable_credentials=load_info.destination_displayable_credentials, + destination_fingerprint=load_info.destination_fingerprint, + dataset_name=load_info.dataset_name, + schemas=[], + ) + + # attach all schemas + for schema_name in pipeline.schemas: + schema = pipeline.schemas[schema_name] + payload["schemas"].append(schema.to_dict()) + + def _future_send() -> None: + try: + url = pipeline.runtime_config.dlthub_dsn + STATE_URL_SUFFIX + response = requests.put(url, data=json.dumps(payload)) + if response.status_code != 200: + logger.debug( + f"Failed to send state to platform, response code: {response.status_code}" + ) + except Exception as e: + logger.debug(f"Exception while sending state to platform: {e}") + + _THREAD_POOL.thread_pool.submit(_future_send) + + +def on_start_trace(trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> None: + pass + + +def on_start_trace_step( + trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline +) -> None: + pass + + +def on_end_trace_step( + trace: PipelineTrace, + step: PipelineStepTrace, + pipeline: SupportsPipeline, + step_info: Any, + send_state: bool, +) -> None: + if send_state: + # also sync schemas to dlthub + _sync_schemas_to_platform(trace, pipeline) + + +def on_end_trace(trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool) -> None: + _send_trace_to_platform(trace, pipeline) diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index e725a2f726..88a4d185fb 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -8,6 +8,7 @@ from dlt.common import pendulum from dlt.common.runtime.logger import suppress_and_warn +from dlt.common.runtime.exec_info import TExecutionContext, get_execution_context from dlt.common.configuration import is_secret_hint from dlt.common.configuration.utils import _RESOLVED_TRACES from dlt.common.pipeline import ( @@ -102,6 +103,8 @@ class PipelineTrace: """Pipeline runtime trace containing data on "extract", "normalize" and "load" steps and resolved config and secret values.""" transaction_id: str + pipeline_name: str + execution_context: TExecutionContext started_at: datetime.datetime steps: List[PipelineStepTrace] """A list of steps in the trace""" @@ -136,6 +139,10 @@ def last_pipeline_step_trace(self, step_name: TPipelineStep) -> PipelineStepTrac return step return None + def asdict(self) -> DictStrAny: + """A dictionary representation of PipelineTrace that can be loaded with `dlt`""" + return dataclasses.asdict(self) + @property def last_extract_info(self) -> ExtractInfo: step_trace = self.last_pipeline_step_trace("extract") @@ -176,20 +183,25 @@ def on_end_trace_step( step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any, + send_state: bool, ) -> None: ... - def on_end_trace(self, trace: PipelineTrace, pipeline: SupportsPipeline) -> None: ... + def on_end_trace( + self, trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool + ) -> None: ... -# plug in your own tracking module here -# TODO: that probably should be a list of modules / classes with all of them called -TRACKING_MODULE: SupportsTracking = None +# plug in your own tracking modules here +TRACKING_MODULES: List[SupportsTracking] = None def start_trace(step: TPipelineStep, pipeline: SupportsPipeline) -> PipelineTrace: - trace = PipelineTrace(uniq_id(), pendulum.now(), steps=[]) - with suppress_and_warn(): - TRACKING_MODULE.on_start_trace(trace, step, pipeline) + trace = PipelineTrace( + uniq_id(), pipeline.pipeline_name, get_execution_context(), pendulum.now(), steps=[] + ) + for module in TRACKING_MODULES: + with suppress_and_warn(): + module.on_start_trace(trace, step, pipeline) return trace @@ -197,13 +209,18 @@ def start_trace_step( trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline ) -> PipelineStepTrace: trace_step = PipelineStepTrace(uniq_id(), step, pendulum.now()) - with suppress_and_warn(): - TRACKING_MODULE.on_start_trace_step(trace, step, pipeline) + for module in TRACKING_MODULES: + with suppress_and_warn(): + module.on_start_trace_step(trace, step, pipeline) return trace_step def end_trace_step( - trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any + trace: PipelineTrace, + step: PipelineStepTrace, + pipeline: SupportsPipeline, + step_info: Any, + send_state: bool, ) -> None: # saves runtime trace of the pipeline if isinstance(step_info, PipelineStepFailed): @@ -237,16 +254,20 @@ def end_trace_step( trace.resolved_config_values = list(resolved_values) trace.steps.append(step) - with suppress_and_warn(): - TRACKING_MODULE.on_end_trace_step(trace, step, pipeline, step_info) + for module in TRACKING_MODULES: + with suppress_and_warn(): + module.on_end_trace_step(trace, step, pipeline, step_info, send_state) -def end_trace(trace: PipelineTrace, pipeline: SupportsPipeline, trace_path: str) -> None: +def end_trace( + trace: PipelineTrace, pipeline: SupportsPipeline, trace_path: str, send_state: bool +) -> None: trace.finished_at = pendulum.now() if trace_path: save_trace(trace_path, trace) - with suppress_and_warn(): - TRACKING_MODULE.on_end_trace(trace, pipeline) + for module in TRACKING_MODULES: + with suppress_and_warn(): + module.on_end_trace(trace, pipeline, send_state) def merge_traces(last_trace: PipelineTrace, new_trace: PipelineTrace) -> PipelineTrace: diff --git a/dlt/pipeline/track.py b/dlt/pipeline/track.py index 7670c95163..cfaf411a44 100644 --- a/dlt/pipeline/track.py +++ b/dlt/pipeline/track.py @@ -79,7 +79,11 @@ def on_start_trace_step( def on_end_trace_step( - trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any + trace: PipelineTrace, + step: PipelineStepTrace, + pipeline: SupportsPipeline, + step_info: Any, + send_state: bool, ) -> None: if pipeline.runtime_config.sentry_dsn: # print(f"---END SENTRY SPAN {trace.transaction_id}:{step.span_id}: {step} SCOPE: {Hub.current.scope}") @@ -110,7 +114,7 @@ def on_end_trace_step( dlthub_telemetry_track("pipeline", step.step, props) -def on_end_trace(trace: PipelineTrace, pipeline: SupportsPipeline) -> None: +def on_end_trace(trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool) -> None: if pipeline.runtime_config.sentry_dsn: # print(f"---END SENTRY TX: {trace.transaction_id} SCOPE: {Hub.current.scope}") with contextlib.suppress(Exception): diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index d8da21503d..36892ad260 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -564,6 +564,7 @@ class _SecretCredentials(RunConfiguration): "request_backoff_factor": 1, "request_max_retry_delay": 300, "config_files_storage_path": "storage", + "dlthub_dsn": None, "secret_value": None, } assert dict(_SecretCredentials()) == expected_dict diff --git a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py index b63fc3d472..a26e9b774d 100644 --- a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py +++ b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py @@ -57,7 +57,7 @@ def test_multiple_resources_pipeline(): ) load_info = pipeline.run([source1(10), source2(20)]) - source1_schema = load_info.pipeline.schemas.get("source1") # type: ignore[attr-defined] + source1_schema = load_info.pipeline.schemas.get("source1") assert load_info.pipeline.schema_names == ["source2", "source1"] # type: ignore[attr-defined] diff --git a/tests/pipeline/test_platform_connection.py b/tests/pipeline/test_platform_connection.py new file mode 100644 index 0000000000..a0893cfc93 --- /dev/null +++ b/tests/pipeline/test_platform_connection.py @@ -0,0 +1,73 @@ +import dlt +import os +import time +import requests_mock + +TRACE_URL_SUFFIX = "/trace" +STATE_URL_SUFFIX = "/state" + + +def test_platform_connection() -> None: + mock_platform_url = "http://platform.com/endpoint" + + os.environ["RUNTIME__DLTHUB_DSN"] = mock_platform_url + + trace_url = mock_platform_url + TRACE_URL_SUFFIX + state_url = mock_platform_url + STATE_URL_SUFFIX + + # simple pipeline + @dlt.source(name="first_source") + def my_source(): + @dlt.resource(name="test_resource") + def data(): + yield [1, 2, 3] + + return data() + + @dlt.source(name="second_source") + def my_source_2(): + @dlt.resource(name="test_resource") + def data(): + yield [1, 2, 3] + + return data() + + p = dlt.pipeline( + destination="duckdb", + pipeline_name="platform_test_pipeline", + dataset_name="platform_test_dataset", + ) + + with requests_mock.mock() as m: + m.put(mock_platform_url, json={}, status_code=200) + p.run([my_source(), my_source_2()]) + + # sleep a bit and find trace in mock requests + time.sleep(2) + + trace_result = None + state_result = None + for call in m.request_history: + if call.url == trace_url: + assert not trace_result, "Multiple calls to trace endpoint" + trace_result = call.json() + + if call.url == state_url: + assert not state_result, "Multiple calls to state endpoint" + state_result = call.json() + + # basic check of trace result + assert trace_result, "no trace" + assert trace_result["pipeline_name"] == "platform_test_pipeline" + assert len(trace_result["steps"]) == 4 + assert trace_result["execution_context"]["library"]["name"] == "dlt" + + # basic check of state result + assert state_result, "no state update" + assert state_result["pipeline_name"] == "platform_test_pipeline" + assert state_result["dataset_name"] == "platform_test_dataset" + assert len(state_result["schemas"]) == 2 + assert {state_result["schemas"][0]["name"], state_result["schemas"][1]["name"]} == { + "first_source", + "second_source", + }