Skip to content

Commit

Permalink
Fix memory analysis crash when there is multiple free with same addre…
Browse files Browse the repository at this point in the history
…ss (#560)
  • Loading branch information
guotuofeng committed Mar 11, 2022
1 parent 742260c commit 63099f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
7 changes: 0 additions & 7 deletions tb_plugin/torch_tb_profiler/profiler/memory_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def get_preprocessed_records(self):
memory_records = sorted(self.memory_records, key=lambda r: r.ts)

alloc = {} # allocation events may or may not have paired free event
free = {} # free events that does not have paired alloc event
prev_ts = float('-inf') # ensure ordered memory records is ordered
for i, r in enumerate(memory_records):
if r.addr is None:
Expand All @@ -326,10 +325,4 @@ def get_preprocessed_records(self):
r.op_name = alloc_r.op_name
r.parent_op_name = alloc_r.parent_op_name
del alloc[addr]
else:
assert addr not in free
free[addr] = i

if free:
logger.debug(f'{len(free)} memory records do not have associated operator.')
return memory_records
9 changes: 6 additions & 3 deletions tb_plugin/torch_tb_profiler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from . import consts
from . import consts, utils
from .profiler.diffrun import compare_op_tree, diff_summary
from .profiler.memory_parser import MemoryMetrics, MemoryRecord, MemorySnapshot
from .profiler.module_op import Stats
from .profiler.node import OperatorNode
from .utils import Canonicalizer, DisplayRounder

logger = utils.get_logger()


class Run(object):
""" A profiler run. For visualization purpose only.
Expand Down Expand Up @@ -341,7 +343,7 @@ def get_op_name_or_ctx(record: MemoryRecord):
# profile json data prior to pytorch 1.10 do not have addr
# we should ignore them
continue
assert prev_ts < r.ts
assert prev_ts <= r.ts
prev_ts = r.ts
addr = r.addr
size = r.bytes
Expand All @@ -362,7 +364,8 @@ def get_op_name_or_ctx(record: MemoryRecord):
])
del alloc[addr]
else:
assert addr not in free
if addr in free:
logger.warning(f'Address {addr} is freed multiple times')
free[addr] = i

for i in alloc.values():
Expand Down

0 comments on commit 63099f9

Please sign in to comment.