Files
QRotaryTraining/train_dynamic.py
2024-07-20 21:47:18 +02:00

273 lines
13 KiB
Python

import transformers
import torch
from torch.utils import tensorboard
import os
import shutil
import math
from tqdm.auto import tqdm
import gc
import sys
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import get_data_loaders
from tokenizer import get_tokenizer
from dyntrainmodel import DyntrainModel
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
output_dir = os.path.join(output_dir, output_chkpt_dir)
print(f"saveing model to {output_chkpt_dir}")
temperature = model.generation_config.temperature
top_p = model.generation_config.top_p
model.generation_config.temperature = None
model.generation_config.top_p = None
model.save_pretrained(output_dir)
model.generation_config.temperature = temperature
model.generation_config.top_p = top_p
if max_checkpoints > 0:
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
def extract_step(filename):
tokens = filename.split('_')
return int(tokens[1])
if len(files) > max_checkpoints:
min_step = min(map(extract_step, files))
delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}")
print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}")
shutil.rmtree(delete_checkpoit_dir)
def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters: list[torch.nn.Parameter] | None, lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool):
parameters = list[dict]()
parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
param_ids = set([id(p['params']) for p in parameters])
if static_parameters is not None:
for param in static_parameters:
if param.requires_grad and id(param) not in param_ids:
parameters.append({'params': param, 'lr': static_lr})
param_ids.add(id(param))
if not adam8bit:
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
else:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, bitsandbytes must be available")
optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps)
return optimizer
def move_optimizer_param(param, device: torch.device, device_map: dict):
if isinstance(param, torch.Tensor):
move_device = device if device is not None else device_map[id(param)]
assert device is not None or move_device != torch.device("cpu")
old_device = param.device
param.data = param.data.to(move_device)
if param._grad is not None:
param._grad.data = param._grad.data.to(move_device)
if device is not None and id(param) not in device_map:
device_map[id(param)] = old_device
assert old_device != torch.device("cpu")
elif isinstance(param, dict):
for subparam in param.values():
move_optimizer_param(subparam, device, device_map)
def suspend_optimizer(optimizer) -> dict:
device_map = dict()
for param in optimizer.state.values():
move_optimizer_param(param, torch.device("cpu"), device_map)
return device_map
def resume_optimizer(optimizer, device_map: dict):
for param in optimizer.state.values():
move_optimizer_param(param, None, device_map)
def evaluate(model: DyntrainModel, tokenizer,
dataloader: torch.utils.data.DataLoader, globalstep: int,
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
with torch.no_grad():
loss = torch.zeros((1), device="cuda:0")
model.model.eval()
for batch in tqdm(dataloader, desc="Doing eval"):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss += outputs.loss
loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}")
if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1,
max_new_tokens=100, min_new_tokens=100)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Eval generation: {response_decoded}")
log_writer.add_text("Text/Eval", response_decoded, globalstep)
model.model.train()
def max_vram_allocated():
max_vram_alloc = 0
for i in range(0, torch.cuda.device_count()):
max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc)
return max_vram_alloc
def min_vram_allocated():
max_vram_alloc = sys.maxsize
for i in range(0, torch.cuda.device_count()):
max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc)
return max_vram_alloc
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir)
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir,
quantize=model_args.quantize,
target_active_params=int(training_args.max_instant_params * 1e6),
train_static_params=training_args.train_non_linear_layers,
reshuffle_fraction=training_args.churn_percent / 100.0,
gradient_checkpointing=True,
trust_remote_code=True)
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
model.toDevices(devices)
model.reshuffleActive()
model.balanceActive()
paramter_count = sum(p.numel() for p in model.model.parameters())
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m "
f"instantanous active paramters of which {static_parameter_count} are static")
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args,
training_args.per_device_train_batch_size,
training_args.per_device_eval_batch_size,
training_args.do_train, training_args.do_eval)
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
total_steps = steps_per_epoch * training_args.epochs
optimizer = get_optimizer(model.dynamicParameters(),
model.staticParameters() if training_args.train_non_linear_layers else None,
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler = transformers.get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=total_steps
)
if training_args.do_train:
progress_bar = tqdm(range(total_steps))
global_step = 0
model.model.train()
for epoch in range(0, training_args.epochs):
print("*** Train ***")
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}')
for step, batch in enumerate(train_dataloader):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps
loss.backward()
if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
if global_step % training_args.logging_steps == 0:
log_writer.add_scalar("Loss/train", loss, global_step)
optimizer.step()
lr_scheduler.step()
progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB"
f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB")
model.model.zero_grad()
if global_step > 0:
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
print("Reshuffleing")
lr_scheduler.optimizer = None
del optimizer
# distance, error = model.getDistanceAndErrorSample()
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
model.reshuffleActive()
model.balanceActive()
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
optimizer = get_optimizer(model.dynamicParameters(),
model.staticParameters() if training_args.train_non_linear_layers else None,
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
device_map = suspend_optimizer(optimizer)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
resume_optimizer(optimizer, device_map)
global_step += 1
progress_bar.update()
if training_args.flush_allocator:
gc.collect()
torch.cuda.empty_cache()
if training_args.do_eval and training_args.eval_steps == -1:
device_map = suspend_optimizer(optimizer)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
resume_optimizer(optimizer, device_map)
del optimizer
if training_args.do_eval:
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
save_model(model.model, global_step, training_args.output_dir)
return
if __name__ == "__main__":
hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
print("Model Arguments:")
print(model_args)
print("\nData Arguments:")
print(data_args)
print("\nTraining Arguments:")
print(training_args)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
train(model_args, data_args, training_args)