Files
QRotaryTraining/train_dynamic.py
2024-03-06 17:50:40 +01:00

339 lines
15 KiB
Python

import transformers
from transformers import AutoModelForCausalLM, get_scheduler
from peft.utils import _get_submodules
import torch
from torch.utils import tensorboard
import os
import shutil
import math
from tqdm.auto import tqdm
from random import randint
from typing import Tuple
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module
from convertinglinear import ConvertingLinear
from tokenizer import get_tokenizer
def find_all_linear_module_names(model):
module_names = set()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
module_names.add(name)
if 'lm_head' in module_names: # needed for 16-bit
module_names.remove('lm_head')
return list(module_names)
def find_all_outher_module_names(model):
module_names = set()
for name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
module_names.add(name)
return list(module_names)
def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32
print(f'loading base model {model_args.model_name_or_path} in {dtype}...')
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=cache_dir,
torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32,
trust_remote_code=model_args.trust_remote_code,
device_map=None,
attn_implementation="flash_attention_2"
)
# for name, module in model.named_modules():
# if 'norm' in name:
# module = module.to(torch.float32)
return model
@torch.no_grad()
def recursive_setattr(obj, attr, value):
attr = attr.split('.', 1)
if len(attr) == 1:
setattr(obj, attr[0], value)
else:
recursive_setattr(getattr(obj, attr[0]), attr[1], value)
@torch.no_grad()
def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device):
new_module = torch.nn.Linear(module.in_features,
module.out_features,
module.bias is not None,
module.weight.device,
dtype)
new_module.weight = torch.nn.Parameter(module.weight.detach().clone())
new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None
new_module.weight.requires_grad = not frozen
if new_module.bias is not None:
new_module.bias.requires_grad = not frozen
return new_module
@torch.no_grad()
def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device):
if type(module) is torch.nn.Linear:
if frozen:
module.weight.requires_grad = False
if module.bias is not None:
module.bias.requires_grad = False
return module.to(dtype).to(device)
else:
new_module = ConvertingLinear.fromLinear(module).to(dtype)
new_module.weight.requires_grad = True
if new_module.bias is not None:
new_module.bias.requires_grad = True
return new_module.to(device)
elif type(module) is ConvertingLinear:
if not frozen:
module.weight.requires_grad = True
if module.bias is not None:
module.bias.requires_grad = True
assert False
return module.to(dtype).to(device)
else:
new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=dtype)
new_module.weight = torch.nn.Parameter(module.weight.to(dtype))
new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None
new_module.weight.requires_grad = False
if new_module.bias is not None:
new_module.bias.requires_grad = False
return new_module.to(device)
else:
assert False
@torch.no_grad()
def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device):
modules = dict(model.named_modules())
linear_names = find_all_linear_module_names(model)
for key in linear_names:
if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad:
parent, target, target_name = _get_submodules(model, key)
setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device))
modules = dict(model.named_modules())
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
if active_paramter_count > target_params:
raise RuntimeError("Enough paramters must be available to train at least one linear layer")
while active_paramter_count < target_params and len(linear_names) > 0:
i = randint(0, len(linear_names) - 1)
parent, target, target_name = _get_submodules(model, linear_names[i])
new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device)
setattr(parent, target_name, new_module)
active_paramter_count += modules[linear_names[i]].weight.numel()
if modules[linear_names[i]].bias is not None:
active_paramter_count += modules[linear_names[i]].bias.numel()
linear_names.pop(i)
modules = dict()
assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad)
return active_paramter_count
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
output_dir = os.path.join(output_dir, output_chkpt_dir)
model.save_pretrained(output_dir)
if max_checkpoints > 0:
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.starts_with("step_")]
def extract_step(filename):
tokens = filename.split('_')
return int(tokens[1])
if len(files) > max_checkpoints:
min_step = min(map(extract_step, extract_step))
delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}")
print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}")
shutil.rmtree(delete_checkpoit_dir)
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool):
parameters = list()
modules = dict(model.named_modules())
for key in dynamic_module_names:
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
for key in static_module_names:
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
if not adam8bit:
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
else:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, bitsandbytes must be available")
optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps)
return optimizer
def compute_dynamic_parameter_ratio(model):
modules = dict(model.named_modules())
active_linear_parameters = 0
total_linear_parameters = 0
for key in find_all_linear_module_names(model):
active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad)
total_linear_parameters += sum(p.numel() for p in modules[key].parameters())
return math.ceil(total_linear_parameters / active_linear_parameters)
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
if data_args.dataset.endswith("json"):
print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
else:
print("Loading dataset in txt mode")
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
train_dataloader = torch.utils.data.DataLoader(
dataset['train_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['train_dataset'] is not None else None
eval_dataloader = torch.utils.data.DataLoader(
dataset['eval_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['eval_dataset'] is not None else None
if model_args.max_instant_params != 0:
print(f"Target params {model_args.max_instant_params}m")
freeze_random_modules(model, model_args.max_instant_params * 1e6,
torch.float16 if training_args.storage_fp16 else torch.float32,
frozen_device=primary_device, active_device=secondary_device)
paramter_count = sum(p.numel() for p in model.parameters())
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}")
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
total_steps = steps_per_epoch * training_args.epochs
optimizer = get_optimizer(model, find_all_linear_module_names(model),
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=total_steps
)
return model, optimizer, lr_scheduler, train_dataloader
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
primary_device = torch.device(training_args.primary_device)
secondary_device = torch.device(training_args.secondary_device)
log_writer = tensorboard.SummaryWriter()
model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device)
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
total_steps = steps_per_epoch * training_args.epochs
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
if training_args.do_train:
progress_bar = tqdm(range(total_steps))
global_step = 0
model.train()
for epoch in range(0, training_args.epochs):
print("*** Train ***")
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
for step, batch in enumerate(train_dataloader):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps
log_writer.add_scalar("Loss/train", loss, global_step)
loss.backward()
if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
optimizer.step()
lr_scheduler.step()
model.zero_grad()
if global_step % 10 == 0:
print(loss)
if global_step % 10 == 0 and model_args.max_instant_params != 0:
param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6,
torch.float16 if training_args.storage_fp16 else torch.float32,
frozen_device=primary_device,
active_device=secondary_device)
log_writer.add_scalar("Parameters/train", param_count, global_step)
optimizer = get_optimizer(model, find_all_linear_module_names(model),
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
global_step += 1
progress_bar.update()
if global_step % training_args.save_steps == 0:
save_model(model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.flush_allocator:
torch.cuda.empty_cache()
# Evaluation
if training_args.do_eval:
print("*** Evaluate ***")
save_model(model, global_step, training_args.output_dir)
return
if __name__ == "__main__":
hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
print("Model Arguments:")
print(model_args)
print("\nData Arguments:")
print(data_args)
print("\nTraining Arguments:")
print(training_args)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
train(model_args, data_args, training_args)