diff --git a/src/python_opensky/opensky.py b/src/python_opensky/opensky.py index 2e539a5..089d402 100644 --- a/src/python_opensky/opensky.py +++ b/src/python_opensky/opensky.py @@ -37,15 +37,35 @@ class OpenSky: _auth: BasicAuth | None = None _contributing_user: bool = False - def authenticate(self, auth: BasicAuth, *, contributing_user: bool = False) -> None: + async def authenticate( + self, + auth: BasicAuth, + *, + contributing_user: bool = False, + ) -> None: """Authenticate the user.""" self._auth = auth + try: + await self.get_states() + except OpenSkyUnauthenticatedError as exc: + self._auth = None + raise OpenSkyUnauthenticatedError from exc self._contributing_user = contributing_user if contributing_user: self.opensky_credits = 8000 else: self.opensky_credits = 4000 + @property + def is_contributing_user(self) -> bool: + """Return if the user is contributing to OpenSky.""" + return self._contributing_user + + @property + def is_authenticated(self) -> bool: + """Return if the user is correctly authenticated.""" + return self._auth is not None + async def _request( self, uri: str, @@ -107,6 +127,8 @@ async def _request( ClientResponseError, socket.gaierror, ) as exception: + if isinstance(exception, ClientResponseError) and exception.status == 401: + raise OpenSkyUnauthenticatedError from exception msg = "Error occurred while communicating with OpenSky API" raise OpenSkyConnectionError(msg) from exception diff --git a/tests/test_states.py b/tests/test_states.py index 8ac121e..1d48941 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -91,6 +91,16 @@ async def test_own_states( aresponses: ResponsesMockServer, ) -> None: """Test retrieving own states.""" + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + ) aresponses.add( OPENSKY_URL, "/api/states/own", @@ -103,10 +113,14 @@ async def test_own_states( ) async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) - opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + BasicAuth(login="test", password="test"), + contributing_user=True, + ) response: StatesResponse = await opensky.get_own_states() assert len(response.states) == 4 - assert opensky.remaining_credits() == opensky.opensky_credits + assert opensky.opensky_credits == 8000 + assert opensky.remaining_credits() == 7996 await opensky.close() @@ -114,6 +128,16 @@ async def test_unavailable_own_states( aresponses: ResponsesMockServer, ) -> None: """Test retrieving no own states.""" + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + ) aresponses.add( OPENSKY_URL, "/api/states/own", @@ -126,7 +150,10 @@ async def test_unavailable_own_states( ) async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) - opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + BasicAuth(login="test", password="test"), + contributing_user=True, + ) response: StatesResponse = await opensky.get_own_states() assert response.states is not None assert len(response.states) == 0 @@ -244,27 +271,64 @@ def response_handler(request: BaseRequest) -> Response: "/api/states/all", "GET", response_handler, + repeat=2, ) async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) - opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate(BasicAuth(login="test", password="test")) await opensky.get_states() await opensky.close() -async def test_user_credits() -> None: +async def test_unauthorized(aresponses: ResponsesMockServer) -> None: + """Test request authentication.""" + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=401, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + ) + + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + try: + await opensky.authenticate(BasicAuth(login="test", password="test")) + pytest.fail("Should've thrown exception") + except OpenSkyUnauthenticatedError: + pass + assert opensky.is_authenticated is False + await opensky.close() + + +async def test_user_credits(aresponses: ResponsesMockServer) -> None: """Test authenticated user credits.""" + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + repeat=2, + ) async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) assert opensky.opensky_credits == 400 - opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate(BasicAuth(login="test", password="test")) assert opensky.opensky_credits == 4000 - opensky.authenticate( + await opensky.authenticate( BasicAuth(login="test", password="test"), contributing_user=True, ) assert opensky.opensky_credits == 8000 + assert opensky.is_contributing_user is True await opensky.close()