Skip to content

Commit

Permalink
Preserve query parameters in the base URL when joining urls (#101)
Browse files Browse the repository at this point in the history
* Preserve query parameters in the base URL.

* additional import

* tests
  • Loading branch information
skisel-bt authored May 23, 2024
1 parent 67a5985 commit a5a9337
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
12 changes: 10 additions & 2 deletions sw_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import signal
from typing import Any
from urllib.parse import urlparse, urlunparse, urlencode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,10 +47,17 @@ async def sleep(self, seconds: int | float) -> None:
seconds -= 1


def urljoin(*args):
def urljoin(base, *args):
"""
Better version of `urllib.parse.urljoin`
Allows multiple arguments.
Consistent behavior with or without ending slashes.
Preserves query parameters in the base URL.
"""
return '/'.join(map(lambda x: str(x).strip('/'), args))
url_parts = list(urlparse(base))
path = '/'.join(map(lambda x: str(x).strip('/'), args))
if url_parts[2]:
url_parts[2] = '/'.join([url_parts[2].strip('/'), path.strip('/')])
else:
url_parts[2] = path
return urlunparse(url_parts)
28 changes: 28 additions & 0 deletions sw_utils/tests/test_urljoin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from sw_utils.common import urljoin

def test_urljoin_basic():
assert urljoin('http://example.com', 'path') == 'http://example.com/path'
assert urljoin('http://example.com/', 'path') == 'http://example.com/path'
assert urljoin('http://example.com', 'path1', 'path2') == 'http://example.com/path1/path2'
assert urljoin('http://example.com/', 'path1', 'path2') == 'http://example.com/path1/path2'

def test_urljoin_with_query():
assert urljoin('http://example.com?query=1', 'path') == 'http://example.com/path?query=1'
assert urljoin('http://example.com/?query=1', 'path') == 'http://example.com/path?query=1'
assert urljoin('http://example.com?query=1', 'path1', 'path2') == 'http://example.com/path1/path2?query=1'
assert urljoin('http://example.com/?query=1', 'path1', 'path2') == 'http://example.com/path1/path2?query=1'

def test_urljoin_with_fragment():
assert urljoin('http://example.com#fragment', 'path') == 'http://example.com/path#fragment'
assert urljoin('http://example.com/#fragment', 'path') == 'http://example.com/path#fragment'
assert urljoin('http://example.com#fragment', 'path1', 'path2') == 'http://example.com/path1/path2#fragment'
assert urljoin('http://example.com/#fragment', 'path1', 'path2') == 'http://example.com/path1/path2#fragment'

def test_urljoin_edge_cases():
assert urljoin('http://example.com', '') == 'http://example.com'
assert urljoin('http://example.com/', '') == 'http://example.com/'
assert urljoin('http://example.com', '/', '/') == 'http://example.com/'
assert urljoin('http://example.com/', '/', '/') == 'http://example.com/'
assert urljoin('http://example.com//', 'path') == 'http://example.com/path'
assert urljoin('http://example.com/', '//path') == 'http://example.com/path'

0 comments on commit a5a9337

Please sign in to comment.