Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The code for counting the duration is wrong. #76

Open
Mrliduanyang opened this issue Jan 29, 2021 · 0 comments
Open

The code for counting the duration is wrong. #76

Mrliduanyang opened this issue Jan 29, 2021 · 0 comments

Comments

@Mrliduanyang
Copy link

I find two errors in the code for counting duration in file analyzer.py:

  1. In PyTorch, the execution of the program is asynchronous. If we use the following code to record the start and end time, the duration will be very short, because the end time is recorded without waiting for the GPU to complete the computation.

module_stats.start_time = time.time()

module_stats.end_time = time.time()

  1. If a module in CNN passes forward propagation multiple times, according to the following code, only the duration of the last forward propagation will be recorded, not the duration of each forward propagation.

module_stats.duration = module_stats.end_time-module_stats.start_time

Here is my solution:

# tensorwatch\tensorwatch\model_graph\torchstat\analyzer.py
class ModuleStats:
    def __init__(self, name) -> None:
        # self.duration = 0.0
        self.duration = []

def _forward_pre_hook(module_stats:ModuleStats, module:nn.Module, input):
    assert not module_stats.done
    torch.cuda.synchronize()
    module_stats.start_time = time.time()

def _forward_post_hook(module_stats:ModuleStats, module:nn.Module, input, output):
    assert not module_stats.done
    torch.cuda.synchronize()
    module_stats.end_time = time.time()
    # Using a list to store the duration of each forward propagation.
    # module_stats.duration = module_stats.end_time-module_stats.start_time
    module_stats.duration.append(module_stats.end_time - module_stats.start_time)
    # other code
# tensorwatch\tensorwatch\model_graph\torchstat\stat_tree.py        
class StatNode(object):
    def __init__(self, name=str(), parent=None):
        # self.duration = 0
        self._duration = []
        
    @property
    def duration(self):
        # total_duration = self._duration
        total_duration = sum(self._duration)
        for child in self.children:
            total_duration += child.duration
        return total_duration
        # or
        return self._duration

I also provide a simple comparison result. In the Bottleneck of the ResNet backbone, the same relu function will be called three times, so there will be three corresponding durations. But in the TensorWatch statistics, we can only see one record of relu in the Bottleneck.

https://github.com/open-mmlab/mmdetection/blob/f07de13b82b746dde558202f720ec2225f276d73/mmdet/models/backbones/resnet.py#L260-L299

1

But using my modified code, we can see that the duration of the three calls to the relu function are all recorded.

2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant