Skip to content

Commit

Permalink
Merge branch 'rr/auth-session-support' into devel
Browse files Browse the repository at this point in the history
* rr/auth-session-support:
  change way to remove params
  fix lint
  add test
  fix pop missing key
  fix get api_config before params assignment
  clean request payload
  fix session object not getting through
  add test
  add test for api_config
  fix null access api_config issue
  add AuthorizedSession support
  • Loading branch information
couture-ql committed Jun 17, 2022
2 parents 1e81355 + 24390ef commit 31e76e5
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 22 deletions.
1 change: 1 addition & 0 deletions nasdaqdatalink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .model.point_in_time import PointInTime
from .model.data import Data
from .model.merged_dataset import MergedDataset
from .model.authorized_session import AuthorizedSession
from .get import get
from .bulkdownload import bulkdownload
from .export_table import export_table
Expand Down
23 changes: 23 additions & 0 deletions nasdaqdatalink/api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ class ApiConfig:
retry_status_codes = [429] + list(range(500, 512))
verify_ssl = True

def read_key(self, filename=None):
if not os.path.isfile(filename):
raise_empty_file(filename)

with open(filename, 'r') as f:
apikey = get_first_non_empty(f)

if not apikey:
raise_empty_file(filename)

self.api_key = apikey


def create_file(config_filename):
# Create the file as well as the parent dir if needed.
Expand Down Expand Up @@ -102,3 +114,14 @@ def read_key(filename=None):
read_key_from_environment_variable()
elif config_file_exists(filename):
read_key_from_file(filename)


def get_config_from_kwargs(kwargs):
result = ApiConfig
if isinstance(kwargs, dict):
params = kwargs.get('params')
if isinstance(params, dict):
result = params.get('api_config')
if not isinstance(result, ApiConfig):
result = ApiConfig
return result
32 changes: 24 additions & 8 deletions nasdaqdatalink/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@

from .util import Util
from .version import VERSION
from .api_config import ApiConfig
from .api_config import ApiConfig, get_config_from_kwargs
from nasdaqdatalink.errors.data_link_error import (
DataLinkError, LimitExceededError, InternalServerError,
AuthenticationError, ForbiddenError, InvalidRequestError,
NotFoundError, ServiceUnavailableError)

KW_TO_REMOVE = [
'session',
'api_config'
]


class Connection:
@classmethod
Expand All @@ -22,31 +27,37 @@ def request(cls, http_verb, url, **options):
headers = options['headers']
else:
headers = {}
api_config = get_config_from_kwargs(options)

accept_value = 'application/json'
if ApiConfig.api_version:
accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version
if api_config.api_version:
accept_value += ", application/vnd.data.nasdaq+json;version=%s" % api_config.api_version

headers = Util.merge_to_dicts({'accept': accept_value,
'request-source': 'python',
'request-source-version': VERSION}, headers)
if ApiConfig.api_key:
headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers)
if api_config.api_key:
headers = Util.merge_to_dicts({'x-api-token': api_config.api_key}, headers)

options['headers'] = headers

abs_url = '%s/%s' % (ApiConfig.api_base, url)
abs_url = '%s/%s' % (api_config.api_base, url)

return cls.execute_request(http_verb, abs_url, **options)

@classmethod
def execute_request(cls, http_verb, url, **options):
session = cls.get_session()
session = options.get('params', {}).get('session', None)
if session is None:
session = cls.get_session()

api_config = get_config_from_kwargs(options)

cls.options_kw_strip(options)
try:
response = session.request(method=http_verb,
url=url,
verify=ApiConfig.verify_ssl,
verify=api_config.verify_ssl,
**options)
if response.status_code < 200 or response.status_code >= 300:
cls.handle_api_error(response)
Expand Down Expand Up @@ -118,3 +129,8 @@ def handle_api_error(cls, resp):
klass = d_klass.get(code_letter, DataLinkError)

raise klass(message, resp.status_code, resp.text, resp.headers, code)

@classmethod
def options_kw_strip(self, options):
for kw in KW_TO_REMOVE:
options.get('params', {}).pop(kw, None)
5 changes: 3 additions & 2 deletions nasdaqdatalink/get_point_in_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_point_in_time(datatable_code, **options):

data = None
page_count = 0
api_config = options.get('api_config', ApiConfig)
while True:
next_options = copy.deepcopy(options)
next_data = PointInTime(datatable_code, pit=pit_options).data(params=next_options)
Expand All @@ -32,10 +33,10 @@ def get_point_in_time(datatable_code, **options):
else:
data.extend(next_data)

if page_count >= ApiConfig.page_limit:
if page_count >= api_config.page_limit:
raise LimitExceededError(
Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code,
ApiConfig.api_key
api_config.api_key
)
)

Expand Down
5 changes: 3 additions & 2 deletions nasdaqdatalink/get_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def get_table(datatable_code, **options):

