From 026adca7d79b14a262cca689618d89b57b65e939 Mon Sep 17 00:00:00 2001 From: echoht <553052687@qq.com> Date: Sat, 6 May 2023 11:05:50 +0800 Subject: [PATCH 1/3] fix batch size bug --- train.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 535a414..f11f9b7 100644 --- a/train.py +++ b/train.py @@ -246,14 +246,28 @@ def get_score(self, logit_label, labels): return scores def rrhf_loss(self, scores, idxs, rw_scores): - diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b - rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b - aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] + # diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b + # #print(rw_scores) + # rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b # batch * cand + # aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] + # return -diff[aval].sum() + print(scores.shape) # score shape (batch * cand) + cand = rw_scores.shape[1] + new_scores = scores.reshape(-1, cand) # batch * cand + diff = new_scores.unsqueeze(1) - new_scores.unsqueeze(-1) # batch * cand * cand + rw_diff = rw_scores.unsqueeze(1) - rw_scores.unsqueeze(-1) + aval = torch.bitwise_and(rw_diff > 0, diff < 0) return -diff[aval].sum() - def sft_loss(self, logit_label, idxs, rw_scores): - max_idx = torch.argmax(rw_scores) - return -logit_label[max_idx].mean() + def sft_loss(self, logit_label, idxs, rw_scores): # (batch * cand) *L + max_idx = torch.argmax(rw_scores, dim=1) # batch + # 每个task的response个数均相同 + cand = rw_scores.shape[1] + print("logit_label:", logit_label.shape) + logit_label_batch = torch.reshape(logit_label, (-1, cand, logit_label.shape[-1])) # batch * cand * L + expert_response_logit_label = logit_label_batch[:1, max_idx].squeeze() # batch * L + return -torch.sum(expert_response_logit_label.mean()) + #return -logit_label[max_idx].mean() def compute_loss(self, model, inputs, return_outputs=False): if self.args.only_use_provide: From 77d7a6096e02ad6b4c7dcf45ac317b9d4453cf8c Mon Sep 17 00:00:00 2001 From: echoht <553052687@qq.com> Date: Sat, 6 May 2023 11:12:19 +0800 Subject: [PATCH 2/3] --other=remove print info --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index f11f9b7..4c25348 100644 --- a/train.py +++ b/train.py @@ -251,7 +251,7 @@ def rrhf_loss(self, scores, idxs, rw_scores): # rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b # batch * cand # aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] # return -diff[aval].sum() - print(scores.shape) # score shape (batch * cand) + cand = rw_scores.shape[1] new_scores = scores.reshape(-1, cand) # batch * cand diff = new_scores.unsqueeze(1) - new_scores.unsqueeze(-1) # batch * cand * cand @@ -263,11 +263,9 @@ def sft_loss(self, logit_label, idxs, rw_scores): # (batch * cand) *L max_idx = torch.argmax(rw_scores, dim=1) # batch # 每个task的response个数均相同 cand = rw_scores.shape[1] - print("logit_label:", logit_label.shape) logit_label_batch = torch.reshape(logit_label, (-1, cand, logit_label.shape[-1])) # batch * cand * L expert_response_logit_label = logit_label_batch[:1, max_idx].squeeze() # batch * L return -torch.sum(expert_response_logit_label.mean()) - #return -logit_label[max_idx].mean() def compute_loss(self, model, inputs, return_outputs=False): if self.args.only_use_provide: From 701aaebc6bc395c67be4c96fe4ecfd03c84eca4a Mon Sep 17 00:00:00 2001 From: echoht <553052687@qq.com> Date: Sat, 6 May 2023 11:57:31 +0800 Subject: [PATCH 3/3] --other=fix index error --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 4c25348..f0cfa9f 100644 --- a/train.py +++ b/train.py @@ -264,7 +264,7 @@ def sft_loss(self, logit_label, idxs, rw_scores): # (batch * cand) *L # 每个task的response个数均相同 cand = rw_scores.shape[1] logit_label_batch = torch.reshape(logit_label, (-1, cand, logit_label.shape[-1])) # batch * cand * L - expert_response_logit_label = logit_label_batch[:1, max_idx].squeeze() # batch * L + expert_response_logit_label = torch.gather(logit_label_batch, dim=1, index=max_idx.view(-1, 1, 1).repeat(1, 1, logit_label_batch.size(-1))).squeeze() # batch * L return -torch.sum(expert_response_logit_label.mean()) def compute_loss(self, model, inputs, return_outputs=False):