diff --git a/arguments.py b/arguments.py index d5c55a9..410f772 100644 --- a/arguments.py +++ b/arguments.py @@ -45,6 +45,10 @@ class ModelArguments: default=False, metadata={"help": "Never resize tokenizer embeddings"} ) + quantize: Optional[bool] = field ( + default=False, + metadata={"help": "Quantize parameters not currently be actively trained"} + ) @dataclass @@ -85,9 +89,8 @@ class TrainingArguments(): save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'}) save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) - primary_device: str = field(default="cuda:0", metadata={"help": 'The primary device to use'}) - secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'}) - train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'}) + train_non_linear_layers: Optional[bool] = field(default=False, metadata={"help": 'train non linear layers'}) flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'}) max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"}) - churn_percent: int = field(default=0, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"}) + churn_percent: int = field(default=100, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"}) + eval_steps: int = field(default=-1, metadata={"help": "Number of optimization steps after wich to compute the evaluation loss"}) diff --git a/dyntrainmodel.py b/dyntrainmodel.py index ef4f595..c436983 100644 --- a/dyntrainmodel.py +++ b/dyntrainmodel.py @@ -1,9 +1,10 @@ from transformers import AutoModelForCausalLM import torch from utils import replace_module -from modules import DynamicConvertingLinear, Linear +from modules import DynamicConvertingLinear, Linear, DynamicQantizedLinear from random import randint import math +from tqdm import tqdm class LinearGroup: @@ -20,9 +21,9 @@ class LinearGroup: module.inplaceTo(dtype, device) self.modules[-1].setOutputDevice(output_device) - def setFrozen(self, frozen: bool) -> None: + def setFrozen(self, frozen: bool, convert: bool = True) -> None: for module in self.modules: - module.setFrozen(frozen) + module.setFrozen(frozen, convert) def isFrozen(self) -> bool: return self.modules[0].isFrozen() @@ -39,9 +40,26 @@ class LinearGroup: def getDevice(self) -> torch.device: return self.modules[0].weight.device + def compress(self) -> None: + for module in self.modules: + module.compress() + + def decompress(self) -> None: + for module in self.modules: + module.decompress() + + def checkDistance(self) -> tuple[float, float]: + distance_accum = 0.0 + error_accum = 0.0 + for module in self.modules: + distance, error = module.checkDistance() + distance_accum += distance**2 + error_accum += error**2 + return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules))) + class DyntrainModel: - def __init__(self, model_name_or_path: str, cache_dir: str, + def __init__(self, model_name_or_path: str, cache_dir: str, quantize: bool, target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, @@ -55,28 +73,32 @@ class DyntrainModel: if reshuffle_fraction < 0.10 or reshuffle_fraction > 1: raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") self.devices = list() + self.inital_reshufle = True if gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - modules = dict(self.model.named_modules()) self.frozen_linear_groups = list() self.active_linear_groups = list() - linear_group_names = DyntrainModel._get_linear_group_names(self.model) + linear_group_names = DyntrainModel._getLinearGroupNames(self.model) for group in linear_group_names: for key in group: - if DyntrainModel.isModuleIn16bitOutlist(key): - replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16)) - else: - replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32)) + replace_module(self.model, key, self._getModule(key, quantize, "cuda:0", "cpu")) self.frozen_linear_groups.append(LinearGroup(self.model, group)) self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) for group in self.frozen_linear_groups: - group.setFrozen(True) - self.reshuffleActive() + group.setFrozen(True, False) - def _get_nonlinear_names(layer: torch.nn.Module): + def _getModule(self, key: str, quantize: bool, active_device: torch.device, cold_device: torch.device): + output_dtype = torch.float16 if DyntrainModel.isModuleIn16bitOutlist(key) else torch.float32 + modules = dict(self.model.named_modules()) + if quantize: + return DynamicQantizedLinear.fromLinear(modules[key], active_device, cold_device, output_dtype, torch.float16) + else: + return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype) + + def _getNonlinearNames(layer: torch.nn.Module): names = list() modules = dict(layer.named_modules()) @@ -85,7 +107,7 @@ class DyntrainModel: names.append(key) return names - def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]: + def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]: linear_groups = list() list_counter = 0 in_sequence = False @@ -140,8 +162,11 @@ class DyntrainModel: def reshuffleActive(self) -> None: active_count = len(self.active_linear_groups) + index = 0 while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction): - group = self.active_linear_groups.pop(0) + distance, error = self.active_linear_groups[index].checkDistance() + print(f"linear group has moved {distance} with an error of {error}") + group = self.active_linear_groups.pop(index) group.setFrozen(True) self.frozen_linear_groups.append(group) @@ -161,25 +186,39 @@ class DyntrainModel: assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params + def activeParamtersByDevice(self) -> list[int]: + out = [0] * len(self.devices) + for group in self.active_linear_groups: + out[self.devices.index(group.getDevice())] += group.paramCount() + return out + def balanceActive(self) -> None: - device_groups = list() - for index in range(0, len(self.devices)): - device_groups.append(list()) + active_counts = self.activeParamtersByDevice() + bits_per_param = list() + for i, count in enumerate(active_counts): + memory = torch.cuda.get_device_properties(self.devices[i]).total_memory + if i == 0: + memory = memory * 0.8 + bits_per_param.append(count / memory) + + max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1]) + min_index, min_bits_per_param = min(enumerate(active_counts), key=lambda x: x[1]) for group in self.active_linear_groups: - device_groups[self.devices.index(group.getDevice())].append(group) - - min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1]) - max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1]) - - if max_count - 2 > min_count: - device_groups[max_index][0].inplaceTo(device=self.devices[min_index]) - self.balanceActive() + if group.getDevice() is self.devices[max_index]: + memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory + if max_index == 0: + memory = memory * 0.8 + swing = group.paramCount() / memory + if max_bits_per_param - swing > min_bits_per_param + swing: + group.inplaceTo(device=self.devices[min_index]) + self.balanceActive() def toDevices(self, devices: list[torch.device]) -> None: assert len(devices) > 0 modules = dict(self.model.named_modules()) total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices) + total_memory -= torch.cuda.get_device_properties(devices[0]).total_memory * 0.2 static_param_count = self.staticParameterCount() total_parameter_count = static_param_count + self.dynamicParameterCount() params_per_byte = total_parameter_count / float(total_memory) @@ -187,14 +226,17 @@ class DyntrainModel: self.devices = devices - for key in DyntrainModel._get_nonlinear_names(self.model): + for key in DyntrainModel._getNonlinearNames(self.model): replace_module(self.model, key, modules[key].to(devices[0])) linear_groups = self.active_linear_groups + self.frozen_linear_groups group_index = 0 - for device in devices[:-1]: - params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte + for i, device in enumerate(devices[:-1]): + memory = torch.cuda.get_device_properties(devices).total_memory + if i == 0: + memory = memory * 0.8 + params_for_device = memory * params_per_byte params = 0 while params_for_device > params and group_index < len(linear_groups): linear_groups[group_index].inplaceTo(device=device) @@ -204,3 +246,6 @@ class DyntrainModel: while group_index < len(linear_groups): linear_groups[group_index].inplaceTo(device=devices[-1]) group_index += 1 + + for group in tqdm(linear_groups, desc="Perpareing layers"): + group.compress() diff --git a/modules.py b/modules.py index a68e855..cf50422 100644 --- a/modules.py +++ b/modules.py @@ -20,10 +20,23 @@ class Linear(torch.nn.Linear): new_module.bias = in_module.bias return new_module - def setFrozen(self, frozen: bool): + def compress(self) -> None: + self.inplaceTo(torch.float16) + + def decompress(self) -> None: + self.inplaceTo(torch.float32) + + def setFrozen(self, frozen: bool, convert: bool = True): self.weight.requires_grad = not frozen if self.bias is not None: self.bias.requires_grad = not frozen + if convert: + if frozen: + breakpoint() + self.compress() + else: + self.decompress() + self.weightStart = torch.Tensor(self.weight).clone().detach() def isFrozen(self) -> bool: return not self.weight.requires_grad @@ -38,7 +51,7 @@ class Linear(torch.nn.Linear): self.weight = torch.nn.Parameter(self.weight.to(device)) if self.bias is not None: self.bias = torch.nn.Parameter(self.bias.to(device)) - Linear.setFrozen(self, frozen) + Linear.setFrozen(self, frozen, False) def _apply(self, fn, recurse: bool = True): if fn.__name__ == "convert": @@ -72,17 +85,12 @@ class DynamicConvertingLinear(Linear): new_module.bias = in_module.bias return new_module - def setFrozen(self, frozen: bool): - super().setFrozen(frozen) - - if frozen: - self.inplaceTo(torch.float16) - else: - self.inplaceTo(torch.float32) - def setOutputDevice(self, output_device: torch.device): self.output_device = output_device + def checkDistance(self) -> tuple[float, float]: + return (10.0, 0.0) + def forward(self, input: torch.Tensor): output_dtype = input.dtype if self.output_dtype is None else self.output_dtype output_device = input.device if self.output_device is None else self.output_device @@ -120,7 +128,7 @@ class DynamicQantizedLinear(Linear): new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None return new_module - def quantize(self): + def compress(self) -> None: weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size, compress_statistics=False, quant_type=self.quant_type) @@ -132,19 +140,15 @@ class DynamicQantizedLinear(Linear): weight = torch.nn.Parameter(self.weight.to(self.cold_device)) bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None - def dequantize(self): + def decompress(self) -> None: if self.weight_quantized is None: - raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") + raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable") dtype = self.weight.dtype self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) if self.bias_quantized: self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) - self.weight_quantized = None - self.weight_state = None - self.bias_quantized = None - self.bias_state = None - def checkDistance(self) -> float: + def checkDistance(self) -> tuple[float, float]: if self.weight_quantized is None: raise RuntimeError("checkDistance() called without quantized weights avialable") original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) @@ -154,22 +158,13 @@ class DynamicQantizedLinear(Linear): quant_type=self.quant_type) dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype) dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype) - return (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight) / dequantized_original_weight.numel()).item() + distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item() + error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item() + return (distance, error) def setOutputDevice(self, output_device: torch.device): self.output_device = output_device - def setFrozen(self, frozen: bool) -> None: - if frozen == self.isFrozen(): - return - - super().setFrozen(frozen) - - if frozen: - self.quantize() - else: - self.dequantize() - def forward(self, x: torch.Tensor): output_dtype = x.dtype if self.output_dtype is None else self.output_dtype output_device = x.device if self.output_device is None else self.output_device @@ -183,9 +178,27 @@ class DynamicQantizedLinear(Linear): else: if self.weight_quantized is None: raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") + if x.device != self.weight_quantized.device: + x = x.to(self.weight_quantized.device) bias = None if self.bias_quantized is not None: bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype) out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state) return out.to(output_device).to(output_dtype) + + def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): + if dtype is not None: + super().inplaceTo(dtype=dtype) + if device is not None: + frozen = self.isFrozen() + self.active_device = device + if self.weight_quantized is not None: + self.weight_quantized = self.weight_quantized.to(device) + self.weight_state = self.weight_state.to(device) + if self.bias_quantized is not None: + self.bias_quantized = self.bias_quantized.to(device) + self.bias_state = self.bias_state.to(device) + if not frozen: + super().inplaceTo(device=device) + self.setFrozen(frozen, False) diff --git a/train_dynamic.py b/train_dynamic.py index 6de2f2b..941c3ca 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -39,10 +39,11 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters parameters = list() parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad) param_ids = set([id(p['params']) for p in parameters]) - for param in static_parameters: - if param.requires_grad and id(param) not in param_ids: - parameters.append({'params': param, 'lr': static_lr}) - param_ids.add(id(param)) + if static_parameters is not None: + for param in static_parameters: + if param.requires_grad and id(param) not in param_ids: + parameters.append({'params': param, 'lr': static_lr}) + param_ids.add(id(param)) if not adam8bit: optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon) @@ -55,19 +56,34 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters return optimizer +def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> float: + print("*** Eval ***") + loss = torch.zeros((1), device="cuda:0") + model.model.eval() + for batch in dataloader: + for key in batch: + batch[key] = batch[key].to("cuda:0") + outputs = model.model(**batch) + loss += outputs.loss + loss = loss / len(dataloader) + print(f"Eval Loss {loss.item()}") + + def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): - primary_device = torch.device(training_args.primary_device) - secondary_device = torch.device(training_args.secondary_device) log_writer = tensorboard.SummaryWriter() model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6, - reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True) - model.toDevices([primary_device, secondary_device]) + reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True, + quantize=model_args.quantize) + devices = list(torch.device(i) for i in range(0, torch.cuda.device_count())) + model.toDevices(devices) + model.reshuffleActive() model.balanceActive() paramter_count = sum(p.numel() for p in model.model.parameters()) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) - print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters") + static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0 + print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters of which {static_parameter_count} are static") tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) @@ -96,7 +112,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T total_steps = steps_per_epoch * training_args.epochs optimizer = get_optimizer(model.dynamicParameters(), - model.staticParameters(), + model.staticParameters() if training_args.train_non_linear_layers else None, training_args.learning_rate, training_args.learning_rate / dynamic_param_ratio, training_args.weight_decay, @@ -115,6 +131,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T global_step = 0 model.model.train() for epoch in range(0, training_args.epochs): + model.model.train() print("*** Train ***") print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}') for step, batch in enumerate(train_dataloader): @@ -131,17 +148,17 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T model.model.zero_grad() - if global_step % 10 == 0: - print(loss) + if global_step % 5 == 0: + print(f"Train Loss {loss.item()}") - if global_step % 10 == 0 and training_args.max_instant_params != 0: + if global_step % 50 == 0 and training_args.max_instant_params != 0: lr_scheduler.optimizer = None del optimizer model.reshuffleActive() model.balanceActive() log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) optimizer = get_optimizer(model.dynamicParameters(), - model.staticParameters(), + model.staticParameters() if training_args.train_non_linear_layers else None, training_args.learning_rate, training_args.learning_rate / dynamic_param_ratio, training_args.weight_decay, @@ -152,14 +169,19 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T global_step += 1 progress_bar.update() + if global_step > 0: if global_step % training_args.save_steps == 0: save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) + if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0: + evaluate(model, eval_dataloader) if training_args.flush_allocator: torch.cuda.empty_cache() + if training_args.do_eval and training_args.eval_steps == -1: + evaluate(model, eval_dataloader) # Evaluation if training_args.do_eval: - print("*** Evaluate ***") + evaluate(model, eval_dataloader) save_model(model.model, global_step, training_args.output_dir)