data = None
page_count = 0
api_config = options.get('api_config', ApiConfig)
while True:
next_options = copy.deepcopy(options)
next_data = Datatable(datatable_code).data(params=next_options)
Expand All @@ -23,10 +24,10 @@ def get_table(datatable_code, **options):
else:
data.extend(next_data)

if page_count >= ApiConfig.page_limit:
if page_count >= api_config.page_limit:
raise LimitExceededError(
Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code,
ApiConfig.api_key
api_config.api_key
)
)

Expand Down
57 changes: 57 additions & 0 deletions nasdaqdatalink/model/authorized_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import nasdaqdatalink
from nasdaqdatalink.api_config import ApiConfig
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
import requests
import urllib


def get_retries(api_config=nasdaqdatalink.ApiConfig):
retries = None
if not api_config.use_retries:
return Retry(total=0)

Retry.BACKOFF_MAX = api_config.max_wait_between_retries
retries = Retry(total=api_config.number_of_retries,
connect=api_config.number_of_retries,
read=api_config.number_of_retries,
status_forcelist=api_config.retry_status_codes,
backoff_factor=api_config.retry_backoff_factor,
raise_on_status=False)
return retries


class AuthorizedSession:
def __init__(self, api_config=ApiConfig) -> None:
super(AuthorizedSession, self).__init__()
if not isinstance(api_config, ApiConfig):
api_config = ApiConfig
self._api_config = api_config
self._auth_session = requests.Session()
retries = get_retries(self._api_config)
adapter = HTTPAdapter(max_retries=retries)
self._auth_session.mount(api_config.api_protocol, adapter)

proxies = urllib.request.getproxies()
if proxies is not None:
self._auth_session.proxies.update(proxies)

def get(self, dataset, **kwargs):
nasdaqdatalink.get(dataset, session=self._auth_session,
api_config=self._api_config, **kwargs)

def bulkdownload(self, database, **kwargs):
nasdaqdatalink.bulkdownload(database, session=self._auth_session,
api_config=self._api_config, **kwargs)

def export_table(self, datatable_code, **kwargs):
nasdaqdatalink.export_table(datatable_code, session=self._auth_session,
api_config=self._api_config, **kwargs)

def get_table(self, datatable_code, **options):
nasdaqdatalink.get_table(datatable_code, session=self._auth_session,
api_config=self._api_config, **options)

def get_point_in_time(self, datatable_code, **options):
nasdaqdatalink.get_point_in_time(datatable_code, session=self._auth_session,
api_config=self._api_config, **options)
13 changes: 7 additions & 6 deletions nasdaqdatalink/model/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from six.moves.urllib.parse import urlencode, urlparse

import nasdaqdatalink.model.dataset
from nasdaqdatalink.api_config import ApiConfig
from nasdaqdatalink.api_config import get_config_from_kwargs
from nasdaqdatalink.connection import Connection
from nasdaqdatalink.errors.data_link_error import DataLinkError
from nasdaqdatalink.message import Message
Expand All @@ -21,15 +21,16 @@ def get_code_from_meta(cls, metadata):
return metadata['database_code']

def bulk_download_url(self, **options):
api_config = get_config_from_kwargs(options)
url = self._bulk_download_path()
url = ApiConfig.api_base + '/' + url
url = api_config.api_base + '/' + url

if 'params' not in options:
options['params'] = {}
if ApiConfig.api_key:
options['params']['api_key'] = ApiConfig.api_key
if ApiConfig.api_version:
options['params']['api_version'] = ApiConfig.api_version
if api_config.api_key:
options['params']['api_key'] = api_config.api_key
if api_config.api_version:
options['params']['api_version'] = api_config.api_version

if list(options.keys()):
url += '?' + urlencode(options['params'])
Expand Down
5 changes: 3 additions & 2 deletions nasdaqdatalink/utils/request_type_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from urllib.parse import urlencode
from nasdaqdatalink.api_config import ApiConfig
from nasdaqdatalink.api_config import get_config_from_kwargs


class RequestType(object):
Expand All @@ -13,7 +13,8 @@ class RequestType(object):
@classmethod
def get_request_type(cls, url, **params):
query_string = urlencode(params['params'])
request_url = '%s/%s/%s' % (ApiConfig.api_base, url, query_string)
api_config = get_config_from_kwargs(params)
request_url = '%s/%s/%s' % (api_config.api_base, url, query_string)
if RequestType.USE_GET_REQUEST and (len(request_url) < cls.MAX_URL_LENGTH_FOR_GET):
return 'get'
else:
Expand Down
55 changes: 55 additions & 0 deletions test/test_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,58 @@ def test_read_key_from_file_with_tab(self):
def test_read_key_from_file_with_multi_newline(self):
given = "keyfordefaultfile\n\nanotherkey\n"
self._read_key_from_file_helper(given, TEST_DEFAULT_FILE_CONTENTS)

