293 lines
14 KiB
Python
293 lines
14 KiB
Python
|
|
# QRotaryTraining - A novel method for fully training all parameters of large
|
|
# language models (llms) while using less device memory than traditional methods.
|
|
# Copyright (C) 2024 Carl Philipp Klemm
|
|
#
|
|
# This file is part of QRotaryTraining.
|
|
#
|
|
# QRotaryTraining is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# QRotaryTraining is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with QRotaryTraining. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import transformers
|
|
import torch
|
|
from torch.utils import tensorboard
|
|
import os
|
|
import shutil
|
|
import math
|
|
from tqdm.auto import tqdm
|
|
import gc
|
|
import sys
|
|
|
|
from arguments import DataArguments, ModelArguments, TrainingArguments
|
|
from datamodules import get_data_loaders
|
|
from tokenizer import get_tokenizer
|
|
|
|
from dyntrainmodel import DyntrainModel
|
|
|
|
|
|
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
|
|
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
|
|
output_dir = os.path.join(output_dir, output_chkpt_dir)
|
|
|
|
print(f"saveing model to {output_chkpt_dir}")
|
|
|
|
temperature = model.generation_config.temperature
|
|
top_p = model.generation_config.top_p
|
|
model.generation_config.temperature = None
|
|
model.generation_config.top_p = None
|
|
model.save_pretrained(output_dir)
|
|
model.generation_config.temperature = temperature
|
|
model.generation_config.top_p = top_p
|
|
|
|
if max_checkpoints > 0:
|
|
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
|
|
|
|
def extract_step(filename):
|
|
tokens = filename.split('_')
|
|
return int(tokens[1])
|
|
|
|
if len(files) > max_checkpoints:
|
|
min_step = min(map(extract_step, files))
|
|
delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}")
|
|
print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}")
|
|
shutil.rmtree(delete_checkpoit_dir)
|
|
|
|
|
|
def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters: list[torch.nn.Parameter] | None, lr: float, static_lr: float,
|
|
weight_decay: float, eps: float, adam8bit: bool):
|
|
parameters = list[dict]()
|
|
parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
|
|
param_ids = set([id(p['params']) for p in parameters])
|
|
if static_parameters is not None:
|
|
for param in static_parameters:
|
|
if param.requires_grad and id(param) not in param_ids:
|
|
parameters.append({'params': param, 'lr': static_lr})
|
|
param_ids.add(id(param))
|
|
|
|
if not adam8bit:
|
|
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
|
else:
|
|
try:
|
|
import bitsandbytes as bnb
|
|
except ImportError:
|
|
raise ImportError("To use 8-bit Adam, bitsandbytes must be available")
|
|
optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps)
|
|
return optimizer
|
|
|
|
|
|
def move_optimizer_param(param, device: torch.device, device_map: dict):
|
|
if isinstance(param, torch.Tensor):
|
|
move_device = device if device is not None else device_map[id(param)]
|
|
assert device is not None or move_device != torch.device("cpu")
|
|
old_device = param.device
|
|
param.data = param.data.to(move_device)
|
|
if param._grad is not None:
|
|
param._grad.data = param._grad.data.to(move_device)
|
|
if device is not None and id(param) not in device_map:
|
|
device_map[id(param)] = old_device
|
|
assert old_device != torch.device("cpu")
|
|
elif isinstance(param, dict):
|
|
for subparam in param.values():
|
|
move_optimizer_param(subparam, device, device_map)
|
|
|
|
|
|
def suspend_optimizer(optimizer) -> dict:
|
|
device_map = dict()
|
|
for param in optimizer.state.values():
|
|
move_optimizer_param(param, torch.device("cpu"), device_map)
|
|
return device_map
|
|
|
|
|
|
def resume_optimizer(optimizer, device_map: dict):
|
|
for param in optimizer.state.values():
|
|
move_optimizer_param(param, None, device_map)
|
|
|
|
|
|
def evaluate(model: DyntrainModel, tokenizer,
|
|
dataloader: torch.utils.data.DataLoader, globalstep: int,
|
|
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
|
|
with torch.no_grad():
|
|
loss = torch.zeros((1), device="cuda:0")
|
|
model.model.eval()
|
|
|
|
for batch in tqdm(dataloader, desc="Doing eval"):
|
|
for key in batch:
|
|
batch[key] = batch[key].to("cuda:0")
|
|
outputs = model.model(**batch)
|
|
loss += outputs.loss
|
|
loss = loss / len(dataloader)
|
|
log_writer.add_scalar("Loss/Eval", loss, globalstep)
|
|
print(f"Eval Loss {loss.item()}")
|
|
|
|
if eval_prompt is not None:
|
|
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
|
|
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
|
|
outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1,
|
|
max_new_tokens=100, min_new_tokens=100)
|
|
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
print(f"Eval generation: {response_decoded}")
|
|
log_writer.add_text("Text/Eval", response_decoded, globalstep)
|
|
model.model.train()
|
|
|
|
|
|
def max_vram_allocated():
|
|
max_vram_alloc = 0
|
|
for i in range(0, torch.cuda.device_count()):
|
|
max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc)
|
|
return max_vram_alloc
|
|
|
|
|
|
def min_vram_allocated():
|
|
max_vram_alloc = sys.maxsize
|
|
for i in range(0, torch.cuda.device_count()):
|
|
max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc)
|
|
return max_vram_alloc
|
|
|
|
|
|
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
|
log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir)
|
|
|
|
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir,
|
|
quantize=model_args.quantize,
|
|
target_active_params=int(training_args.max_instant_params * 1e6),
|
|
train_static_params=training_args.train_non_linear_layers,
|
|
reshuffle_fraction=training_args.churn_percent / 100.0,
|
|
gradient_checkpointing=True,
|
|
trust_remote_code=True)
|
|
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
|
|
model.toDevices(devices)
|
|
model.reshuffleActive()
|
|
model.balanceActive()
|
|
|
|
paramter_count = sum(p.numel() for p in model.model.parameters())
|
|
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
|
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
|
|
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m "
|
|
f"instantanous active paramters of which {static_parameter_count} are static")
|
|
|
|
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
|
|
|
train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args,
|
|
training_args.per_device_train_batch_size,
|
|
training_args.per_device_eval_batch_size,
|
|
training_args.do_train, training_args.do_eval)
|
|
|
|
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
|
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
|
|
total_steps = steps_per_epoch * training_args.epochs
|
|
|
|
optimizer = get_optimizer(model.dynamicParameters(),
|
|
model.staticParameters() if training_args.train_non_linear_layers else None,
|
|
training_args.learning_rate,
|
|
training_args.learning_rate / dynamic_param_ratio,
|
|
training_args.weight_decay,
|
|
training_args.adam_epsilon,
|
|
training_args.adam8bit)
|
|
|
|
lr_scheduler = transformers.get_scheduler(
|
|
name=training_args.lr_scheduler_type,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=training_args.warmup_steps,
|
|
num_training_steps=total_steps
|
|
)
|
|
|
|
if training_args.do_train:
|
|
progress_bar = tqdm(range(total_steps))
|
|
global_step = 0
|
|
model.model.train()
|
|
for epoch in range(0, training_args.epochs):
|
|
print("*** Train ***")
|
|
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}')
|
|
for step, batch in enumerate(train_dataloader):
|
|
for key in batch:
|
|
batch[key] = batch[key].to("cuda:0")
|
|
outputs = model.model(**batch)
|
|
loss = outputs.loss / training_args.gradient_accumulation_steps
|
|
loss.backward()
|
|
|
|
if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
|
|
if global_step % training_args.logging_steps == 0:
|
|
log_writer.add_scalar("Loss/train", loss, global_step)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
|
|
progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB"
|
|
f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB")
|
|
|
|
model.model.zero_grad()
|
|
|
|
if global_step > 0:
|
|
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
|
|
print("Reshuffleing")
|
|
lr_scheduler.optimizer = None
|
|
del optimizer
|
|
# distance, error = model.getDistanceAndErrorSample()
|
|
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
|
|
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
|
|
|
|
model.reshuffleActive()
|
|
model.balanceActive()
|
|
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
|
optimizer = get_optimizer(model.dynamicParameters(),
|
|
model.staticParameters() if training_args.train_non_linear_layers else None,
|
|
training_args.learning_rate,
|
|
training_args.learning_rate / dynamic_param_ratio,
|
|
training_args.weight_decay,
|
|
training_args.adam_epsilon,
|
|
training_args.adam8bit)
|
|
lr_scheduler.optimizer = optimizer
|
|
|
|
if global_step % training_args.save_steps == 0:
|
|
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
|
|
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
|
|
device_map = suspend_optimizer(optimizer)
|
|
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
|
resume_optimizer(optimizer, device_map)
|
|
|
|
global_step += 1
|
|
progress_bar.update()
|
|
|
|
if training_args.flush_allocator:
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
if training_args.do_eval and training_args.eval_steps == -1:
|
|
device_map = suspend_optimizer(optimizer)
|
|
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
|
resume_optimizer(optimizer, device_map)
|
|
|
|
del optimizer
|
|
|
|
if training_args.do_eval:
|
|
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
|
|
|
save_model(model.model, global_step, training_args.output_dir)
|
|
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
|
model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
|
|
|
|
print("Model Arguments:")
|
|
print(model_args)
|
|
print("\nData Arguments:")
|
|
print(data_args)
|
|
print("\nTraining Arguments:")
|
|
print(training_args)
|
|
|
|
transformers.utils.logging.enable_default_handler()
|
|
transformers.utils.logging.enable_explicit_format()
|
|
|
|
train(model_args, data_args, training_args)
|