From 5acb6809ed97dcdc0665fa9850e4afd0a5130961 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 13 Mar 2024 19:45:52 +0100 Subject: [PATCH] wip refactor --- convertinglinear.py | 30 ------ dyntrainmodel.py | 190 ++++++++++++++++++++++++++++++++++++ modules.py | 68 +++++++++++++ train_dynamic.py | 229 ++++++-------------------------------------- utils.py | 7 ++ 5 files changed, 292 insertions(+), 232 deletions(-) delete mode 100644 convertinglinear.py create mode 100644 dyntrainmodel.py create mode 100644 modules.py create mode 100644 utils.py diff --git a/convertinglinear.py b/convertinglinear.py deleted file mode 100644 index 09bad44..0000000 --- a/convertinglinear.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch - - -class ConvertingLinear(torch.nn.Linear): - def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None): - super().__init__(in_features, out_features, bias, device, dtype) - self.output_dtype = output_dtype - - def forward(self, input: torch.Tensor): - output_dtype = input.dtype if self.output_dtype is None else self.output_dtype - output_device = input.device - if input.device != self.weight.device: - input = input.to(self.weight.device) - if input.dtype != self.weight.dtype: - input = input.to(self.weight.dtype) - output = torch.nn.Linear.forward(self, input) - if torch.isnan(output).any() or self.weight.dtype != torch.float32: - breakpoint() - return output.to(output_device).to(output_dtype) - - @classmethod - def fromLinear(cls, in_module: torch.nn.Linear): - new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, - out_features=in_module.out_features, - bias=in_module.bias is not None, - device=in_module.weight.device, - dtype=in_module.weight.dtype) - new_module.weight = in_module.weight - new_module.bias = in_module.bias - return new_module diff --git a/dyntrainmodel.py b/dyntrainmodel.py new file mode 100644 index 0000000..433ef4f --- /dev/null +++ b/dyntrainmodel.py @@ -0,0 +1,190 @@ +from transformers import AutoModelForCausalLM +import torch +from utils import replace_module +from modules import ConvertingLinear, Linear +from random import randint + + +def find_all_linear_module_names(model) -> list[str]: + module_names = set() + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear): + module_names.add(name) + + if 'lm_head' in module_names: # needed for 16-bit + module_names.remove('lm_head') + return list(module_names) + + +def find_all_outher_module_names(model) -> list[str]: + module_names = set() + for name, module in model.named_modules(): + if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)): + module_names.add(name) + return list(module_names) + + +class LinearGroup: + def __init__(self, model, group_names: list): + self.modules = list() + model_modules = dict(model.named_modules()) + for name in group_names: + self.modules.append(model_modules[name]) + assert isinstance(self.modules[0], ConvertingLinear) + assert isinstance(self.modules[-1], ConvertingLinear) + + def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None: + for module in self.modules: + module.inplaceTo(dtype, device) + self.modules[-1].setOutputDevice(output_device) + + def setFrozen(self, frozen: bool) -> None: + for module in self.modules: + module.setFrozen(frozen) + + def isFrozen(self) -> bool: + return self.modules[0].isFrozen() + + def parameters(self) -> list[torch.nn.Parameter]: + params = list() + for module in self.modules: + params.extend(module.parameters()) + return params + + def paramCount(self) -> int: + return sum(p.numel() for p in self.parameters()) + + +class DyntrainModel: + def __init__(self, model_name_or_path: str, cache_dir: str, + target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False): + self.model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + torch_dtype=torch.float32, + trust_remote_code=trust_remote_code, + device_map=None, + ) + self.linear_groups = list() + self.target_active_params = target_active_params + + self._prepare() + self.reshuffleActive() + + def _get_nonlinear_names(layer: torch.nn.Module): + names = list() + modules = dict(layer.named_modules()) + + for key in modules.keys(): + if not isinstance(modules[key], torch.nn.Linear): + names.append(key) + return names + + def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]: + linear_groups = list() + list_counter = 0 + in_sequence = False + modules = dict(layer.named_modules()) + + for key in modules.keys(): + if isinstance(modules[key], torch.nn.Linear) and key != "lm_head": + if not in_sequence: + linear_groups.append(list()) + in_sequence = True + linear_groups[list_counter].append(key) + elif in_sequence: + in_sequence = False + list_counter = list_counter + 1 + return linear_groups + + def _prepare(self) -> None: + modules = dict(self.model.named_modules()) + linear_groups = DyntrainModel._get_linear_group_names(self.model) + + for group in linear_groups: + replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16)) + replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32)) + if len(group) > 2: + for index in range(1, len(group) - 1): + replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16))) + self.linear_groups.append(LinearGroup(self.model, group)) + + def dynamicParameters(self) -> list: + parameters = list() + for group in self.linear_groups: + parameters.extend(group.parameters()) + return parameters + + def staticParameters(self) -> list: + modules = dict(self.model.named_modules()) + dynamic_param_ids = set([id(p) for p in self.dynamicParameters()]) + parameters = list() + for key in modules.keys(): + for param in modules[key].parameters(): + if id(param) not in dynamic_param_ids: + parameters.append(param) + return parameters + + def dynamicParameterCount(self) -> int: + return sum(p.numel() for p in self.dynamicParameters()) + + def staticParameterCount(self) -> int: + return sum(p.numel() for p in self.staticParameters()) + + def activeParameterCount(self) -> int: + total_params = self.dynamicParameters() + self.staticParameters() + return sum(p.numel() for p in total_params if total_params) + + def reshuffleActive(self) -> None: + for group in self.linear_groups: + group.setFrozen(True) + + indecies = list(range(0, len(self.linear_groups))) + params = self.staticParameterCount() + while params < self.target_active_params and len(indecies) > 0: + i = randint(0, len(indecies) - 1) + self.linear_groups[indecies[i]].setFrozen(False) + params += self.linear_groups[indecies[i]].paramCount() + indecies.pop(i) + + for group in self.linear_groups: + if group.isFrozen(): + group.inplaceTo(dtype=torch.float16) + else: + group.inplaceTo(dtype=torch.float32) + print(group.modules[0].weight.dtype) + + def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None: + modules = dict(self.model.named_modules()) + total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices) + total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8 + static_param_count = self.staticParameterCount() + total_parameter_count = static_param_count + self.dynamicParameterCount() + params_per_byte = total_parameter_count / float(total_memory) + print(f"{1/params_per_byte} bytes available per parameter") + + breakpoint() + + for key in DyntrainModel._get_nonlinear_names(self.model): + replace_module(self.model, key, modules[key].to(primary_device)) + + breakpoint() + + group_index = 0 + params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count + primary_params = static_param_count + while params_for_primary > primary_params and group_index < len(self.linear_groups): + self.linear_groups[group_index].inplaceTo(device=primary_device) + primary_params += self.linear_groups[group_index].paramCount() + group_index += 1 + + for device in secondary_devices[:-1]: + params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte + params = 0 + while params_for_device > params and group_index < len(self.linear_groups): + self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device) + params += self.linear_groups[group_index].paramCount() + group_index += 1 + + while group_index < len(self.linear_groups): + self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device) diff --git a/modules.py b/modules.py new file mode 100644 index 0000000..470b0d7 --- /dev/null +++ b/modules.py @@ -0,0 +1,68 @@ +import torch + + +class Linear(torch.nn.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + + @classmethod + def fromLinear(cls, in_module: torch.nn.Linear): + new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, + out_features=in_module.out_features, + bias=in_module.bias is not None, + device=in_module.weight.device, + dtype=in_module.weight.dtype) + new_module.weight = in_module.weight + new_module.bias = in_module.bias + return new_module + + def setFrozen(self, frozen: bool): + self.weight.requires_grad = not frozen + if self.bias is not None: + self.bias.requires_grad = not frozen + + def isFrozen(self) -> bool: + return not self.weight.requires_grad + + def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): + if dtype is not None: + self.weight = torch.nn.Parameter(self.weight.to(dtype)) + if device is not None: + self.weight = torch.nn.Parameter(self.weight.to(device)) + + +class ConvertingLinear(Linear): + def __init__(self, + in_features, out_features, bias=True, device=None, dtype=None, + output_dtype=None, output_device=None): + super().__init__(in_features, out_features, bias, device, dtype) + self.output_dtype = output_dtype + self.output_device = output_device + + @classmethod + def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None): + new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, + out_features=in_module.out_features, + bias=in_module.bias is not None, + device=in_module.weight.device, + dtype=in_module.weight.dtype) + new_module.output_dtype = output_dtype + new_module.output_device = output_device + new_module.weight = in_module.weight + new_module.bias = in_module.bias + return new_module + + def setOutputDevice(self, output_device: torch.device): + self.output_device = output_device + + def forward(self, input: torch.Tensor): + output_dtype = input.dtype if self.output_dtype is None else self.output_dtype + output_device = input.device if self.output_device is None else self.output_device + if input.device != self.weight.device: + input = input.to(self.weight.device) + if input.dtype != self.weight.dtype: + input = input.to(self.weight.dtype) + output = torch.nn.Linear.forward(self, input) + if torch.isnan(output).any() or self.weight.dtype != torch.float32: + breakpoint() + return output.to(output_device).to(output_dtype) diff --git a/train_dynamic.py b/train_dynamic.py index 18f5a16..92dcbd7 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -1,156 +1,18 @@ import transformers -from transformers import AutoModelForCausalLM, get_scheduler -from peft.utils import _get_submodules +from transformers import get_scheduler import torch from torch.utils import tensorboard import os import shutil import math -import collections from tqdm.auto import tqdm -from random import randint -import collections - from arguments import DataArguments, ModelArguments, TrainingArguments from datamodules import create_data_module_s2s, create_data_module -from convertinglinear import ConvertingLinear from tokenizer import get_tokenizer - -def find_all_linear_module_names(model): - module_names = set() - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear): - module_names.add(name) - - if 'lm_head' in module_names: # needed for 16-bit - module_names.remove('lm_head') - return list(module_names) - - -def find_all_outher_module_names(model): - module_names = set() - for name, module in model.named_modules(): - if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)): - module_names.add(name) - return list(module_names) - - -def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing): - dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32 - print(f'loading base model {model_args.model_name_or_path} in {dtype}...') - - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - cache_dir=cache_dir, - torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32, - trust_remote_code=model_args.trust_remote_code, - device_map=None, - attn_implementation="flash_attention_2" - ) - - modules = dict(model.named_modules()) - keys = find_all_linear_module_names(model) - for key in keys: - modules[key].weight = modules[key].weight.to(dtype) - if modules[key].bias is not None: - modules[key].bias = modules[key].bias.to(dtype) - - return model - - -@torch.no_grad() -def recursive_setattr(obj, attr, value): - attr = attr.split('.', 1) - if len(attr) == 1: - setattr(obj, attr[0], value) - else: - recursive_setattr(getattr(obj, attr[0]), attr[1], value) - - -@torch.no_grad() -def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device): - new_module = torch.nn.Linear(module.in_features, - module.out_features, - module.bias is not None, - module.weight.device, - dtype) - new_module.weight = torch.nn.Parameter(module.weight.detach().clone()) - new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None - new_module.weight.requires_grad = not frozen - if new_module.bias is not None: - new_module.bias.requires_grad = not frozen - return new_module - - -@torch.no_grad() -def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device): - if type(module) is torch.nn.Linear: - if frozen: - module.weight.requires_grad = False - if module.bias is not None: - module.bias.requires_grad = False - return module.to(dtype).to(device) - else: - new_module = ConvertingLinear.fromLinear(module).to(dtype) - new_module.weight.requires_grad = True - if new_module.bias is not None: - new_module.bias.requires_grad = True - return new_module.to(device) - elif type(module) is ConvertingLinear: - if not frozen: - module.weight.requires_grad = True - if module.bias is not None: - module.bias.requires_grad = True - assert False - return module.to(dtype).to(device) - else: - new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - device=module.weight.device, - dtype=dtype) - new_module.weight = torch.nn.Parameter(module.weight.to(dtype)) - new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None - new_module.weight.requires_grad = False - if new_module.bias is not None: - new_module.bias.requires_grad = False - return new_module.to(device) - else: - assert False - - -@torch.no_grad() -def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device): - modules = dict(model.named_modules()) - linear_names = find_all_linear_module_names(model) - - for key in linear_names: - if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad: - parent, target, target_name = _get_submodules(model, key) - setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device)) - modules = dict(model.named_modules()) - - active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad) - if active_paramter_count > target_params: - raise RuntimeError("Enough paramters must be available to train at least one linear layer") - - while active_paramter_count < target_params and len(linear_names) > 0: - i = randint(0, len(linear_names) - 1) - parent, target, target_name = _get_submodules(model, linear_names[i]) - new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device) - setattr(parent, target_name, new_module) - active_paramter_count += modules[linear_names[i]].weight.numel() - if modules[linear_names[i]].bias is not None: - active_paramter_count += modules[linear_names[i]].bias.numel() - linear_names.pop(i) - modules = dict() - - assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad) - - return active_paramter_count +from dyntrainmodel import DyntrainModel def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0): @@ -172,26 +34,15 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = shutil.rmtree(delete_checkpoit_dir) -def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float, +def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters: list[torch.nn.parameter], lr: float, static_lr: float, weight_decay: float, eps: float, adam8bit: bool): - - all_keys = dynamic_module_names + static_module_names - duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1] - if len(duplicated) > 0: - print("duplicated items:") - for item in duplicated: - print(item) - raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters") - parameters = list() - modules = dict(model.named_modules()) - for key in dynamic_module_names: - parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad) + parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad) param_ids = set([id(p['params']) for p in parameters]) - for key in static_module_names: - parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids) - for p in modules[key].parameters(): - param_ids.add(id(p)) + for param in static_parameters: + if param.requires_grad and id(param) not in param_ids: + parameters.append({'params': param, 'lr': static_lr}) + param_ids.add(id(param)) if not adam8bit: optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon) @@ -204,18 +55,17 @@ def get_optimizer(model, dynamic_module_names: list, static_module_names: list, return optimizer -def compute_dynamic_parameter_ratio(model): - modules = dict(model.named_modules()) - active_linear_parameters = 0 - total_linear_parameters = 0 - for key in find_all_linear_module_names(model): - active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad) - total_linear_parameters += sum(p.numel() for p in modules[key].parameters()) - return math.ceil(total_linear_parameters / active_linear_parameters) +def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): + primary_device = torch.device(training_args.primary_device) + secondary_device = torch.device(training_args.secondary_device) + log_writer = tensorboard.SummaryWriter() + model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True) + model = model.toDevices(primary_device, [secondary_device]) -def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple: - model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device) + paramter_count = sum(p.numel() for p in model.model.parameters()) + active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) + print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters") tokenizer = get_tokenizer(model, training_args.cache_dir, model_args) @@ -239,48 +89,24 @@ def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: batch_size=training_args.per_device_train_batch_size ) if dataset['eval_dataset'] is not None else None - if model_args.max_instant_params != 0: - print(f"Target params {model_args.max_instant_params}m") - freeze_random_modules(model, model_args.max_instant_params * 1e6, - torch.float16 if training_args.storage_fp16 else torch.float32, - frozen_device=primary_device, active_device=secondary_device) - - paramter_count = sum(p.numel() for p in model.parameters()) - active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters") - - dynamic_param_ratio = compute_dynamic_parameter_ratio(model) - print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}") - + dynamic_param_ratio = (model.staticParamterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) total_steps = steps_per_epoch * training_args.epochs - optimizer = get_optimizer(model, find_all_linear_module_names(model), - find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(), + optimizer = get_optimizer(model.dynamicParameters(), + model.staticParameters(), training_args.learning_rate, training_args.learning_rate / dynamic_param_ratio, training_args.weight_decay, training_args.adam_epsilon, training_args.adam8bit) + lr_scheduler = get_scheduler( name=training_args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=total_steps ) - return model, optimizer, lr_scheduler, train_dataloader - - -def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): - primary_device = torch.device(training_args.primary_device) - secondary_device = torch.device(training_args.secondary_device) - log_writer = tensorboard.SummaryWriter() - - model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device) - - steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) - total_steps = steps_per_epoch * training_args.epochs - dynamic_param_ratio = compute_dynamic_parameter_ratio(model) if training_args.do_train: progress_bar = tqdm(range(total_steps)) @@ -307,13 +133,12 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T print(loss) if global_step % 10 == 0 and model_args.max_instant_params != 0: - param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6, - torch.float16 if training_args.storage_fp16 else torch.float32, - frozen_device=primary_device, - active_device=secondary_device) - log_writer.add_scalar("Parameters/train", param_count, global_step) - optimizer = get_optimizer(model, find_all_linear_module_names(model), - find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(), + lr_scheduler.optimizer = None + del optimizer + model.reshuffleActive() + log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) + optimizer = get_optimizer(model.dynamicParameters(), + model.staticParameters(), training_args.learning_rate, training_args.learning_rate / dynamic_param_ratio, training_args.weight_decay, diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..6d11c0c --- /dev/null +++ b/utils.py @@ -0,0 +1,7 @@ +from peft.utils import _get_submodules +import torch + + +def replace_module(model, key: str, module: torch.nn.Module): + parent, target, target_name = _get_submodules(model, key) + setattr(parent, target_name, module)