From 8a2762beee0b7cca99941cba0a01b33ddffd0a4c Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 27 Nov 2024 15:32:48 +0100 Subject: [PATCH] [viz tool] add policy pred column --- lerobot/scripts/visualize_dataset_html.py | 5 +++-- .../templates/visualize_dataset_template.html | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 51272cb15..39ce81e6e 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -88,6 +88,7 @@ def run_server( port: str, static_folder: Path, template_folder: Path, + has_policy = False, ): app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @@ -130,7 +131,7 @@ def show_episode(dataset_namespace, dataset_name, episode_id): dataset_info=dataset_info, videos_info=videos_info, ep_csv_url=ep_csv_url, - has_policy=False, + has_policy = has_policy, ) app.run(host=host, port=port) @@ -344,7 +345,7 @@ def visualize_dataset_html( write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy) if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) + run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy=policy is not None) def main(): diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 0fa1e713e..cb178af65 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -229,7 +229,8 @@

dygraph: null, currentFrameData: null, columnNames: ["state", "action", "pred action"], - nColumns: 2, + hasPolicy: {% if has_policy %}true{% else %}false{% endif %}, + nColumns: {% if has_policy %}3{% else %}2{% endif %}, nStates: 0, nActions: 0, checked: [], @@ -278,6 +279,9 @@

const seriesNames = this.dygraph.getLabels().slice(1); this.nStates = seriesNames.findIndex(item => item.startsWith('action_')); this.nActions = seriesNames.length - this.nStates; + if(this.hasPolicy){ + this.nActions = Math.floor(this.nActions / 2); + } const colors = []; const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness // colors for "state" lines @@ -290,6 +294,13 @@

const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`; colors.push(color); } + if(this.hasPolicy){ + // colors for "action" lines + for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) { + const color = `hsl(${hue}, 100%, ${LIGHTNESS[2]}%)`; + colors.push(color); + } + } this.dygraph.updateOptions({ colors }); this.colors = colors; @@ -327,6 +338,10 @@

// row consists of [state value, action value] row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row + if(this.hasPolicy){ + const predActionValueIdx = stateValueIdx + this.nStates + this.nActions; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN, pred_action1, ..., pred_actionN] + row.push(rowIndex < this.nActions ? this.currentFrameData[predActionValueIdx] : nullCell); // push "action value" to row + } rowIndex += 1; rows.push(row); }