diff --git a/arguments.py b/arguments.py index 2879de1..d5c55a9 100644 --- a/arguments.py +++ b/arguments.py @@ -41,10 +41,6 @@ class ModelArguments: default=False, metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} ) - max_instant_params: int = field( - default=0, - metadata={"help": "Maximum amount of paramters to optimize per step in millions"} - ) noresize: Optional[bool] = field( default=False, metadata={"help": "Never resize tokenizer embeddings"} @@ -93,3 +89,5 @@ class TrainingArguments(): secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'}) train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'}) flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'}) + max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"}) + churn_percent: int = field(default=0, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"}) diff --git a/dyntrainmodel.py b/dyntrainmodel.py index 687ee2b..ef4f595 100644 --- a/dyntrainmodel.py +++ b/dyntrainmodel.py @@ -1,30 +1,11 @@ from transformers import AutoModelForCausalLM import torch from utils import replace_module -from modules import ConvertingLinear, Linear +from modules import DynamicConvertingLinear, Linear from random import randint import math -def find_all_linear_module_names(model) -> list[str]: - module_names = set() - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear): - module_names.add(name) - - if 'lm_head' in module_names: # needed for 16-bit - module_names.remove('lm_head') - return list(module_names) - - -def find_all_outher_module_names(model) -> list[str]: - module_names = set() - for name, module in model.named_modules(): - if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)): - module_names.add(name) - return list(module_names) - - class LinearGroup: def __init__(self, model, group_names: list): self.modules = list() @@ -61,7 +42,7 @@ class LinearGroup: class DyntrainModel: def __init__(self, model_name_or_path: str, cache_dir: str, - target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False): + target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, cache_dir=cache_dir, @@ -69,16 +50,30 @@ class DyntrainModel: trust_remote_code=trust_remote_code, device_map=None ) - self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) - self.linear_groups = list() self.target_active_params = target_active_params - + self.reshuffle_fraction = reshuffle_fraction + if reshuffle_fraction < 0.10 or reshuffle_fraction > 1: + raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") self.devices = list() if gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - self._prepare() + modules = dict(self.model.named_modules()) + self.frozen_linear_groups = list() + self.active_linear_groups = list() + + linear_group_names = DyntrainModel._get_linear_group_names(self.model) + for group in linear_group_names: + for key in group: + if DyntrainModel.isModuleIn16bitOutlist(key): + replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16)) + else: + replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32)) + self.frozen_linear_groups.append(LinearGroup(self.model, group)) + self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) + for group in self.frozen_linear_groups: + group.setFrozen(True) self.reshuffleActive() def _get_nonlinear_names(layer: torch.nn.Module): @@ -117,21 +112,9 @@ class DyntrainModel: "v_proj"}) return key in whitelist - def _prepare(self) -> None: - modules = dict(self.model.named_modules()) - linear_groups = DyntrainModel._get_linear_group_names(self.model) - - for group in linear_groups: - for key in group: - if DyntrainModel.isModuleIn16bitOutlist(key): - replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16)) - else: - replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32)) - self.linear_groups.append(LinearGroup(self.model, group)) - def dynamicParameters(self) -> list: parameters = list() - for group in self.linear_groups: + for group in self.frozen_linear_groups + self.active_linear_groups: parameters.extend(group.parameters()) return parameters @@ -156,23 +139,24 @@ class DyntrainModel: return sum(p.numel() for p in total_params if p.requires_grad) def reshuffleActive(self) -> None: - for group in self.linear_groups: + active_count = len(self.active_linear_groups) + while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction): + group = self.active_linear_groups.pop(0) group.setFrozen(True) + self.frozen_linear_groups.append(group) - indecies = list(range(0, len(self.linear_groups))) - params = self.staticParameterCount() - while params < self.target_active_params and len(indecies) > 0: - i = randint(0, len(indecies) - 1) - self.linear_groups[indecies[i]].setFrozen(False) - params += self.linear_groups[indecies[i]].paramCount() - indecies.pop(i) + params = self.activeParameterCount() + + if params >= self.target_active_params: + RuntimeError("Insuficant active parameters to suffle active") + while params < self.target_active_params and len(self.frozen_linear_groups) > 0: + i = randint(0, len(self.frozen_linear_groups) - 1) + group = self.frozen_linear_groups.pop(i) + group.setFrozen(False) + params += group.paramCount() + self.active_linear_groups.append(group) print(math.ceil(params / 1e6)) - for group in self.linear_groups: - if group.isFrozen(): - group.inplaceTo(dtype=torch.float16) - else: - group.inplaceTo(dtype=torch.float32) active_params = self.activeParameterCount() assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params @@ -182,9 +166,8 @@ class DyntrainModel: for index in range(0, len(self.devices)): device_groups.append(list()) - for group in self.linear_groups: - if not group.isFrozen(): - device_groups[self.devices.index(group.getDevice())].append(group) + for group in self.active_linear_groups: + device_groups[self.devices.index(group.getDevice())].append(group) min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1]) max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1]) @@ -207,15 +190,17 @@ class DyntrainModel: for key in DyntrainModel._get_nonlinear_names(self.model): replace_module(self.model, key, modules[key].to(devices[0])) + linear_groups = self.active_linear_groups + self.frozen_linear_groups + group_index = 0 for device in devices[:-1]: params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte params = 0 - while params_for_device > params and group_index < len(self.linear_groups): - self.linear_groups[group_index].inplaceTo(device=device) - params += self.linear_groups[group_index].paramCount() + while params_for_device > params and group_index < len(linear_groups): + linear_groups[group_index].inplaceTo(device=device) + params += linear_groups[group_index].paramCount() group_index += 1 - while group_index < len(self.linear_groups): - self.linear_groups[group_index].inplaceTo(device=devices[-1]) + while group_index < len(linear_groups): + linear_groups[group_index].inplaceTo(device=devices[-1]) group_index += 1 diff --git a/modules.py b/modules.py index 385c6da..a68e855 100644 --- a/modules.py +++ b/modules.py @@ -1,4 +1,8 @@ import torch +import bitsandbytes as bnb +import torch.multiprocessing as multiprocessing +from typing import overload, Optional, Union +from functools import wraps class Linear(torch.nn.Linear): @@ -34,12 +38,22 @@ class Linear(torch.nn.Linear): self.weight = torch.nn.Parameter(self.weight.to(device)) if self.bias is not None: self.bias = torch.nn.Parameter(self.bias.to(device)) - self.setFrozen(frozen) + Linear.setFrozen(self, frozen) + + def _apply(self, fn, recurse: bool = True): + if fn.__name__ == "convert": + return self + else: + return super()._apply(fn, recurse) + + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + breakpoint() + return self -class ConvertingLinear(Linear): - def __init__(self, - in_features, out_features, bias=True, device=None, dtype=None, +class DynamicConvertingLinear(Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None, output_device=None): super().__init__(in_features, out_features, bias, device, dtype) self.output_dtype = output_dtype @@ -58,6 +72,14 @@ class ConvertingLinear(Linear): new_module.bias = in_module.bias return new_module + def setFrozen(self, frozen: bool): + super().setFrozen(frozen) + + if frozen: + self.inplaceTo(torch.float16) + else: + self.inplaceTo(torch.float32) + def setOutputDevice(self, output_device: torch.device): self.output_device = output_device @@ -69,6 +91,101 @@ class ConvertingLinear(Linear): if input.dtype != self.weight.dtype: input = input.to(self.weight.dtype) output = torch.nn.Linear.forward(self, input) - if torch.isnan(output).any(): - breakpoint() return output.to(output_device).to(output_dtype) + + +class DynamicQantizedLinear(Linear): + def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device, + output_dtype=None, compute_dtype=None, output_device=None): + super().__init__(in_features, out_features, bias, cold_device, torch.float32) + self.active_device = active_device + self.cold_device = cold_device + self.output_device = output_device + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.weight_quantized = None + self.weight_state = None + self.bias_quantized = None + self.bias_state = None + self.block_size = 128 + self.quant_type = 'nf4' + + @classmethod + def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device, + output_dtype=None, compute_dtype=torch.float16, output_device=None): + new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, + active_device=active_device, cold_device=cold_device, output_dtype=output_dtype, + compute_dtype=compute_dtype, output_device=output_device) + new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device)) + new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None + return new_module + + def quantize(self): + weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) + self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size, + compress_statistics=False, quant_type=self.quant_type) + if self.bias is not None: + bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) + self.bias_quantized, self.bias_state = bnb.functional.quantize_4bit(bias, blocksize=self.block_size, + compress_statistics=False, quant_type=self.quant_type) + + weight = torch.nn.Parameter(self.weight.to(self.cold_device)) + bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None + + def dequantize(self): + if self.weight_quantized is None: + raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") + dtype = self.weight.dtype + self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) + if self.bias_quantized: + self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) + self.weight_quantized = None + self.weight_state = None + self.bias_quantized = None + self.bias_state = None + + def checkDistance(self) -> float: + if self.weight_quantized is None: + raise RuntimeError("checkDistance() called without quantized weights avialable") + original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) + quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight, + blocksize=self.block_size, + compress_statistics=True, + quant_type=self.quant_type) + dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype) + dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype) + return (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight) / dequantized_original_weight.numel()).item() + + def setOutputDevice(self, output_device: torch.device): + self.output_device = output_device + + def setFrozen(self, frozen: bool) -> None: + if frozen == self.isFrozen(): + return + + super().setFrozen(frozen) + + if frozen: + self.quantize() + else: + self.dequantize() + + def forward(self, x: torch.Tensor): + output_dtype = x.dtype if self.output_dtype is None else self.output_dtype + output_device = x.device if self.output_device is None else self.output_device + + if not self.isFrozen(): + if x.device != self.weight.device: + x = x.to(self.weight.device) + if x.dtype != self.weight.dtype: + x = x.to(self.weight.dtype) + return super().forward(x).to(output_device).to(output_dtype) + else: + if self.weight_quantized is None: + raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") + bias = None + if self.bias_quantized is not None: + bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype) + out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state) + + return out.to(output_device).to(output_dtype) diff --git a/train_dynamic.py b/train_dynamic.py index c9e4de7..6de2f2b 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -60,7 +60,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T secondary_device = torch.device(training_args.secondary_device) log_writer = tensorboard.SummaryWriter() - model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True) + model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6, + reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True) model.toDevices([primary_device, secondary_device]) model.balanceActive() @@ -133,7 +134,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if global_step % 10 == 0: print(loss) - if global_step % 10 == 0 and model_args.max_instant_params != 0: + if global_step % 10 == 0 and training_args.max_instant_params != 0: lr_scheduler.optimizer = None del optimizer model.reshuffleActive() diff --git a/utils.py b/utils.py index 6d11c0c..c58bc06 100644 --- a/utils.py +++ b/utils.py @@ -5,3 +5,22 @@ import torch def replace_module(model, key: str, module: torch.nn.Module): parent, target, target_name = _get_submodules(model, key) setattr(parent, target_name, module) + + +def find_all_linear_module_names(model) -> list[str]: + module_names = set() + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + module_names.add(name) + + if 'lm_head' in module_names: # needed for 16-bit + module_names.remove('lm_head') + return list(module_names) + + +def find_all_outher_module_names(model) -> list[str]: + module_names = set() + for name, module in model.named_modules(): + if not isinstance(module, torch.nn.Linear): + module_names.add(name) + return list(module_names)