def test_default_instance_will_have_share_values_with_singleton(self):
os.environ['NASDAQ_DATA_LINK_API_KEY'] = 'setinenv'
ApiConfig.api_key = None
read_key()
api_config = ApiConfig()
self.assertEqual(api_config.api_key, "setinenv")
# make sure change in instance will not affect the singleton
api_config.api_key = None
self.assertEqual(ApiConfig.api_key, "setinenv")

def test_get_config_from_kwargs_return_api_config_if_present(self):
api_config = get_config_from_kwargs({
'params': {
'api_config': ApiConfig()
}
})
self.assertTrue(isinstance(api_config, ApiConfig))

def test_get_config_from_kwargs_return_singleton_if_not_present_or_wrong_type(self):
api_config = get_config_from_kwargs(None)
self.assertTrue(issubclass(api_config, ApiConfig))
self.assertFalse(isinstance(api_config, ApiConfig))
api_config = get_config_from_kwargs(1)
self.assertTrue(issubclass(api_config, ApiConfig))
self.assertFalse(isinstance(api_config, ApiConfig))
api_config = get_config_from_kwargs({
'params': None
})
self.assertTrue(issubclass(api_config, ApiConfig))
self.assertFalse(isinstance(api_config, ApiConfig))

def test_instance_read_key_should_raise_error(self):
api_config = ApiConfig()
with self.assertRaises(TypeError):
api_config.read_key(None)
with self.assertRaises(ValueError):
api_config.read_key('')

def test_instance_read_key_should_raise_error_when_empty(self):
save_key("", TEST_KEY_FILE)
api_config = ApiConfig()
with self.assertRaises(ValueError):
# read empty file
api_config.read_key(TEST_KEY_FILE)

def test_instance_read_the_right_key(self):
expected_key = 'ilovepython'
save_key(expected_key, TEST_KEY_FILE)
api_config = ApiConfig()
api_config.api_key = ''
api_config.read_key(TEST_KEY_FILE)
self.assertEqual(ApiConfig.api_key, expected_key)


64 changes: 64 additions & 0 deletions test/test_authorized_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest
from nasdaqdatalink.model.authorized_session import AuthorizedSession
from nasdaqdatalink.api_config import ApiConfig
from requests.sessions import Session
from requests.adapters import HTTPAdapter
from mock import patch


class AuthorizedSessionTest(unittest.TestCase):
def test_authorized_session_assign_correct_internal_config(self):
authed_session = AuthorizedSession()
self.assertTrue(issubclass(authed_session._api_config, ApiConfig))
authed_session = AuthorizedSession(None)
self.assertTrue(issubclass(authed_session._api_config, ApiConfig))
api_config = ApiConfig()
authed_session = AuthorizedSession(api_config)
self.assertTrue(isinstance(authed_session._api_config, ApiConfig))

def test_authorized_session_pass_created_session(self):
ApiConfig.use_retries = True
ApiConfig.number_of_retries = 130
authed_session = AuthorizedSession()
self.assertTrue(isinstance(authed_session._auth_session, Session))
adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol)
self.assertTrue(isinstance(adapter, HTTPAdapter))
self.assertEqual(adapter.max_retries.connect, 130)

@patch("nasdaqdatalink.get")
def test_call_get_with_session_and_api_config(self, mock):
api_config = ApiConfig()
authed_session = AuthorizedSession(api_config)
authed_session.get('WIKI/AAPL')
mock.assert_called_with('WIKI/AAPL', api_config=api_config,
session=authed_session._auth_session)

@patch("nasdaqdatalink.bulkdownload")
def test_call_bulkdownload_with_session_and_api_config(self, mock):
api_config = ApiConfig()
authed_session = AuthorizedSession(api_config)
authed_session.bulkdownload('NSE')
mock.assert_called_with('NSE', api_config=api_config,
session=authed_session._auth_session)

@patch("nasdaqdatalink.export_table")
def test_call_export_table_with_session_and_api_config(self, mock):
authed_session = AuthorizedSession()
authed_session.export_table('WIKI/AAPL')
mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig,
session=authed_session._auth_session)

@patch("nasdaqdatalink.get_table")
def test_call_get_table_with_session_and_api_config(self, mock):
authed_session = AuthorizedSession()
authed_session.get_table('WIKI/AAPL')
mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig,
session=authed_session._auth_session)

@patch("nasdaqdatalink.get_point_in_time")
def test_call_get_point_in_time_with_session_and_api_config(self, mock):
authed_session = AuthorizedSession()
authed_session.get_point_in_time('DATABASE/CODE', interval='asofdate', date='2020-01-01')
mock.assert_called_with('DATABASE/CODE', interval='asofdate',
date='2020-01-01', api_config=ApiConfig,
session=authed_session._auth_session)
Loading

0 comments on commit 31e76e5

Please sign in to comment.