forked from huggingface/nanotron
-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_generate.py
251 lines (218 loc) · 10.5 KB
/
run_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Nanotron Inference Script
Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4
```
"""
import argparse
import os
from pathlib import Path
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import (
GenerationArgs,
LoggingArgs,
ParallelismArgs,
get_config_from_file,
)
from nanotron.generation.decode import (
GenerationInput,
TokenizerConfig,
decode_text,
decode_tokenized,
)
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
OneForwardOneBackwardPipelineEngine,
)
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.random import (
RandomStates,
get_current_random_state,
get_synced_random_state,
set_random_seed,
)
from nanotron.serialize import load_weights
from nanotron.trainer import CONFIG_TO_MODEL_CLASS, mark_tied_parameters
try:
from transformers import AutoTokenizer
except ImportError:
AutoTokenizer = None
logger = logging.get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path")
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--pp", type=int, default=0)
parser.add_argument("--tp", type=int, default=0)
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate")
return parser.parse_args()
def main():
args = get_args()
assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist"
config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix())
model_config = config.model.model_config
tokenizer_path = config.tokenizer.tokenizer_name_or_path
parallel_config = ParallelismArgs(
dp=args.dp or config.parallelism.dp,
pp=args.pp or config.parallelism.pp,
tp=args.tp or config.parallelism.tp,
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
# Initialise all process groups
parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)
# Set log levels
logging_config = LoggingArgs(
log_level="info",
log_level_replica="info",
)
# Set log levels
set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config)
log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0)
log_rank(f"tokenizer_path: {tokenizer_path}", logger=logger, level=logging.INFO, rank=0)
dtype = torch.bfloat16
# Set random states
set_random_seed(42)
model_config_cls = model_config.__class__.__name__
if model_config_cls not in CONFIG_TO_MODEL_CLASS:
raise ValueError(
f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
)
# Get synchronized random states
if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
random_states = RandomStates(
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg)}
)
else:
# We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
random_states = RandomStates({})
model = build_model(
model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls](
config=model_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=random_states,
),
dtype=dtype,
parallel_context=parallel_context,
)
# Mark some parameters as tied
# TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead?
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
# Sanity check model
sanity_check(root_module=model)
# Load checkpoint
checkpoint_path = args.ckpt_path
log_rank(
f"Loading checkpoint from {checkpoint_path}:",
logger=logger,
level=logging.INFO,
rank=0,
)
load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path)
model.eval()
if AutoTokenizer is not None:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
elif getattr(model.config, "pad_token_id", None) is not None:
tokenizer.pad_token_id = int(model.config.pad_token_id)
elif getattr(model.config, "eos_token_id", None) is not None:
tokenizer.pad_token_id = int(model.config.eos_token_id)
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left" # TODO @nouamane: do we want this?
dummy_inputs = [
"The future of AI is",
# "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"def fib(n)",
# 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.',
# "Advancements in technology will lead to",
# "Tomorrow's world is shaped by",
]
outputs = decode_text(
input_iter=(GenerationInput(text=text) for text in dummy_inputs),
tokenizer=tokenizer,
# TODO @thomasw21: From ModelWithLoss extract the model.
model=model.model,
parallel_context=parallel_context,
max_new_tokens=args.max_new_tokens,
max_micro_batch_size=2,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
tokenizer_config=TokenizerConfig(max_input_length=None),
is_bench=os.environ.get("USE_BENCH", "0") == "1",
)
for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)
log_rank(
f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)
else:
outputs = decode_tokenized(
input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"),
input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"),
model=model.model,
parallel_context=parallel_context,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
max_micro_batch_size=1,
max_new_tokens=12,
returns_logits=False,
)
for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)
log_rank(
f"generation: {generated_ids[len(input_ids) :]}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)
dist.barrier()
if __name__ == "__main__":
main()