add gpu memory rebalanceing
This commit is contained in:
@ -3,6 +3,7 @@ import torch
|
|||||||
from utils import replace_module
|
from utils import replace_module
|
||||||
from modules import ConvertingLinear, Linear
|
from modules import ConvertingLinear, Linear
|
||||||
from random import randint
|
from random import randint
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_module_names(model) -> list[str]:
|
def find_all_linear_module_names(model) -> list[str]:
|
||||||
@ -30,8 +31,8 @@ class LinearGroup:
|
|||||||
model_modules = dict(model.named_modules())
|
model_modules = dict(model.named_modules())
|
||||||
for name in group_names:
|
for name in group_names:
|
||||||
self.modules.append(model_modules[name])
|
self.modules.append(model_modules[name])
|
||||||
assert isinstance(self.modules[0], ConvertingLinear)
|
for module in self.modules:
|
||||||
assert isinstance(self.modules[-1], ConvertingLinear)
|
assert isinstance(module, Linear)
|
||||||
|
|
||||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
|
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
|
||||||
for module in self.modules:
|
for module in self.modules:
|
||||||
@ -54,6 +55,9 @@ class LinearGroup:
|
|||||||
def paramCount(self) -> int:
|
def paramCount(self) -> int:
|
||||||
return sum(p.numel() for p in self.parameters())
|
return sum(p.numel() for p in self.parameters())
|
||||||
|
|
||||||
|
def getDevice(self) -> torch.device:
|
||||||
|
return self.modules[0].weight.device
|
||||||
|
|
||||||
|
|
||||||
class DyntrainModel:
|
class DyntrainModel:
|
||||||
def __init__(self, model_name_or_path: str, cache_dir: str,
|
def __init__(self, model_name_or_path: str, cache_dir: str,
|
||||||
@ -63,11 +67,17 @@ class DyntrainModel:
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
device_map=None,
|
device_map=None
|
||||||
)
|
)
|
||||||
|
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
|
||||||
self.linear_groups = list()
|
self.linear_groups = list()
|
||||||
self.target_active_params = target_active_params
|
self.target_active_params = target_active_params
|
||||||
|
|
||||||
|
self.devices = list()
|
||||||
|
|
||||||
|
if gradient_checkpointing:
|
||||||
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
self._prepare()
|
self._prepare()
|
||||||
self.reshuffleActive()
|
self.reshuffleActive()
|
||||||
|
|
||||||
@ -76,7 +86,7 @@ class DyntrainModel:
|
|||||||
modules = dict(layer.named_modules())
|
modules = dict(layer.named_modules())
|
||||||
|
|
||||||
for key in modules.keys():
|
for key in modules.keys():
|
||||||
if not isinstance(modules[key], torch.nn.Linear):
|
if not isinstance(modules[key], torch.nn.Linear) and len(list(modules[key].children())) == 0 or key == "lm_head":
|
||||||
names.append(key)
|
names.append(key)
|
||||||
return names
|
return names
|
||||||
|
|
||||||
@ -97,16 +107,26 @@ class DyntrainModel:
|
|||||||
list_counter = list_counter + 1
|
list_counter = list_counter + 1
|
||||||
return linear_groups
|
return linear_groups
|
||||||
|
|
||||||
|
def isModuleIn16bitOutlist(key: str) -> bool:
|
||||||
|
key = key.split('.')[-1]
|
||||||
|
whitelist = set({
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj"})
|
||||||
|
return key in whitelist
|
||||||
|
|
||||||
def _prepare(self) -> None:
|
def _prepare(self) -> None:
|
||||||
modules = dict(self.model.named_modules())
|
modules = dict(self.model.named_modules())
|
||||||
linear_groups = DyntrainModel._get_linear_group_names(self.model)
|
linear_groups = DyntrainModel._get_linear_group_names(self.model)
|
||||||
|
|
||||||
for group in linear_groups:
|
for group in linear_groups:
|
||||||
replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16))
|
for key in group:
|
||||||
replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32))
|
if DyntrainModel.isModuleIn16bitOutlist(key):
|
||||||
if len(group) > 2:
|
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
|
||||||
for index in range(1, len(group) - 1):
|
else:
|
||||||
replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16)))
|
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
|
||||||
self.linear_groups.append(LinearGroup(self.model, group))
|
self.linear_groups.append(LinearGroup(self.model, group))
|
||||||
|
|
||||||
def dynamicParameters(self) -> list:
|
def dynamicParameters(self) -> list:
|
||||||
@ -133,7 +153,7 @@ class DyntrainModel:
|
|||||||
|
|
||||||
def activeParameterCount(self) -> int:
|
def activeParameterCount(self) -> int:
|
||||||
total_params = self.dynamicParameters() + self.staticParameters()
|
total_params = self.dynamicParameters() + self.staticParameters()
|
||||||
return sum(p.numel() for p in total_params if total_params)
|
return sum(p.numel() for p in total_params if p.requires_grad)
|
||||||
|
|
||||||
def reshuffleActive(self) -> None:
|
def reshuffleActive(self) -> None:
|
||||||
for group in self.linear_groups:
|
for group in self.linear_groups:
|
||||||
@ -146,45 +166,56 @@ class DyntrainModel:
|
|||||||
self.linear_groups[indecies[i]].setFrozen(False)
|
self.linear_groups[indecies[i]].setFrozen(False)
|
||||||
params += self.linear_groups[indecies[i]].paramCount()
|
params += self.linear_groups[indecies[i]].paramCount()
|
||||||
indecies.pop(i)
|
indecies.pop(i)
|
||||||
|
print(math.ceil(params / 1e6))
|
||||||
|
|
||||||
for group in self.linear_groups:
|
for group in self.linear_groups:
|
||||||
if group.isFrozen():
|
if group.isFrozen():
|
||||||
group.inplaceTo(dtype=torch.float16)
|
group.inplaceTo(dtype=torch.float16)
|
||||||
else:
|
else:
|
||||||
group.inplaceTo(dtype=torch.float32)
|
group.inplaceTo(dtype=torch.float32)
|
||||||
print(group.modules[0].weight.dtype)
|
active_params = self.activeParameterCount()
|
||||||
|
|
||||||
def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None:
|
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
|
||||||
|
|
||||||
|
def balanceActive(self) -> None:
|
||||||
|
device_groups = list()
|
||||||
|
for index in range(0, len(self.devices)):
|
||||||
|
device_groups.append(list())
|
||||||
|
|
||||||
|
for group in self.linear_groups:
|
||||||
|
if not group.isFrozen():
|
||||||
|
device_groups[self.devices.index(group.getDevice())].append(group)
|
||||||
|
|
||||||
|
min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
|
||||||
|
max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
|
||||||
|
|
||||||
|
if max_count - 2 > min_count:
|
||||||
|
device_groups[max_index][0].inplaceTo(device=self.devices[min_index])
|
||||||
|
self.balanceActive()
|
||||||
|
|
||||||
|
def toDevices(self, devices: list[torch.device]) -> None:
|
||||||
|
assert len(devices) > 0
|
||||||
modules = dict(self.model.named_modules())
|
modules = dict(self.model.named_modules())
|
||||||
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices)
|
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
|
||||||
total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8
|
|
||||||
static_param_count = self.staticParameterCount()
|
static_param_count = self.staticParameterCount()
|
||||||
total_parameter_count = static_param_count + self.dynamicParameterCount()
|
total_parameter_count = static_param_count + self.dynamicParameterCount()
|
||||||
params_per_byte = total_parameter_count / float(total_memory)
|
params_per_byte = total_parameter_count / float(total_memory)
|
||||||
print(f"{1/params_per_byte} bytes available per parameter")
|
print(f"{math.floor(1/params_per_byte)} bytes available per parameter")
|
||||||
|
|
||||||
breakpoint()
|
self.devices = devices
|
||||||
|
|
||||||
for key in DyntrainModel._get_nonlinear_names(self.model):
|
for key in DyntrainModel._get_nonlinear_names(self.model):
|
||||||
replace_module(self.model, key, modules[key].to(primary_device))
|
replace_module(self.model, key, modules[key].to(devices[0]))
|
||||||
|
|
||||||
breakpoint()
|
|
||||||
|
|
||||||
group_index = 0
|
group_index = 0
|
||||||
params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count
|
for device in devices[:-1]:
|
||||||
primary_params = static_param_count
|
params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
|
||||||
while params_for_primary > primary_params and group_index < len(self.linear_groups):
|
|
||||||
self.linear_groups[group_index].inplaceTo(device=primary_device)
|
|
||||||
primary_params += self.linear_groups[group_index].paramCount()
|
|
||||||
group_index += 1
|
|
||||||
|
|
||||||
for device in secondary_devices[:-1]:
|
|
||||||
params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte
|
|
||||||
params = 0
|
params = 0
|
||||||
while params_for_device > params and group_index < len(self.linear_groups):
|
while params_for_device > params and group_index < len(self.linear_groups):
|
||||||
self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device)
|
self.linear_groups[group_index].inplaceTo(device=device)
|
||||||
params += self.linear_groups[group_index].paramCount()
|
params += self.linear_groups[group_index].paramCount()
|
||||||
group_index += 1
|
group_index += 1
|
||||||
|
|
||||||
while group_index < len(self.linear_groups):
|
while group_index < len(self.linear_groups):
|
||||||
self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device)
|
self.linear_groups[group_index].inplaceTo(device=devices[-1])
|
||||||
|
group_index += 1
|
||||||
|
@ -25,10 +25,16 @@ class Linear(torch.nn.Linear):
|
|||||||
return not self.weight.requires_grad
|
return not self.weight.requires_grad
|
||||||
|
|
||||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||||
|
frozen = self.isFrozen()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = torch.nn.Parameter(self.bias.to(dtype))
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self.weight = torch.nn.Parameter(self.weight.to(device))
|
self.weight = torch.nn.Parameter(self.weight.to(device))
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = torch.nn.Parameter(self.bias.to(device))
|
||||||
|
self.setFrozen(frozen)
|
||||||
|
|
||||||
|
|
||||||
class ConvertingLinear(Linear):
|
class ConvertingLinear(Linear):
|
||||||
@ -63,6 +69,6 @@ class ConvertingLinear(Linear):
|
|||||||
if input.dtype != self.weight.dtype:
|
if input.dtype != self.weight.dtype:
|
||||||
input = input.to(self.weight.dtype)
|
input = input.to(self.weight.dtype)
|
||||||
output = torch.nn.Linear.forward(self, input)
|
output = torch.nn.Linear.forward(self, input)
|
||||||
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
if torch.isnan(output).any():
|
||||||
breakpoint()
|
breakpoint()
|
||||||
return output.to(output_device).to(output_dtype)
|
return output.to(output_device).to(output_dtype)
|
||||||
|
@ -61,13 +61,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
log_writer = tensorboard.SummaryWriter()
|
log_writer = tensorboard.SummaryWriter()
|
||||||
|
|
||||||
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
|
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
|
||||||
model = model.toDevices(primary_device, [secondary_device])
|
model.toDevices([primary_device, secondary_device])
|
||||||
|
model.balanceActive()
|
||||||
|
|
||||||
paramter_count = sum(p.numel() for p in model.model.parameters())
|
paramter_count = sum(p.numel() for p in model.model.parameters())
|
||||||
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
||||||
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
|
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
||||||
|
|
||||||
if data_args.dataset.endswith("json"):
|
if data_args.dataset.endswith("json"):
|
||||||
print("Loading dataset in s2s mode")
|
print("Loading dataset in s2s mode")
|
||||||
@ -89,7 +90,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
batch_size=training_args.per_device_train_batch_size
|
batch_size=training_args.per_device_train_batch_size
|
||||||
) if dataset['eval_dataset'] is not None else None
|
) if dataset['eval_dataset'] is not None else None
|
||||||
|
|
||||||
dynamic_param_ratio = (model.staticParamterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
||||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
||||||
total_steps = steps_per_epoch * training_args.epochs
|
total_steps = steps_per_epoch * training_args.epochs
|
||||||
|
|
||||||
@ -111,14 +112,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
progress_bar = tqdm(range(total_steps))
|
progress_bar = tqdm(range(total_steps))
|
||||||
global_step = 0
|
global_step = 0
|
||||||
model.train()
|
model.model.train()
|
||||||
for epoch in range(0, training_args.epochs):
|
for epoch in range(0, training_args.epochs):
|
||||||
print("*** Train ***")
|
print("*** Train ***")
|
||||||
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
|
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to("cuda:0")
|
batch[key] = batch[key].to("cuda:0")
|
||||||
outputs = model(**batch)
|
outputs = model.model(**batch)
|
||||||
loss = outputs.loss / training_args.gradient_accumulation_steps
|
loss = outputs.loss / training_args.gradient_accumulation_steps
|
||||||
log_writer.add_scalar("Loss/train", loss, global_step)
|
log_writer.add_scalar("Loss/train", loss, global_step)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -127,7 +128,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
model.zero_grad()
|
model.model.zero_grad()
|
||||||
|
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
print(loss)
|
print(loss)
|
||||||
@ -136,6 +137,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
lr_scheduler.optimizer = None
|
lr_scheduler.optimizer = None
|
||||||
del optimizer
|
del optimizer
|
||||||
model.reshuffleActive()
|
model.reshuffleActive()
|
||||||
|
model.balanceActive()
|
||||||
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
||||||
optimizer = get_optimizer(model.dynamicParameters(),
|
optimizer = get_optimizer(model.dynamicParameters(),
|
||||||
model.staticParameters(),
|
model.staticParameters(),
|
||||||
@ -150,7 +152,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
|
||||||
if global_step % training_args.save_steps == 0:
|
if global_step % training_args.save_steps == 0:
|
||||||
save_model(model, global_step, training_args.output_dir, training_args.max_checkpoints)
|
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
|
||||||
if training_args.flush_allocator:
|
if training_args.flush_allocator:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -158,7 +160,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
print("*** Evaluate ***")
|
print("*** Evaluate ***")
|
||||||
|
|
||||||
save_model(model, global_step, training_args.output_dir)
|
save_model(model.model, global_step, training_args.output_dir)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user