Skip to content

Commit

Permalink
Merge pull request #13 from Zipstack/feat/support-to-configure-db-path
Browse files Browse the repository at this point in the history
feat: Added support to configure DB path, `result` in CSV report
  • Loading branch information
ritwik-g authored Dec 12, 2024
2 parents ae62f8e + ee85994 commit adc6aeb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 49 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
*.db
.venv/
*.csv
.mypy_cache/
.venv/
.python-version
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ This will display detailed usage information.
- `-t`, `--api_timeout`: Timeout (in seconds) for API requests (default: 10).
- `-i`, `--poll_interval`: Interval (in seconds) between API status polls (default: 5).
- `-p`, `--parallel_call_count`: Number of parallel API calls (default: 10).
- `--csv_report`: Path to export the detailed report as a CSV file.
- `--db_path`: Path where the SQlite DB file is stored (default: './file_processing.db')
- `--retry_failed`: Retry processing of failed files.
- `--retry_pending`: Retry processing of pending files by making new requests.
- `--skip_pending`: Skip processing of pending files.
Expand All @@ -67,7 +69,6 @@ This will display detailed usage information.
- `--print_report`: Print a detailed report of all processed files at the end.
- `--exclude_metadata`: Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.
- `--no_verify`: Disable SSL certificate verification. (By default, SSL verification is enabled.)
- `--csv_report`: Path to export the detailed report as a CSV file.

## Usage Examples

Expand Down
103 changes: 56 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from tqdm import tqdm
from unstract.api_deployments.client import APIDeploymentsClient

DB_NAME = "file_processing.db"
global_arguments = None
logger = logging.getLogger(__name__)


Expand All @@ -29,6 +27,7 @@ class Arguments:
api_timeout: int = 10
poll_interval: int = 5
input_folder_path: str = ""
db_path: str = ""
parallel_call_count: int = 5
retry_failed: bool = False
retry_pending: bool = False
Expand All @@ -42,8 +41,8 @@ class Arguments:


