add gpu memory rebalanceing

This commit is contained in:
2024-03-17 22:54:33 +01:00
parent 5acb6809ed
commit 38a7f7cfc4
3 changed files with 78 additions and 39 deletions

View File

@ -3,6 +3,7 @@ import torch
from utils import replace_module
from modules import ConvertingLinear, Linear
from random import randint
import math
def find_all_linear_module_names(model) -> list[str]:
@ -30,8 +31,8 @@ class LinearGroup:
model_modules = dict(model.named_modules())
for name in group_names:
self.modules.append(model_modules[name])
assert isinstance(self.modules[0], ConvertingLinear)
assert isinstance(self.modules[-1], ConvertingLinear)
for module in self.modules:
assert isinstance(module, Linear)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
for module in self.modules:
@ -54,6 +55,9 @@ class LinearGroup:
def paramCount(self) -> int:
return sum(p.numel() for p in self.parameters())
def getDevice(self) -> torch.device:
return self.modules[0].weight.device
class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str,
@ -63,11 +67,17 @@ class DyntrainModel:
cache_dir=cache_dir,
torch_dtype=torch.float32,
trust_remote_code=trust_remote_code,
device_map=None,
device_map=None
)
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
self.linear_groups = list()
self.target_active_params = target_active_params
self.devices = list()
if gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
self._prepare()
self.reshuffleActive()
@ -76,7 +86,7 @@ class DyntrainModel:
modules = dict(layer.named_modules())
for key in modules.keys():
if not isinstance(modules[key], torch.nn.Linear):
if not isinstance(modules[key], torch.nn.Linear) and len(list(modules[key].children())) == 0 or key == "lm_head":
names.append(key)
return names
@ -97,16 +107,26 @@ class DyntrainModel:
list_counter = list_counter + 1
return linear_groups
def isModuleIn16bitOutlist(key: str) -> bool:
key = key.split('.')[-1]
whitelist = set({
"gate_proj",
"up_proj",
"q_proj",
"k_proj",
"v_proj"})
return key in whitelist
def _prepare(self) -> None:
modules = dict(self.model.named_modules())
linear_groups = DyntrainModel._get_linear_group_names(self.model)
for group in linear_groups:
replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16))
replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32))
if len(group) > 2:
for index in range(1, len(group) - 1):
replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16)))
for key in group:
if DyntrainModel.isModuleIn16bitOutlist(key):
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
else:
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
self.linear_groups.append(LinearGroup(self.model, group))
def dynamicParameters(self) -> list:
@ -133,7 +153,7 @@ class DyntrainModel:
def activeParameterCount(self) -> int:
total_params = self.dynamicParameters() + self.staticParameters()
return sum(p.numel() for p in total_params if total_params)
return sum(p.numel() for p in total_params if p.requires_grad)
def reshuffleActive(self) -> None:
for group in self.linear_groups:
@ -146,45 +166,56 @@ class DyntrainModel:
self.linear_groups[indecies[i]].setFrozen(False)
params += self.linear_groups[indecies[i]].paramCount()
indecies.pop(i)
print(math.ceil(params / 1e6))
for group in self.linear_groups:
if group.isFrozen():
group.inplaceTo(dtype=torch.float16)
else:
group.inplaceTo(dtype=torch.float32)
print(group.modules[0].weight.dtype)
active_params = self.activeParameterCount()
def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None:
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
def balanceActive(self) -> None:
device_groups = list()
for index in range(0, len(self.devices)):
device_groups.append(list())
for group in self.linear_groups:
if not group.isFrozen():
device_groups[self.devices.index(group.getDevice())].append(group)
min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
if max_count - 2 > min_count:
device_groups[max_index][0].inplaceTo(device=self.devices[min_index])
self.balanceActive()
def toDevices(self, devices: list[torch.device]) -> None:
assert len(devices) > 0
modules = dict(self.model.named_modules())
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices)
total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
static_param_count = self.staticParameterCount()
total_parameter_count = static_param_count + self.dynamicParameterCount()
params_per_byte = total_parameter_count / float(total_memory)
print(f"{1/params_per_byte} bytes available per parameter")
print(f"{math.floor(1/params_per_byte)} bytes available per parameter")
breakpoint()
self.devices = devices
for key in DyntrainModel._get_nonlinear_names(self.model):
replace_module(self.model, key, modules[key].to(primary_device))
breakpoint()
replace_module(self.model, key, modules[key].to(devices[0]))
group_index = 0
params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count
primary_params = static_param_count
while params_for_primary > primary_params and group_index < len(self.linear_groups):
self.linear_groups[group_index].inplaceTo(device=primary_device)
primary_params += self.linear_groups[group_index].paramCount()
group_index += 1
for device in secondary_devices[:-1]:
params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte
for device in devices[:-1]:
params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
params = 0
while params_for_device > params and group_index < len(self.linear_groups):
self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device)
self.linear_groups[group_index].inplaceTo(device=device)
params += self.linear_groups[group_index].paramCount()
group_index += 1
while group_index < len(self.linear_groups):
self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device)
self.linear_groups[group_index].inplaceTo(device=devices[-1])
group_index += 1