Skip to content

Commit

Permalink
combine auxiliary.py and util.py (#594)
Browse files Browse the repository at this point in the history
Moved all of util.py into auxiliary.py
Removed preserialize_object function
Moved generate_embed_from_kwargs to the ipinfo function directly
  • Loading branch information
ObsoleteXero authored Aug 29, 2023
1 parent 02ee198 commit 0842a25
Show file tree
Hide file tree
Showing 35 changed files with 301 additions and 380 deletions.
17 changes: 8 additions & 9 deletions techsupport_bot/base/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import error
import expiringdict
import munch
import util
from base import auxiliary
from discord.ext import commands
from unidecode import unidecode
Expand Down Expand Up @@ -612,9 +611,9 @@ async def on_message_edit(self, before, after):
return

attrs = ["content", "embeds"]
diff = util.get_object_diff(before, after, attrs)
diff = auxiliary.get_object_diff(before, after, attrs)
embed = discord.Embed()
embed = util.add_diff_fields(embed, diff)
embed = auxiliary.add_diff_fields(embed, diff)
embed.add_field(name="Author", value=before.author)
embed.add_field(name="Channel", value=getattr(before.channel, "name", "DM"))
embed.add_field(
Expand Down Expand Up @@ -792,10 +791,10 @@ async def on_guild_channel_update(self, before, after):
"permissions_synced",
"position",
]
diff = util.get_object_diff(before, after, attrs)
diff = auxiliary.get_object_diff(before, after, attrs)

embed = discord.Embed()
embed = util.add_diff_fields(embed, diff)
embed = auxiliary.add_diff_fields(embed, diff)
embed.add_field(name="Channel Name", value=before.name)
embed.add_field(name="Server", value=before.guild.name)

Expand Down Expand Up @@ -981,7 +980,7 @@ async def on_guild_update(self, before, after):
"""
See: https://discordpy.readthedocs.io/en/latest/api.html#discord.on_guild_update
"""
diff = util.get_object_diff(
diff = auxiliary.get_object_diff(
before,
after,
[
Expand Down Expand Up @@ -1009,7 +1008,7 @@ async def on_guild_update(self, before, after):
)

embed = discord.Embed()
embed = util.add_diff_fields(embed, diff)
embed = auxiliary.add_diff_fields(embed, diff)
embed.add_field(name="Server", value=before.name)

log_channel = await self.get_log_channel_from_guild(
Expand Down Expand Up @@ -1053,10 +1052,10 @@ async def on_guild_role_delete(self, role):
async def on_guild_role_update(self, before, after):
"""See: https://discordpy.readthedocs.io/en/latest/api.html#discord.on_guild_role_update"""
attrs = ["color", "mentionable", "name", "permissions", "position", "tags"]
diff = util.get_object_diff(before, after, attrs)
diff = auxiliary.get_object_diff(before, after, attrs)

embed = discord.Embed()
embed = util.add_diff_fields(embed, diff)
embed = auxiliary.add_diff_fields(embed, diff)
embed.add_field(name="Server", value=before.name)

log_channel = await self.get_log_channel_from_guild(
Expand Down
196 changes: 193 additions & 3 deletions techsupport_bot/base/auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
This replaces duplicate or similar code across many extensions
"""

import json
from functools import wraps

import discord
import munch


def generate_basic_embed(
Expand Down Expand Up @@ -77,7 +81,7 @@ async def add_list_of_reactions(message: discord.Message, reactions: list) -> No
await message.add_reaction(emoji)


def construct_mention_string(targets: list) -> str:
def construct_mention_string(targets: list[discord.User]) -> str:
"""Builds a string of mentions from a list of users.
parameters:
Expand Down Expand Up @@ -128,7 +132,7 @@ def prepare_deny_embed(message: str) -> discord.Embed:


async def send_deny_embed(
message: str, channel: discord.abc.Messageable, author: discord.Member = None
message: str, channel: discord.abc.Messageable, author: discord.Member | None = None
) -> discord.Message:
"""Sends a formatted deny embed to the given channel
Expand Down Expand Up @@ -166,7 +170,7 @@ def prepare_confirm_embed(message: str) -> discord.Embed:


async def send_confirm_embed(
message: str, channel: discord.abc.Messageable, author: discord.Member = None
message: str, channel: discord.abc.Messageable, author: discord.Member | None = None
) -> discord.Message:
"""Sends a formatted deny embed to the given channel
Expand All @@ -184,3 +188,189 @@ async def send_confirm_embed(
content=construct_mention_string([author]), embed=embed
)
return message


async def get_json_from_attachments(
message: discord.Message, as_string: bool = False, allow_failure: bool = False
) -> munch.Munch | str | None:
"""Returns concatted JSON from a message's attachments.
parameters:
ctx (discord.ext.Context): the context object for the message
message (Message): the message object
as_string (bool): True if the serialized JSON should be returned
allow_failure (bool): True if an exception should be ignored when parsing attachments
"""
if not message.attachments:
return None

attachment_jsons = []
for attachment in message.attachments:
try:
json_bytes = await attachment.read()
attachment_jsons.append(json.loads(json_bytes.decode("UTF-8")))
except Exception as exception:
if allow_failure:
continue
raise exception

if len(attachment_jsons) == 1:
attachment_jsons = attachment_jsons[0]

return (
json.dumps(attachment_jsons) if as_string else munch.munchify(attachment_jsons)
)


def config_schema_matches(input_config: dict, current_config: dict) -> list[str] | None:
"""Performs a schema check on an input guild config.
parameters:
input_config (dict): the config to be added
current_config (dict): the current config
"""
if (
any(key not in current_config for key in input_config.keys())
or len(current_config) != len(input_config) + 1
):
added_keys = []
removed_keys = []

for key in input_config.keys():
if key not in current_config and key != "_id":
added_keys.append(key)

for key in current_config.keys():
if key not in input_config and key != "_id":
removed_keys.append(key)

result = []
for key in added_keys:
result.append("added: " + key)

for key in removed_keys:
result.append("removed: " + key)

return result

return None


def with_typing(command: discord.ext.commands.Command) -> discord.ext.commands.Command:
"""Decorator for commands to utilize "async with" ctx.typing()
This will show the bot as typing... until the command completes
parameters:
command (discord.ext.commands.Command): the command object to modify
"""
original_callback = command.callback

@wraps(original_callback)
async def typing_wrapper(*args, **kwargs):
context = args[1]

typing_func = getattr(context, "typing", None)

if not typing_func:
await original_callback(*args, **kwargs)
else:
try:
async with typing_func():
await original_callback(*args, **kwargs)
except discord.Forbidden:
await original_callback(*args, **kwargs)

# this has to be done so invoke will see the original signature
typing_wrapper.__name__ = command.name

# calls the internal setter
command.callback = typing_wrapper
command.callback.__module__ = original_callback.__module__

return command


def get_object_diff(
before: object, after: object, attrs_to_check: list
) -> munch.Munch | dict:
"""Finds differences in before, after object pairs.
before (obj): the before object
after (obj): the after object
attrs_to_check (list): the attributes to compare
"""
result = {}

for attr in attrs_to_check:
after_value = getattr(after, attr, None)
if not after_value:
continue

before_value = getattr(before, attr, None)
if not before_value:
continue

if before_value != after_value:
result[attr] = munch.munchify(
{"before": before_value, "after": after_value}
)

return result


def add_diff_fields(embed: discord.Embed, diff: dict) -> discord.Embed:
"""Adds fields to an embed based on diff data.
parameters:
embed (discord.Embed): the embed object
diff (dict): the diff data for an object
"""
for attr, diff_data in diff.items():
attru = attr.upper()
if isinstance(diff_data.before, list):
action = (
"added" if len(diff_data.before) < len(diff_data.after) else "removed"
)
list_diff = set(repr(diff_data.after)) ^ set(repr(diff_data.before))

embed.add_field(
name=f"{attru} {action}", value=",".join(str(o) for o in list_diff)
)
continue

# Checking if content is a string, and not anything else for guild update.
if isinstance(diff_data.before, str):
# expanding the before data to 4096 characters
embed.add_field(name=f"{attru} (before)", value=diff_data.before[:1024])
if len(diff_data.before) > 1024:
embed.add_field(
name=f"{attru} (before continue)", value=diff_data.before[1025:2048]
)
if len(diff_data.before) > 2048 and len(diff_data.after) <= 2800:
embed.add_field(
name=f"{attru} (before continue)", value=diff_data.before[2049:3072]
)
if len(diff_data.before) > 3072 and len(diff_data.after) <= 1800:
embed.add_field(
name=f"{attru} (before continue)", value=diff_data.before[3073:4096]
)

# expanding the after data to 4096 characters
embed.add_field(name=f"{attru} (after)", value=diff_data.after[:1024])
if len(diff_data.after) > 1024:
embed.add_field(
name=f"{attru} (after continue)", value=diff_data.after[1025:2048]
)
if len(diff_data.after) > 2048:
embed.add_field(
name=f"{attru} (after continue)", value=diff_data.after[2049:3072]
)
if len(diff_data.after) > 3072:
embed.add_field(
name=f"{attru} (after continue)", value=diff_data.after[3073:4096]
)
else:
embed.add_field(name=f"{attru} (before)", value=diff_data.before or None)
embed.add_field(name=f"{attru} (after)", value=diff_data.after or None)
return embed
Loading

0 comments on commit 0842a25

Please sign in to comment.