# Initialize SQLite DB
def init_db():
conn = sqlite3.connect(DB_NAME)
def init_db(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

# Create the table if it doesn't exist
Expand Down Expand Up @@ -89,7 +88,7 @@ def init_db():

# Check if the file is already processed
def skip_file_processing(file_name, args: Arguments):
conn = sqlite3.connect(DB_NAME)
conn = sqlite3.connect(args.db_path)
c = conn.cursor()
c.execute(
"SELECT execution_status FROM file_status WHERE file_name = ?", (file_name,)
Expand Down Expand Up @@ -124,6 +123,7 @@ def update_db(
time_taken,
status_code,
status_api_endpoint,
args: Arguments
):

total_embedding_cost = None
Expand All @@ -138,7 +138,7 @@ def update_db(
if execution_status == "ERROR":
error_message = extract_error_message(result)

conn = sqlite3.connect(DB_NAME)
conn = sqlite3.connect(args.db_path)
conn.set_trace_callback(
lambda x: (
logger.debug(f"[{file_name}] Executing statement: {x}")
Expand Down Expand Up @@ -232,8 +232,8 @@ def extract_error_message(result):
return result.get("error", "No error message found")

# Print final summary with count of each status and average time using a single SQL query
def print_summary():
conn = sqlite3.connect(DB_NAME)
def print_summary(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

# Fetch count and average time for each status
Expand All @@ -255,8 +255,8 @@ def print_summary():
print(f"Status '{status}': {count}")


def print_report():
conn = sqlite3.connect(DB_NAME)
def print_report(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

# Fetch required fields, including total_cost and total_tokens
Expand Down Expand Up @@ -318,36 +318,36 @@ def print_report():

print("\nNote: For more detailed error messages, use the CSV report argument.")

def export_report_to_csv(output_path):
conn = sqlite3.connect(DB_NAME)
def export_report_to_csv(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

c.execute(
"""
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
SELECT file_name, execution_status, result, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
FROM file_status
"""
)
report_data = c.fetchall()
conn.close()

if not report_data:
print("No data available to export.")
print("No data available to export as CSV.")
return

# Define the headers
headers = [
"File Name", "Execution Status", "Time Elapsed (seconds)",
"File Name", "Execution Status", "Result", "Time Elapsed (seconds)",
"Total Embedding Cost", "Total Embedding Tokens",
"Total LLM Cost", "Total LLM Tokens", "Error Message"
]

try:
with open(output_path, 'w', newline='') as csvfile:
with open(args.csv_report, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(headers) # Write headers
writer.writerows(report_data) # Write data rows
print(f"CSV successfully exported to {output_path}")
print(f"CSV successfully exported to '{args.csv_report}'")
except Exception as e:
print(f"Error exporting to CSV: {e}")

Expand All @@ -357,7 +357,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
status_endpoint = None

# If retry_pending is True, check if the status API endpoint is available
conn = sqlite3.connect(DB_NAME)
conn = sqlite3.connect(args.db_path)
c = conn.cursor()
c.execute(
"SELECT status_api_endpoint FROM file_status WHERE file_name = ? AND execution_status NOT IN ('COMPLETED', 'ERROR')",
Expand All @@ -382,7 +382,7 @@ def get_status_endpoint(file_path, client, args: Arguments):

# Fresh API call to process the file
execution_status = "STARTING"
update_db(file_path, execution_status, None, None, None, None)
update_db(file_path, execution_status, None, None, None, None, args=args)
response = client.structure_file(file_paths=[file_path])
logger.debug(f"[{file_path}] Response of initial API call: {response}")
status_endpoint = response.get(
Expand All @@ -397,6 +397,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
None,
status_code,
status_endpoint,
args=args
)
return status_endpoint, execution_status, response

Expand Down Expand Up @@ -436,7 +437,7 @@ def process_file(
execution_status = response.get("execution_status")
status_code = response.get("status_code") # Default to 200 if not provided
update_db(
file_path, execution_status, None, None, status_code, status_endpoint
file_path, execution_status, None, None, status_code, status_endpoint, args=args
)

result = response
Expand All @@ -456,7 +457,7 @@ def process_file(
end_time = time.time()
time_taken = round(end_time - start_time, 2)
update_db(
file_path, execution_status, result, time_taken, status_code, status_endpoint
file_path, execution_status, result, time_taken, status_code, status_endpoint, args=args
)
logger.info(f"[{file_path}]: Processing completed: {execution_status}")

Expand Down Expand Up @@ -501,14 +502,14 @@ def load_folder(args: Arguments):


def main():
parser = argparse.ArgumentParser(description="Process files using the API.")
parser = argparse.ArgumentParser(description="Process files using Unstract's API deployment")
parser.add_argument(
"-e",
"--api_endpoint",
dest="api_endpoint",
type=str,
required=True,
help="API Endpoint to use for processing the files.",
help="API Endpoint to use for processing the files",
)
parser.add_argument(
"-k",
Expand All @@ -524,55 +525,68 @@ def main():
dest="api_timeout",
type=int,
default=10,
help="Time in seconds to wait before switching to async mode.",
help="Time in seconds to wait before switching to async mode (default: 10)",
)
parser.add_argument(
"-i",
"--poll_interval",
dest="poll_interval",
type=int,
default=5,
help="Time in seconds the process will sleep between polls in async mode.",
help="Time in seconds the process will sleep between polls in async mode (default: 5)",
)
parser.add_argument(
"-f",
"--input_folder_path",
dest="input_folder_path",
type=str,
required=True,
help="Path where the files to process are present.",
help="Path where the files to process are present",
)
parser.add_argument(
"-p",
"--parallel_call_count",
dest="parallel_call_count",
type=int,
default=5,
help="Number of calls to be made in parallel.",
help="Number of calls to be made in parallel (default: 5)",
)
parser.add_argument(
"--db_path",
dest="db_path",
type=str,
default="file_processing.db",
help="Path where the SQlite DB file is stored (default: './file_processing.db)'",
)
parser.add_argument(
'--csv_report',
dest="csv_report",
type=str,
help='Path to export the detailed report as a CSV file',
)
parser.add_argument(
"--retry_failed",
dest="retry_failed",
action="store_true",
help="Retry processing of failed files.",
help="Retry processing of failed files (default: True)",
)
parser.add_argument(
"--retry_pending",
dest="retry_pending",
action="store_true",
help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API).",
help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API) (default: True)",
)
parser.add_argument(
"--skip_pending",
dest="skip_pending",
action="store_true",
help="Skip processing of pending files (Over rides --retry-pending).",
help="Skip processing of pending files (overrides --retry-pending) (default: True)",
)
parser.add_argument(
"--skip_unprocessed",
dest="skip_unprocessed",
action="store_true",
help="Skip unprocessed files while retry processing of failed files.",
help="Skip unprocessed files while retry processing of failed files (default: True)",
)
parser.add_argument(
"--log_level",
Expand All @@ -586,52 +600,47 @@ def main():
"--print_report",
dest="print_report",
action="store_true",
help="Print a detailed report of all file processed.",
help="Print a detailed report of all file processed (default: True)",
)

parser.add_argument(
"--exclude_metadata",
dest="include_metadata",
action="store_false",
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.",
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file (default: False)",
)

parser.add_argument(
"--no_verify",
dest="verify",
action="store_false",
help="Disable SSL certificate verification.",
)

parser.add_argument(
'--csv_report',
dest="csv_report",
type=str,
help='Path to export the detailed report as a CSV file',
help="Disable SSL certificate verification (default: False)",
)

args = Arguments(**vars(parser.parse_args()))

ch = logging.StreamHandler(sys.stdout)
ch.setLevel(args.log_level)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logging.basicConfig(level=args.log_level, handlers=[ch])

logger.warning(f"Running with params: {args}")

init_db() # Initialize DB
init_db(args=args) # Initialize DB

load_folder(args=args)

print_summary() # Print summary at the end
print_summary(args=args) # Print summary at the end
if args.print_report:
print_report()
print_report(args=args)
logger.warning(
"Elapsed time calculation of a file which was resumed"
" from pending state will not be correct"
)

if args.csv_report:
export_report_to_csv(args.csv_report)
export_report_to_csv(args=args)


if __name__ == "__main__":
Expand Down

0 comments on commit adc6aeb

Please sign in to comment.