forked from huggingface/transformers-bloom-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bloom-ds-zero-inference.py
224 lines (166 loc) · 6.86 KB
/
bloom-ds-zero-inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# usage:
# deepspeed --num_gpus 8 bloom-ds-inference.py --name bigscience/bloom
#
# to run benchmarks:
# deepspeed --num_gpus 8 bloom-ds-inference.py --name bigscience/bloom --benchmark
#
# This is going to improve, but at the moment, the process is a bit cumbersome - we first use
# 1. use Deepspeed-ZeRO to instantiate the model on GPUs, w/o loading the checkpoints,
# 2. free the allocated storage
# 3. start Deepspeed-Inference and only now load the checkpoint
# 4. run generate
# Done.
#
import gc
import math
import os
import time
from argparse import ArgumentParser
import torch
import torch.distributed as dist
import deepspeed
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock
t_start = time.time()
num_tokens = 100
parser = ArgumentParser()
parser.add_argument("--name", required=True, type=str, help="model_name")
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark")
parser.add_argument("--cpu_offload", action="store_true", help="whether to activate CPU offload")
parser.add_argument("--nvme_offload_path", help="whether to activate NVME offload and the path on nvme")
args = parser.parse_args()
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
deepspeed.init_distributed("nccl")
rank = dist.get_rank()
def print_rank0(*msg):
if rank != 0:
return
print(*msg)
### Model loading and instantiating on GPU (via ZeRO)
model_name = args.name
print_rank0(f"*** Loading the model {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
# XXX: can't automatically derive dtype via config's `from_pretrained`
dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16
model_hidden_size = config.hidden_size
train_batch_size = 1 * world_size
ds_config = {
"fp16": {
"enabled": dtype == torch.float16,
},
"bf16": {
"enabled": dtype == torch.bfloat16,
},
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": model_hidden_size * model_hidden_size,
"stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
"stage3_param_persistence_threshold": 0,
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False,
}
if args.cpu_offload and args.nvme_offload_path:
raise ValueError("Use one of --cpu_offload or --nvme_offload_path and not both")
if args.cpu_offload:
ds_config["zero_optimization"]["offload_param"] = dict(device="cpu", pin_memory=True)
if args.nvme_offload_path:
ds_config["zero_optimization"]["offload_param"] = dict(
device="nvme",
pin_memory=True,
nvme_path=args.nvme_offload_path,
buffer_size=4e9,
)
dschf = HfDeepSpeedConfig(ds_config) # this tells from_pretrained to instantiate directly on gpus
if args.benchmark:
torch.cuda.empty_cache()
gc.collect()
deepspeed.runtime.utils.see_memory_usage("pre-from-pretrained", force=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
if args.benchmark:
deepspeed.runtime.utils.see_memory_usage("post-from-pretrained", force=True)
model = model.eval()
print_rank0(ds_config)
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module
if args.benchmark:
t_ready = time.time()
deepspeed.runtime.utils.see_memory_usage("start-of-generate", force=True)
### Generate
print_rank0(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")
input_sentences = [
"DeepSpeed is a machine learning framework",
"He is working on",
"He has a",
"He got all",
"Everyone is happy and I can",
"The new movie that got Oscar this year",
"In the far far distance from our galaxy,",
"Peace is the only way",
]
if args.batch_size > len(input_sentences):
# dynamically extend to support larger bs by repetition
input_sentences *= math.ceil(args.batch_size / len(input_sentences))
generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False)
print_rank0(f"Generate args {generate_kwargs}")
inputs = input_sentences[: args.batch_size]
def generate():
"""returns a list of zipped inputs, outputs and number of new tokens"""
input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
outputs = model.generate(**input_tokens, **generate_kwargs)
input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
output_tokens_lengths = [x.shape[0] for x in outputs]
total_new_tokens = [o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return zip(inputs, outputs, total_new_tokens)
# XXX: this is currently doing world_size streams on world_size gpus, so we can feed it different inputs on each! and hence the time can be divided by world_size
print_rank0("*** Running generate")
t_generate_start = time.time()
pairs = generate()
t_generate_span = time.time() - t_generate_start
for i, o, _ in pairs:
print_rank0(f"{'-'*60}\nin={i}\nout={o}\n")
### Benchmark
if args.benchmark:
# clear cache / free memory
torch.cuda.empty_cache()
gc.collect()
deepspeed.runtime.utils.see_memory_usage("end-of-generate", force=True)
print_rank0("*** Running benchmark")
# warm up
for i in range(1):
_ = generate()
torch.cuda.synchronize()
# benchmark
t0 = time.time()
cycles = 5
total_new_tokens_generated = 0
for i in range(cycles):
generated = generate()
total_new_tokens_generated += sum(new_tokens for _, _, new_tokens in generated)
torch.cuda.synchronize()
# note that we actually generate world_size unique streams (though the benchmark feeds the same inputs)
total_new_tokens_generated *= world_size
througput = (time.time() - t0) / (total_new_tokens_generated)
print_rank0(
f"""
*** Performance stats:
Throughput per token including tokenize: {througput*1000:.2f} msecs
Start to ready to generate: {t_ready - t_start:.3f} secs
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs
"""
)