Skip to content

Commit

Permalink
295 Fix extra requirements for MFA (#296)
Browse files Browse the repository at this point in the history
* fix extra requirements for MFA

* fix linter error
  • Loading branch information
dantownsend authored Sep 9, 2024
1 parent 724d5b8 commit 2f30742
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
25 changes: 10 additions & 15 deletions piccolo_api/encryption/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
from abc import ABCMeta, abstractmethod

if t.TYPE_CHECKING:
import cryptography
import nacl
from cryptography.fernet import Fernet


logger = logging.getLogger(__name__)


def get_cryptography() -> cryptography: # type: ignore
def get_fernet_class() -> t.Type[Fernet]: # type: ignore
try:
import cryptography
from cryptography.fernet import Fernet
except ImportError as e:
print(
"Install pip install piccolo_api[cryptography] to use this "
"feature."
)
raise e

return cryptography
return Fernet


class EncryptionProvider(metaclass=ABCMeta):
Expand Down Expand Up @@ -108,14 +108,12 @@ def __init__(self, encryption_key: bytes):

@staticmethod
def get_new_key() -> bytes:
cryptography = get_cryptography()
return cryptography.fernet.Fernet.generate_key() # type: ignore
Fernet = get_fernet_class()
return Fernet.generate_key() # type: ignore

def encrypt(self, value: str, add_prefix: bool = True) -> str:
cryptography = get_cryptography()
fernet = cryptography.fernet.Fernet( # type: ignore
self.encryption_key
)
Fernet = get_fernet_class()
fernet = Fernet(self.encryption_key) # type: ignore
encrypted_value = fernet.encrypt(value.encode("utf-8")).decode("utf-8")
return (
self.add_prefix(encrypted_value=encrypted_value)
Expand All @@ -127,11 +125,8 @@ def decrypt(self, encrypted_value: str, has_prefix: bool = True) -> str:
if has_prefix:
encrypted_value = self.remove_prefix(encrypted_value)

cryptography = get_cryptography()

fernet = cryptography.fernet.Fernet( # type: ignore
self.encryption_key
)
Fernet = get_fernet_class()
fernet = Fernet(self.encryption_key) # type: ignore
return fernet.decrypt(encrypted_value.encode("utf-8")).decode("utf-8")


Expand Down
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
LONG_DESCRIPTION = f.read()


EXTRAS = ["s3"]
EXTRAS = [
"authenticator",
"cryptography",
"pynacl",
"s3",
]


def parse_requirement(req_path: str) -> t.List[str]:
Expand Down

0 comments on commit 2f30742

Please sign in to comment.