import transformers import torch from torch.utils import tensorboard import os import shutil import math from tqdm.auto import tqdm import gc import sys from arguments import DataArguments, ModelArguments, TrainingArguments from datamodules import get_data_loaders from tokenizer import get_tokenizer from dyntrainmodel import DyntrainModel def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0): output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else "" output_dir = os.path.join(output_dir, output_chkpt_dir) print(f"saveing model to {output_chkpt_dir}") temperature = model.generation_config.temperature top_p = model.generation_config.top_p model.generation_config.temperature = None model.generation_config.top_p = None model.save_pretrained(output_dir) model.generation_config.temperature = temperature model.generation_config.top_p = top_p if max_checkpoints > 0: files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")] def extract_step(filename): tokens = filename.split('_') return int(tokens[1]) if len(files) > max_checkpoints: min_step = min(map(extract_step, files)) delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}") print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}") shutil.rmtree(delete_checkpoit_dir) def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters: list[torch.nn.Parameter] | None, lr: float, static_lr: float, weight_decay: float, eps: float, adam8bit: bool): parameters = list[dict]() parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad) param_ids = set([id(p['params']) for p in parameters]) if static_parameters is not None: for param in static_parameters: if param.requires_grad and id(param) not in param_ids: parameters.append({'params': param, 'lr': static_lr}) param_ids.add(id(param)) if not adam8bit: optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon) else: try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, bitsandbytes must be available") optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps) return optimizer def move_optimizer_param(param, device: torch.device, device_map: dict): if isinstance(param, torch.Tensor): move_device = device if device is not None else device_map[id(param)] assert device is not None or move_device != torch.device("cpu") old_device = param.device param.data = param.data.to(move_device) if param._grad is not None: param._grad.data = param._grad.data.to(move_device) if device is not None and id(param) not in device_map: device_map[id(param)] = old_device assert old_device != torch.device("cpu") elif isinstance(param, dict): for subparam in param.values(): move_optimizer_param(subparam, device, device_map) def suspend_optimizer(optimizer) -> dict: device_map = dict() for param in optimizer.state.values(): move_optimizer_param(param, torch.device("cpu"), device_map) return device_map def resume_optimizer(optimizer, device_map: dict): for param in optimizer.state.values(): move_optimizer_param(param, None, device_map) def evaluate(model: DyntrainModel, tokenizer, dataloader: torch.utils.data.DataLoader, globalstep: int, log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None): with torch.no_grad(): loss = torch.zeros((1), device="cuda:0") model.model.eval() for batch in tqdm(dataloader, desc="Doing eval"): for key in batch: batch[key] = batch[key].to("cuda:0") outputs = model.model(**batch) loss += outputs.loss loss = loss / len(dataloader) log_writer.add_scalar("Loss/Eval", loss, globalstep) print(f"Eval Loss {loss.item()}") if eval_prompt is not None: input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False) outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100, min_new_tokens=100) response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] print(f"Eval generation: {response_decoded}") log_writer.add_text("Text/Eval", response_decoded, globalstep) model.model.train() def max_vram_allocated(): max_vram_alloc = 0 for i in range(0, torch.cuda.device_count()): max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc) return max_vram_alloc def min_vram_allocated(): max_vram_alloc = sys.maxsize for i in range(0, torch.cuda.device_count()): max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc) return max_vram_alloc def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir) model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, quantize=model_args.quantize, target_active_params=int(training_args.max_instant_params * 1e6), train_static_params=training_args.train_non_linear_layers, reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True) devices = list(torch.device(i) for i in range(0, torch.cuda.device_count())) model.toDevices(devices) model.reshuffleActive() model.balanceActive() paramter_count = sum(p.numel() for p in model.model.parameters()) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0 print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m " f"instantanous active paramters of which {static_parameter_count} are static") tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args, training_args.per_device_train_batch_size, training_args.per_device_eval_batch_size, training_args.do_train, training_args.do_eval) dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1 total_steps = steps_per_epoch * training_args.epochs optimizer = get_optimizer(model.dynamicParameters(), model.staticParameters() if training_args.train_non_linear_layers else None, training_args.learning_rate, training_args.learning_rate / dynamic_param_ratio, training_args.weight_decay, training_args.adam_epsilon, training_args.adam8bit) lr_scheduler = transformers.get_scheduler( name=training_args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=total_steps ) if training_args.do_train: progress_bar = tqdm(range(total_steps)) global_step = 0 model.model.train() for epoch in range(0, training_args.epochs): print("*** Train ***") print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}') for step, batch in enumerate(train_dataloader): for key in batch: batch[key] = batch[key].to("cuda:0") outputs = model.model(**batch) loss = outputs.loss / training_args.gradient_accumulation_steps loss.backward() if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader): if global_step % training_args.logging_steps == 0: log_writer.add_scalar("Loss/train", loss, global_step) optimizer.step() lr_scheduler.step() progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB" f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB") model.model.zero_grad() if global_step > 0: if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0: print("Reshuffleing") lr_scheduler.optimizer = None del optimizer # distance, error = model.getDistanceAndErrorSample() # log_writer.add_histogram("Distances/Train", distance, max_bins=50) # log_writer.add_histogram("Errors/Train", error, max_bins=50) model.reshuffleActive() model.balanceActive() log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) optimizer = get_optimizer(model.dynamicParameters(), model.staticParameters() if training_args.train_non_linear_layers else None, training_args.learning_rate, training_args.learning_rate / dynamic_param_ratio, training_args.weight_decay, training_args.adam_epsilon, training_args.adam8bit) lr_scheduler.optimizer = optimizer if global_step % training_args.save_steps == 0: save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0: device_map = suspend_optimizer(optimizer) evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) resume_optimizer(optimizer, device_map) global_step += 1 progress_bar.update() if training_args.flush_allocator: gc.collect() torch.cuda.empty_cache() if training_args.do_eval and training_args.eval_steps == -1: device_map = suspend_optimizer(optimizer) evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) resume_optimizer(optimizer, device_map) del optimizer if training_args.do_eval: evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) save_model(model.model, global_step, training_args.output_dir) return if __name__ == "__main__": hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True) print("Model Arguments:") print(model_args) print("\nData Arguments:") print(data_args) print("\nTraining Arguments:") print(training_args) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() train(model_args, data_args, training_args)