Skip to content

Commit

Permalink
Merge pull request #727 from dlt-hub/d#/platform_connection
Browse files Browse the repository at this point in the history
prototype platform connection
  • Loading branch information
sh-rp authored Nov 24, 2023
2 parents e48813f + d453db3 commit cfb6e66
Show file tree
Hide file tree
Showing 16 changed files with 405 additions and 105 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ dev: has-poetry

lint:
./check-package.sh
poetry run black ./ --diff --exclude=".*syntax_error.py|\.venv.*"
poetry run black ./ --diff --exclude=".*syntax_error.py|\.venv.*|_storage/.*"
# poetry run isort ./ --diff
poetry run mypy --config-file mypy.ini dlt tests
poetry run flake8 --max-line-length=200 dlt
poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases
# $(MAKE) lint-security

format:
poetry run black ./ --exclude=".*syntax_error.py|\.venv.*"
poetry run black ./ --exclude=".*syntax_error.py|\.venv.*|_storage/.*"
# poetry run isort ./

test-and-lint-snippets:
Expand Down
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/run_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class RunConfiguration(BaseConfiguration):
request_max_retry_delay: float = 300
"""Maximum delay between http request retries"""
config_files_storage_path: str = "/run/config/"
"""Platform connection"""
dlthub_dsn: Optional[TSecretStrValue] = None

__section__ = "runtime"

Expand Down
27 changes: 27 additions & 0 deletions dlt/common/managed_thread_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional

import atexit
from concurrent.futures import ThreadPoolExecutor


class ManagedThreadPool:
def __init__(self, max_workers: int = 1) -> None:
self._max_workers = max_workers
self._thread_pool: Optional[ThreadPoolExecutor] = None

def _create_thread_pool(self) -> None:
assert not self._thread_pool, "Thread pool already created"
self._thread_pool = ThreadPoolExecutor(self._max_workers)
# flush pool on exit
atexit.register(self.stop)

@property
def thread_pool(self) -> ThreadPoolExecutor:
if not self._thread_pool:
self._create_thread_pool()
return self._thread_pool

def stop(self, wait: bool = True) -> None:
if self._thread_pool:
self._thread_pool.shutdown(wait=wait)
self._thread_pool = None
7 changes: 6 additions & 1 deletion dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TYPE_CHECKING,
Tuple,
TypedDict,
Mapping,
)
from typing_extensions import NotRequired

Expand Down Expand Up @@ -72,7 +73,7 @@ def asdict(self) -> DictStrAny:
"""A dictionary representation of NormalizeInfo that can be loaded with `dlt`"""
d = self._asdict()
# list representation creates a nice table
d["row_counts"] = [(k, v) for k, v in self.row_counts.items()]
d["row_counts"] = [{"table_name": k, "count": v} for k, v in self.row_counts.items()]
return d

def asstr(self, verbosity: int = 0) -> str:
Expand Down Expand Up @@ -227,6 +228,10 @@ class SupportsPipeline(Protocol):
def state(self) -> TPipelineState:
"""Returns dictionary with pipeline state"""

@property
def schemas(self) -> Mapping[str, Schema]:
"""Mapping of all pipeline schemas"""

def set_local_state_val(self, key: str, value: Any) -> None:
"""Sets value in local state. Local state is not synchronized with destination."""

Expand Down
33 changes: 19 additions & 14 deletions dlt/common/runtime/exec_info.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import io
import os
import contextlib
import sys
import multiprocessing
import platform

from dlt.common.runtime.typing import TExecutionContext, TVersion, TExecInfoNames
from dlt.common.typing import StrStr, StrAny, Literal, List
from dlt.common.utils import filter_env_vars
from dlt.version import __version__


TExecInfoNames = Literal[
"kubernetes",
"docker",
"codespaces",
"github_actions",
"airflow",
"notebook",
"colab",
"aws_lambda",
"gcp_cloud_function",
]
from dlt.version import __version__, DLT_PKG_NAME


# if one of these environment variables is set, we assume to be running in CI env
CI_ENVIRONMENT_TELL = [
"bamboo.buildKey",
Expand Down Expand Up @@ -174,3 +167,15 @@ def is_aws_lambda() -> bool:
def is_gcp_cloud_function() -> bool:
"Return True if the process is running in the serverless platform GCP Cloud Functions"
return os.environ.get("FUNCTION_NAME") is not None


def get_execution_context() -> TExecutionContext:
"Get execution context information"
return TExecutionContext(
ci_run=in_continuous_integration(),
python=sys.version.split(" ")[0],
cpu=multiprocessing.cpu_count(),
exec_info=exec_info_names(),
os=TVersion(name=platform.system(), version=platform.release()),
library=TVersion(name=DLT_PKG_NAME, version=__version__),
)
38 changes: 14 additions & 24 deletions dlt/common/runtime/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,31 @@

# several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py
import os
import sys
import multiprocessing

import atexit
import base64
import requests
import platform
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Optional
from dlt.common.configuration.paths import get_dlt_data_dir

from dlt.common.runtime import logger
from dlt.common.managed_thread_pool import ManagedThreadPool

from dlt.common.configuration.specs import RunConfiguration
from dlt.common.runtime.exec_info import exec_info_names, in_continuous_integration
from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext
from dlt.common.typing import DictStrAny, StrAny
from dlt.common.utils import uniq_id
from dlt.version import __version__, DLT_PKG_NAME
from dlt.version import __version__

TEventCategory = Literal["pipeline", "command", "helper"]

_THREAD_POOL: ThreadPoolExecutor = None
_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_SESSION: requests.Session = None
_WRITE_KEY: str = None
_SEGMENT_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts
_SEGMENT_ENDPOINT = "https://api.segment.io/v1/track"
_SEGMENT_CONTEXT: DictStrAny = None
_SEGMENT_CONTEXT: TExecutionContext = None


def init_segment(config: RunConfiguration) -> None:
Expand All @@ -36,9 +35,8 @@ def init_segment(config: RunConfiguration) -> None:
), "dlthub_telemetry_segment_write_key not present in RunConfiguration"

