You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
import torch
|
|
from transformers import AutoModelForCausalLM
|
|
from accelerate import dispatch_model
|
|
|
|
|
|
def _device_map(num_gpus, num_layers):
|
|
per_gpu_layers = (num_layers + 2) / num_gpus
|
|
|
|
device_map = {
|
|
'transformer.wte': 0,
|
|
'transformer.ln_f': 0,
|
|
'lm_head': num_gpus-1
|
|
}
|
|
|
|
used = 1
|
|
gpu_target = 0
|
|
for i in range(num_layers):
|
|
if used >= per_gpu_layers:
|
|
gpu_target += 1
|
|
used = 0 if gpu_target < num_gpus-1 else 1
|
|
assert gpu_target < num_gpus
|
|
device_map[f'transformer.h.{i}'] = gpu_target
|
|
used += 1
|
|
|
|
return device_map
|
|
|
|
|
|
def load_model_on_gpus(model_name_or_path, num_gpus: int = 2):
|
|
num_devices = torch.cuda.device_count()
|
|
|
|
if num_gpus == 1:
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto',
|
|
trust_remote_code=True).eval()
|
|
elif 1 < num_gpus <= num_devices:
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='cpu',
|
|
trust_remote_code=True).eval()
|
|
num_layers = model.config.num_hidden_layers
|
|
device_map = _device_map(num_gpus, num_layers)
|
|
print(device_map)
|
|
model = dispatch_model(model, device_map=device_map)
|
|
else:
|
|
raise KeyError
|
|
|
|
return model
|