Skip to content

Commit

Permalink
updated datamodels to use dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
anisbhsl committed Oct 13, 2023
1 parent bedb138 commit beb41a1
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 53 deletions.
14 changes: 7 additions & 7 deletions pynequa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def search_query(self, query_params: QueryParams) -> Dict:

payload = {
"app": self.app_name,
"query": query_params._prepare_query_args(
"query": query_params.generate_payload(
query_name=self.query_name)
}

Expand Down Expand Up @@ -160,7 +160,7 @@ def search_profile(self, profile_name: str, query_params: QueryParams,
payload = {
"profile": profile_name,
"responsetype": response_type,
"query": query_params._prepare_query_args(query_name=self.query_name),
"query": query_params.generate_payload(query_name=self.query_name),
}

return self.post(endpoint=endpoint, payload=payload)
Expand Down Expand Up @@ -209,7 +209,7 @@ def search_preview(self, query_params: QueryParams, action: str = "get",
"action": action,
"id": id,
"origin": origin,
"query": query_params._prepare_query_args(
"query": query_params.generate_payload(
query_name=self.query_name)
}

Expand Down Expand Up @@ -281,7 +281,7 @@ def search_similardocuments(self, source_doc_id: str,
payload = {
"app": self.app_name,
"sourceDocumentId": source_doc_id,
"query": query_params._prepare_query_args(
"query": query_params.generate_payload(
query_name=self.query_name)
}

Expand All @@ -300,7 +300,7 @@ def search_query_links(self, web_sevice: str, query_params: QueryParams) -> Dict
endpoint = "search.querylinks"
payload = {
"webService": web_sevice,
"query": query_params._prepare_query_args(
"query": query_params.generate_payload(
query_name=self.query_name)
}
return self.post(endpoint=endpoint, payload=payload)
Expand Down Expand Up @@ -355,9 +355,9 @@ def search_profile_subtree(self, profile: str, query_params: QueryParams,
endpoint = "search.profile.subtree"
payload = {
"profile": profile,
"query": query_params._prepare_query_args(
"query": query_params.generate_payload(
query_name=self.query_name),
"tree": tree_params._generate_tree_params_payload()
"tree": tree_params.generate_payload()
}
return self.post(endpoint=endpoint, payload=payload)

Expand Down
163 changes: 117 additions & 46 deletions pynequa/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,48 @@
from typing import Dict, List, Optional


class TreeParams:
from abc import abstractmethod
from dataclasses import dataclass, field
import logging

logging.basicConfig(level=logging.DEBUG)


class SinequaBaseModel:
"""
Abstract base class for all Sinequa models.
"""

@abstractmethod
def generate_payload(self, **kwargs) -> Dict:
"""
This is abstract method for SinequaBaseModel.
Every child class should implement this method.
"""
raise NotImplementedError()


@dataclass
class TreeParams(SinequaBaseModel):
"""
Represents the parameters for configuring a tree parameters.
Attributes:
box (str): The name of the relevant tree navigation box (required).
column (str): The name of the index column associated with the
navigation box (required).
op (str, optional): The relational operator. Default is 'eq'.
Possible values: '=', '!=', '<', '<=', '>=', '>', 'between', 'not between'.
value (str): The filter value (required).
"""
box: str = ""
column: str = ""
op: str = ""
value: str = ""

def _generate_tree_params_payload(self) -> Dict:
def generate_payload(self, **kwargs) -> Dict:
"""
This method generates payload for
TreeParams.
"""
return {
"box": self.box,
"column": self.column,
Expand All @@ -16,73 +51,94 @@ def _generate_tree_params_payload(self) -> Dict:
}


class SelectParams:
@dataclass
class SelectParams(SinequaBaseModel):
expression: str = ""
facet: str = ""

def _generate_select_params_payload(self) -> Dict:
def generate_payload(self, **kwargs) -> Dict:
"""
This method generates payload for
SelectParams.
"""
return {
"expression": self.expression,
"facet": self.facet,
}


class OpenParams:
@dataclass
class OpenParams(SinequaBaseModel):
expression: str = ""
facet: str = ""

def _generate_open_params_payload(self) -> Dict:
def generate_payload(self, **kwargs) -> Dict:
"""
This method generates payload for
OpenParams.
"""
return {
"expression": self.expression,
"facet": self.facet,
}


class AdvancedParams:
@dataclass
class AdvancedParams(SinequaBaseModel):
col_name: str = ""
col_value: str = ""
value: str = ""
operator: str = ""

def _generate_advanced_params_payload(self) -> Dict:
return {
col_value: str = None
value: str or int = None
operator: str = None
debug: bool = False

def generate_payload(self, **kwargs) -> Dict:
"""
This method generates payload for
AdvancedParams.
"""
payload = {
self.col_name: self.col_value,
"value": self.value,
"operator": self.operator
}

if self.debug:
logging.debug(payload)

return payload

class QueryParams:

@dataclass
class QueryParams(SinequaBaseModel):
name: str = "" # required
action: str = None
action: Optional[str] = None
search_text: str = "" # required
select_params: List[SelectParams] = []
additional_select_clause: str = None
additional_where_clause: str = None
open_params: List[OpenParams] = []
page: int = 0
page_size: int = 0
tab: str = None
scope: str = None
basket: str = None
is_first_page: bool = False
strict_refine: bool = False
global_relevance: int = None
question_language: str = None
question_default_language: str = None
spelling_correction_mode: str = None
spelling_correction_filter: str = None
document_weight: str = None
text_part_weights: str = None
relevance_transforms: str = None
remove_duplicates: bool = False
aggregations: List[str] = []
order_by: str = None
group_by: str = None
select_params: Optional[List[SelectParams]
] = field(default_factory=lambda: [])
additional_select_clause: Optional[str] = None
additional_where_clause: Optional[str] = None
open_params: Optional[List[OpenParams]] = field(default_factory=lambda: [])
page: Optional[int] = 1
page_size: Optional[int] = 10
tab: Optional[str] = None
scope: Optional[str] = None
basket: Optional[str] = None
is_first_page: Optional[bool] = False
strict_refine: Optional[bool] = False
global_relevance: Optional[int] = None
question_language: Optional[str] = None
question_default_language: Optional[str] = None
spelling_correction_mode: Optional[str] = None
spelling_correction_filter: Optional[str] = None
document_weight: Optional[str] = None
text_part_weights: Optional[str] = None
relevance_transforms: Optional[str] = None
remove_duplicates: Optional[bool] = False
aggregations: Optional[List[str]] = field(default_factory=lambda: [])
order_by: Optional[str] = None
group_by: Optional[str] = None
advanced: Optional[AdvancedParams] = None

def __init__(self) -> None:
pass
debug: bool = False

def _prepare_query_args(self, query_name: str) -> Dict:
params = {
Expand All @@ -101,7 +157,7 @@ def _prepare_query_args(self, query_name: str) -> Dict:
if len(self.select_params) > 0:
select_params = []
for item in self.select_params:
select_params.append(item._generate_select_params_payload())
select_params.append(item.generate_payload())
params["select"] = self.select_params

if self.additional_select_clause is not None:
Expand All @@ -113,7 +169,7 @@ def _prepare_query_args(self, query_name: str) -> Dict:
if len(self.open_params) > 0:
open_params = []
for item in self.open_params:
open_params.append(item._generate_open_params_payload())
open_params.append(item.generate_payload())
params["open"] = self.open_params

if self.page is not None:
Expand Down Expand Up @@ -169,6 +225,21 @@ def _prepare_query_args(self, query_name: str) -> Dict:
params["groupBy"] = self.group_by

if self.advanced is not None:
params["advanced"] = self.advanced._generate_advanced_params_payload()
params["advanced"] = self.advanced.generate_payload()

return params

def generate_payload(self, **kwargs) -> Dict:
"""
This method generates payload for
QueryParams.
Args:
query_name(str): Name of query service to query for
"""
query_name = kwargs.get("query_name")
payload = self._prepare_query_args(query_name)
if self.debug:
logging.debug(payload)

return payload
35 changes: 35 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pynequa.models import QueryParams
import unittest
import logging


class TestQueryParams(unittest.TestCase):

def test_query_params_payload(self):
"""
Test if query params payload is correctly
generated or not.
"""
qp = QueryParams(
name="query",
search_text="What was Landsat-9 launched?"
)

payload = qp.generate_payload()
logging.debug(payload)

keys_which_must_be_in_payload = [
"name",
"text",
"isFirstpage",
"strictRefine",
"removeDuplicates"
]

for key in keys_which_must_be_in_payload:
if key not in payload:
self.assertEqual(key, "test", f"{key} is mising in payload")


if __name__ == '__main__':
unittest.main()

0 comments on commit beb41a1

Please sign in to comment.