Skip to content

Commit

Permalink
RSDK-9146: Change Python TabularDataBySQL/MQL return type to raw BSON (
Browse files Browse the repository at this point in the history
  • Loading branch information
jckras authored Nov 1, 2024
1 parent b932e16 commit e0cec71
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"grpclib>=0.4.7",
"protobuf==5.28.2",
"typing-extensions>=4.12.2",
"pymongo>=4.10.1"
]

[project.urls]
Expand Down
11 changes: 6 additions & 5 deletions src/viam/app/data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import bson

from google.protobuf.struct_pb2 import Struct
from grpclib.client import Channel, Stream
Expand Down Expand Up @@ -241,7 +242,7 @@ async def tabular_data_by_filter(
LOGGER.error(f"Failed to write tabular data to file {dest}", exc_info=e)
return data, response.count, response.last

async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> List[Dict[str, ValueTypes]]:
async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> List[Dict[str, Union[ValueTypes, datetime]]]:
"""Obtain unified tabular data and metadata, queried with SQL.
::
Expand All @@ -264,9 +265,9 @@ async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> Lis
"""
request = TabularDataBySQLRequest(organization_id=organization_id, sql_query=sql_query)
response: TabularDataBySQLResponse = await self._data_client.TabularDataBySQL(request, metadata=self._metadata)
return [struct_to_dict(struct) for struct in response.data]
return [bson.decode(bson_bytes) for bson_bytes in response.raw_data]

async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes]) -> List[Dict[str, ValueTypes]]:
async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes]) -> List[Dict[str, Union[ValueTypes, datetime]]]:
"""Obtain unified tabular data and metadata, queried with MQL.
::
Expand Down Expand Up @@ -303,7 +304,7 @@ async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes
"""
request = TabularDataByMQLRequest(organization_id=organization_id, mql_binary=mql_binary)
response: TabularDataByMQLResponse = await self._data_client.TabularDataByMQL(request, metadata=self._metadata)
return [struct_to_dict(struct) for struct in response.data]
return [bson.decode(bson_bytes) for bson_bytes in response.raw_data]

async def binary_data_by_filter(
self,
Expand Down
11 changes: 6 additions & 5 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
from grpclib.server import Stream
from numpy.typing import NDArray
from datetime import datetime
import bson

from viam.app.data_client import DataClient
from viam.gen.app.v1.app_pb2 import FragmentHistoryEntry, GetFragmentHistoryRequest, GetFragmentHistoryResponse
Expand Down Expand Up @@ -791,12 +793,11 @@ async def SetSmartMachineCredentials(
self.cloud_config = request.cloud
await stream.send_message(SetSmartMachineCredentialsResponse())


class MockData(UnimplementedDataServiceBase):
def __init__(
self,
tabular_response: List[DataClient.TabularData],
tabular_query_response: List[Dict[str, ValueTypes]],
tabular_query_response: List[Dict[str, Union[ValueTypes, datetime]]],
binary_response: List[BinaryData],
delete_remove_response: int,
tags_response: List[str],
Expand Down Expand Up @@ -986,12 +987,12 @@ async def RemoveBinaryDataFromDatasetByIDs(
async def TabularDataBySQL(self, stream: Stream[TabularDataBySQLRequest, TabularDataBySQLResponse]) -> None:
request = await stream.recv_message()
assert request is not None
await stream.send_message(TabularDataBySQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))
await stream.send_message(TabularDataBySQLResponse(raw_data=[bson.encode(dict) for dict in self.tabular_query_response]))

async def TabularDataByMQL(self, stream: Stream[TabularDataByMQLRequest, TabularDataByMQLResponse]) -> None:
request = await stream.recv_message()
assert request is not None
await stream.send_message(TabularDataByMQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))
await stream.send_message(TabularDataByMQLResponse(raw_data=[bson.encode(dict) for dict in self.tabular_query_response]))


class MockDataset(DatasetServiceBase):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from google.protobuf.timestamp_pb2 import Timestamp
from grpclib.testing import ChannelFor

from datetime import datetime
from viam.app.data_client import DataClient
from viam.proto.app.data import Annotations, BinaryData, BinaryID, BinaryMetadata, BoundingBox, CaptureMetadata, Filter, Order
from viam.utils import create_filter
Expand Down Expand Up @@ -101,7 +102,7 @@

TABULAR_RESPONSE = [DataClient.TabularData(TABULAR_DATA, TABULAR_METADATA, START_DATETIME, END_DATETIME)]
TABULAR_QUERY_RESPONSE = [
{"key1": 1, "key2": "2", "key3": [1, 2, 3], "key4": {"key4sub1": 1}},
{"key1": START_DATETIME, "key2": "2", "key3": [1, 2, 3], "key4": {"key4sub1": END_DATETIME}},
]
BINARY_RESPONSE = [BinaryData(binary=BINARY_DATA, metadata=BINARY_METADATA)]
DELETE_REMOVE_RESPONSE = 1
Expand Down Expand Up @@ -153,12 +154,14 @@ async def test_tabular_data_by_sql(self, service: MockData):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
response = await client.tabular_data_by_sql(ORG_ID, SQL_QUERY)
assert isinstance(response[0]["key1"], datetime)
assert response == TABULAR_QUERY_RESPONSE

async def test_tabular_data_by_mql(self, service: MockData):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
response = await client.tabular_data_by_mql(ORG_ID, MQL_BINARY)
assert isinstance(response[0]["key1"], datetime)
assert response == TABULAR_QUERY_RESPONSE

async def test_binary_data_by_filter(self, service: MockData):
Expand Down
Loading

0 comments on commit e0cec71

Please sign in to comment.