From 3a84d819df81377cd0730d0304d21529c3fe9247 Mon Sep 17 00:00:00 2001 From: songyant Date: Thu, 1 Aug 2024 03:57:17 -0700 Subject: [PATCH 1/3] split the tensor allocate operation of comm tensors from function prepComms to function generate_io_tensors --- et_replay/tools/comm_replay.py | 44 ++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 384f690f..d7a96577 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -12,7 +12,7 @@ import logging import os import time -from typing import Dict, List, Set +from typing import Dict, List, Set, Tuple import numpy as np import torch @@ -610,12 +610,35 @@ def hashEtCommsOp(self, commsOp: commsArgs) -> int: return hash(op) + def generate_io_tensors( + self, + curComm: commsArgs, + commsParams: commsParamsHolderBase, + regenerateTensors: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Use exactly specified inMsgSize/outMsgSize if call from trace replay + # This avoid regenerating sizes such as in _prep_all_gather_base + commsParams.size_from_trace = True + commsParams.dtype = self.dtypeMap[curComm.dtype] + if not curComm.id or regenerateTensors: + return super().prepComm(curComm, commsParams) + else: + commsOpHash = self.hashEtCommsOp(curComm) + if commsOpHash in self.et_to_tensors: + # Allocate input/output tensors if first time replay, otherwise the previous ones. + super().prepComm(curComm, commsParams, False) + (ipTensor, opTensor) = self.et_to_tensors[commsOpHash] + else: + (ipTensor, opTensor) = super().prepComm(curComm, commsParams, True) + self.et_to_tensors[commsOpHash] = (ipTensor, opTensor) + return (ipTensor, opTensor) + def prepComms( self, curComm: commsArgs, commsParams: commsParamsHolderBase, regenerateTensors: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepares the appropriate tensors for the current collective communication. @@ -686,22 +709,7 @@ def prepComms( f"shrink message sizes to curInNumElem {curComm.inMsgSize}, curOutNumElem {curComm.outMsgSize}" ) - # Use exactly specified inMsgSize/outMsgSize if call from trace replay - # This avoid regenerating sizes such as in _prep_all_gather_base - commsParams.size_from_trace = True - commsParams.dtype = self.dtypeMap[curComm.dtype] - if not curComm.id or regenerateTensors: - return super().prepComm(curComm, commsParams) - else: - commsOpHash = self.hashEtCommsOp(curComm) - if commsOpHash in self.et_to_tensors: - # Allocate input/output tensors if first time replay, otherwise the previous ones. - super().prepComm(curComm, commsParams, False) - (ipTensor, opTensor) = self.et_to_tensors[commsOpHash] - else: - (ipTensor, opTensor) = super().prepComm(curComm, commsParams, True) - self.et_to_tensors[commsOpHash] = (ipTensor, opTensor) - return (ipTensor, opTensor) + return self.generate_io_tensors(curComm, commsParams, regenerateTensors) def commRebalance(self, curComm: commsArgs) -> None: """ From 24e502417abff20e8ade81c74b417c0c588d3336 Mon Sep 17 00:00:00 2001 From: songyant Date: Thu, 1 Aug 2024 05:13:57 -0700 Subject: [PATCH 2/3] split one comm node replay functional code into function replaySingle and invoke replaySingle to replay every node in function replayTrace --- et_replay/tools/comm_replay.py | 268 +++++++++++++-------------------- 1 file changed, 107 insertions(+), 161 deletions(-) diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index d7a96577..d774c301 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -121,6 +121,8 @@ def __init__(self): self.out_path = "" self.outputRanks = None self.colls_per_batch = -1 + self.coll_in_batch_num = 0 + self.replay_start_time = -1 self.use_timestamp = False self.num_replays = 1 self.profiler_num_replays_start = 0 @@ -997,189 +999,133 @@ def replayTrace( Returns: None """ + self.coll_in_batch_num = 0 + self.replay_start_time = time.monotonic_ns() + for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]): + self.replaySingle(commsParams, curComm, cnt, warmup) + + def replaySingle(self, commsParams: commsParamsHolderBase, curComm: commsArgs, cnt: int, warmup: bool = False): if warmup: logLable = "[Warm-up]" else: logLable = f"[Replay {self.replayIter}]" - coll_in_batch_num = 0 - startTime = time.monotonic_ns() - for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]): - curBlocks = curComm.markerStack if curComm.markerStack is not None else [] - curBlockStack = ( - " ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown" - ) - - # Replay compute - if curComm.compute is not None: - # Prepare to run the compute function - computeFunc = self.prepComputeReplay(commsParams, curComm) - - # Running the kernel - logger.info( - f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {curComm.compute}" - ) - - # Run the kernel and report the total time - (latency, global_latency) = self.runCompute( - func=computeFunc, curBlockStack=curBlockStack - ) - recordName = curComm.compute - - # Replay comm - else: - if warmup: - self.commRebalance(curComm) - - # Get the name of the collective from the comm object - collName = paramToCommName(curComm.comms) - (groupRank, groupDesc) = self.getCommGroupInfo(curComm, commsParams) - # Skip comm if the local process doesn't belong to the PG or encounter an unexpected collective - if ( - collName not in self.allowList - or groupRank == -1 - or ( - collName in ("send", "isend") - and curComm.src_rank != self.backendFuncs.get_global_rank() - ) - or ( - collName in ("recv", "irecv") - and curComm.dst_rank != self.backendFuncs.get_global_rank() - ) - ): - continue + curBlocks = curComm.markerStack if curComm.markerStack is not None else [] + curBlockStack = ( + " ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown" + ) - if groupRank >= 0: - commDesc = f"{str(curComm.comms)}: NumElemsIn={curComm.inMsgSize}, NumElemsOut={curComm.outMsgSize}, Dtype={curComm.dtype}" - if curComm.comms == "all_to_allv": - commDesc += ( - f", InSplit={curComm.inSplit}, OutSplit={curComm.outSplit}" - ) - if curComm.comms in supportedP2pOps: - commDesc += f", Src_Rank={curComm.src_rank}, Dst_Rank={curComm.dst_rank}" - logger.info( - f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {commDesc} with {groupDesc}" - ) + # Replay compute + if curComm.compute is not None: + # Prepare to run the compute function + computeFunc = self.prepComputeReplay(commsParams, curComm) - # read fields and prepare the tensors - ( - self.collectiveArgs.ipTensor, - self.collectiveArgs.opTensor, - ) = self.prepComms(curComm, commsParams, not self.reuse_tensors) + # Running the kernel + logger.info( + f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {curComm.compute}" + ) - if not warmup and self.colls_per_batch > 0 and coll_in_batch_num == 0: - batch_begin = time.monotonic() + # Run the kernel and report the total time + (latency, global_latency) = self.runCompute( + func=computeFunc, curBlockStack=curBlockStack + ) + recordName = curComm.compute - # wait for collective timestamp if enabled. - if not warmup and self.use_timestamp: - self.waitForTimestamp(curComm, startTime) + # Replay comm + else: + if warmup: + self.commRebalance(curComm) - # send comm request to pytorch backend - (latency, global_latency) = self.runComms( - collName, curComm, curBlockStack + # Get the name of the collective from the comm object + collName = paramToCommName(curComm.comms) + (groupRank, groupDesc) = self.getCommGroupInfo(curComm, commsParams) + # Skip comm if the local process doesn't belong to the PG or encounter an unexpected collective + if ( + collName not in self.allowList + or groupRank == -1 + or ( + collName in ("send", "isend") + and curComm.src_rank != self.backendFuncs.get_global_rank() ) - - # perform data validation check on the final opTensor - if ( - self.is_blocking - and commsParams.dcheck == 1 - and collName not in ("wait", "barrier") - ): - commsParams.collective = collName - commsParams.srcOrDst = ( - curComm.root if curComm.root is not None else 0 - ) - self.dcheck( - commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor - ) - - # calculating batch latency (batch defined by --colls-per-batch) - if not warmup and collName == "wait" and self.colls_per_batch > 0: - coll_in_batch_num += 1 - if coll_in_batch_num == self.colls_per_batch: - batch_latency = ( - time.monotonic() - batch_begin - ) * 1e3 # make it millisecond - coll_in_batch_num = 0 - self.batchLat.append(batch_latency) - - recordName = collName - - if not warmup: - # record performance metrics - self.recordCommReplay( - commsParams, - curComm, - recordName, - latency, - curBlockStack, - global_latency, - curBlocks, + or ( + collName in ("recv", "irecv") + and curComm.dst_rank != self.backendFuncs.get_global_rank() ) + ): + return - if self.backendFuncs.get_global_rank() == 0: + if groupRank >= 0: + commDesc = f"{str(curComm.comms)}: NumElemsIn={curComm.inMsgSize}, NumElemsOut={curComm.outMsgSize}, Dtype={curComm.dtype}" + if curComm.comms == "all_to_allv": + commDesc += ( + f", InSplit={curComm.inSplit}, OutSplit={curComm.outSplit}" + ) + if curComm.comms in supportedP2pOps: + commDesc += f", Src_Rank={curComm.src_rank}, Dst_Rank={curComm.dst_rank}" logger.info( - f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} in block [{curBlockStack}]... {global_latency:.2f} us" + f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {commDesc} with {groupDesc}" ) - def replaySingle( - self, - commsParams: commsParamsHolderBase, - id: int, - regenerateTensors: bool = True, - ) -> torch.Tensor: - """ - Replay comms trace. - Args: - commsParams: Run-time parameters for replay. - id: comms op id. - Returns: - Output tensor. - """ - for _, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]): - if curComm.id == id: - collName = paramToCommName(curComm.comms) - if collName not in self.allowList: - return torch.Tensor() - - curBlocks = ( - curComm.markerStack if curComm.markerStack is not None else [] - ) - curBlockStack = ( - " ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown" - ) + # read fields and prepare the tensors + ( + self.collectiveArgs.ipTensor, + self.collectiveArgs.opTensor, + ) = self.prepComms(curComm, commsParams, not self.reuse_tensors) - if self.backendFuncs.get_global_rank() == 0: - logger.debug( - f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm.comms)}\n" - ) + if not warmup and self.colls_per_batch > 0 and self.coll_in_batch_num == 0: + batch_begin = time.monotonic() - # read fields and prepare the tensors - ( - self.collectiveArgs.ipTensor, - self.collectiveArgs.opTensor, - ) = self.prepComms(curComm, commsParams, regenerateTensors) + # wait for collective timestamp if enabled. + if not warmup and self.use_timestamp: + self.waitForTimestamp(curComm, self.replay_start_time) - # send comm request to pytorch backend - (latency, global_latency) = self.runComms( - collName, curComm, curBlockStack + # send comm request to pytorch backend + (latency, global_latency) = self.runComms( + collName, curComm, curBlockStack + ) + + # perform data validation check on the final opTensor + if ( + self.is_blocking + and commsParams.dcheck == 1 + and collName not in ("wait", "barrier") + ): + commsParams.collective = collName + commsParams.srcOrDst = ( + curComm.root if curComm.root is not None else 0 + ) + self.dcheck( + commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor ) - # perform data validation check on the final opTensor - if ( - self.is_blocking - and commsParams.dcheck == 1 - and collName not in ("wait", "barrier") - ): - commsParams.collective = collName - commsParams.srcOrDst = ( - curComm.root if curComm.root is not None else 0 - ) - self.dcheck( - commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor - ) + # calculating batch latency (batch defined by --colls-per-batch) + if not warmup and collName == "wait" and self.colls_per_batch > 0: + self.coll_in_batch_num += 1 + if self.coll_in_batch_num == self.colls_per_batch: + batch_latency = ( + time.monotonic() - batch_begin + ) * 1e3 # make it millisecond + self.coll_in_batch_num = 0 + self.batchLat.append(batch_latency) + + recordName = collName + + if not warmup: + # record performance metrics + self.recordCommReplay( + commsParams, + curComm, + recordName, + latency, + curBlockStack, + global_latency, + curBlocks, + ) - return self.collectiveArgs.opTensor + if self.backendFuncs.get_global_rank() == 0: + logger.info( + f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} in block [{curBlockStack}]... {global_latency:.2f} us" + ) def benchTime(self, commsParams: commsParamsHolderBase) -> None: """ From d7dfd4146ebb212d671779a8932138f789e331f5 Mon Sep 17 00:00:00 2001 From: songyant Date: Tue, 6 Aug 2024 22:41:31 -0700 Subject: [PATCH 3/3] add tensor lazy allocate and recyle strategy to comp replay --- et_replay/et_replay_utils.py | 2 +- et_replay/tools/et_replay.py | 371 ++++++++++++++++++++--------------- 2 files changed, 209 insertions(+), 164 deletions(-) diff --git a/et_replay/et_replay_utils.py b/et_replay/et_replay_utils.py index 5013bb62..f02e08fa 100644 --- a/et_replay/et_replay_utils.py +++ b/et_replay/et_replay_utils.py @@ -368,7 +368,7 @@ def build_torchscript_func(n): if ( n.op_schema == "" or n.name == "aten::record_stream" - or n.name.startswith("aten::_foreach") + #or n.name.startswith("aten::_foreach") ): return None, None diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 480c078f..0e7dcaa4 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -5,9 +5,12 @@ import os import sys import time -from collections import defaultdict +import copy +from collections import defaultdict, Counter +from collections.abc import Iterable from datetime import datetime -from typing import Dict +from functools import reduce, partial +from typing import Optional, Tuple, Dict, List import numpy as np import torch @@ -33,7 +36,7 @@ TORCH_DTYPES_RNG, TORCH_DTYPES_RNG_str, ) -from et_replay.execution_trace import ExecutionTrace +from et_replay.execution_trace import ExecutionTrace, NodeType, Node from et_replay.utils import trace_handler from param_bench.train.compute.python.lib import pytorch as lib_pytorch from param_bench.train.compute.python.lib.init_helper import load_modules @@ -70,6 +73,8 @@ def __init__(self): self.commsParams = None self.regenerate_tensors = None + self.recycle_storages = False + self.cuda = "cuda" self.device = torch.device(self.cuda) @@ -99,9 +104,18 @@ def __init__(self): # Dict that stores the shapes of a tensor, for the convenience of quickly determining whether # to create a unique tensor in replay if the id is same but shape is different. self.tensor_shapes = defaultdict(set) - # Dict that maps tensor storage id to its size, and a map for {device, torch.Tensor}. + # Dict that maps tensor storage id to a map for {device, torch.Tensor}. # The tensor with the same storage id may located on different devices. - self.tensor_storage_map: Dict[int, []] = defaultdict(set) + self.tensor_storage_map: Dict[int, Dict[torch.device, torch.Tensor]] = {} + self.tensor_alloc_set = set() + # Dict that maps tensor storage id to its size + self.tensor_storage_sizes: Dict[int, int] = defaultdict(int) + # List that holds referenced tensor storages and devices when replay an operation, after the operation complete, it will be cleared + self.referenced_tensor_storage_ids: List[Tuple[int, torch.device]] = [] + # Ref cnt of instantiate tensors, used to recycle storages + self.instantiate_tensor_ref_cnts: Dict[int, Dict[int, Dict[torch.device, Counter]]] = defaultdict(lambda: defaultdict(Counter)) + # Ref cnt for each iter + self.instantiate_tensor_ref_cnts_per_it: Dict[int, Dict[int, Dict[torch.device, Counter]]] = defaultdict(lambda: defaultdict(Counter)) # Mark those tensors that occur first as an input in the original et as needing to be instantiated in replay # at the very beginning. self.instantiate = set() @@ -199,6 +213,7 @@ def initBench(self): self.wait_delay = self.args.delay self.cpu = self.args.cpu self.tf32 = self.args.tf32 + self.recycle_storages = self.args.recycle_storages # Single trace. if not self.args.trace_path: @@ -240,7 +255,7 @@ def initBench(self): ) self.et = ExecutionTrace(json.load(et)) else: - self.trace_file = f"{self.args.trace_path}/rank{self.comms_env_params['global_rank']}.json" + self.trace_file = f"{self.args.trace_path}/rank-{self.comms_env_params['global_rank']}.json" with open(self.trace_file, "r") as f: self.et = ExecutionTrace(json.load(f)) @@ -283,8 +298,8 @@ def reset_registry(self): None if v is None else ( - v.cpu() - if self.tensor_device[k] == "cpu" or self.cpu + v + if self.tensor_device[k] == "cpu" or self.cpu or (isinstance(v, tuple) and v[0] == "lazy_alloc") else v.cuda(self.device) ) ) @@ -296,13 +311,13 @@ def reset_registry(self): None if v is None else ( - v.cpu() - if k in self.cpu_tensor or self.cpu + v if k in self.cpu_tensor or self.cpu or (isinstance(v, tuple) and v[0] == "lazy_alloc") else v.cuda(self.device) ) ) for k, v in self.tensor_registry_permanent.items() } + self.tensor_registry_permanent.clear() gc.collect() torch.cuda.empty_cache() @@ -383,27 +398,7 @@ def has_parallel_parent(node): assert len(self.parallel_nodes_ids) == len(set(self.parallel_nodes_ids)) def analyze_tensors(self): - def add_storage_tensor(t_id, device): - # t_id is a tupe of (tensor_id, storage_id, offset, number of element, - # number of bytes for each element, device) - - # ET does not save the size of the tensor storage, so we iterate over all the - # tensors to find the maximum size of the storage. - storage_id = t_id[1] - if storage_id not in self.tensor_storage_map: - # the storage size for this tensor is the sum of the storage offset and - # number of elements * number of bytes per element. - self.tensor_storage_map[storage_id] = [ - t_id[2] + t_id[3] * t_id[4], - {}, - ] - else: - self.tensor_storage_map[storage_id][0] = max( - self.tensor_storage_map[storage_id][0], t_id[2] + t_id[3] * t_id[4] - ) - def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): - add_storage_tensor(t_id, device) # If we did not see this tensor before, add it as a unique tensor. if t_id not in self.original_unique_tensors: self.original_unique_tensors.add(t_id) @@ -492,9 +487,64 @@ def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): if t_id in self.input_tensor_ids: output_set.add(self.tensors_mapping[(node.id, t_id, False)]) - def allocate_tensors(self): - start_ns = time.time_ns() + def get_tensor_from_storage(self, node, storage_id, data_offset, elem_bytes, device, shape, data_type): + tensor_data = self.tensor_storage_map.setdefault(storage_id, {}) + device = torch.device(device) + + if device not in tensor_data: + if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: + storage_tensor = torch.rand( + (self.tensor_storage_sizes[storage_id] // elem_bytes), dtype=data_type, device=device + ) + else: + storage_tensor = torch.ones( + (self.tensor_storage_sizes[storage_id] // elem_bytes), dtype=data_type, device=device + ) + + tensor_data[device] = storage_tensor + + if (storage_id, device) not in self.tensor_alloc_set: + self.tensor_alloc_set.add((storage_id, device)) + else: + print("repeat alloc, may caused by wrong recycle:", (storage_id, device)) + exit(1) + else: + storage_tensor = tensor_data[device] + x = torch.empty(0, dtype=data_type) + if device != torch.device("cpu"): + x = x.cuda(torch.device(device)) + x = x.set_( + storage_tensor.untyped_storage(), + storage_offset=data_offset, + size=shape, + ) + + return x + def add_tensor_registry_permanent(self, node_id, data_type, storage_id, storage_offset, element_num, item_size, device_str, shape, replay_t_id): + if data_type == "Tensor(signed char)": + dtype, _ = TORCH_DTYPES_RNG["signed char"] + else: + dtype, _ = TORCH_DTYPES_RNG[ + data_type.lstrip("Tensor(").rstrip(")") + ] + + device = torch.device(device_str) + + self.tensor_storage_sizes[storage_id] = max(storage_offset + element_num * item_size, self.tensor_storage_sizes[storage_id]) + + self.tensor_registry_permanent[replay_t_id] = ("lazy_alloc", (storage_id, device), + partial(self.get_tensor_from_storage, + storage_id=storage_id, + data_offset=storage_offset, + elem_bytes=item_size, + device=device_str, + shape=shape, + data_type=dtype + ) + ) + + def allocate_tensors(self): for node in self.sorted_nodes: if node.name == "record_param_comms" and ( self.compute_only or self.args.separate @@ -519,11 +569,16 @@ def allocate_tensors(self): ) tensor_strides = node.get_input_tensor_strides() for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): - device = self.device + tensor_id, storage_id, storage_offset, element_num, item_size, device_str = t_id if self.tensor_with_device: - device = t_id[5] t_id = tuple(list(t_id)[:5]) + else: + device_str = self.device replay_t_id = self.tensors_mapping[(node.id, t_id, True)] + + if replay_t_id in self.instantiate and device_str != "": + self.instantiate_tensor_ref_cnts[storage_id][torch.device(device_str)][node.id] += 1 + if ( t_id in self.input_tensor_ids and replay_t_id not in self.tensor_registry_permanent.keys() @@ -541,26 +596,17 @@ def allocate_tensors(self): if "fbgemm::split_embedding_codegen_lookup" in node.name: self.unchangeable_intermediate_tensors.add(replay_t_id) else: - if data_type == "Tensor(signed char)": - dtype, _ = TORCH_DTYPES_RNG["signed char"] - else: - dtype, _ = TORCH_DTYPES_RNG[ - data_type.lstrip("Tensor(").rstrip(")") - ] - - strides = None - if node.input_strides is not None: - strides = tensor_strides[idx] - tensor = self.get_tensor_from_storage( - t_id[1], # storage_id - t_id[2], # offset - t_id[4], # number of bytes per element - device, - shape, - dtype, - strides, + self.add_tensor_registry_permanent( + node_id=node.id, + data_type=data_type, + storage_id=storage_id, + storage_offset=storage_offset, + element_num=element_num, + item_size=item_size, + device_str=device_str, + shape=shape, + replay_t_id=replay_t_id ) - self.tensor_registry_permanent[replay_t_id] = tensor if node.name == "aten::embedding_bag": self.unchangeable_intermediate_tensors.add(replay_t_id) if node.name == "aten::pin_memory" and idx == 0: @@ -590,8 +636,6 @@ def allocate_tensors(self): ][i] = (i * nnz) ###### - print(f"Tensor allocation time: {(time.time_ns() - start_ns) / 1000000.0} ms") - def build_func(self, node): if is_fbgemm_forward(node): if self.cpu: @@ -990,44 +1034,6 @@ def _generate_run_ops_str(override): print(code_str, file=f) exec(code_str) - def get_tensor_from_storage( - self, storage_id, data_offset, elem_bytes, device, shape, data_type, strides - ): - assert storage_id in self.tensor_storage_map - - tensor_data = self.tensor_storage_map[storage_id] - device = torch.device(device) - if device not in tensor_data[1]: - if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: - storage_tensor = torch.rand( - (tensor_data[0] // elem_bytes), dtype=data_type, device=device - ) - else: - storage_tensor = torch.ones( - (tensor_data[0] // elem_bytes), dtype=data_type, device=device - ) - tensor_data[1][device] = storage_tensor - else: - storage_tensor = tensor_data[1][device] - x = torch.empty(0, dtype=data_type) - if device != torch.device("cpu"): - x = x.cuda(device) - if strides is None: - x = x.set_( - storage_tensor.untyped_storage(), - storage_offset=data_offset, - size=shape, - ) - else: - x = x.set_( - storage_tensor.untyped_storage(), - storage_offset=data_offset, - size=shape, - stride=strides, - ) - - return x - def get_inputs(self, node): try: if is_fbgemm_forward(node): @@ -1103,89 +1109,116 @@ def get_inputs(self, node): except Exception as e: print(f"Inputs error: {e} at node: {node.id}") + def lazy_alloc_tensors(self, inputs, node): + for i in range(len(inputs)): + if isinstance(inputs[i], tuple) and inputs[i][0] == "lazy_alloc": + inputs[i] = inputs[i][2](node) + elif isinstance(inputs[i], list): + for j in range(len(inputs[i])): + if isinstance(inputs[i][j], tuple) and inputs[i][j][0] == "lazy_alloc": + inputs[i][j] = inputs[i][j][2](node) + + def recycle_instantiate_tensors(self, node_id, storage_id, device): + assert node_id in self.instantiate_tensor_ref_cnts_per_it[storage_id][device] + self.instantiate_tensor_ref_cnts_per_it[storage_id][device][node_id] -= 1 + if self.recycle_storages and (not any(self.instantiate_tensor_ref_cnts_per_it[storage_id][device].values())): + del self.tensor_storage_map[storage_id][device] + def run_op(self, node, iter): - if node.name == "record_param_comms" and not self.compute_only: + if node.name == "record_param_comms": return if self.debug and iter >= self.numWarmupIters: start_ns = time.time_ns() func, output_count = self.funcs[node.id] - if not func: - return + inputs = self.get_inputs(node) + outputs = [] - # Workaround to eliminate the "strides() called on undefined Tensor" error. - if node.name == "aten::convolution_backward": - inputs[-1] = [True, True, True] + self.lazy_alloc_tensors(inputs, node) - # Workaround to handle tensor with same id but different data types (ads_cmf10x_single_iter_512_newest_eg.json). - if node.name == "aten::index_add_": - inputs[3] = inputs[3].to(torch.float64) - inputs[2] = inputs[2].to(torch.int) - if node.name == "aten::index_copy_": - if node.input_types[3] == "Tensor(double)": - inputs[3] = inputs[3].to(torch.float64) - if node.input_types[2] == "Tensor(long)": - inputs[2] = inputs[2].to(torch.int64) - if node.name == "aten::index_select": - inputs[2] = inputs[2].to(torch.int) + if func: + # Workaround to eliminate the "strides() called on undefined Tensor" error. + if node.name == "aten::convolution_backward": + inputs[-1] = [True, True, True] - if self.debug and iter >= self.numWarmupIters: - before_execution = time.time_ns() + # Workaround to handle tensor with same id but different data types (ads_cmf10x_single_iter_512_newest_eg.json). + if node.name == "aten::index_add_": + inputs[3] = inputs[3].to(torch.float64) + inputs[2] = inputs[2].to(torch.int) + if node.name == "aten::index_copy_": + if node.input_types[3] == "Tensor(double)": + inputs[3] = inputs[3].to(torch.float64) + if node.input_types[2] == "Tensor(long)": + inputs[2] = inputs[2].to(torch.int64) + if node.name == "aten::index_select": + inputs[2] = inputs[2].to(torch.int) + + if self.debug and iter >= self.numWarmupIters: + before_execution = time.time_ns() - try: - outputs = [] - if output_count == 0: - if node.kernel_backend == "triton": - exec( - f"func.run(*inputs[:-2], grid={inputs[-2]}, stream={inputs[-1]})" - ) - else: - func(*inputs) - else: - if output_count == 1: - tmp = (func(*inputs),) + try: + if output_count == 0: + if node.kernel_backend == "triton": + exec( + f"func.run(*inputs[:-2], grid={inputs[-2]}, stream={inputs[-1]})" + ) + else: + func(*inputs) else: - tmp = func(*inputs) - # Flatten any tensor lists - # TODO: Simplify this - if not tmp: - print(f"Not expect that {node.id} has no output.") - return - for x in tmp: - if isinstance(x, list) and isinstance(x[0], torch.Tensor): - outputs.extend(x) - elif isinstance(x, torch.Tensor): - outputs.append(x) - except Exception as e: - print( - f"Run op exception Error: {e}, node id: {node.id}, func: {func}, inputs: {inputs}" - ) - exit(1) - - if node.name == "aten::repeat_interleave": - current_len = node.input_shapes[0][0] - target_len = node.output_shapes[0][0] - if current_len < target_len: - dtype, _ = TORCH_DTYPES_RNG[ - node.output_types[0].lstrip("Tensor(").rstrip(")") - ] - tmp = torch.zeros(target_len - current_len).to(dtype).cuda(self.device) - outputs[0] = torch.cat((tmp, outputs[0])) + if output_count == 1: + tmp = (func(*inputs),) + else: + tmp = func(*inputs) + # Flatten any tensor lists + # TODO: Simplify this + if not tmp: + print(f"Not expect that {node.id} has no output.") + return + for x in tmp: + if isinstance(x, list) and isinstance(x[0], torch.Tensor): + outputs.extend(x) + elif isinstance(x, torch.Tensor): + outputs.append(x) + except Exception as e: + print( + f"Run op exception Error: {e}, node id: {node.id}, func: {func}, inputs: {inputs}" + ) + exit(1) - if self.debug and iter >= self.numWarmupIters: - after_execution = time.time_ns() + if node.name == "aten::repeat_interleave": + current_len = node.input_shapes[0][0] + target_len = node.output_shapes[0][0] + if current_len < target_len: + dtype, _ = TORCH_DTYPES_RNG[ + node.output_types[0].lstrip("Tensor(").rstrip(")") + ] + tmp = torch.zeros(target_len - current_len).to(dtype).cuda(self.device) + outputs[0] = torch.cat((tmp, outputs[0])) + if self.debug and iter >= self.numWarmupIters: + after_execution = time.time_ns() + + need_del_replay_t_ids_in_input = set() # deal with scenario that the tensor used multi times in the input list for _, t_id, _ in get_input_tensors(node): + tensor_id, storage_id, storage_offset, element_num, item_size, device_str = t_id if self.tensor_with_device: t_id = tuple(list(t_id)[:5]) + device = torch.device(device_str) if device_str != "" else None + else: + device = self.device replay_t_id = self.tensors_mapping[(node.id, t_id, True)] if ( node.id >= self.replay_tensor_id_to_last_node_id_map[replay_t_id] and replay_t_id not in self.instantiate ): - del self.tensor_registry[replay_t_id] + need_del_replay_t_ids_in_input.add(replay_t_id) + elif replay_t_id in self.instantiate and device is not None: + self.recycle_instantiate_tensors(node.id, storage_id, device) + + for replay_t_id in need_del_replay_t_ids_in_input: + del self.tensor_registry[replay_t_id] for (_, t_id, _), output in zip(get_output_tensors(node), outputs): if self.tensor_with_device: @@ -1201,6 +1234,8 @@ def run_op(self, node, iter): self.tensor_registry[replay_t_id] = output else: del output + if replay_t_id in self.tensor_registry: + del self.tensor_registry[replay_t_id] else: del output @@ -1276,6 +1311,17 @@ def benchTime(self): event_1 = torch.cuda.Event(enable_timing=True) event_2 = torch.cuda.Event(enable_timing=True) + def run_op(event_1, event_2, iter): + self.instantiate_tensor_ref_cnts_per_it = copy.deepcopy(self.instantiate_tensor_ref_cnts) + event_1.record() + for node in self.sorted_nodes: + self.run_op(node, iter) + event_2.record() + self.tensor_alloc_set.clear() + torch.cuda.synchronize(self.device) + gc.collect() + torch.cuda.empty_cache() + if self.et_profile: et_file = "/tmp/replay_et.json" et = ExecutionTraceObserver() @@ -1338,12 +1384,9 @@ def benchTime(self): ) prev_iter = iter start_ns = time.time_ns() - event_1.record() - for node in self.sorted_nodes: - self.run_op(node, iter) + run_op(event_1, event_2, iter) print("Finished one iteration.") - event_2.record() - torch.cuda.synchronize(self.device) + if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) prof.step() @@ -1375,11 +1418,7 @@ def benchTime(self): ) prev_iter = iter start_ns = time.time_ns() - event_1.record() - for node in self.sorted_nodes: - self.run_op(node, iter) - event_2.record() - torch.cuda.synchronize(self.device) + run_op(event_1, event_2, iter) if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) benchmark_result["execution finished"] = True @@ -1579,6 +1618,12 @@ def readComputeArgs(self, check_args: bool = True): default=True, help="when a et_id is being replayed multiple times, setting this to false will use temsors from previous runs.", ) + parser.add_argument( + "--recycle_storages", + action="store_true", + default=False, + help="when hit out of memory issues during replaying, set this flag to recycle tensor storages." + ) self.args, _ = parser.parse_known_args() # Check if both 'input' and 'trace_path' are not provided