From 38a7f7cfc4a339cd4519b773c9ad6d4b0e308da8 Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 17 Mar 2024 22:54:33 +0100 Subject: [PATCH] add gpu memory rebalanceing --- dyntrainmodel.py | 91 ++++++++++++++++++++++++++++++++---------------- modules.py | 8 ++++- train_dynamic.py | 18 +++++----- 3 files changed, 78 insertions(+), 39 deletions(-) diff --git a/dyntrainmodel.py b/dyntrainmodel.py index 433ef4f..687ee2b 100644 --- a/dyntrainmodel.py +++ b/dyntrainmodel.py @@ -3,6 +3,7 @@ import torch from utils import replace_module from modules import ConvertingLinear, Linear from random import randint +import math def find_all_linear_module_names(model) -> list[str]: @@ -30,8 +31,8 @@ class LinearGroup: model_modules = dict(model.named_modules()) for name in group_names: self.modules.append(model_modules[name]) - assert isinstance(self.modules[0], ConvertingLinear) - assert isinstance(self.modules[-1], ConvertingLinear) + for module in self.modules: + assert isinstance(module, Linear) def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None: for module in self.modules: @@ -54,6 +55,9 @@ class LinearGroup: def paramCount(self) -> int: return sum(p.numel() for p in self.parameters()) + def getDevice(self) -> torch.device: + return self.modules[0].weight.device + class DyntrainModel: def __init__(self, model_name_or_path: str, cache_dir: str, @@ -63,11 +67,17 @@ class DyntrainModel: cache_dir=cache_dir, torch_dtype=torch.float32, trust_remote_code=trust_remote_code, - device_map=None, + device_map=None ) + self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) self.linear_groups = list() self.target_active_params = target_active_params + self.devices = list() + + if gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + self._prepare() self.reshuffleActive() @@ -76,7 +86,7 @@ class DyntrainModel: modules = dict(layer.named_modules()) for key in modules.keys(): - if not isinstance(modules[key], torch.nn.Linear): + if not isinstance(modules[key], torch.nn.Linear) and len(list(modules[key].children())) == 0 or key == "lm_head": names.append(key) return names @@ -97,16 +107,26 @@ class DyntrainModel: list_counter = list_counter + 1 return linear_groups + def isModuleIn16bitOutlist(key: str) -> bool: + key = key.split('.')[-1] + whitelist = set({ + "gate_proj", + "up_proj", + "q_proj", + "k_proj", + "v_proj"}) + return key in whitelist + def _prepare(self) -> None: modules = dict(self.model.named_modules()) linear_groups = DyntrainModel._get_linear_group_names(self.model) for group in linear_groups: - replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16)) - replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32)) - if len(group) > 2: - for index in range(1, len(group) - 1): - replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16))) + for key in group: + if DyntrainModel.isModuleIn16bitOutlist(key): + replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16)) + else: + replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32)) self.linear_groups.append(LinearGroup(self.model, group)) def dynamicParameters(self) -> list: @@ -133,7 +153,7 @@ class DyntrainModel: def activeParameterCount(self) -> int: total_params = self.dynamicParameters() + self.staticParameters() - return sum(p.numel() for p in total_params if total_params) + return sum(p.numel() for p in total_params if p.requires_grad) def reshuffleActive(self) -> None: for group in self.linear_groups: @@ -146,45 +166,56 @@ class DyntrainModel: self.linear_groups[indecies[i]].setFrozen(False) params += self.linear_groups[indecies[i]].paramCount() indecies.pop(i) + print(math.ceil(params / 1e6)) for group in self.linear_groups: if group.isFrozen(): group.inplaceTo(dtype=torch.float16) else: group.inplaceTo(dtype=torch.float32) - print(group.modules[0].weight.dtype) + active_params = self.activeParameterCount() - def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None: + assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params + + def balanceActive(self) -> None: + device_groups = list() + for index in range(0, len(self.devices)): + device_groups.append(list()) + + for group in self.linear_groups: + if not group.isFrozen(): + device_groups[self.devices.index(group.getDevice())].append(group) + + min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1]) + max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1]) + + if max_count - 2 > min_count: + device_groups[max_index][0].inplaceTo(device=self.devices[min_index]) + self.balanceActive() + + def toDevices(self, devices: list[torch.device]) -> None: + assert len(devices) > 0 modules = dict(self.model.named_modules()) - total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices) - total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8 + total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices) static_param_count = self.staticParameterCount() total_parameter_count = static_param_count + self.dynamicParameterCount() params_per_byte = total_parameter_count / float(total_memory) - print(f"{1/params_per_byte} bytes available per parameter") + print(f"{math.floor(1/params_per_byte)} bytes available per parameter") - breakpoint() + self.devices = devices for key in DyntrainModel._get_nonlinear_names(self.model): - replace_module(self.model, key, modules[key].to(primary_device)) - - breakpoint() + replace_module(self.model, key, modules[key].to(devices[0])) group_index = 0 - params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count - primary_params = static_param_count - while params_for_primary > primary_params and group_index < len(self.linear_groups): - self.linear_groups[group_index].inplaceTo(device=primary_device) - primary_params += self.linear_groups[group_index].paramCount() - group_index += 1 - - for device in secondary_devices[:-1]: - params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte + for device in devices[:-1]: + params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte params = 0 while params_for_device > params and group_index < len(self.linear_groups): - self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device) + self.linear_groups[group_index].inplaceTo(device=device) params += self.linear_groups[group_index].paramCount() group_index += 1 while group_index < len(self.linear_groups): - self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device) + self.linear_groups[group_index].inplaceTo(device=devices[-1]) + group_index += 1 diff --git a/modules.py b/modules.py index 470b0d7..385c6da 100644 --- a/modules.py +++ b/modules.py @@ -25,10 +25,16 @@ class Linear(torch.nn.Linear): return not self.weight.requires_grad def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): + frozen = self.isFrozen() if dtype is not None: self.weight = torch.nn.Parameter(self.weight.to(dtype)) + if self.bias is not None: + self.bias = torch.nn.Parameter(self.bias.to(dtype)) if device is not None: self.weight = torch.nn.Parameter(self.weight.to(device)) + if self.bias is not None: + self.bias = torch.nn.Parameter(self.bias.to(device)) + self.setFrozen(frozen) class ConvertingLinear(Linear): @@ -63,6 +69,6 @@ class ConvertingLinear(Linear): if input.dtype != self.weight.dtype: input = input.to(self.weight.dtype) output = torch.nn.Linear.forward(self, input) - if torch.isnan(output).any() or self.weight.dtype != torch.float32: + if torch.isnan(output).any(): breakpoint() return output.to(output_device).to(output_dtype) diff --git a/train_dynamic.py b/train_dynamic.py index 92dcbd7..c9e4de7 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -61,13 +61,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T log_writer = tensorboard.SummaryWriter() model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True) - model = model.toDevices(primary_device, [secondary_device]) + model.toDevices([primary_device, secondary_device]) + model.balanceActive() paramter_count = sum(p.numel() for p in model.model.parameters()) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters") - tokenizer = get_tokenizer(model, training_args.cache_dir, model_args) + tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) if data_args.dataset.endswith("json"): print("Loading dataset in s2s mode") @@ -89,7 +90,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T batch_size=training_args.per_device_train_batch_size ) if dataset['eval_dataset'] is not None else None - dynamic_param_ratio = (model.staticParamterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() + dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) total_steps = steps_per_epoch * training_args.epochs @@ -111,14 +112,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if training_args.do_train: progress_bar = tqdm(range(total_steps)) global_step = 0 - model.train() + model.model.train() for epoch in range(0, training_args.epochs): print("*** Train ***") print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}') for step, batch in enumerate(train_dataloader): for key in batch: batch[key] = batch[key].to("cuda:0") - outputs = model(**batch) + outputs = model.model(**batch) loss = outputs.loss / training_args.gradient_accumulation_steps log_writer.add_scalar("Loss/train", loss, global_step) loss.backward() @@ -127,7 +128,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T optimizer.step() lr_scheduler.step() - model.zero_grad() + model.model.zero_grad() if global_step % 10 == 0: print(loss) @@ -136,6 +137,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T lr_scheduler.optimizer = None del optimizer model.reshuffleActive() + model.balanceActive() log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) optimizer = get_optimizer(model.dynamicParameters(), model.staticParameters(), @@ -150,7 +152,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T progress_bar.update() if global_step % training_args.save_steps == 0: - save_model(model, global_step, training_args.output_dir, training_args.max_checkpoints) + save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) if training_args.flush_allocator: torch.cuda.empty_cache() @@ -158,7 +160,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if training_args.do_eval: print("*** Evaluate ***") - save_model(model, global_step, training_args.output_dir) + save_model(model.model, global_step, training_args.output_dir) return