Skip to content

Commit

Permalink
Improve highlighting for text annotation search results
Browse files Browse the repository at this point in the history
Close #264
  • Loading branch information
bkis committed Sep 3, 2024
1 parent 4ac4487 commit 448f6ab
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 42 deletions.
21 changes: 18 additions & 3 deletions Tekst-API/tekst/models/search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime
from typing import Annotated, Any, Literal

Expand Down Expand Up @@ -27,19 +28,31 @@ class SearchResults(ModelBase):
max_score: float | None

@classmethod
def __transform_highlights(cls, hit: dict[str, Any]) -> dict[str, set[str]]:
def __transform_highlights(
cls,
hit: dict[str, Any],
highlights_generators: dict[str, Callable[[dict[str, Any]], list[str]]],
) -> dict[str, set[str]]:
if not hit.get("highlight"):
return {}
highlights = {}
highlights_generators = highlights_generators or {}
for k, v in hit["highlight"].items():
hl_res_id = k.split(".")[1]
if hl_res_id not in highlights:
highlights[hl_res_id] = set()
if hl_res_id in highlights_generators:
v = highlights_generators[hl_res_id](hit)
highlights[hl_res_id].update(v)
return highlights

@classmethod
def from_es_results(cls, results: dict[str, Any]) -> "SearchResults":
def from_es_results(
cls,
results: dict[str, Any],
highlights_generators: dict[str, Callable[[dict[str, Any]], list[str]]]
| None = None,
) -> "SearchResults":
return cls(
hits=[
SearchHit(
Expand All @@ -50,7 +63,9 @@ def from_es_results(cls, results: dict[str, Any]) -> "SearchResults":
level=hit["_source"]["level"],
position=hit["_source"]["position"],
score=hit["_score"],
highlight=cls.__transform_highlights(hit),
highlight=cls.__transform_highlights(
hit, highlights_generators or {}
),
)
for hit in results["hits"]["hits"]
],
Expand Down
11 changes: 11 additions & 0 deletions Tekst-API/tekst/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pkgutil

from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import lru_cache
from os.path import realpath
from pathlib import Path
Expand Down Expand Up @@ -365,6 +366,16 @@ def rtype_es_queries(
Common content fields are not included in the returned queries.
"""

@classmethod
def highlights_generator(cls) -> Callable[[dict[str, Any]], list[str]] | None:
"""
For resource types that need a custom highlights generator, this method can be
overwritten to return a function that takes a list of search hits and returns
custom highlights for them. If this function returns None (the default if not
overwritten), the default highlighting will be used.
"""
return None

@classmethod
@abstractmethod
async def export(
Expand Down
97 changes: 62 additions & 35 deletions Tekst-API/tekst/resources/text_annotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import csv

from collections.abc import Callable
from pathlib import Path
from typing import Annotated, Any, Literal
from uuid import uuid4

from pydantic import BeforeValidator, Field, StringConstraints, field_validator
from typing_extensions import TypeAliasType
Expand Down Expand Up @@ -35,6 +37,42 @@ def content_model(cls) -> type["TextAnnotationContent"]:
def search_query_model(cls) -> type["TextAnnotationSearchQuery"]:
return TextAnnotationSearchQuery

@classmethod
def rtype_index_doc_props(cls) -> dict[str, Any]:
return {
"tokens": {
"type": "nested",
"dynamic": True,
"properties": {
"token": {
"type": "keyword",
"normalizer": "no_diacritics_normalizer",
"fields": {"strict": {"type": "keyword"}},
}
},
},
}

@classmethod
def rtype_index_doc_data(
cls,
content: "TextAnnotationContent",
) -> dict[str, Any]:
return {
"tokens": [
{
"token": token.token or "",
"annotations": {
anno.key: anno.value[0] if len(anno.value) == 1 else anno.value
for anno in token.annotations
}
if token.annotations
else None,
}
for token in content.tokens
],
}

@classmethod
def rtype_es_queries(
cls,
Expand All @@ -45,9 +83,10 @@ def rtype_es_queries(
es_queries = []
strict_suffix = ".strict" if strict else ""
res_id = str(query.common.resource_id)
q_id = str(uuid4())

if (
not query.resource_type_specific.token.strip("* ")
query.resource_type_specific.token.strip(" ") == "*"
and not query.resource_type_specific.annotations
):
# handle empty/match-all query (query for existing target resource field)
Expand All @@ -60,10 +99,14 @@ def rtype_es_queries(
"field": f"resources.{res_id}.tokens.token",
}
},
"inner_hits": {"name": q_id},
}
}
)
else:
elif (
query.resource_type_specific.token
or query.resource_type_specific.annotations
):
# construct token query
token_query = (
{
Expand Down Expand Up @@ -136,47 +179,31 @@ def rtype_es_queries(
],
},
},
"inner_hits": {"name": q_id},
}
}
)

return es_queries

@classmethod
def rtype_index_doc_props(cls) -> dict[str, Any]:
return {
"tokens": {
"type": "nested",
"dynamic": True,
"properties": {
"token": {
"type": "keyword",
"normalizer": "no_diacritics_normalizer",
"fields": {"strict": {"type": "keyword"}},
}
},
},
}
def highlights_generator(cls) -> Callable[[dict[str, Any]], list[str]] | None:
def _highlights_generator(hit: dict[str, Any]) -> list[str]:
hl_strings = []
for hl_k, hl_v in hit["highlight"].items():
if ".comment" in hl_k:
hl_strings.extend(hl_v)
for ih in hit.get("inner_hits", {}).values():
for ih_hit in ih.get("hits", {}).get("hits", []):
token = ih_hit["_source"]["token"]
annos = ih_hit["_source"]["annotations"]
annos = (
f" ({'; '.join([a for a in annos.values()])})" if annos else ""
)
hl_strings.append(f"{token} {annos}")
return hl_strings

@classmethod
def rtype_index_doc_data(
cls,
content: "TextAnnotationContent",
) -> dict[str, Any]:
return {
"tokens": [
{
"token": token.token or "",
"annotations": {
anno.key: anno.value[0] if len(anno.value) == 1 else anno.value
for anno in token.annotations
}
if token.annotations
else None,
}
for token in content.tokens
],
}
return _highlights_generator

@classmethod
async def export(
Expand Down
10 changes: 6 additions & 4 deletions Tekst-API/tekst/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,16 @@ async def search_advanced(
# construct all the sub-queries
sub_queries_must = []
sub_queries_should = []

highlights_generators = {} # special highlights generators, if any
for q in queries:
if str(q.common.resource_id) in target_resource_ids:
resource_es_queries = resource_types_mgr.get(
q.resource_type_specific.resource_type
).es_queries(
res_type = resource_types_mgr.get(q.resource_type_specific.resource_type)
resource_es_queries = res_type.es_queries(
query=q,
strict=settings_general.strict,
)
if (hl_gen := res_type.highlights_generator()) is not None:
highlights_generators[str(q.common.resource_id)] = hl_gen
if q.common.required:
sub_queries_must.extend(resource_es_queries)
else:
Expand Down Expand Up @@ -497,4 +498,5 @@ async def search_advanced(
source={"includes": QUERY_SOURCE_INCLUDES},
timeout=_cfg.es.timeout_search_s,
),
highlights_generators=highlights_generators,
)

0 comments on commit 448f6ab

Please sign in to comment.