# create thread pool to send telemetry to segment
global _THREAD_POOL, _WRITE_KEY, _SESSION
if not _THREAD_POOL:
_THREAD_POOL = ThreadPoolExecutor(1)
global _WRITE_KEY, _SESSION
if not _SESSION:
_SESSION = requests.Session()
# flush pool on exit
atexit.register(_at_exit_cleanup)
Expand Down Expand Up @@ -81,10 +79,9 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]:


def _at_exit_cleanup() -> None:
global _THREAD_POOL, _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _THREAD_POOL:
_THREAD_POOL.shutdown(wait=True)
_THREAD_POOL = None
global _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _SESSION:
_THREAD_POOL.stop(True)
_SESSION.close()
_SESSION = None
_WRITE_KEY = None
Expand Down Expand Up @@ -141,7 +138,7 @@ def _segment_request_payload(event_name: str, properties: StrAny, context: StrAn
}


def _default_context_fields() -> DictStrAny:
def _default_context_fields() -> TExecutionContext:
"""Return a dictionary that contains the default context values.
Return:
Expand All @@ -152,14 +149,7 @@ def _default_context_fields() -> DictStrAny:
if not _SEGMENT_CONTEXT:
# Make sure to update the example in docs/docs/telemetry/telemetry.mdx
# if you change / add context
_SEGMENT_CONTEXT = {
"os": {"name": platform.system(), "version": platform.release()},
"ci_run": in_continuous_integration(),
"python": sys.version.split(" ")[0],
"library": {"name": DLT_PKG_NAME, "version": __version__},
"cpu": multiprocessing.cpu_count(),
"exec_info": exec_info_names(),
}
_SEGMENT_CONTEXT = get_execution_context()

# avoid returning the cached dict --> caller could modify the dictionary...
# usually we would use `lru_cache`, but that doesn't return a dict copy and
Expand Down Expand Up @@ -207,4 +197,4 @@ def _future_send() -> None:
if not data.get("success"):
logger.debug(f"Segment telemetry request returned a failure. Response: {data}")

_THREAD_POOL.submit(_future_send)
_THREAD_POOL.thread_pool.submit(_future_send)
46 changes: 46 additions & 0 deletions dlt/common/runtime/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Type,
TypedDict,
NewType,
Union,
get_args,
)


TExecInfoNames = Literal[
"kubernetes",
"docker",
"codespaces",
"github_actions",
"airflow",
"notebook",
"colab",
"aws_lambda",
"gcp_cloud_function",
]


class TVersion(TypedDict):
"""TypeDict representing a library version"""

name: str
version: str


class TExecutionContext(TypedDict):
"""TypeDict representing the runtime context info"""

ci_run: bool
python: str
cpu: int
exec_info: List[TExecInfoNames]
library: TVersion
os: TVersion
4 changes: 4 additions & 0 deletions dlt/common/storages/load_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from dlt.common.configuration.accessors import config
from dlt.common.exceptions import TerminalValueError
from dlt.common.schema import Schema, TSchemaTables, TTableSchemaColumns
from dlt.common.schema.typing import TStoredSchema
from dlt.common.storages.configuration import LoadStorageConfiguration
from dlt.common.storages.versioned_storage import VersionedStorage
from dlt.common.storages.data_item_storage import DataItemStorage
Expand Down Expand Up @@ -112,6 +113,7 @@ class LoadPackageInfo(NamedTuple):
package_path: str
state: TLoadPackageState
schema_name: str
schema_hash: str
schema_update: TSchemaTables
completed_at: datetime.datetime
jobs: Dict[TJobState, List[LoadJobInfo]]
Expand All @@ -135,6 +137,7 @@ def asdict(self) -> DictStrAny:
table["columns"] = columns
d.pop("schema_update")
d["tables"] = tables
d["schema_hash"] = self.schema_hash
return d

def asstr(self, verbosity: int = 0) -> str:
Expand Down Expand Up @@ -374,6 +377,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo:
self.storage.make_full_path(package_path),
package_state,
schema.name,
schema.version_hash,
applied_update,
package_created_at,
all_jobs,
Expand Down
4 changes: 2 additions & 2 deletions dlt/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ def run(


# plug default tracking module
from dlt.pipeline import trace, track
from dlt.pipeline import trace, track, platform

trace.TRACKING_MODULE = track
trace.TRACKING_MODULES = [track, platform]

# setup default pipeline in the container
Container()[PipelineContext] = PipelineContext(pipeline)
Loading

0 comments on commit cfb6e66

Please sign in to comment.