-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmodeling_live_llama.py
79 lines (68 loc) · 3.03 KB
/
modeling_live_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torch import nn
from transformers import LlamaForCausalLM, Cache
from transformers.activations import GELUActivation
from transformers.utils import logging
from .configuration_live_llama import LiveLlamaConfig
from ..modeling_live import build_live, LiveMixin
logger = logging.get_logger(__name__)
class LiveLlamaForCausalLM(LlamaForCausalLM, LiveMixin):
config_class = LiveLlamaConfig
_keys_to_ignore_on_load_missing = ['vision_encoder', 'connector']
def __init__(self, config: LiveLlamaConfig):
super().__init__(config)
self.connector = torch.nn.Sequential(
torch.nn.Linear(config.vision_hidden_size, config.hidden_size, bias=True),
GELUActivation(config.hidden_size),
torch.nn.Linear(config.hidden_size, config.hidden_size, bias=True),
)
def forward(
self,
input_ids: torch.LongTensor = None,
frames: torch.FloatTensor = None,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_values: list[torch.FloatTensor] = None,
inputs_embeds: torch.FloatTensor = None,
labels: torch.LongTensor = None,
use_cache: bool = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
cache_position: torch.LongTensor = None,
**kwargs,
):
if inputs_embeds is None:
inputs_embeds = self.joint_embed(input_ids, frames)
outputs = super().forward(
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
# labels
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
cache_position=cache_position,
)
loss = None
if labels is not None:
logits = outputs[0]
v_mask = input_ids.flatten(0, 1) == self.config.v_placeholder_id
weight = v_mask * self.config.stream_loss_weight + ~v_mask
loss = nn.functional.cross_entropy(logits.flatten(0, 1), labels.flatten(), reduction='none') * weight
loss = loss.sum() / (labels >= 0).sum()
if not return_dict:
return (loss,) + outputs[1:] if loss is not None else outputs
outputs.loss = loss
return outputs
def generate_after_embed(self, input_ids, frames, **kwargs):
return super().generate(inputs_embeds=self.joint_embed(input_ids, frames), **kwargs)
def build_live_llama(**kwargs):
return build_live(config_class=LiveLlamaConfig, model_class=LiveLlamaForCausalLM, **kwargs)
if __name__ == '__main__':
from ..arguments_live import LiveOnePlusTrainingArguments
print(LiveOnePlusTrainingArguments().to_dict())
model, tokenizer = build_live_llama(is_training=True, **LiveOnePlusTrainingArguments().to_dict())
print(model.config, tokenizer)