Skip to content

Commit

Permalink
feat(signals): Improve send signals process (#24)
Browse files Browse the repository at this point in the history
* feat(sqlstorage): Add limit and offset arguments in get_all_signals function
* feat(sqlstorage): Change delete_signals and delete_machines method to improve pruning
* test(*): Add some test path to gitignore
* feat(storage): Modify get_all_signals to have sent and is_failing param
* feat(client): Add batch_size param for send_signals and prune_failing_machines_signals
* feat(storage): Rename get_all_signals to get_signals
* feat(log): Improve log messages
* feat(signal): Return number of sent and pruned signal in related methods
* test(*): Add test for get_signals
* chore(changelog): Prepare release 0.4.0
* fix(send): Do not offset even if no prune as we are taking only not sent signals
* feat(signals): Rename _send_signals to _send_signals_to_capi
* style(*): Pass through lint
* style(*): Pass through lint
  • Loading branch information
julienloizelet authored Feb 23, 2024
1 parent 9826f2a commit f1c43d3
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 77 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,7 @@ src/cscapi/_version.py
.vscode

#ddev
.ddev
.ddev

# customm scripts
examples/**/perf*
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@ functions provided by the `src/cscapi` folder.

---

## [0.4.0](https://github.com/crowdsecurity/python-capi-sdk/releases/tag/v0.4.0) - 2024-02-23
[_Compare with previous release_](https://github.com/crowdsecurity/python-capi-sdk/compare/v0.3.0...v0.4.0)


### Changed

- **Breaking change**: Rename `StorageInterface::get_all_signals` to `get_signals` and add `limit`, `offset`, `sent` and `is_failing` arguments
- **Breaking change**: Change `StorageInterface::delete_signals` signature to require a list of signal ids
- **Breaking change**: Change `StorageInterface::delete_machines` signature to require a list of machine ids
- Add `batch_size` argument to `CAPIClient::send_signals` and `CAPIClient::prune_failing_machines_signals` methods
- `CAPIClient::send_signals` and `CAPIClient::prune_failing_machines_signals` now return the number of signals sent or pruned
- `CAPIClient::send_signals` and `CAPIClient::prune_failing_machines_signals` now send and prune signals in batches


### Removed

- **Breaking change**: Remove `CAPIClient::_prune_sent_signals` method


---

## [0.3.0](https://github.com/crowdsecurity/python-capi-sdk/releases/tag/v0.3.0) - 2024-02-16
[_Compare with previous release_](https://github.com/crowdsecurity/python-capi-sdk/compare/v0.2.1...v0.3.0)

Expand Down
16 changes: 16 additions & 0 deletions examples/shell_scripts/prune_failing_machines_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@

import argparse
import sys
import logging

from cscapi.client import CAPIClient, CAPIClientConfig
from cscapi.sql_storage import SQLStorage

logger = logging.getLogger("capi-py-sdk")
logger.setLevel(logging.DEBUG) # Change this to the level you want
console_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


class CustomHelpFormatter(argparse.HelpFormatter):
def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):
Expand All @@ -26,6 +34,12 @@ def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):
help="Local database name. Example: cscapi.db",
required=True,
)
parser.add_argument(
"--batch_size",
type=int,
help="Batch size for pruning signals. Example: 1000",
default=1000,
)
args = parser.parse_args()
except argparse.ArgumentError as e:
print(e)
Expand All @@ -34,11 +48,13 @@ def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):

database = args.database
database_message = f"\tLocal storage database: {database}\n"
batch_size_message = f"\tBatch size: {args.batch_size}\n"

print(
f"\nPruning signals for failing machines\n\n"
f"Details:\n"
f"{database_message}"
f"{batch_size_message}"
f"\n\n"
)

Expand Down
20 changes: 19 additions & 1 deletion examples/shell_scripts/send_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import sys
import logging
import time
from cscapi.client import CAPIClient, CAPIClientConfig
from cscapi.sql_storage import SQLStorage
from cscapi.utils import create_signal
Expand Down Expand Up @@ -65,6 +66,12 @@ def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):
help="Local database name. Example: cscapi.db",
default=None,
)
parser.add_argument(
"--batch_size",
type=int,
help="Batch size for sending signals. Example: 1000",
default=1000,
)
parser.add_argument(
"--context",
type=str,
Expand Down Expand Up @@ -103,6 +110,7 @@ def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):
else "cscapi_examples_prod.db" if args.prod else "cscapi_examples_dev.db"
)
database_message = f"\tLocal storage database: {database}\n"
batch_size_message = f"\tBatch size: {args.batch_size}\n"

