Skip to content

Commit

Permalink
add table prompt formatting and splitting
Browse files Browse the repository at this point in the history
remove talkjs code
  • Loading branch information
devxpy committed Oct 30, 2023
1 parent 6a32208 commit f2e87e7
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 193 deletions.
178 changes: 122 additions & 56 deletions daras_ai_v2/azure_doc_extract.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import json
import csv
import io
import re
import typing
from time import sleep

import requests
from furl import furl
from jinja2.lexer import whitespace_re

from daras_ai_v2 import settings
from gooeysite import wsgi

assert wsgi

from time import sleep
from daras_ai_v2.redis_cache import redis_cache_decorator
from daras_ai_v2.text_splitter import default_length_function

auth_headers = {"Ocp-Apim-Subscription-Key": settings.AZURE_FORM_RECOGNIZER_KEY}


def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"):
result = azure_form_recognizer(pdf_url, model_id)
return [
records_to_text(extract_records(result, page["pageNumber"]))
for page in result["pages"]
]


@redis_cache_decorator
def azure_form_recognizer(pdf_url: str, model_id: str):
r = requests.post(
str(
furl(settings.AZURE_FORM_RECOGNIZER_ENDPOINT)
Expand All @@ -27,19 +37,12 @@ def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"):
r.raise_for_status()
location = r.headers["Operation-Location"]
while True:
r = requests.get(
location,
headers=auth_headers,
)
r = requests.get(location, headers=auth_headers)
r.raise_for_status()
r_json = r.json()
match r_json.get("status"):
case "succeeded":
result = r_json["analyzeResult"]
return [
records_to_text(extract_records(result, page["pageNumber"]))
for page in result["pages"]
]
return r_json["analyzeResult"]
case "failed":
raise Exception(r_json)
case _:
Expand All @@ -60,14 +63,14 @@ def extract_records(result: dict, page_num: int) -> list[dict]:
outer=table["polygon"], inner=para["boundingRegions"][0]["polygon"]
):
if not table.get("added"):
records.append({"role": "table", "content": table["content"]})
records.append({"role": "csv", "content": table["content"]})
table["added"] = True
break
else:
records.append(
{
"role": para.get("role", ""),
"content": remove_selection_marks(para["content"]).strip(),
"content": strip_content(para["content"]),
}
)
return records
Expand All @@ -87,26 +90,13 @@ def records_to_text(records: list[dict]) -> str:
return ret.strip()


# def table_to_html(table):
# with redirect_stdout(io.StringIO()) as f:
# print("<table>")
# print("<tr>")
# idx = 0
# for cell in table["cells"]:
# if idx != cell["rowIndex"]:
# print("</tr>")
# print("<tr>")
# idx = cell["rowIndex"]
# if cell.get("kind") == "columnHeader":
# tag = "th"
# else:
# tag = "td"
# print(
# f"<{tag} rowspan={cell.get('rowSpan', 1)} colspan={cell.get('columnSpan',1)}>{cell['content'].strip()}</{tag}>"
# )
# print("</tr>")
# print("</table>")
# return f.getvalue()
def rect_contains(*, outer: list[int], inner: list[int]):
tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y = outer
for pt_x, pt_y in zip(inner[::2], inner[1::2]):
# if the point is inside the bounding box, return True
if tl_x <= pt_x <= tr_x and tl_y <= pt_y <= bl_y:
return True
return False


def extract_tables(result, page):
Expand All @@ -117,7 +107,7 @@ def extract_tables(result, page):
continue
except (KeyError, IndexError):
continue
plain = table_to_plain(table)
plain = table_to_csv(table)
table_polys.append(
{
"polygon": table["boundingRegions"][0]["polygon"],
Expand All @@ -128,36 +118,112 @@ def extract_tables(result, page):
return table_polys


def rect_contains(*, outer: list[int], inner: list[int]):
tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y = outer
for pt_x, pt_y in zip(inner[::2], inner[1::2]):
# if the point is inside the bounding box, return True
if tl_x <= pt_x <= tr_x and tl_y <= pt_y <= bl_y:
return True
return False
def table_to_csv(table: dict) -> str:
return table_arr_to_csv(table_to_arr(table))


def table_to_plain(table):
THEAD = "**"


def table_to_arr(table: dict) -> list[list[str]]:
with open(f"table-{table['columnCount']}.json", "w") as f:
f.write(str(table))
arr = [["" for _ in range(table["columnCount"])] for _ in range(table["rowCount"])]
for cell in table["cells"]:
for i in range(cell.get("rowSpan", 1)):
row_idx = cell["rowIndex"] + i
for j in range(cell.get("columnSpan", 1)):
col_idx = cell["columnIndex"] + j
arr[row_idx][col_idx] = remove_selection_marks(cell["content"]).strip()
content = strip_content(cell["content"])
if cell.get("kind") in ("rowHeader", "columnHeader", "stubHead"):
content = THEAD + content + THEAD
arr[row_idx][col_idx] = content
return arr


# NOTE: These are individual tokens in the gpt-4 vocab, and must be handled with care
THEAD_SEP = "|--"
TROW_END = "|\n"
TROW_SEP = " |"


def table_arr_to_prompt(arr: typing.Iterable[list[str]]) -> str:
text = ""
prev_is_header = True
for row in arr:
is_header = _strip_header_from_row(row)
row = _remove_long_dupe_header(row)
if prev_is_header and not is_header:
text += THEAD_SEP * len(row) + TROW_END
text += TROW_SEP + TROW_SEP.join(row) + TROW_END
prev_is_header = is_header
return text


def table_arr_to_prompt_chunked(
arr: typing.Iterable[list[str]], chunk_size: int
) -> typing.Iterable[str]:
header = ""
chunk = ""
prev_is_header = True
for row in arr:
is_header = _strip_header_from_row(row)
row = _remove_long_dupe_header(row)
if prev_is_header and not is_header:
header += THEAD_SEP * len(row) + TROW_END
next_chunk = TROW_SEP + TROW_SEP.join(row) + TROW_END
if is_header:
header += next_chunk
if default_length_function(header) > chunk_size:
yield header
header = ""
else:
if default_length_function(header + chunk + next_chunk) > chunk_size:
yield header + chunk.rstrip()
chunk = ""
chunk += next_chunk
prev_is_header = is_header
if chunk:
yield header + chunk.rstrip()


def _strip_header_from_row(row):
is_header = False
for i, cell in enumerate(row):
if cell.startswith(THEAD) and cell.endswith(THEAD):
row[i] = cell[len(THEAD) : -len(THEAD)]
is_header = True
return is_header


def _remove_long_dupe_header(row: list[str], cutoff: int = 2) -> list[str]:
r = -1
l = 0
for cell in row[-2::-1]:
if not row[r]:
r -= 1
l -= 1
continue
if cell == row[r]:
l -= 1
else:
break
if -l >= cutoff:
row = row[:l] + [""] * -l
return row

## note: ' |' ' |\n' ' |--' are individual tokens in the gpt-4 tokenizer, and must be handled with care
ret = ""
for i, row in enumerate(arr):
ret += " |" + " |".join(row) + " |\n"
if i == 0:
ret += " " + ("|--" * table["columnCount"]) + "|\n"

return ret
def table_arr_to_csv(arr: typing.Iterable[list[str]]) -> str:
f = io.StringIO()
writer = csv.writer(f)
writer.writerows(arr)
return f.getvalue()


selection_marks_re = re.compile(r":(un)?selected:")


def remove_selection_marks(text):
return selection_marks_re.sub("", text)
def strip_content(text: str) -> str:
text = selection_marks_re.sub("", text)
text = whitespace_re.sub(" ", text)
return text.strip()
14 changes: 7 additions & 7 deletions daras_ai_v2/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from google.cloud import firestore

FIREBASE_SESSION_COOKIE = "firebase_session"
ANONYMOUS_USER_COOKIE = "anonymous_user"

Expand All @@ -14,13 +12,15 @@


def get_client():
from google.cloud import firestore

global _client
if _client is None:
_client = firestore.Client()
return _client


def get_doc_field(doc_ref: firestore.DocumentReference, field: str, default=None):
def get_doc_field(doc_ref: "firestore.DocumentReference", field: str, default=None):
snapshot = doc_ref.get([field])
if not snapshot.exists:
return default
Expand All @@ -30,13 +30,13 @@ def get_doc_field(doc_ref: firestore.DocumentReference, field: str, default=None
return default


def get_user_doc_ref(uid: str) -> firestore.DocumentReference:
def get_user_doc_ref(uid: str) -> "firestore.DocumentReference":
return get_doc_ref(collection_id=USERS_COLLECTION, document_id=uid)


def get_or_create_doc(
doc_ref: firestore.DocumentReference,
) -> firestore.DocumentSnapshot:
doc_ref: "firestore.DocumentReference",
) -> "firestore.DocumentSnapshot":
doc = doc_ref.get()
if not doc.exists:
doc_ref.create({})
Expand All @@ -50,7 +50,7 @@ def get_doc_ref(
collection_id=DEFAULT_COLLECTION,
sub_collection_id: str = None,
sub_document_id: str = None,
) -> firestore.DocumentReference:
) -> "firestore.DocumentReference":
db_collection = get_client().collection(collection_id)
doc_ref = db_collection.document(document_id)
if sub_collection_id:
Expand Down
10 changes: 2 additions & 8 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from daras_ai_v2.redis_cache import (
get_redis_cache,
)
from daras_ai_v2.text_splitter import default_length_function

DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible."

Expand Down Expand Up @@ -120,20 +121,13 @@ def is_chat_model(self) -> bool:
LargeLanguageModels.llama2_70b_chat: 4096,
}

threadlocal = threading.local()


def calc_gpt_tokens(
text: str | list[str] | dict | list[dict],
*,
sep: str = "",
is_chat_model: bool = True,
) -> int:
try:
enc = threadlocal.gpt2enc
except AttributeError:
enc = tiktoken.get_encoding("gpt2")
threadlocal.gpt2enc = enc
if isinstance(text, (str, dict)):
messages = [text]
else:
Expand All @@ -151,7 +145,7 @@ def calc_gpt_tokens(
else str(entry)
)
)
return len(enc.encode(combined))
return default_length_function(combined)


def get_openai_error_cls():
Expand Down
8 changes: 5 additions & 3 deletions daras_ai_v2/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@

def default_length_function(text: str) -> int:
try:
enc = threadlocal.gpt2enc
enc = threadlocal.enc
except AttributeError:
enc = tiktoken.get_encoding("gpt2")
threadlocal.gpt2enc = enc
enc = tiktoken.encoding_for_model("gpt-4")
threadlocal.enc = enc
return len(enc.encode(text))


Expand All @@ -60,12 +60,14 @@ def __init__(
text: str,
span: tuple[int, int],
length_function: L = default_length_function,
**kwargs,
):
self.text = text
self.span = span
self.start = self.span[0]
self.end = self.span[1]
self.length_function = length_function
self.kwargs = kwargs

def __len__(self):
if self._length is None:
Expand Down
Loading

0 comments on commit f2e87e7

Please sign in to comment.