Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage 2 Training Fails with NaN Loss on Single GPU Due to Inconsistent Checkpoint Keys #254

Open
5Hyeons opened this issue Jun 13, 2024 · 0 comments

Comments

@5Hyeons
Copy link

5Hyeons commented Jun 13, 2024

Description

When trying to start stage 2 training after completing stage 1 using a single A100 80GB GPU with My Korean dataset, I encountered an issue where g_loss becomes NaN.

Upon investigation, it was found that the y_rec_gt_pred output from model.decoder was NaN:

ipdb> y_rec_gt_pred
tensor([[[nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0')

The s variable, an input to the decoder, had abnormally large values:

ipdb> p s
tensor([[ 4.1757e+18, -2.2990e+18, -4.4499e+18, -4.2868e+18,  2.8710e+18,
          1.4094e+17, -2.4173e+18, -5.7211e+18,  1.2887e+18,  1.2334e+18,
          ...
         -3.6897e+18,  2.0664e+17, -3.9657e+18,  2.1473e+18,  2.9162e+18,
         -2.3997e+18,  4.6772e+18,  3.3755e+17, -1.0300e+17, -1.7092e+18,
          2.6885e+18, -3.8825e+18, -2.4909e+18]], device='cuda:0',
       grad_fn=<GatherBackward>)

The s input is derived from the style encoder:
s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))

The inputs st and gt were within normal ranges, and i found the style encoder’s weights were not properly loaded from the checkpoint:

print(type(model.style_encoder.shared[0]))

print('params weight:')
print(params['style_encoder']['shared.0.weight_orig'][0])
print('model weight:')
print(model.style_encoder.shared[0].weight[0], model.style_encoder.shared[0].weight[0].device)
print('params bias:')
print(params['style_encoder']['shared.0.bias'][0])
print('model bias:')
print(model.style_encoder.shared[0].bias[0], model.style_encoder.shared[0].bias[0].device)
<class 'torch.nn.modules.conv.Conv2d'>

params weight:
tensor([[[-0.0216, -0.2234, -0.3011],
         [ 0.3677,  0.4262,  0.0299],
         [ 0.3779,  0.2360,  0.1636]]])
model weight:
tensor([[[ 0.1406, -0.1316, -0.1859],
         [-0.2805, -0.0880, -0.2270],
         [ 0.1440, -0.0139,  0.1625]]]) cpu
params bias:
tensor(0.1457)
model bias:
tensor(-0.0927, device='cuda:0', grad_fn=<SelectBackward0>) cuda:0

Cause

The issue arises due to inconsistent key names when loading checkpoints between the first and second stages. In the second stage, the MyDataParallel class is used, which prefixes all model keys with ‘module.’. However, if you are using single gpu, the first stage does not apply this prefix when saving checkpoints. --> #120

This inconsistency prevents the proper loading of the model parameters, leading to NaN values in the loss calculation.

Solution

To address this, I’ve updated the load_checkpoint function to handle cases where the checkpoint keys do not match the model keys by creating a new state_dict with matching keys if direct loading fails.

Updated load_checkpoint Function

def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
    state = torch.load(path, map_location='cpu')
    params = state['net']

    for key in model:
        if key in params and key not in ignore_modules:
            print('%s loaded' % key)
            try:
                model[key].load_state_dict(params[key], strict=True)
            except:
                from collections import OrderedDict
                state_dict = params[key]
                new_state_dict = OrderedDict()
                print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
                for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
                    new_state_dict[k_m] = v_c
                model[key].load_state_dict(new_state_dict, strict=True)
    _ = [model[key].eval() for key in model]

    if not load_only_params:
        epoch = state["epoch"]
        iters = state["iters"]
        optimizer.load_state_dict(state["optimizer"])
    else:
        epoch = 0
        iters = 0

    return model, optimizer, epoch, iters

Additionally, I have submitted a PR to address this issue: #253

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant