update device mapping func for supporting multi-gpu inference
parent
2996651acb
commit
8310e25513
@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
|
|
||||||
|
def _device_map(num_gpus, num_layers):
|
||||||
|
per_gpu_layers = (num_layers + 2) / num_gpus
|
||||||
|
|
||||||
|
device_map = {
|
||||||
|
'transformer.wte': 0,
|
||||||
|
'transformer.ln_f': 0,
|
||||||
|
'lm_head': num_gpus-1
|
||||||
|
}
|
||||||
|
|
||||||
|
used = 1
|
||||||
|
gpu_target = 0
|
||||||
|
for i in range(num_layers):
|
||||||
|
if used >= per_gpu_layers:
|
||||||
|
gpu_target += 1
|
||||||
|
used = 0 if gpu_target < num_gpus-1 else 1
|
||||||
|
assert gpu_target < num_gpus
|
||||||
|
device_map[f'transformer.h.{i}'] = gpu_target
|
||||||
|
used += 1
|
||||||
|
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_on_gpus(model_name_or_path, num_gpus: int = 2):
|
||||||
|
num_devices = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if num_gpus == 1:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto',
|
||||||
|
trust_remote_code=True).eval()
|
||||||
|
elif 1 < num_gpus <= num_devices:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='cpu',
|
||||||
|
trust_remote_code=True).eval()
|
||||||
|
num_layers = model.config.num_hidden_layers
|
||||||
|
device_map = _device_map(num_gpus, num_layers)
|
||||||
|
print(device_map)
|
||||||
|
model = dispatch_model(model, device_map=device_map)
|
||||||
|
else:
|
||||||
|
raise KeyError
|
||||||
|
|
||||||
|
return model
|
Loading…
Reference in New Issue