print(
f"\nSending signal for {machine_id_message}\n\n"
Expand All @@ -114,6 +122,7 @@ def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):
f"{context_message}"
f"{machine_scenarios_message}"
f"{database_message}"
f"{batch_size_message}"
f"{user_agent_message}"
f"\n\n"
)
Expand Down Expand Up @@ -145,4 +154,13 @@ def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):

client.add_signals(signals)

client.send_signals()
total_start_time = time.time()

print(f"Starting time elapsed for sending signals: {total_start_time} seconds")

client.send_signals(batch_size=args.batch_size)

total_end_time = time.time()
print(
f"Total time elapsed for sending signals: {total_end_time - total_start_time:.2f} seconds"
)
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
black
pytest
pytest-dotenv
pytest-httpx
pytest-httpx==0.29.0
pytest-timeout
sqlalchemy-utils
mysql-connector-python
pymysql
pg8000
psutil

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sqlalchemy>=1.4
python-dateutil
httpx
httpx==0.26.*
dacite
importlib-metadata
pyjwt
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package_dir =
packages = find:
python_requires = >=3.9
install_requires =
sqlalchemy
sqlalchemy>=1.4
python-dateutil
httpx==0.26.*
dacite
Expand Down
97 changes: 64 additions & 33 deletions src/cscapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from dataclasses import asdict, replace, dataclass
from importlib import metadata
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, Tuple

import httpx
import jwt
Expand All @@ -24,6 +24,7 @@
CAPI_SIGNALS_ENDPOINT = "/signals"
CAPI_DECISIONS_ENDPOINT = "/decisions/stream"
CAPI_METRICS_ENDPOINT = "/metrics"
SIGNAL_BATCH_LIMIT = 1000


def has_valid_token(
Expand Down Expand Up @@ -85,23 +86,48 @@ def add_signals(self, signals: List[SignalModel]):
for signal in signals:
self.storage.update_or_create_signal(signal)

def prune_failing_machines_signals(self):
signals = self.storage.get_all_signals()
for machine_id, signals in _group_signals_by_machine_id(signals).items():
machine = self.storage.get_machine_by_id(machine_id)
if machine.is_failing:
self.storage.delete_signals(signals)
def prune_failing_machines_signals(
self, batch_size: int = SIGNAL_BATCH_LIMIT
) -> int:
total_pruned = 0
while True:
signals = self.storage.get_signals(limit=batch_size, is_failing=True)
if not signals:
break

def send_signals(self, prune_after_send: bool = True):
unsent_signals_by_machineid = _group_signals_by_machine_id(
filter(lambda signal: not signal.sent, self.storage.get_all_signals())
)
self._send_signals_by_machine_id(unsent_signals_by_machineid, prune_after_send)
signal_ids = [signal.alert_id for signal in signals]
self.storage.delete_signals(signal_ids)
total_pruned += len(signals)

self.logger.info(f"Total pruned signals: {total_pruned}")
return total_pruned

def send_signals(
self, prune_after_send: bool = True, batch_size: int = SIGNAL_BATCH_LIMIT
) -> int:
offset = 0
total_sent = 0
while True:
signals = self.storage.get_signals(
limit=batch_size, offset=offset, sent=False, is_failing=False
)
if not signals:
self.logger.info(f"No signals to send, stopping sending")
break
unsent_signals_by_machineid = _group_signals_by_machine_id(signals)

batch_sent = self._send_signals_by_machine_id(
unsent_signals_by_machineid, prune_after_send
)
total_sent += batch_sent

self.logger.info(f"Total sent signals: {total_sent}")
return total_sent

def _has_valid_scenarios(self, machine: MachineModel) -> bool:
current_scenarios = self.scenarios
stored_scenarios = machine.scenarios
if len(stored_scenarios) == 0:
if not stored_scenarios:
return False

return current_scenarios == stored_scenarios
Expand All @@ -110,16 +136,17 @@ def _send_signals_by_machine_id(
self,
signals_by_machineid: Dict[str, List[SignalModel]],
prune_after_send: bool = False,
):
) -> int:
machines_to_process_attempts: List[MachineModel] = [
MachineModel(machine_id=machine_id, scenarios=self.scenarios)
for machine_id in signals_by_machineid.keys()
]

