From a5a9337d108dd8151755f4c51576753bf8c5a229 Mon Sep 17 00:00:00 2001 From: Sergey Kisel <132274447+skisel-bt@users.noreply.github.com> Date: Thu, 23 May 2024 10:21:13 +0200 Subject: [PATCH] Preserve query parameters in the base URL when joining urls (#101) * Preserve query parameters in the base URL. * additional import * tests --- sw_utils/common.py | 12 ++++++++++-- sw_utils/tests/test_urljoin.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 sw_utils/tests/test_urljoin.py diff --git a/sw_utils/common.py b/sw_utils/common.py index 75598e5..7dfff26 100644 --- a/sw_utils/common.py +++ b/sw_utils/common.py @@ -2,6 +2,7 @@ import logging import signal from typing import Any +from urllib.parse import urlparse, urlunparse, urlencode logger = logging.getLogger(__name__) @@ -46,10 +47,17 @@ async def sleep(self, seconds: int | float) -> None: seconds -= 1 -def urljoin(*args): +def urljoin(base, *args): """ Better version of `urllib.parse.urljoin` Allows multiple arguments. Consistent behavior with or without ending slashes. + Preserves query parameters in the base URL. """ - return '/'.join(map(lambda x: str(x).strip('/'), args)) + url_parts = list(urlparse(base)) + path = '/'.join(map(lambda x: str(x).strip('/'), args)) + if url_parts[2]: + url_parts[2] = '/'.join([url_parts[2].strip('/'), path.strip('/')]) + else: + url_parts[2] = path + return urlunparse(url_parts) diff --git a/sw_utils/tests/test_urljoin.py b/sw_utils/tests/test_urljoin.py new file mode 100644 index 0000000..37a4bc1 --- /dev/null +++ b/sw_utils/tests/test_urljoin.py @@ -0,0 +1,28 @@ +import pytest +from sw_utils.common import urljoin + +def test_urljoin_basic(): + assert urljoin('http://example.com', 'path') == 'http://example.com/path' + assert urljoin('http://example.com/', 'path') == 'http://example.com/path' + assert urljoin('http://example.com', 'path1', 'path2') == 'http://example.com/path1/path2' + assert urljoin('http://example.com/', 'path1', 'path2') == 'http://example.com/path1/path2' + +def test_urljoin_with_query(): + assert urljoin('http://example.com?query=1', 'path') == 'http://example.com/path?query=1' + assert urljoin('http://example.com/?query=1', 'path') == 'http://example.com/path?query=1' + assert urljoin('http://example.com?query=1', 'path1', 'path2') == 'http://example.com/path1/path2?query=1' + assert urljoin('http://example.com/?query=1', 'path1', 'path2') == 'http://example.com/path1/path2?query=1' + +def test_urljoin_with_fragment(): + assert urljoin('http://example.com#fragment', 'path') == 'http://example.com/path#fragment' + assert urljoin('http://example.com/#fragment', 'path') == 'http://example.com/path#fragment' + assert urljoin('http://example.com#fragment', 'path1', 'path2') == 'http://example.com/path1/path2#fragment' + assert urljoin('http://example.com/#fragment', 'path1', 'path2') == 'http://example.com/path1/path2#fragment' + +def test_urljoin_edge_cases(): + assert urljoin('http://example.com', '') == 'http://example.com' + assert urljoin('http://example.com/', '') == 'http://example.com/' + assert urljoin('http://example.com', '/', '/') == 'http://example.com/' + assert urljoin('http://example.com/', '/', '/') == 'http://example.com/' + assert urljoin('http://example.com//', 'path') == 'http://example.com/path' + assert urljoin('http://example.com/', '//path') == 'http://example.com/path'