-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
train_dolly.py
205 lines (161 loc) · 7.73 KB
/
train_dolly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Databricks notebook source
# MAGIC %md
# MAGIC ## Train Dolly
# MAGIC
# MAGIC This fine-tunes EleutherAI Pythia models
# MAGIC (e.g. [pythia-2.8b](https://huggingface.co/EleutherAI/pythia-2.8b),
# MAGIC [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b), or
# MAGIC [pythia-12b](https://huggingface.co/EleutherAI/pythia-12b)) on
# MAGIC the [databricks-dolly-15k](https://github.com/databrickslabs/dolly/tree/master/data) dataset.
# MAGIC
# MAGIC ```
# MAGIC Licensed under the Apache License, Version 2.0 (the "License");
# MAGIC you may not use this file except in compliance with the License.
# MAGIC You may obtain a copy of the License at
# MAGIC
# MAGIC http://www.apache.org/licenses/LICENSE-2.0
# MAGIC
# MAGIC Unless required by applicable law or agreed to in writing, software
# MAGIC distributed under the License is distributed on an "AS IS" BASIS,
# MAGIC WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# MAGIC See the License for the specific language governing permissions and
# MAGIC limitations under the License.
# MAGIC ```
# MAGIC
# MAGIC The EleutherAI Pythia models are [Apache 2.0 licensed](https://huggingface.co/EleutherAI/gpt-j-6B) and
# MAGIC the [databricks-dolly-15k](https://github.com/databrickslabs/dolly/tree/master/data) is licensed under the terms
# MAGIC of [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://creativecommons.org/licenses/by-sa/3.0/legalcode),
# MAGIC which means it can be used for either academic or commercial purposes.
# COMMAND ----------
# MAGIC %md
# MAGIC Install these additional NVIDIA libraries for Databricks Runtime 13.x+ ML:
# COMMAND ----------
# MAGIC !wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb -O /tmp/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb && \
# MAGIC wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcublas-dev-11-7_11.10.1.25-1_amd64.deb -O /tmp/libcublas-dev-11-7_11.10.1.25-1_amd64.deb && \
# MAGIC wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb -O /tmp/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb && \
# MAGIC wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcurand-dev-11-7_10.2.10.91-1_amd64.deb -O /tmp/libcurand-dev-11-7_10.2.10.91-1_amd64.deb && \
# MAGIC dpkg -i /tmp/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb && \
# MAGIC dpkg -i /tmp/libcublas-dev-11-7_11.10.1.25-1_amd64.deb && \
# MAGIC dpkg -i /tmp/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb && \
# MAGIC dpkg -i /tmp/libcurand-dev-11-7_10.2.10.91-1_amd64.deb
# COMMAND ----------
# MAGIC %pip install -r requirements.txt
# COMMAND ----------
# MAGIC %load_ext autoreload
# MAGIC %autoreload 2
# COMMAND ----------
import logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
logging.getLogger("py4j").setLevel(logging.WARNING)
logging.getLogger("sh.command").setLevel(logging.ERROR)
# COMMAND ----------
import os
import re
from datetime import datetime
from training.consts import DEFAULT_INPUT_MODEL, SUGGESTED_INPUT_MODELS
from training.trainer import load_training_dataset, load_tokenizer
dbutils.widgets.combobox("input_model", DEFAULT_INPUT_MODEL, SUGGESTED_INPUT_MODELS, "input_model")
dbutils.widgets.text("num_gpus", "", "num_gpus")
dbutils.widgets.text("local_training_root", "", "local_training_root")
dbutils.widgets.text("dbfs_output_root", "", "dbfs_output_root")
dbutils.widgets.text("experiment_id", "", "experiment_id")
dbutils.widgets.combobox("gpu_family", "a100", ["v100", "a10", "a100"])
# COMMAND ----------
timestamp = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = "dolly"
experiment_id = dbutils.widgets.get("experiment_id")
input_model = dbutils.widgets.get("input_model")
if experiment_id:
experiment_id = re.sub(r"\s+", "_", experiment_id.strip())
model_name = f"{model_name}__{experiment_id}"
checkpoint_dir_name = f"{model_name}__{timestamp}"
dolly_training_dir_name = "dolly_training"
# Use the local training root path if it was provided. Otherwise try to find a sensible default.
local_training_root = dbutils.widgets.get("local_training_root")
if not local_training_root:
# Use preferred path when working in a Databricks cluster if it exists.
if os.path.exists("/local_disk0"):
local_training_root = os.path.join("/local_disk0", dolly_training_dir_name)
# Otherwise use the home directory.
else:
local_training_root = os.path.join(os.path.expanduser('~'), dolly_training_dir_name)
dbfs_output_root = dbutils.widgets.get("dbfs_output_root")
if not dbfs_output_root:
dbfs_output_root = f"/dbfs/{dolly_training_dir_name}"
os.makedirs(local_training_root, exist_ok=True)
os.makedirs(dbfs_output_root, exist_ok=True)
local_output_dir = os.path.join(local_training_root, checkpoint_dir_name)
dbfs_output_dir = os.path.join(dbfs_output_root, checkpoint_dir_name)
tensorboard_display_dir = f"{local_output_dir}/runs"
print(f"Local Output Dir: {local_output_dir}")
print(f"DBFS Output Dir: {dbfs_output_dir}")
print(f"Tensorboard Display Dir: {tensorboard_display_dir}")
# pick an appropriate config file
gpu_family = dbutils.widgets.get("gpu_family")
config_file_name = f"{gpu_family}_config.json"
deepspeed_config = os.path.join(os.getcwd(), "config", config_file_name)
print(f"Deepspeed config file: {deepspeed_config}")
# configure the batch_size
batch_size = 3
if gpu_family == "a10":
batch_size = 4
elif gpu_family == "a100":
batch_size = 6
# configure num_gpus, if specified
num_gpus_flag = ""
num_gpus = dbutils.widgets.get("num_gpus")
if num_gpus:
num_gpus = int(num_gpus)
num_gpus_flag = f"--num_gpus={num_gpus}"
if gpu_family == "v100":
bf16_flag = "--bf16 false"
else:
bf16_flag = "--bf16 true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# COMMAND ----------
# MAGIC %load_ext tensorboard
# MAGIC %tensorboard --logdir '{tensorboard_display_dir}'
# COMMAND ----------
!deepspeed {num_gpus_flag} \
--module training.trainer \
--input-model {input_model} \
--deepspeed {deepspeed_config} \
--epochs 2 \
--local-output-dir {local_output_dir} \
--dbfs-output-dir {dbfs_output_dir} \
--per-device-train-batch-size {batch_size} \
--per-device-eval-batch-size {batch_size} \
--logging-steps 10 \
--save-steps 200 \
--save-total-limit 20 \
--eval-steps 50 \
--warmup-steps 50 \
--test-size 200 \
--lr 5e-6 \
{bf16_flag}
# COMMAND ----------
from training.generate import generate_response, load_model_tokenizer_for_generate
model, tokenizer = load_model_tokenizer_for_generate(dbfs_output_dir)
# COMMAND ----------
# Examples from https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html
instructions = [
"Write a love letter to Edgar Allan Poe.",
"Write a tweet announcing Dolly, a large language model from Databricks.",
"I'm selling my Nikon D-750, write a short blurb for my ad.",
"Explain to me the difference between nuclear fission and fusion.",
"Give me a list of 5 science fiction books I should read next.",
]
# set some additional pipeline args
pipeline_kwargs = {'torch_dtype': "auto"}
if gpu_family == "v100":
pipeline_kwargs['torch_dtype'] = "float16"
elif gpu_family == "a10" or gpu_family == "a100":
pipeline_kwargs['torch_dtype'] = "bfloat16"
# Use the model to generate responses for each of the instructions above.
for instruction in instructions:
response = generate_response(instruction, model=model, tokenizer=tokenizer, **pipeline_kwargs)
if response:
print(f"Instruction: {instruction}\n\n{response}\n\n-----------\n")
# COMMAND ----------