Skip to content

Commit

Permalink
Added better security around updating settings and usage of custom qu…
Browse files Browse the repository at this point in the history
…eries
  • Loading branch information
aschonfeld committed Mar 6, 2024
1 parent 57faae4 commit e6abe29
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 32 deletions.
6 changes: 3 additions & 3 deletions dtale/code_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def build_code_export(data_id, imports="import pandas as pd\n\n", query=None):
final_history = [startup_str] + history
final_query = query
if final_query is None:
final_query = build_query(data_id, settings.get("query"))
final_query = build_query(data_id, global_state.get_query(data_id))

if final_query is not None and final_query != "":
if len(ctxt_vars or {}):
Expand All @@ -113,9 +113,9 @@ def build_code_export(data_id, imports="import pandas as pd\n\n", query=None):
final_history.append(
"df = df.query({})\n".format(triple_quote(final_query))
)
elif settings.get("query"):
elif global_state.get_query(data_id):
final_history.append(
"df = df.query({})\n".format(triple_quote(settings["query"]))
"df = df.query({})\n".format(triple_quote(global_state.get_query(data_id)))
)
if "sortInfo" in settings:
cols, dirs = [], []
Expand Down
3 changes: 1 addition & 2 deletions dtale/column_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ class ColumnAnalysis(object):
def __init__(self, data_id, req):
self.data_id = data_id
self.analysis_type = get_str_arg(req, "type")
curr_settings = global_state.get_settings(data_id) or {}
self.query = build_query(data_id, curr_settings.get("query"))
self.query = build_query(data_id, global_state.get_query(data_id))
self.data = load_filterable_data(data_id, req, query=self.query)
self.selected_col = find_selected_column(
self.data, get_str_arg(req, "col", "values")
Expand Down
3 changes: 1 addition & 2 deletions dtale/dash_application/drilldown_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ def load_drilldown_content(
if click_data:
click_point = next((p for p in click_data.get("points", [])), None)
if click_point:
curr_settings = global_state.get_settings(data_id) or {}
query = build_query(
data_id, all_inputs.get("query") or curr_settings.get("query")
data_id, all_inputs.get("query") or global_state.get_query(data_id)
)
x_col = all_inputs.get("x")
y_col = next((y2 for y2 in make_list(all_inputs.get("y"))), None)
Expand Down
2 changes: 1 addition & 1 deletion dtale/data_reshapers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, data_id, shape_type, cfg):
def reshape(self):
data = run_query(
global_state.get_data(self.data_id),
(global_state.get_settings(self.data_id) or {}).get("query"),
global_state.get_query(self.data_id),
global_state.get_context_variables(self.data_id),
)
return self.builder.reshape(data)
Expand Down
6 changes: 6 additions & 0 deletions dtale/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,12 @@ def get_name(self, data_id):
def get_settings(self, data_id):
return self.get_data_inst(data_id).settings

def get_query(self, data_id):
if load_flag(data_id, "enable_custom_filters", False):
curr_settings = self.get_settings(data_id) or {}
return curr_settings.get("query")
return None

def get_metadata(self, data_id):
return self.get_data_inst(data_id).metadata

Expand Down
5 changes: 2 additions & 3 deletions dtale/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def build_col_key(col):
return "`{}`".format(col)


def build_query(data_id, query=None, as_query_builder=False):
def build_query(data_id, query=None):
curr_settings = global_state.get_settings(data_id) or {}
return inner_build_query(curr_settings, query)

Expand Down Expand Up @@ -204,9 +204,8 @@ def load_filterable_data(data_id, req, query=None, columns=None):
data = instance.load_data(columns=columns)
data, _ = format_data(data)
return data
curr_settings = global_state.get_settings(data_id) or {}
if filtered:
final_query = query or build_query(data_id, curr_settings.get("query"))
final_query = query or build_query(data_id, global_state.get_query(data_id))
return run_query(
handle_predefined(data_id),
final_query,
Expand Down
2 changes: 1 addition & 1 deletion dtale/timeseries_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, data_id, report_type, cfg):
def run(self):
data = run_query(
global_state.get_data(self.data_id),
(global_state.get_settings(self.data_id) or {}).get("query"),
global_state.get_query(self.data_id),
global_state.get_context_variables(self.data_id),
)
return self.report.run(data)
Expand Down
46 changes: 26 additions & 20 deletions dtale/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,11 @@ def update_settings(data_id):
:return: JSON
"""

global_state.update_settings(data_id, get_json_arg(request, "settings", {}))
updated_settings = get_json_arg(request, "settings", {})
if not global_state.load_flag(data_id, "enable_custom_filters", False):
updated_settings.pop("query", None)

global_state.update_settings(data_id, updated_settings)
return jsonify(dict(success=True))


Expand Down Expand Up @@ -2590,12 +2594,11 @@ def edit_cell(data_id):
row_index = get_int_arg(request, "rowIndex")
updated = get_str_arg(request, "updated")
updated_str = updated
curr_settings = global_state.get_settings(data_id)

# make sure to load filtered data in order to get correct row index
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
ignore_empty=True,
)
Expand Down Expand Up @@ -2771,7 +2774,7 @@ def get_data(data_id):

curr_settings = global_state.get_settings(data_id) or {}
curr_locked = curr_settings.get("locked", [])
final_query = build_query(data_id, curr_settings.get("query"))
final_query = build_query(data_id, global_state.get_query(data_id))
highlight_filter = curr_settings.get("highlightFilter") or False

if global_state.is_arcticdb:
Expand Down Expand Up @@ -3040,7 +3043,7 @@ def load_file(fpath, encoding="utf-8"):
@exception_decorator
def load_filtered_ranges(data_id):
curr_settings = global_state.get_settings(data_id) or {}
final_query = build_query(data_id, curr_settings.get("query"))
final_query = build_query(data_id, global_state.get_query(data_id))
if not final_query:
return {}
curr_filtered_ranges = curr_settings.get("filteredRanges", {})
Expand Down Expand Up @@ -3081,11 +3084,10 @@ def _filter_numeric(col):
@dtale.route("/data-export/<data_id>")
@exception_decorator
def data_export(data_id):
curr_settings = global_state.get_settings(data_id) or {}
curr_dtypes = global_state.get_dtypes(data_id) or []
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
ignore_empty=True,
)
Expand Down Expand Up @@ -3174,10 +3176,9 @@ def build_correlations_matrix_image(


def build_correlations_matrix(data_id, is_pps=False, encode_strings=False, image=False):
curr_settings = global_state.get_settings(data_id) or {}
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
)
valid_corr_cols, valid_str_corr_cols, valid_date_cols = correlations.get_col_groups(
Expand Down Expand Up @@ -3445,10 +3446,9 @@ def get_correlations_ts(data_id):
data: {:col1:col2: {data: [{corr: 0.99, date: 'YYYY-MM-DD'},...], max: 0.99, min: 0.99}
} or {error: 'Exception message', traceback: 'Exception stacktrace'}
"""
curr_settings = global_state.get_settings(data_id) or {}
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
)
cols = get_json_arg(request, "cols")
Expand Down Expand Up @@ -3565,10 +3565,9 @@ def get_scatter(data_id):
date_col = get_str_arg(request, "dateCol")
rolling = get_bool_arg(request, "rolling")