attempt_count = 0
total_sent = 0

while machines_to_process_attempts:
self.logger.info(f"attempt {attempt_count} to send signals")
self.logger.info(f"attempt {attempt_count + 1} to send signals")
retry_machines_to_process_attempts: List[MachineModel] = []
if attempt_count >= self.max_retries:
for machine_to_process in machines_to_process_attempts:
Expand All @@ -142,11 +169,15 @@ def _send_signals_by_machine_id(
self.logger.info(
f"sending signals for machine {machine_to_process.machine_id}"
)
sent_signal_ids = []
try:
self._send_signals(
sent_signal_ids = self._send_signals_to_capi(
machine_to_process.token,
signals_by_machineid[machine_to_process.machine_id],
)
sent_signal_ids_count = len(sent_signal_ids)
total_sent += sent_signal_ids_count
self.logger.info(f"sent {sent_signal_ids_count} signals")
except httpx.HTTPStatusError as exc:
self.logger.error(
f"error while sending signals: {exc} for machine {machine_to_process.machine_id}"
Expand All @@ -160,11 +191,11 @@ def _send_signals_by_machine_id(
machine_to_process.token = None
retry_machines_to_process_attempts.append(machine_to_process)
continue
if prune_after_send:
if prune_after_send and sent_signal_ids:
self.logger.info(
f"pruning sent signals for machine {machine_to_process.machine_id}"
)
self._prune_sent_signals()
self.storage.delete_signals(sent_signal_ids)

self.logger.info(
f"sending metrics for machine {machine_to_process.machine_id}"
Expand All @@ -179,15 +210,20 @@ def _send_signals_by_machine_id(

attempt_count += 1
machines_to_process_attempts = retry_machines_to_process_attempts
if (len(retry_machines_to_process_attempts) != 0) and (
if retry_machines_to_process_attempts and (
attempt_count < self.max_retries
):
self.logger.info(
f"waiting {self.retry_delay} seconds before retrying sending signals"
)
time.sleep(self.retry_delay)

def _send_signals(self, token: str, signals: SignalModel):
return total_sent

def _send_signals_to_capi(
self, token: str, signals: List[SignalModel]
) -> List[int]:
result = []
for signal_batch in batched(signals, 250):
body = [asdict(signal) for signal in signal_batch]
resp = self.http_client.post(
Expand All @@ -196,11 +232,17 @@ def _send_signals(self, token: str, signals: SignalModel):
headers={"Authorization": token},
)
resp.raise_for_status()
self._mark_signals_as_sent(signal_batch)
result.extend(self._mark_signals_as_sent(signal_batch))

return result

def _mark_signals_as_sent(self, signals: List[SignalModel]):
def _mark_signals_as_sent(self, signals: Tuple[SignalModel]) -> List[int]:
result = []
for signal in signals:
self.storage.update_or_create_signal(replace(signal, sent=True))
result.append(signal.alert_id)

return result

def _send_metrics_for_machine(self, machine: MachineModel):
for _ in range(self.max_retries + 1):
Expand All @@ -227,17 +269,6 @@ def _send_metrics_for_machine(self, machine: MachineModel):
f"received error {exc} while sending metrics for machine {machine.machine_id}"
)

def _prune_sent_signals(self):
signals = list(
filter(lambda signal: signal.sent, self.storage.get_all_signals())
)

self.storage.delete_signals(signals)

def _clear_all_signals(self):
signals = self.storage.get_all_signals()
self.storage.delete_signals(signals)

def _refresh_machine_token(self, machine: MachineModel) -> MachineModel:
machine.scenarios = self.scenarios
resp = self.http_client.post(
Expand Down
Loading

0 comments on commit f1c43d3

Please sign in to comment.