From e6abe2974fb7e0da7152a669b92de171d01df15a Mon Sep 17 00:00:00 2001 From: "Andrew Schonfeld (Boston)" Date: Tue, 5 Mar 2024 09:25:32 -0500 Subject: [PATCH] Added better security around updating settings and usage of custom queries --- dtale/code_export.py | 6 +-- dtale/column_analysis.py | 3 +- dtale/dash_application/drilldown_modal.py | 3 +- dtale/data_reshapers.py | 2 +- dtale/global_state.py | 6 +++ dtale/query.py | 5 +-- dtale/timeseries_analysis.py | 2 +- dtale/views.py | 46 +++++++++++++---------- tests/dtale/correlations/test_views.py | 13 +++++++ tests/dtale/test_column_analysis.py | 1 + tests/dtale/test_reshapers.py | 1 + tests/dtale/test_views.py | 1 + 12 files changed, 57 insertions(+), 32 deletions(-) diff --git a/dtale/code_export.py b/dtale/code_export.py index f9659ae3..4a78d6d5 100644 --- a/dtale/code_export.py +++ b/dtale/code_export.py @@ -95,7 +95,7 @@ def build_code_export(data_id, imports="import pandas as pd\n\n", query=None): final_history = [startup_str] + history final_query = query if final_query is None: - final_query = build_query(data_id, settings.get("query")) + final_query = build_query(data_id, global_state.get_query(data_id)) if final_query is not None and final_query != "": if len(ctxt_vars or {}): @@ -113,9 +113,9 @@ def build_code_export(data_id, imports="import pandas as pd\n\n", query=None): final_history.append( "df = df.query({})\n".format(triple_quote(final_query)) ) - elif settings.get("query"): + elif global_state.get_query(data_id): final_history.append( - "df = df.query({})\n".format(triple_quote(settings["query"])) + "df = df.query({})\n".format(triple_quote(global_state.get_query(data_id))) ) if "sortInfo" in settings: cols, dirs = [], [] diff --git a/dtale/column_analysis.py b/dtale/column_analysis.py index de156375..4b81b2fa 100644 --- a/dtale/column_analysis.py +++ b/dtale/column_analysis.py @@ -83,8 +83,7 @@ class ColumnAnalysis(object): def __init__(self, data_id, req): self.data_id = data_id self.analysis_type = get_str_arg(req, "type") - curr_settings = global_state.get_settings(data_id) or {} - self.query = build_query(data_id, curr_settings.get("query")) + self.query = build_query(data_id, global_state.get_query(data_id)) self.data = load_filterable_data(data_id, req, query=self.query) self.selected_col = find_selected_column( self.data, get_str_arg(req, "col", "values") diff --git a/dtale/dash_application/drilldown_modal.py b/dtale/dash_application/drilldown_modal.py index 77f0b3ac..52cf996f 100644 --- a/dtale/dash_application/drilldown_modal.py +++ b/dtale/dash_application/drilldown_modal.py @@ -219,9 +219,8 @@ def load_drilldown_content( if click_data: click_point = next((p for p in click_data.get("points", [])), None) if click_point: - curr_settings = global_state.get_settings(data_id) or {} query = build_query( - data_id, all_inputs.get("query") or curr_settings.get("query") + data_id, all_inputs.get("query") or global_state.get_query(data_id) ) x_col = all_inputs.get("x") y_col = next((y2 for y2 in make_list(all_inputs.get("y"))), None) diff --git a/dtale/data_reshapers.py b/dtale/data_reshapers.py index 3949df33..90bdd950 100644 --- a/dtale/data_reshapers.py +++ b/dtale/data_reshapers.py @@ -43,7 +43,7 @@ def __init__(self, data_id, shape_type, cfg): def reshape(self): data = run_query( global_state.get_data(self.data_id), - (global_state.get_settings(self.data_id) or {}).get("query"), + global_state.get_query(self.data_id), global_state.get_context_variables(self.data_id), ) return self.builder.reshape(data) diff --git a/dtale/global_state.py b/dtale/global_state.py index 8a966a71..28eaeabc 100644 --- a/dtale/global_state.py +++ b/dtale/global_state.py @@ -443,6 +443,12 @@ def get_name(self, data_id): def get_settings(self, data_id): return self.get_data_inst(data_id).settings + def get_query(self, data_id): + if load_flag(data_id, "enable_custom_filters", False): + curr_settings = self.get_settings(data_id) or {} + return curr_settings.get("query") + return None + def get_metadata(self, data_id): return self.get_data_inst(data_id).metadata diff --git a/dtale/query.py b/dtale/query.py index c2dfaba2..6bc90b1f 100644 --- a/dtale/query.py +++ b/dtale/query.py @@ -11,7 +11,7 @@ def build_col_key(col): return "`{}`".format(col) -def build_query(data_id, query=None, as_query_builder=False): +def build_query(data_id, query=None): curr_settings = global_state.get_settings(data_id) or {} return inner_build_query(curr_settings, query) @@ -204,9 +204,8 @@ def load_filterable_data(data_id, req, query=None, columns=None): data = instance.load_data(columns=columns) data, _ = format_data(data) return data - curr_settings = global_state.get_settings(data_id) or {} if filtered: - final_query = query or build_query(data_id, curr_settings.get("query")) + final_query = query or build_query(data_id, global_state.get_query(data_id)) return run_query( handle_predefined(data_id), final_query, diff --git a/dtale/timeseries_analysis.py b/dtale/timeseries_analysis.py index b117509d..c7b3614b 100644 --- a/dtale/timeseries_analysis.py +++ b/dtale/timeseries_analysis.py @@ -42,7 +42,7 @@ def __init__(self, data_id, report_type, cfg): def run(self): data = run_query( global_state.get_data(self.data_id), - (global_state.get_settings(self.data_id) or {}).get("query"), + global_state.get_query(self.data_id), global_state.get_context_variables(self.data_id), ) return self.report.run(data) diff --git a/dtale/views.py b/dtale/views.py index 46f10d09..db2c7da3 100644 --- a/dtale/views.py +++ b/dtale/views.py @@ -1583,7 +1583,11 @@ def update_settings(data_id): :return: JSON """ - global_state.update_settings(data_id, get_json_arg(request, "settings", {})) + updated_settings = get_json_arg(request, "settings", {}) + if not global_state.load_flag(data_id, "enable_custom_filters", False): + updated_settings.pop("query", None) + + global_state.update_settings(data_id, updated_settings) return jsonify(dict(success=True)) @@ -2590,12 +2594,11 @@ def edit_cell(data_id): row_index = get_int_arg(request, "rowIndex") updated = get_str_arg(request, "updated") updated_str = updated - curr_settings = global_state.get_settings(data_id) # make sure to load filtered data in order to get correct row index data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ignore_empty=True, ) @@ -2771,7 +2774,7 @@ def get_data(data_id): curr_settings = global_state.get_settings(data_id) or {} curr_locked = curr_settings.get("locked", []) - final_query = build_query(data_id, curr_settings.get("query")) + final_query = build_query(data_id, global_state.get_query(data_id)) highlight_filter = curr_settings.get("highlightFilter") or False if global_state.is_arcticdb: @@ -3040,7 +3043,7 @@ def load_file(fpath, encoding="utf-8"): @exception_decorator def load_filtered_ranges(data_id): curr_settings = global_state.get_settings(data_id) or {} - final_query = build_query(data_id, curr_settings.get("query")) + final_query = build_query(data_id, global_state.get_query(data_id)) if not final_query: return {} curr_filtered_ranges = curr_settings.get("filteredRanges", {}) @@ -3081,11 +3084,10 @@ def _filter_numeric(col): @dtale.route("/data-export/") @exception_decorator def data_export(data_id): - curr_settings = global_state.get_settings(data_id) or {} curr_dtypes = global_state.get_dtypes(data_id) or [] data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ignore_empty=True, ) @@ -3174,10 +3176,9 @@ def build_correlations_matrix_image( def build_correlations_matrix(data_id, is_pps=False, encode_strings=False, image=False): - curr_settings = global_state.get_settings(data_id) or {} data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ) valid_corr_cols, valid_str_corr_cols, valid_date_cols = correlations.get_col_groups( @@ -3445,10 +3446,9 @@ def get_correlations_ts(data_id): data: {:col1:col2: {data: [{corr: 0.99, date: 'YYYY-MM-DD'},...], max: 0.99, min: 0.99} } or {error: 'Exception message', traceback: 'Exception stacktrace'} """ - curr_settings = global_state.get_settings(data_id) or {} data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ) cols = get_json_arg(request, "cols") @@ -3565,10 +3565,9 @@ def get_scatter(data_id): date_col = get_str_arg(request, "dateCol") rolling = get_bool_arg(request, "rolling") - curr_settings = global_state.get_settings(data_id) or {} data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ) idx_col = str("_corr_index") @@ -4025,10 +4024,9 @@ def build_column_text(data_id): columns = request.json.get("columns") columns = json.loads(columns) - curr_settings = global_state.get_settings(data_id) or {} data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ignore_empty=True, ) @@ -4042,10 +4040,9 @@ def build_row_text(data_id): request.json.get(p) for p in ["start", "end", "rows", "columns"] ) columns = json.loads(columns) - curr_settings = global_state.get_settings(data_id) or {} data = run_query( handle_predefined(data_id), - build_query(data_id, curr_settings.get("query")), + build_query(data_id, global_state.get_query(data_id)), global_state.get_context_variables(data_id), ignore_empty=True, ) @@ -4229,7 +4226,7 @@ def build_missingno_chart(chart_type, data_id): @exception_decorator def drop_filtered_rows(data_id): curr_settings = global_state.get_settings(data_id) or {} - final_query = build_query(data_id, curr_settings.get("query")) + final_query = build_query(data_id, global_state.get_query(data_id)) curr_history = global_state.get_history(data_id) or [] curr_history += [ ( @@ -4265,8 +4262,17 @@ def drop_filtered_rows(data_id): @dtale.route("/move-filters-to-custom/") @exception_decorator def move_filters_to_custom(data_id): - curr_settings = global_state.get_settings(data_id) or {} - query = build_query(data_id, curr_settings.get("query")) + if not global_state.load_flag(data_id, "enable_custom_filters", False): + return jsonify( + dict( + success=False, + error=( + "Custom Filters not enabled! Custom filters are vulnerable to code injection attacks, please only " + "use in trusted environments." + ), + ) + ) + query = build_query(data_id, global_state.get_query(data_id)) global_state.update_settings( data_id, { diff --git a/tests/dtale/correlations/test_views.py b/tests/dtale/correlations/test_views.py index e6f7d678..4b2aec90 100644 --- a/tests/dtale/correlations/test_views.py +++ b/tests/dtale/correlations/test_views.py @@ -6,6 +6,7 @@ from pkg_resources import parse_version +import dtale.global_state as global_state import dtale.pandas_util as pandas_util from dtale.charts.utils import CHART_POINTS_LIMIT @@ -13,6 +14,15 @@ from tests.dtale.test_views import app, build_ts_data from tests.dtale import build_data_inst, build_settings, build_dtypes + +def setup_function(function): + global_state.cleanup() + + +def teardown_function(function): + global_state.cleanup() + + CORRELATIONS_CODE = """# DISCLAIMER: 'df' refers to the data you passed in when calling 'dtale.show' import numpy as np @@ -90,6 +100,7 @@ def test_get_correlations(unittest, test_data, rolling_data): with app.test_client() as c: build_data_inst({c.port: test_data}) build_dtypes({c.port: views.build_dtypes_state(test_data)}) + global_state.set_app_settings(dict(enable_custom_filters=True)) settings = {c.port: {"query": "missing_col == 'blah'"}} build_settings(settings) response = c.get("/dtale/correlations/{}".format(c.port)) @@ -309,6 +320,7 @@ def test_get_correlations_ts(unittest, rolling_data): with app.test_client() as c: build_data_inst({c.port: test_data}) + global_state.set_app_settings(dict(enable_custom_filters=True)) settings = {c.port: {"query": "missing_col == 'blah'"}} build_settings(settings) response = c.get("/dtale/correlations-ts/{}".format(c.port)) @@ -483,6 +495,7 @@ def test_get_scatter(unittest, rolling_data): with app.test_client() as c: build_data_inst({c.port: test_data}) build_dtypes({c.port: views.build_dtypes_state(test_data)}) + global_state.set_app_settings(dict(enable_custom_filters=True)) settings = {c.port: {"query": "missing_col == 'blah'"}} build_settings(settings) params = dict(dateCol="date", cols=json.dumps(["foo", "bar"]), date="20000101") diff --git a/tests/dtale/test_column_analysis.py b/tests/dtale/test_column_analysis.py index 7ef57fde..b0e532fe 100644 --- a/tests/dtale/test_column_analysis.py +++ b/tests/dtale/test_column_analysis.py @@ -24,6 +24,7 @@ def test_get_column_analysis(unittest, test_data): build_data_inst({c.port: test_data}) build_dtypes({c.port: views.build_dtypes_state(test_data)}) build_settings({c.port: {}}) + global_state.set_app_settings(dict(enable_custom_filters=True)) response = c.get( "/dtale/column-analysis/{}".format(c.port), query_string=dict(col="foo", filtered="true"), diff --git a/tests/dtale/test_reshapers.py b/tests/dtale/test_reshapers.py index d8d13cdf..246a2db1 100644 --- a/tests/dtale/test_reshapers.py +++ b/tests/dtale/test_reshapers.py @@ -269,6 +269,7 @@ def test_transpose(custom_data, unittest): build_data_inst(data) build_dtypes(dtypes) build_settings(settings) + global_state.set_app_settings(dict(enable_custom_filters=True)) reshape_cfg = dict(index=["security_id"], columns=["Col0"]) resp = c.get( "/dtale/reshape/{}".format(c.port), diff --git a/tests/dtale/test_views.py b/tests/dtale/test_views.py index 109cf3aa..b5ad343c 100644 --- a/tests/dtale/test_views.py +++ b/tests/dtale/test_views.py @@ -1758,6 +1758,7 @@ def test_get_data(unittest, test_data): with app.test_client() as c: build_data_inst({c.port: test_data}) build_dtypes({c.port: views.build_dtypes_state(test_data)}) + global_state.set_app_settings(dict(enable_custom_filters=True)) build_settings({c.port: dict(query="missing_col == 'blah'")}) response = c.get( "/dtale/data/{}".format(c.port), query_string=dict(ids=json.dumps(["0"]))