diff --git a/README.md b/README.md index dd3b199..1b61993 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,18 @@ data_link_log = logging.getLogger("nasdaqdatalink") data_link_log.setLevel(logging.DEBUG) ``` +### Session + +By default, every API request will create a new session; This will have a performance impact when you wish to make multiple requests(see #16). You can use `AuthorizedSession` to take advantage of the reusing session: + +```python +import nasdaqdatalink +session = nasdaqdatalink.AuthorizedSession() +data1 = session.get_table('ZACKS/FC', ticker='AAPL') +data2 = session.get_table('ZACKS/FC', ticker='MFST') +data3 = session.get_table('ZACKS/FC', ticker='NVDA') +``` + ### Detailed Usage Our API can provide more than just data. It can also be used to search and provide metadata or to programmatically retrieve data. For these more advanced techniques please follow our [Detailed Method Guide](./FOR_DEVELOPERS.md). diff --git a/nasdaqdatalink/__init__.py b/nasdaqdatalink/__init__.py index d1e4654..7aacf70 100644 --- a/nasdaqdatalink/__init__.py +++ b/nasdaqdatalink/__init__.py @@ -10,6 +10,7 @@ from .model.point_in_time import PointInTime from .model.data import Data from .model.merged_dataset import MergedDataset +from .model.authorized_session import AuthorizedSession from .get import get from .bulkdownload import bulkdownload from .export_table import export_table diff --git a/nasdaqdatalink/api_config.py b/nasdaqdatalink/api_config.py index dea1dd0..d86d576 100644 --- a/nasdaqdatalink/api_config.py +++ b/nasdaqdatalink/api_config.py @@ -17,6 +17,18 @@ class ApiConfig: retry_status_codes = [429] + list(range(500, 512)) verify_ssl = True + def read_key(self, filename=None): + if not os.path.isfile(filename): + raise_empty_file(filename) + + with open(filename, 'r') as f: + apikey = get_first_non_empty(f) + + if not apikey: + raise_empty_file(filename) + + self.api_key = apikey + def create_file(config_filename): # Create the file as well as the parent dir if needed. @@ -102,3 +114,14 @@ def read_key(filename=None): read_key_from_environment_variable() elif config_file_exists(filename): read_key_from_file(filename) + + +def get_config_from_kwargs(kwargs): + result = ApiConfig + if isinstance(kwargs, dict): + params = kwargs.get('params') + if isinstance(params, dict): + result = params.get('api_config') + if not isinstance(result, ApiConfig): + result = ApiConfig + return result diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 350ed49..d820d24 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -8,12 +8,17 @@ from .util import Util from .version import VERSION -from .api_config import ApiConfig +from .api_config import ApiConfig, get_config_from_kwargs from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) +KW_TO_REMOVE = [ + 'session', + 'api_config' +] + class Connection: @classmethod @@ -22,31 +27,37 @@ def request(cls, http_verb, url, **options): headers = options['headers'] else: headers = {} + api_config = get_config_from_kwargs(options) accept_value = 'application/json' - if ApiConfig.api_version: - accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version + if api_config.api_version: + accept_value += ", application/vnd.data.nasdaq+json;version=%s" % api_config.api_version headers = Util.merge_to_dicts({'accept': accept_value, 'request-source': 'python', 'request-source-version': VERSION}, headers) - if ApiConfig.api_key: - headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) + if api_config.api_key: + headers = Util.merge_to_dicts({'x-api-token': api_config.api_key}, headers) options['headers'] = headers - abs_url = '%s/%s' % (ApiConfig.api_base, url) + abs_url = '%s/%s' % (api_config.api_base, url) return cls.execute_request(http_verb, abs_url, **options) @classmethod def execute_request(cls, http_verb, url, **options): - session = cls.get_session() + session = options.get('params', {}).get('session', None) + if session is None: + session = cls.get_session() + + api_config = get_config_from_kwargs(options) + cls.options_kw_strip(options) try: response = session.request(method=http_verb, url=url, - verify=ApiConfig.verify_ssl, + verify=api_config.verify_ssl, **options) if response.status_code < 200 or response.status_code >= 300: cls.handle_api_error(response) @@ -118,3 +129,8 @@ def handle_api_error(cls, resp): klass = d_klass.get(code_letter, DataLinkError) raise klass(message, resp.status_code, resp.text, resp.headers, code) + + @classmethod + def options_kw_strip(self, options): + for kw in KW_TO_REMOVE: + options.get('params', {}).pop(kw, None) diff --git a/nasdaqdatalink/get_point_in_time.py b/nasdaqdatalink/get_point_in_time.py index c4f7578..73d6ed2 100644 --- a/nasdaqdatalink/get_point_in_time.py +++ b/nasdaqdatalink/get_point_in_time.py @@ -23,6 +23,7 @@ def get_point_in_time(datatable_code, **options): data = None page_count = 0 + api_config = options.get('api_config', ApiConfig) while True: next_options = copy.deepcopy(options) next_data = PointInTime(datatable_code, pit=pit_options).data(params=next_options) @@ -32,10 +33,10 @@ def get_point_in_time(datatable_code, **options): else: data.extend(next_data) - if page_count >= ApiConfig.page_limit: + if page_count >= api_config.page_limit: raise LimitExceededError( Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code, - ApiConfig.api_key + api_config.api_key ) ) diff --git a/nasdaqdatalink/get_table.py b/nasdaqdatalink/get_table.py index c07d3c8..32188bb 100644 --- a/nasdaqdatalink/get_table.py +++ b/nasdaqdatalink/get_table.py @@ -14,6 +14,7 @@ def get_table(datatable_code, **options): data = None page_count = 0 + api_config = options.get('api_config', ApiConfig) while True: next_options = copy.deepcopy(options) next_data = Datatable(datatable_code).data(params=next_options) @@ -23,10 +24,10 @@ def get_table(datatable_code, **options): else: data.extend(next_data) - if page_count >= ApiConfig.page_limit: + if page_count >= api_config.page_limit: raise LimitExceededError( Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code, - ApiConfig.api_key + api_config.api_key ) ) diff --git a/nasdaqdatalink/model/authorized_session.py b/nasdaqdatalink/model/authorized_session.py new file mode 100644 index 0000000..7c2c9e2 --- /dev/null +++ b/nasdaqdatalink/model/authorized_session.py @@ -0,0 +1,57 @@ +import nasdaqdatalink +from nasdaqdatalink.api_config import ApiConfig +from urllib3.util.retry import Retry +from requests.adapters import HTTPAdapter +import requests +import urllib + + +def get_retries(api_config=nasdaqdatalink.ApiConfig): + retries = None + if not api_config.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = api_config.max_wait_between_retries + retries = Retry(total=api_config.number_of_retries, + connect=api_config.number_of_retries, + read=api_config.number_of_retries, + status_forcelist=api_config.retry_status_codes, + backoff_factor=api_config.retry_backoff_factor, + raise_on_status=False) + return retries + + +class AuthorizedSession: + def __init__(self, api_config=ApiConfig) -> None: + super(AuthorizedSession, self).__init__() + if not isinstance(api_config, ApiConfig): + api_config = ApiConfig + self._api_config = api_config + self._auth_session = requests.Session() + retries = get_retries(self._api_config) + adapter = HTTPAdapter(max_retries=retries) + self._auth_session.mount(api_config.api_protocol, adapter) + + proxies = urllib.request.getproxies() + if proxies is not None: + self._auth_session.proxies.update(proxies) + + def get(self, dataset, **kwargs): + nasdaqdatalink.get(dataset, session=self._auth_session, + api_config=self._api_config, **kwargs) + + def bulkdownload(self, database, **kwargs): + nasdaqdatalink.bulkdownload(database, session=self._auth_session, + api_config=self._api_config, **kwargs) + + def export_table(self, datatable_code, **kwargs): + nasdaqdatalink.export_table(datatable_code, session=self._auth_session, + api_config=self._api_config, **kwargs) + + def get_table(self, datatable_code, **options): + nasdaqdatalink.get_table(datatable_code, session=self._auth_session, + api_config=self._api_config, **options) + + def get_point_in_time(self, datatable_code, **options): + nasdaqdatalink.get_point_in_time(datatable_code, session=self._auth_session, + api_config=self._api_config, **options) diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 870dedc..443daf9 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -3,7 +3,7 @@ from six.moves.urllib.parse import urlencode, urlparse import nasdaqdatalink.model.dataset -from nasdaqdatalink.api_config import ApiConfig +from nasdaqdatalink.api_config import get_config_from_kwargs from nasdaqdatalink.connection import Connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message @@ -21,15 +21,16 @@ def get_code_from_meta(cls, metadata): return metadata['database_code'] def bulk_download_url(self, **options): + api_config = get_config_from_kwargs(options) url = self._bulk_download_path() - url = ApiConfig.api_base + '/' + url + url = api_config.api_base + '/' + url if 'params' not in options: options['params'] = {} - if ApiConfig.api_key: - options['params']['api_key'] = ApiConfig.api_key - if ApiConfig.api_version: - options['params']['api_version'] = ApiConfig.api_version + if api_config.api_key: + options['params']['api_key'] = api_config.api_key + if api_config.api_version: + options['params']['api_version'] = api_config.api_version if list(options.keys()): url += '?' + urlencode(options['params']) diff --git a/nasdaqdatalink/utils/request_type_util.py b/nasdaqdatalink/utils/request_type_util.py index a53af61..97d63cc 100644 --- a/nasdaqdatalink/utils/request_type_util.py +++ b/nasdaqdatalink/utils/request_type_util.py @@ -1,5 +1,5 @@ from urllib.parse import urlencode -from nasdaqdatalink.api_config import ApiConfig +from nasdaqdatalink.api_config import get_config_from_kwargs class RequestType(object): @@ -13,7 +13,8 @@ class RequestType(object): @classmethod def get_request_type(cls, url, **params): query_string = urlencode(params['params']) - request_url = '%s/%s/%s' % (ApiConfig.api_base, url, query_string) + api_config = get_config_from_kwargs(params) + request_url = '%s/%s/%s' % (api_config.api_base, url, query_string) if RequestType.USE_GET_REQUEST and (len(request_url) < cls.MAX_URL_LENGTH_FOR_GET): return 'get' else: diff --git a/test/test_api_config.py b/test/test_api_config.py index 2c183b5..c21efe0 100644 --- a/test/test_api_config.py +++ b/test/test_api_config.py @@ -132,3 +132,58 @@ def test_read_key_from_file_with_tab(self): def test_read_key_from_file_with_multi_newline(self): given = "keyfordefaultfile\n\nanotherkey\n" self._read_key_from_file_helper(given, TEST_DEFAULT_FILE_CONTENTS) + + def test_default_instance_will_have_share_values_with_singleton(self): + os.environ['NASDAQ_DATA_LINK_API_KEY'] = 'setinenv' + ApiConfig.api_key = None + read_key() + api_config = ApiConfig() + self.assertEqual(api_config.api_key, "setinenv") + # make sure change in instance will not affect the singleton + api_config.api_key = None + self.assertEqual(ApiConfig.api_key, "setinenv") + + def test_get_config_from_kwargs_return_api_config_if_present(self): + api_config = get_config_from_kwargs({ + 'params': { + 'api_config': ApiConfig() + } + }) + self.assertTrue(isinstance(api_config, ApiConfig)) + + def test_get_config_from_kwargs_return_singleton_if_not_present_or_wrong_type(self): + api_config = get_config_from_kwargs(None) + self.assertTrue(issubclass(api_config, ApiConfig)) + self.assertFalse(isinstance(api_config, ApiConfig)) + api_config = get_config_from_kwargs(1) + self.assertTrue(issubclass(api_config, ApiConfig)) + self.assertFalse(isinstance(api_config, ApiConfig)) + api_config = get_config_from_kwargs({ + 'params': None + }) + self.assertTrue(issubclass(api_config, ApiConfig)) + self.assertFalse(isinstance(api_config, ApiConfig)) + + def test_instance_read_key_should_raise_error(self): + api_config = ApiConfig() + with self.assertRaises(TypeError): + api_config.read_key(None) + with self.assertRaises(ValueError): + api_config.read_key('') + + def test_instance_read_key_should_raise_error_when_empty(self): + save_key("", TEST_KEY_FILE) + api_config = ApiConfig() + with self.assertRaises(ValueError): + # read empty file + api_config.read_key(TEST_KEY_FILE) + + def test_instance_read_the_right_key(self): + expected_key = 'ilovepython' + save_key(expected_key, TEST_KEY_FILE) + api_config = ApiConfig() + api_config.api_key = '' + api_config.read_key(TEST_KEY_FILE) + self.assertEqual(ApiConfig.api_key, expected_key) + + diff --git a/test/test_authorized_session.py b/test/test_authorized_session.py new file mode 100644 index 0000000..60f20ba --- /dev/null +++ b/test/test_authorized_session.py @@ -0,0 +1,64 @@ +import unittest +from nasdaqdatalink.model.authorized_session import AuthorizedSession +from nasdaqdatalink.api_config import ApiConfig +from requests.sessions import Session +from requests.adapters import HTTPAdapter +from mock import patch + + +class AuthorizedSessionTest(unittest.TestCase): + def test_authorized_session_assign_correct_internal_config(self): + authed_session = AuthorizedSession() + self.assertTrue(issubclass(authed_session._api_config, ApiConfig)) + authed_session = AuthorizedSession(None) + self.assertTrue(issubclass(authed_session._api_config, ApiConfig)) + api_config = ApiConfig() + authed_session = AuthorizedSession(api_config) + self.assertTrue(isinstance(authed_session._api_config, ApiConfig)) + + def test_authorized_session_pass_created_session(self): + ApiConfig.use_retries = True + ApiConfig.number_of_retries = 130 + authed_session = AuthorizedSession() + self.assertTrue(isinstance(authed_session._auth_session, Session)) + adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol) + self.assertTrue(isinstance(adapter, HTTPAdapter)) + self.assertEqual(adapter.max_retries.connect, 130) + + @patch("nasdaqdatalink.get") + def test_call_get_with_session_and_api_config(self, mock): + api_config = ApiConfig() + authed_session = AuthorizedSession(api_config) + authed_session.get('WIKI/AAPL') + mock.assert_called_with('WIKI/AAPL', api_config=api_config, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.bulkdownload") + def test_call_bulkdownload_with_session_and_api_config(self, mock): + api_config = ApiConfig() + authed_session = AuthorizedSession(api_config) + authed_session.bulkdownload('NSE') + mock.assert_called_with('NSE', api_config=api_config, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.export_table") + def test_call_export_table_with_session_and_api_config(self, mock): + authed_session = AuthorizedSession() + authed_session.export_table('WIKI/AAPL') + mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.get_table") + def test_call_get_table_with_session_and_api_config(self, mock): + authed_session = AuthorizedSession() + authed_session.get_table('WIKI/AAPL') + mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.get_point_in_time") + def test_call_get_point_in_time_with_session_and_api_config(self, mock): + authed_session = AuthorizedSession() + authed_session.get_point_in_time('DATABASE/CODE', interval='asofdate', date='2020-01-01') + mock.assert_called_with('DATABASE/CODE', interval='asofdate', + date='2020-01-01', api_config=ApiConfig, + session=authed_session._auth_session) diff --git a/test/test_connection.py b/test/test_connection.py index 7777d6e..f8a513f 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -6,6 +6,7 @@ NotFoundError, ServiceUnavailableError) from test.test_retries import ModifyRetrySettingsTestCase from test.helpers.httpretty_extension import httpretty +import requests import json from mock import patch, call from nasdaqdatalink.version import VERSION @@ -81,3 +82,48 @@ def test_build_request(self, request_method, mock): 'request-source-version': VERSION}, params={'per_page': 10, 'page': 2}) self.assertEqual(mock.call_args, expected) + + @parameterized.expand(['GET', 'POST']) + @patch('nasdaqdatalink.connection.Connection.execute_request') + def test_build_request_with_custom_api_config(self, request_method, mock): + ApiConfig.api_key = 'api_token' + ApiConfig.api_version = '2015-04-09' + api_config = ApiConfig() + api_config.api_key = 'custom_api_token' + api_config.api_version = '2022-06-09' + session = requests.session() + params = {'per_page': 10, 'page': 2, 'api_config': api_config, 'session': session} + headers = {'x-custom-header': 'header value'} + Connection.request(request_method, 'databases', headers=headers, params=params) + expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases', + headers={'x-custom-header': 'header value', + 'x-api-token': 'custom_api_token', + 'accept': ('application/json, ' + 'application/vnd.data.nasdaq+json;version=2022-06-09'), + 'request-source': 'python', + 'request-source-version': VERSION}, + params={'per_page': 10, 'page': 2, + 'session': session, 'api_config': api_config}) + self.assertEqual(mock.call_args, expected) + + def test_remove_session_and_api_config_param(self): + ApiConfig.api_key = 'api_token' + ApiConfig.api_version = '2015-04-09' + ApiConfig.verify_ssl = True + api_config = ApiConfig() + api_config.api_key = 'custom_api_token' + api_config.api_version = '2022-06-09' + api_config.verify_ssl = False + session = requests.Session() + params = {'per_page': 10, 'page': 2, 'api_config': api_config, 'session': session} + headers = {'x-custom-header': 'header value'} + dummy_response = requests.Response() + dummy_response.status_code = 200 + with patch.object(session, 'request', return_value=dummy_response) as mock: + Connection.execute_request( + 'GET', 'https://data.nasdaq.com/api/v3/databases', headers=headers, params=params) + mock.assert_called_once_with(method='GET', + url='https://data.nasdaq.com/api/v3/databases', + verify=False, + headers=headers, + params={'per_page': 10, 'page': 2})