diff --git a/dyntrainmodel.py b/dyntrainmodel.py index e6a1638..1fa08d6 100644 --- a/dyntrainmodel.py +++ b/dyntrainmodel.py @@ -68,7 +68,9 @@ class LinearGroup: class DyntrainModel: def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool, - target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): + target_active_params: int, train_static_params: bool, + reshuffle_fraction: float, gradient_checkpointing: bool, + trust_remote_code: bool = False): self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, cache_dir=cache_dir, @@ -82,6 +84,7 @@ class DyntrainModel: raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") self.devices = list[torch.device]() self.inital_reshufle = True + self.train_static_params = train_static_params if gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) @@ -167,8 +170,14 @@ class DyntrainModel: def staticParameterCount(self) -> int: return sum(p.numel() for p in self.staticParameters()) + def activeDynamicParameterCount(self) -> int: + return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad) + def activeParameterCount(self) -> int: - total_params = self.dynamicParameters() + self.staticParameters() + if self.train_static_params: + total_params = self.dynamicParameters() + self.staticParameters() + else: + total_params = self.dynamicParameters() return sum(p.numel() for p in total_params if p.requires_grad) def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor): @@ -187,7 +196,7 @@ class DyntrainModel: params = self.activeParameterCount() if params >= self.target_active_params: - RuntimeError("Insuficant active parameters to suffle active") + raise RuntimeError("Insuficant active parameters to suffle active") while params < self.target_active_params and len(self.frozen_linear_groups) > 0: i = randint(0, len(self.frozen_linear_groups) - 1) group = self.frozen_linear_groups.pop(i) @@ -199,7 +208,7 @@ class DyntrainModel: active_params = self.activeParameterCount() - assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params + assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params def activeParamtersByDevice(self) -> list[int]: out = [0] * len(self.devices) @@ -213,7 +222,7 @@ class DyntrainModel: for i, count in enumerate(active_counts): memory = torch.cuda.get_device_properties(self.devices[i]).total_memory if i == 0: - memory = int(memory * 0.8) + memory = int(memory * 0.5) bits_per_param.append(count / memory) max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1]) @@ -223,7 +232,7 @@ class DyntrainModel: if group.getDevice() is self.devices[max_index]: memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory if max_index == 0: - memory = int(memory * 0.8) + memory = int(memory * 0.5) swing = group.paramCount() / memory if max_bits_per_param - swing > min_bits_per_param + swing: group.inplaceTo(device=self.devices[min_index]) diff --git a/modules.py b/modules.py index 9ff9dee..6d59399 100644 --- a/modules.py +++ b/modules.py @@ -108,7 +108,7 @@ class DynamicConvertingLinear(Linear): class DynamicQantizedLinear(Linear): def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device, - output_dtype=None, compute_dtype=None, output_device=None): + output_dtype=None, compute_dtype=None, output_device=None, cold_dtype=torch.float32): super().__init__(in_features, out_features, bias, cold_device, torch.float32) self.active_device = active_device self.cold_device = cold_device @@ -120,8 +120,8 @@ class DynamicQantizedLinear(Linear): self.bias_quantized = None self.bias_state = None self.block_size = 128 - self.quant_type = 'nf4' - self.weight_start = self.weight.clone().detach() + #self.weight_start = self.weight.clone().detach() + self.cold_dtype = cold_dtype @classmethod def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"), @@ -131,19 +131,19 @@ class DynamicQantizedLinear(Linear): compute_dtype=compute_dtype, output_device=output_device) new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device)) new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None - new_module.weight_start = new_module.weight.clone().detach() + #new_module.weight_start = new_module.weight.clone().detach() return new_module def compress(self) -> None: - weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) + weight = self.weight.contiguous().to(torch.float16).to(self.active_device) self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size) if self.bias is not None: - bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) + bias = self.bias.contiguous().to(torch.float16).to(self.active_device) self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size) frozen = self.isFrozen() - self.weight = torch.nn.Parameter(self.weight.to(self.cold_device)) - self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None + self.weight = torch.nn.Parameter(self.weight.to(self.cold_dtype).to(self.cold_device)) + self.bias = torch.nn.Parameter(self.bias.to(self.cold_dtype).to(self.cold_device)) if self.bias is not None else None self.setFrozen(frozen, False) def decompress(self) -> None: @@ -151,16 +151,16 @@ class DynamicQantizedLinear(Linear): self.weight_state = None self.bias_quantized = None self.bias_state = None - self.weight_start = self.weight.clone().detach().to(self.cold_device) - self.weight = torch.nn.Parameter(self.weight.to(self.active_device)) + #self.weight_start = self.weight.clone().detach().to(self.cold_device) + self.weight = torch.nn.Parameter(self.weight.to(self.active_device).to(torch.float32)) if self.bias_quantized: - self.bias = torch.nn.Parameter(self.bias.to(self.active_device)) + self.bias = torch.nn.Parameter(self.bias.to(self.active_device).to(torch.float32)) def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]: original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16) quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size) dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype) - distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32) + distance = torch.zeros((2)) #(self.weight_start - self.weight.to(self.cold_device)).to(torch.float32) error = (dequantized_original_weight - original_weight).to(torch.float32) return (distance, error) diff --git a/tokenizer.py b/tokenizer.py index c16f3df..e271933 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -30,13 +30,13 @@ def smart_tokenizer_and_embedding_resize( def get_tokenizer(model, cache_dir, model_args: ModelArguments): - print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}') + tokenizer_path = model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path + print(f'Tokenizer: {tokenizer_path}') tokenizer = transformers.AutoTokenizer.from_pretrained( - model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path, + tokenizer_path, cache_dir=cache_dir, padding_side="right", use_fast=False, - eos_token="[EOS]", tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None, trust_remote_code=model_args.trust_remote_code ) diff --git a/train_dynamic.py b/train_dynamic.py index 5dcf8cf..865d623 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -1,6 +1,4 @@ import transformers -from transformers import get_scheduler - import torch from torch.utils import tensorboard import os @@ -8,9 +6,10 @@ import shutil import math from tqdm.auto import tqdm import gc +import sys from arguments import DataArguments, ModelArguments, TrainingArguments -from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub +from datamodules import get_data_loaders from tokenizer import get_tokenizer from dyntrainmodel import DyntrainModel @@ -19,7 +18,16 @@ from dyntrainmodel import DyntrainModel def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0): output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else "" output_dir = os.path.join(output_dir, output_chkpt_dir) + + print(f"saveing model to {output_chkpt_dir}") + + temperature = model.generation_config.temperature + top_p = model.generation_config.top_p + model.generation_config.temperature = None + model.generation_config.top_p = None model.save_pretrained(output_dir) + model.generation_config.temperature = temperature + model.generation_config.top_p = top_p if max_checkpoints > 0: files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")] @@ -57,37 +65,85 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters return optimizer +def move_optimizer_param(param, device: torch.device, device_map: dict): + if isinstance(param, torch.Tensor): + move_device = device if device is not None else device_map[id(param)] + assert device is not None or move_device != torch.device("cpu") + old_device = param.device + param.data = param.data.to(move_device) + if param._grad is not None: + param._grad.data = param._grad.data.to(move_device) + if device is not None and id(param) not in device_map: + device_map[id(param)] = old_device + assert old_device != torch.device("cpu") + elif isinstance(param, dict): + for subparam in param.values(): + move_optimizer_param(subparam, device, device_map) + + +def suspend_optimizer(optimizer) -> dict: + device_map = dict() + for param in optimizer.state.values(): + move_optimizer_param(param, torch.device("cpu"), device_map) + return device_map + + +def resume_optimizer(optimizer, device_map: dict): + for param in optimizer.state.values(): + move_optimizer_param(param, None, device_map) + + def evaluate(model: DyntrainModel, tokenizer, dataloader: torch.utils.data.DataLoader, globalstep: int, log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None): - print("*** Eval ***") - loss = torch.zeros((1), device="cuda:0") - model.model.eval() - for batch in dataloader: - for key in batch: - batch[key] = batch[key].to("cuda:0") - outputs = model.model(**batch) - loss += outputs.loss - loss = loss / len(dataloader) - log_writer.add_scalar("Loss/Eval", loss, globalstep) - print(f"Eval Loss {loss.item()}") - return loss.item() + with torch.no_grad(): + loss = torch.zeros((1), device="cuda:0") + model.model.eval() - if eval_prompt is not None: - input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) - attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False) - outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100) - response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] - print(f"Eval generation: {response_decoded}") - log_writer.add_text("Text/Eval", response_decoded, globalstep) + for batch in tqdm(dataloader, desc="Doing eval"): + for key in batch: + batch[key] = batch[key].to("cuda:0") + outputs = model.model(**batch) + loss += outputs.loss + loss = loss / len(dataloader) + log_writer.add_scalar("Loss/Eval", loss, globalstep) + print(f"Eval Loss {loss.item()}") + + if eval_prompt is not None: + input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) + attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False) + outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, + max_new_tokens=100, min_new_tokens=100) + response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + print(f"Eval generation: {response_decoded}") + log_writer.add_text("Text/Eval", response_decoded, globalstep) + model.model.train() + + +def max_vram_allocated(): + max_vram_alloc = 0 + for i in range(0, torch.cuda.device_count()): + max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc) + return max_vram_alloc + + +def min_vram_allocated(): + max_vram_alloc = sys.maxsize + for i in range(0, torch.cuda.device_count()): + max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc) + return max_vram_alloc def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): - log_writer = tensorboard.SummaryWriter() + log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir) - model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6), - reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True, - quantize=model_args.quantize) + model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, + quantize=model_args.quantize, + target_active_params=int(training_args.max_instant_params * 1e6), + train_static_params=training_args.train_non_linear_layers, + reshuffle_fraction=training_args.churn_percent / 100.0, + gradient_checkpointing=True, + trust_remote_code=True) devices = list(torch.device(i) for i in range(0, torch.cuda.device_count())) model.toDevices(devices) model.reshuffleActive() @@ -96,34 +152,15 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T paramter_count = sum(p.numel() for p in model.model.parameters()) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0 - print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m" + print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m " f"instantanous active paramters of which {static_parameter_count} are static") tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) - if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"): - print("Loading dataset in s2s mode") - data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) - elif data_args.data_from_hub: - data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) - else: - print("Loading dataset in txt mode") - data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) - - dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'} - train_dataloader = torch.utils.data.DataLoader( - dataset['train_dataset'], - shuffle=True, - collate_fn=dataset['data_collator'], - batch_size=training_args.per_device_train_batch_size - ) if dataset['train_dataset'] is not None else None - if training_args.do_eval: - eval_dataloader = torch.utils.data.DataLoader( - dataset['eval_dataset'], - shuffle=True, - collate_fn=dataset['data_collator'], - batch_size=training_args.per_device_train_batch_size - ) + train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args, + training_args.per_device_train_batch_size, + training_args.per_device_eval_batch_size, + training_args.do_train, training_args.do_eval) dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1 @@ -137,7 +174,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T training_args.adam_epsilon, training_args.adam8bit) - lr_scheduler = get_scheduler( + lr_scheduler = transformers.get_scheduler( name=training_args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=training_args.warmup_steps, @@ -149,13 +186,11 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T global_step = 0 model.model.train() for epoch in range(0, training_args.epochs): - model.model.train() print("*** Train ***") - print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}') + print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}') for step, batch in enumerate(train_dataloader): for key in batch: batch[key] = batch[key].to("cuda:0") - outputs = model.model(**batch) loss = outputs.loss / training_args.gradient_accumulation_steps loss.backward() @@ -166,46 +201,52 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T optimizer.step() lr_scheduler.step() + progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB" + f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB") + model.model.zero_grad() - if global_step % 5 == 0: - print(f"Train Loss {loss.item()}") + if global_step > 0: + if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0: + print("Reshuffleing") + lr_scheduler.optimizer = None + del optimizer + # distance, error = model.getDistanceAndErrorSample() + # log_writer.add_histogram("Distances/Train", distance, max_bins=50) + # log_writer.add_histogram("Errors/Train", error, max_bins=50) - if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0: - print("Reshuffleing") - lr_scheduler.optimizer = None - del optimizer - # distance, error = model.getDistanceAndErrorSample() - # log_writer.add_histogram("Distances/Train", distance, max_bins=50) - # log_writer.add_histogram("Errors/Train", error, max_bins=50) + model.reshuffleActive() + model.balanceActive() + log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) + optimizer = get_optimizer(model.dynamicParameters(), + model.staticParameters() if training_args.train_non_linear_layers else None, + training_args.learning_rate, + training_args.learning_rate / dynamic_param_ratio, + training_args.weight_decay, + training_args.adam_epsilon, + training_args.adam8bit) + lr_scheduler.optimizer = optimizer - model.reshuffleActive() - model.balanceActive() - log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) - optimizer = get_optimizer(model.dynamicParameters(), - model.staticParameters() if training_args.train_non_linear_layers else None, - training_args.learning_rate, - training_args.learning_rate / dynamic_param_ratio, - training_args.weight_decay, - training_args.adam_epsilon, - training_args.adam8bit) - lr_scheduler.optimizer = optimizer + if global_step % training_args.save_steps == 0: + save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) + if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0: + device_map = suspend_optimizer(optimizer) + evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) + resume_optimizer(optimizer, device_map) global_step += 1 progress_bar.update() - if global_step > 0: - if global_step % training_args.save_steps == 0: - save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) - if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0: - evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) if training_args.flush_allocator: gc.collect() torch.cuda.empty_cache() if training_args.do_eval and training_args.eval_steps == -1: + device_map = suspend_optimizer(optimizer) evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) + resume_optimizer(optimizer, device_map) + + del optimizer - # Evaluation if training_args.do_eval: evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)