Skip to content

Commit

Permalink
[#116] initial implementation of type hints (#117)
Browse files Browse the repository at this point in the history
* [#116] initial implementation of type hints

* Apply suggestions from code review - using Tuple from typing module

Co-authored-by: Timothy Hodson <[email protected]>

---------

Co-authored-by: Timothy Hodson <[email protected]>
  • Loading branch information
pkdash and thodson-usgs authored Nov 7, 2023
1 parent f6d18aa commit 4d6bfdd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
32 changes: 20 additions & 12 deletions dataretrieval/nwis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
"""

import re
import warnings
import pandas as pd
from io import StringIO
import re
from typing import List, Optional, Union, Tuple

import pandas as pd

from dataretrieval.utils import to_str, format_datetime, update_merge
from dataretrieval.utils import BaseMetadata
from dataretrieval.utils import format_datetime, to_str, update_merge
from .utils import query

WATERDATA_BASE_URL = 'https://nwis.waterdata.usgs.gov/'
Expand Down Expand Up @@ -270,28 +272,31 @@ def _discharge_measurements(ssl_check=True, **kwargs):
return _read_rdb(response.text), NWIS_Metadata(response, **kwargs)


def get_discharge_peaks(sites=None, start=None, end=None,
multi_index=True, ssl_check=True, **kwargs):
def get_discharge_peaks(sites: Optional[Union[List[str], str]] = None,
start: Optional[str] = None, end: Optional[str] = None,
multi_index: bool = True,
ssl_check: bool = True, **kwargs) -> Tuple[pd.DataFrame, BaseMetadata]:
"""
Get discharge peaks from the waterdata service.
Parameters
----------
sites: array of strings
sites: list of strings, string, Optional
If the waterdata parameter site_no is supplied, it will overwrite the
sites parameter
start: string
start: string, Optional
If the waterdata parameter begin_date is supplied, it will overwrite
the start parameter (YYYY-MM-DD)
end: string
end: string, Optional
If the waterdata parameter end_date is supplied, it will overwrite
the end parameter (YYYY-MM-DD)
multi_index: boolean
If False, a dataframe with a single-level index (datetime) is returned
ssl_check: bool
multi_index: boolean, Optional
If False, a dataframe with a single-level index (datetime) is returned,
default is True
ssl_check: boolean, Optional
If True, check SSL certificates, if False, do not check SSL,
default is True
**kwargs: optional
**kwargs: Optional
If supplied, will be used as query parameters
Returns
Expand All @@ -314,6 +319,9 @@ def get_discharge_peaks(sites=None, start=None, end=None,
... start='1980-01-01', end='1980-01-02', stateCd='HI')
"""
if sites and not isinstance(sites, str):
assert isinstance(sites, list), "sites must be a string or a list of strings"

start = kwargs.pop('begin_date', start)
end = kwargs.pop('end_date', end)
sites = kwargs.pop('site_no', sites)
Expand Down
20 changes: 19 additions & 1 deletion tests/waterservices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def test_get_gwlevels(requests_mock):
assert df.size == 16
assert_metadata(requests_mock, request_url, md, site, None, format)


def test_get_discharge_peaks(requests_mock):
"""Tests get_discharge_peaks method correctly generates the request url and returns the result in a DataFrame"""
format = "rdb"
Expand All @@ -149,6 +148,25 @@ def test_get_discharge_peaks(requests_mock):
assert df.size == 240
assert_metadata(requests_mock, request_url, md, site, None, format)

@pytest.mark.parametrize("site_input_type_list", [True, False])
def test_get_discharge_peaks_sites_value_types(requests_mock, site_input_type_list):
"""Tests get_discharge_peaks for valid input types of the 'sites' parameter"""

format = "rdb"
site = '01594440'
request_url = 'https://nwis.waterdata.usgs.gov/nwis/peaks?format={}&site_no={}' \
'&begin_date=2000-02-14&end_date=2020-02-15'.format(format, site)
response_file_path = 'data/waterservices_peaks.txt'
mock_request(requests_mock, request_url, response_file_path)
if site_input_type_list:
sites = [site]
else:
sites = site

df, md = get_discharge_peaks(sites=sites, start='2000-02-14', end='2020-02-15')
assert type(df) is DataFrame
assert df.size == 240


def test_get_discharge_measurements(requests_mock):
"""Tests get_discharge_measurements method correctly generates the request url and returns the result in a
Expand Down

0 comments on commit 4d6bfdd

Please sign in to comment.