From 63099f9f0041bead5f5893108a1dfedd11aec4bb Mon Sep 17 00:00:00 2001 From: Mike Guo Date: Fri, 11 Mar 2022 08:36:27 +0800 Subject: [PATCH] Fix memory analysis crash when there is multiple free with same address (#560) --- tb_plugin/torch_tb_profiler/profiler/memory_parser.py | 7 ------- tb_plugin/torch_tb_profiler/run.py | 9 ++++++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tb_plugin/torch_tb_profiler/profiler/memory_parser.py b/tb_plugin/torch_tb_profiler/profiler/memory_parser.py index 55205899d..014542c56 100644 --- a/tb_plugin/torch_tb_profiler/profiler/memory_parser.py +++ b/tb_plugin/torch_tb_profiler/profiler/memory_parser.py @@ -305,7 +305,6 @@ def get_preprocessed_records(self): memory_records = sorted(self.memory_records, key=lambda r: r.ts) alloc = {} # allocation events may or may not have paired free event - free = {} # free events that does not have paired alloc event prev_ts = float('-inf') # ensure ordered memory records is ordered for i, r in enumerate(memory_records): if r.addr is None: @@ -326,10 +325,4 @@ def get_preprocessed_records(self): r.op_name = alloc_r.op_name r.parent_op_name = alloc_r.parent_op_name del alloc[addr] - else: - assert addr not in free - free[addr] = i - - if free: - logger.debug(f'{len(free)} memory records do not have associated operator.') return memory_records diff --git a/tb_plugin/torch_tb_profiler/run.py b/tb_plugin/torch_tb_profiler/run.py index 353ae0b34..e54109eeb 100644 --- a/tb_plugin/torch_tb_profiler/run.py +++ b/tb_plugin/torch_tb_profiler/run.py @@ -4,13 +4,15 @@ from collections import defaultdict from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from . import consts +from . import consts, utils from .profiler.diffrun import compare_op_tree, diff_summary from .profiler.memory_parser import MemoryMetrics, MemoryRecord, MemorySnapshot from .profiler.module_op import Stats from .profiler.node import OperatorNode from .utils import Canonicalizer, DisplayRounder +logger = utils.get_logger() + class Run(object): """ A profiler run. For visualization purpose only. @@ -341,7 +343,7 @@ def get_op_name_or_ctx(record: MemoryRecord): # profile json data prior to pytorch 1.10 do not have addr # we should ignore them continue - assert prev_ts < r.ts + assert prev_ts <= r.ts prev_ts = r.ts addr = r.addr size = r.bytes @@ -362,7 +364,8 @@ def get_op_name_or_ctx(record: MemoryRecord): ]) del alloc[addr] else: - assert addr not in free + if addr in free: + logger.warning(f'Address {addr} is freed multiple times') free[addr] = i for i in alloc.values():