diff --git a/et_replay/et_replay_utils.py b/et_replay/et_replay_utils.py index 5013bb62..f02e08fa 100644 --- a/et_replay/et_replay_utils.py +++ b/et_replay/et_replay_utils.py @@ -368,7 +368,7 @@ def build_torchscript_func(n): if ( n.op_schema == "" or n.name == "aten::record_stream" - or n.name.startswith("aten::_foreach") + #or n.name.startswith("aten::_foreach") ): return None, None diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 4fbc328f..3b431e2b 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -617,7 +617,6 @@ def generate_io_tensors( curComm: commsArgs, commsParams: commsParamsHolderBase, regenerateTensors: bool, - ) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor]]: # Use exactly specified inMsgSize/outMsgSize if call from trace replay # This avoid regenerating sizes such as in _prep_all_gather_base commsParams.size_from_trace = True @@ -1097,7 +1096,6 @@ def replaySingle( ): commsParams.collective = collName commsParams.srcOrDst = curComm.root if curComm.root is not None else 0 - self.dcheck( commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor ) diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 480c078f..0e7dcaa4 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -5,9 +5,12 @@ import os import sys import time -from collections import defaultdict +import copy +from collections import defaultdict, Counter +from collections.abc import Iterable from datetime import datetime -from typing import Dict +from functools import reduce, partial +from typing import Optional, Tuple, Dict, List import numpy as np import torch @@ -33,7 +36,7 @@ TORCH_DTYPES_RNG, TORCH_DTYPES_RNG_str, ) -from et_replay.execution_trace import ExecutionTrace +from et_replay.execution_trace import ExecutionTrace, NodeType, Node from et_replay.utils import trace_handler from param_bench.train.compute.python.lib import pytorch as lib_pytorch from param_bench.train.compute.python.lib.init_helper import load_modules @@ -70,6 +73,8 @@ def __init__(self): self.commsParams = None self.regenerate_tensors = None + self.recycle_storages = False + self.cuda = "cuda" self.device = torch.device(self.cuda) @@ -99,9 +104,18 @@ def __init__(self): # Dict that stores the shapes of a tensor, for the convenience of quickly determining whether # to create a unique tensor in replay if the id is same but shape is different. self.tensor_shapes = defaultdict(set) - # Dict that maps tensor storage id to its size, and a map for {device, torch.Tensor}. + # Dict that maps tensor storage id to a map for {device, torch.Tensor}. # The tensor with the same storage id may located on different devices. - self.tensor_storage_map: Dict[int, []] = defaultdict(set) + self.tensor_storage_map: Dict[int, Dict[torch.device, torch.Tensor]] = {} + self.tensor_alloc_set = set() + # Dict that maps tensor storage id to its size + self.tensor_storage_sizes: Dict[int, int] = defaultdict(int) + # List that holds referenced tensor storages and devices when replay an operation, after the operation complete, it will be cleared + self.referenced_tensor_storage_ids: List[Tuple[int, torch.device]] = [] + # Ref cnt of instantiate tensors, used to recycle storages + self.instantiate_tensor_ref_cnts: Dict[int, Dict[int, Dict[torch.device, Counter]]] = defaultdict(lambda: defaultdict(Counter)) + # Ref cnt for each iter + self.instantiate_tensor_ref_cnts_per_it: Dict[int, Dict[int, Dict[torch.device, Counter]]] = defaultdict(lambda: defaultdict(Counter)) # Mark those tensors that occur first as an input in the original et as needing to be instantiated in replay # at the very beginning. self.instantiate = set() @@ -199,6 +213,7 @@ def initBench(self): self.wait_delay = self.args.delay self.cpu = self.args.cpu self.tf32 = self.args.tf32 + self.recycle_storages = self.args.recycle_storages # Single trace. if not self.args.trace_path: @@ -240,7 +255,7 @@ def initBench(self): ) self.et = ExecutionTrace(json.load(et)) else: - self.trace_file = f"{self.args.trace_path}/rank{self.comms_env_params['global_rank']}.json" + self.trace_file = f"{self.args.trace_path}/rank-{self.comms_env_params['global_rank']}.json" with open(self.trace_file, "r") as f: self.et = ExecutionTrace(json.load(f)) @@ -283,8 +298,8 @@ def reset_registry(self): None if v is None else ( - v.cpu() - if self.tensor_device[k] == "cpu" or self.cpu + v + if self.tensor_device[k] == "cpu" or self.cpu or (isinstance(v, tuple) and v[0] == "lazy_alloc") else v.cuda(self.device) ) ) @@ -296,13 +311,13 @@ def reset_registry(self): None if v is None else ( - v.cpu() - if k in self.cpu_tensor or self.cpu + v if k in self.cpu_tensor or self.cpu or (isinstance(v, tuple) and v[0] == "lazy_alloc") else v.cuda(self.device) ) ) for k, v in self.tensor_registry_permanent.items() } + self.tensor_registry_permanent.clear() gc.collect() torch.cuda.empty_cache() @@ -383,27 +398,7 @@ def has_parallel_parent(node): assert len(self.parallel_nodes_ids) == len(set(self.parallel_nodes_ids)) def analyze_tensors(self): - def add_storage_tensor(t_id, device): - # t_id is a tupe of (tensor_id, storage_id, offset, number of element, - # number of bytes for each element, device) - - # ET does not save the size of the tensor storage, so we iterate over all the - # tensors to find the maximum size of the storage. - storage_id = t_id[1] - if storage_id not in self.tensor_storage_map: - # the storage size for this tensor is the sum of the storage offset and - # number of elements * number of bytes per element. - self.tensor_storage_map[storage_id] = [ - t_id[2] + t_id[3] * t_id[4], - {}, - ] - else: - self.tensor_storage_map[storage_id][0] = max( - self.tensor_storage_map[storage_id][0], t_id[2] + t_id[3] * t_id[4] - ) - def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): - add_storage_tensor(t_id, device) # If we did not see this tensor before, add it as a unique tensor. if t_id not in self.original_unique_tensors: self.original_unique_tensors.add(t_id) @@ -492,9 +487,64 @@ def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): if t_id in self.input_tensor_ids: output_set.add(self.tensors_mapping[(node.id, t_id, False)]) - def allocate_tensors(self): - start_ns = time.time_ns() + def get_tensor_from_storage(self, node, storage_id, data_offset, elem_bytes, device, shape, data_type): + tensor_data = self.tensor_storage_map.setdefault(storage_id, {}) + device = torch.device(device) + + if device not in tensor_data: + if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: + storage_tensor = torch.rand( + (self.tensor_storage_sizes[storage_id] // elem_bytes), dtype=data_type, device=device + ) + else: + storage_tensor = torch.ones( + (self.tensor_storage_sizes[storage_id] // elem_bytes), dtype=data_type, device=device + ) + + tensor_data[device] = storage_tensor + + if (storage_id, device) not in self.tensor_alloc_set: + self.tensor_alloc_set.add((storage_id, device)) + else: + print("repeat alloc, may caused by wrong recycle:", (storage_id, device)) + exit(1) + else: + storage_tensor = tensor_data[device] + x = torch.empty(0, dtype=data_type) + if device != torch.device("cpu"): + x = x.cuda(torch.device(device)) + x = x.set_( + storage_tensor.untyped_storage(), + storage_offset=data_offset, + size=shape, + ) + + return x + def add_tensor_registry_permanent(self, node_id, data_type, storage_id, storage_offset, element_num, item_size, device_str, shape, replay_t_id): + if data_type == "Tensor(signed char)": + dtype, _ = TORCH_DTYPES_RNG["signed char"] + else: + dtype, _ = TORCH_DTYPES_RNG[ + data_type.lstrip("Tensor(").rstrip(")") + ] + + device = torch.device(device_str) + + self.tensor_storage_sizes[storage_id] = max(storage_offset + element_num * item_size, self.tensor_storage_sizes[storage_id]) + + self.tensor_registry_permanent[replay_t_id] = ("lazy_alloc", (storage_id, device), + partial(self.get_tensor_from_storage, + storage_id=storage_id, + data_offset=storage_offset, + elem_bytes=item_size, + device=device_str, + shape=shape, + data_type=dtype + ) + ) + + def allocate_tensors(self): for node in self.sorted_nodes: if node.name == "record_param_comms" and ( self.compute_only or self.args.separate @@ -519,11 +569,16 @@ def allocate_tensors(self): ) tensor_strides = node.get_input_tensor_strides() for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): - device = self.device + tensor_id, storage_id, storage_offset, element_num, item_size, device_str = t_id if self.tensor_with_device: - device = t_id[5] t_id = tuple(list(t_id)[:5]) + else: + device_str = self.device replay_t_id = self.tensors_mapping[(node.id, t_id, True)] + + if replay_t_id in self.instantiate and device_str != "": + self.instantiate_tensor_ref_cnts[storage_id][torch.device(device_str)][node.id] += 1 + if ( t_id in self.input_tensor_ids and replay_t_id not in self.tensor_registry_permanent.keys() @@ -541,26 +596,17 @@ def allocate_tensors(self): if "fbgemm::split_embedding_codegen_lookup" in node.name: self.unchangeable_intermediate_tensors.add(replay_t_id) else: - if data_type == "Tensor(signed char)": - dtype, _ = TORCH_DTYPES_RNG["signed char"] - else: - dtype, _ = TORCH_DTYPES_RNG[ - data_type.lstrip("Tensor(").rstrip(")") - ] - - strides = None - if node.input_strides is not None: - strides = tensor_strides[idx] - tensor = self.get_tensor_from_storage( - t_id[1], # storage_id - t_id[2], # offset - t_id[4], # number of bytes per element - device, - shape, - dtype, - strides, + self.add_tensor_registry_permanent( + node_id=node.id, + data_type=data_type, + storage_id=storage_id, + storage_offset=storage_offset, + element_num=element_num, + item_size=item_size, + device_str=device_str, + shape=shape, + replay_t_id=replay_t_id ) - self.tensor_registry_permanent[replay_t_id] = tensor if node.name == "aten::embedding_bag": self.unchangeable_intermediate_tensors.add(replay_t_id) if node.name == "aten::pin_memory" and idx == 0: @@ -590,8 +636,6 @@ def allocate_tensors(self): ][i] = (i * nnz) ###### - print(f"Tensor allocation time: {(time.time_ns() - start_ns) / 1000000.0} ms") - def build_func(self, node): if is_fbgemm_forward(node): if self.cpu: @@ -990,44 +1034,6 @@ def _generate_run_ops_str(override): print(code_str, file=f) exec(code_str) - def get_tensor_from_storage( - self, storage_id, data_offset, elem_bytes, device, shape, data_type, strides - ): - assert storage_id in self.tensor_storage_map - - tensor_data = self.tensor_storage_map[storage_id] - device = torch.device(device) - if device not in tensor_data[1]: - if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: - storage_tensor = torch.rand( - (tensor_data[0] // elem_bytes), dtype=data_type, device=device - ) - else: - storage_tensor = torch.ones( - (tensor_data[0] // elem_bytes), dtype=data_type, device=device - ) - tensor_data[1][device] = storage_tensor - else: - storage_tensor = tensor_data[1][device] - x = torch.empty(0, dtype=data_type) - if device != torch.device("cpu"): - x = x.cuda(device) - if strides is None: - x = x.set_( - storage_tensor.untyped_storage(), - storage_offset=data_offset, - size=shape, - ) - else: - x = x.set_( - storage_tensor.untyped_storage(), - storage_offset=data_offset, - size=shape, - stride=strides, - ) - - return x - def get_inputs(self, node): try: if is_fbgemm_forward(node): @@ -1103,89 +1109,116 @@ def get_inputs(self, node): except Exception as e: print(f"Inputs error: {e} at node: {node.id}") + def lazy_alloc_tensors(self, inputs, node): + for i in range(len(inputs)): + if isinstance(inputs[i], tuple) and inputs[i][0] == "lazy_alloc": + inputs[i] = inputs[i][2](node) + elif isinstance(inputs[i], list): + for j in range(len(inputs[i])): + if isinstance(inputs[i][j], tuple) and inputs[i][j][0] == "lazy_alloc": + inputs[i][j] = inputs[i][j][2](node) + + def recycle_instantiate_tensors(self, node_id, storage_id, device): + assert node_id in self.instantiate_tensor_ref_cnts_per_it[storage_id][device] + self.instantiate_tensor_ref_cnts_per_it[storage_id][device][node_id] -= 1 + if self.recycle_storages and (not any(self.instantiate_tensor_ref_cnts_per_it[storage_id][device].values())): + del self.tensor_storage_map[storage_id][device] + def run_op(self, node, iter): - if node.name == "record_param_comms" and not self.compute_only: + if node.name == "record_param_comms": return if self.debug and iter >= self.numWarmupIters: start_ns = time.time_ns() func, output_count = self.funcs[node.id] - if not func: - return + inputs = self.get_inputs(node) + outputs = [] - # Workaround to eliminate the "strides() called on undefined Tensor" error. - if node.name == "aten::convolution_backward": - inputs[-1] = [True, True, True] + self.lazy_alloc_tensors(inputs, node) - # Workaround to handle tensor with same id but different data types (ads_cmf10x_single_iter_512_newest_eg.json). - if node.name == "aten::index_add_": - inputs[3] = inputs[3].to(torch.float64) - inputs[2] = inputs[2].to(torch.int) - if node.name == "aten::index_copy_": - if node.input_types[3] == "Tensor(double)": - inputs[3] = inputs[3].to(torch.float64) - if node.input_types[2] == "Tensor(long)": - inputs[2] = inputs[2].to(torch.int64) - if node.name == "aten::index_select": - inputs[2] = inputs[2].to(torch.int) + if func: + # Workaround to eliminate the "strides() called on undefined Tensor" error. + if node.name == "aten::convolution_backward": + inputs[-1] = [True, True, True] - if self.debug and iter >= self.numWarmupIters: - before_execution = time.time_ns() + # Workaround to handle tensor with same id but different data types (ads_cmf10x_single_iter_512_newest_eg.json). + if node.name == "aten::index_add_": + inputs[3] = inputs[3].to(torch.float64) + inputs[2] = inputs[2].to(torch.int) + if node.name == "aten::index_copy_": + if node.input_types[3] == "Tensor(double)": + inputs[3] = inputs[3].to(torch.float64) + if node.input_types[2] == "Tensor(long)": + inputs[2] = inputs[2].to(torch.int64) + if node.name == "aten::index_select": + inputs[2] = inputs[2].to(torch.int) + + if self.debug and iter >= self.numWarmupIters: + before_execution = time.time_ns() - try: - outputs = [] - if output_count == 0: - if node.kernel_backend == "triton": - exec( - f"func.run(*inputs[:-2], grid={inputs[-2]}, stream={inputs[-1]})" - ) - else: - func(*inputs) - else: - if output_count == 1: - tmp = (func(*inputs),) + try: + if output_count == 0: + if node.kernel_backend == "triton": + exec( + f"func.run(*inputs[:-2], grid={inputs[-2]}, stream={inputs[-1]})" + ) + else: + func(*inputs) else: - tmp = func(*inputs) - # Flatten any tensor lists - # TODO: Simplify this - if not tmp: - print(f"Not expect that {node.id} has no output.") - return - for x in tmp: - if isinstance(x, list) and isinstance(x[0], torch.Tensor): - outputs.extend(x) - elif isinstance(x, torch.Tensor): - outputs.append(x) - except Exception as e: - print( - f"Run op exception Error: {e}, node id: {node.id}, func: {func}, inputs: {inputs}" - ) - exit(1) - - if node.name == "aten::repeat_interleave": - current_len = node.input_shapes[0][0] - target_len = node.output_shapes[0][0] - if current_len < target_len: - dtype, _ = TORCH_DTYPES_RNG[ - node.output_types[0].lstrip("Tensor(").rstrip(")") - ] - tmp = torch.zeros(target_len - current_len).to(dtype).cuda(self.device) - outputs[0] = torch.cat((tmp, outputs[0])) + if output_count == 1: + tmp = (func(*inputs),) + else: + tmp = func(*inputs) + # Flatten any tensor lists + # TODO: Simplify this + if not tmp: + print(f"Not expect that {node.id} has no output.") + return + for x in tmp: + if isinstance(x, list) and isinstance(x[0], torch.Tensor): + outputs.extend(x) + elif isinstance(x, torch.Tensor): + outputs.append(x) + except Exception as e: + print( + f"Run op exception Error: {e}, node id: {node.id}, func: {func}, inputs: {inputs}" + ) + exit(1) - if self.debug and iter >= self.numWarmupIters: - after_execution = time.time_ns() + if node.name == "aten::repeat_interleave": + current_len = node.input_shapes[0][0] + target_len = node.output_shapes[0][0] + if current_len < target_len: + dtype, _ = TORCH_DTYPES_RNG[ + node.output_types[0].lstrip("Tensor(").rstrip(")") + ] + tmp = torch.zeros(target_len - current_len).to(dtype).cuda(self.device) + outputs[0] = torch.cat((tmp, outputs[0])) + if self.debug and iter >= self.numWarmupIters: + after_execution = time.time_ns() + + need_del_replay_t_ids_in_input = set() # deal with scenario that the tensor used multi times in the input list for _, t_id, _ in get_input_tensors(node): + tensor_id, storage_id, storage_offset, element_num, item_size, device_str = t_id if self.tensor_with_device: t_id = tuple(list(t_id)[:5]) + device = torch.device(device_str) if device_str != "" else None + else: + device = self.device replay_t_id = self.tensors_mapping[(node.id, t_id, True)] if ( node.id >= self.replay_tensor_id_to_last_node_id_map[replay_t_id] and replay_t_id not in self.instantiate ): - del self.tensor_registry[replay_t_id] + need_del_replay_t_ids_in_input.add(replay_t_id) + elif replay_t_id in self.instantiate and device is not None: + self.recycle_instantiate_tensors(node.id, storage_id, device) + + for replay_t_id in need_del_replay_t_ids_in_input: + del self.tensor_registry[replay_t_id] for (_, t_id, _), output in zip(get_output_tensors(node), outputs): if self.tensor_with_device: @@ -1201,6 +1234,8 @@ def run_op(self, node, iter): self.tensor_registry[replay_t_id] = output else: del output + if replay_t_id in self.tensor_registry: + del self.tensor_registry[replay_t_id] else: del output @@ -1276,6 +1311,17 @@ def benchTime(self): event_1 = torch.cuda.Event(enable_timing=True) event_2 = torch.cuda.Event(enable_timing=True) + def run_op(event_1, event_2, iter): + self.instantiate_tensor_ref_cnts_per_it = copy.deepcopy(self.instantiate_tensor_ref_cnts) + event_1.record() + for node in self.sorted_nodes: + self.run_op(node, iter) + event_2.record() + self.tensor_alloc_set.clear() + torch.cuda.synchronize(self.device) + gc.collect() + torch.cuda.empty_cache() + if self.et_profile: et_file = "/tmp/replay_et.json" et = ExecutionTraceObserver() @@ -1338,12 +1384,9 @@ def benchTime(self): ) prev_iter = iter start_ns = time.time_ns() - event_1.record() - for node in self.sorted_nodes: - self.run_op(node, iter) + run_op(event_1, event_2, iter) print("Finished one iteration.") - event_2.record() - torch.cuda.synchronize(self.device) + if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) prof.step() @@ -1375,11 +1418,7 @@ def benchTime(self): ) prev_iter = iter start_ns = time.time_ns() - event_1.record() - for node in self.sorted_nodes: - self.run_op(node, iter) - event_2.record() - torch.cuda.synchronize(self.device) + run_op(event_1, event_2, iter) if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) benchmark_result["execution finished"] = True @@ -1579,6 +1618,12 @@ def readComputeArgs(self, check_args: bool = True): default=True, help="when a et_id is being replayed multiple times, setting this to false will use temsors from previous runs.", ) + parser.add_argument( + "--recycle_storages", + action="store_true", + default=False, + help="when hit out of memory issues during replaying, set this flag to recycle tensor storages." + ) self.args, _ = parser.parse_known_args() # Check if both 'input' and 'trace_path' are not provided