curr_settings = global_state.get_settings(data_id) or {}
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
)
idx_col = str("_corr_index")
Expand Down Expand Up @@ -4025,10 +4024,9 @@ def build_column_text(data_id):
columns = request.json.get("columns")
columns = json.loads(columns)

curr_settings = global_state.get_settings(data_id) or {}
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
ignore_empty=True,
)
Expand All @@ -4042,10 +4040,9 @@ def build_row_text(data_id):
request.json.get(p) for p in ["start", "end", "rows", "columns"]
)
columns = json.loads(columns)
curr_settings = global_state.get_settings(data_id) or {}
data = run_query(
handle_predefined(data_id),
build_query(data_id, curr_settings.get("query")),
build_query(data_id, global_state.get_query(data_id)),
global_state.get_context_variables(data_id),
ignore_empty=True,
)
Expand Down Expand Up @@ -4229,7 +4226,7 @@ def build_missingno_chart(chart_type, data_id):
@exception_decorator
def drop_filtered_rows(data_id):
curr_settings = global_state.get_settings(data_id) or {}
final_query = build_query(data_id, curr_settings.get("query"))
final_query = build_query(data_id, global_state.get_query(data_id))
curr_history = global_state.get_history(data_id) or []
curr_history += [
(
Expand Down Expand Up @@ -4265,8 +4262,17 @@ def drop_filtered_rows(data_id):
@dtale.route("/move-filters-to-custom/<data_id>")
@exception_decorator
def move_filters_to_custom(data_id):
curr_settings = global_state.get_settings(data_id) or {}
query = build_query(data_id, curr_settings.get("query"))
if not global_state.load_flag(data_id, "enable_custom_filters", False):
return jsonify(
dict(
success=False,
error=(
"Custom Filters not enabled! Custom filters are vulnerable to code injection attacks, please only "
"use in trusted environments."
),
)
)
query = build_query(data_id, global_state.get_query(data_id))
global_state.update_settings(
data_id,
{
Expand Down
13 changes: 13 additions & 0 deletions tests/dtale/correlations/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@

from pkg_resources import parse_version

import dtale.global_state as global_state
import dtale.pandas_util as pandas_util

from dtale.charts.utils import CHART_POINTS_LIMIT

from tests.dtale.test_views import app, build_ts_data
from tests.dtale import build_data_inst, build_settings, build_dtypes


def setup_function(function):
global_state.cleanup()


def teardown_function(function):
global_state.cleanup()


CORRELATIONS_CODE = """# DISCLAIMER: 'df' refers to the data you passed in when calling 'dtale.show'
import numpy as np
Expand Down Expand Up @@ -90,6 +100,7 @@ def test_get_correlations(unittest, test_data, rolling_data):
with app.test_client() as c:
build_data_inst({c.port: test_data})
build_dtypes({c.port: views.build_dtypes_state(test_data)})
global_state.set_app_settings(dict(enable_custom_filters=True))
settings = {c.port: {"query": "missing_col == 'blah'"}}
build_settings(settings)
response = c.get("/dtale/correlations/{}".format(c.port))
Expand Down Expand Up @@ -309,6 +320,7 @@ def test_get_correlations_ts(unittest, rolling_data):

with app.test_client() as c:
build_data_inst({c.port: test_data})
global_state.set_app_settings(dict(enable_custom_filters=True))
settings = {c.port: {"query": "missing_col == 'blah'"}}
build_settings(settings)
response = c.get("/dtale/correlations-ts/{}".format(c.port))
Expand Down Expand Up @@ -483,6 +495,7 @@ def test_get_scatter(unittest, rolling_data):
with app.test_client() as c:
build_data_inst({c.port: test_data})
build_dtypes({c.port: views.build_dtypes_state(test_data)})
global_state.set_app_settings(dict(enable_custom_filters=True))
settings = {c.port: {"query": "missing_col == 'blah'"}}
build_settings(settings)
params = dict(dateCol="date", cols=json.dumps(["foo", "bar"]), date="20000101")
Expand Down
1 change: 1 addition & 0 deletions tests/dtale/test_column_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_get_column_analysis(unittest, test_data):
build_data_inst({c.port: test_data})
build_dtypes({c.port: views.build_dtypes_state(test_data)})
build_settings({c.port: {}})
global_state.set_app_settings(dict(enable_custom_filters=True))
response = c.get(
"/dtale/column-analysis/{}".format(c.port),
query_string=dict(col="foo", filtered="true"),
Expand Down
1 change: 1 addition & 0 deletions tests/dtale/test_reshapers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def test_transpose(custom_data, unittest):
build_data_inst(data)
build_dtypes(dtypes)
build_settings(settings)
global_state.set_app_settings(dict(enable_custom_filters=True))
reshape_cfg = dict(index=["security_id"], columns=["Col0"])
resp = c.get(
"/dtale/reshape/{}".format(c.port),
Expand Down
1 change: 1 addition & 0 deletions tests/dtale/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,7 @@ def test_get_data(unittest, test_data):
with app.test_client() as c:
build_data_inst({c.port: test_data})
build_dtypes({c.port: views.build_dtypes_state(test_data)})
global_state.set_app_settings(dict(enable_custom_filters=True))
build_settings({c.port: dict(query="missing_col == 'blah'")})
response = c.get(
"/dtale/data/{}".format(c.port), query_string=dict(ids=json.dumps(["0"]))
Expand Down

0 comments on commit e6abe29

Please sign in to comment.