various optimizations

This commit is contained in:
2024-07-20 21:47:18 +02:00
parent 2f35689355
commit c38ac65d5b
4 changed files with 151 additions and 101 deletions

View File

@ -1,6 +1,4 @@
import transformers
from transformers import get_scheduler
import torch
from torch.utils import tensorboard
import os
@ -8,9 +6,10 @@ import shutil
import math
from tqdm.auto import tqdm
import gc
import sys
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub
from datamodules import get_data_loaders
from tokenizer import get_tokenizer
from dyntrainmodel import DyntrainModel
@ -19,7 +18,16 @@ from dyntrainmodel import DyntrainModel
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
output_dir = os.path.join(output_dir, output_chkpt_dir)
print(f"saveing model to {output_chkpt_dir}")
temperature = model.generation_config.temperature
top_p = model.generation_config.top_p
model.generation_config.temperature = None
model.generation_config.top_p = None
model.save_pretrained(output_dir)
model.generation_config.temperature = temperature
model.generation_config.top_p = top_p
if max_checkpoints > 0:
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
@ -57,37 +65,85 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
return optimizer
def move_optimizer_param(param, device: torch.device, device_map: dict):
if isinstance(param, torch.Tensor):
move_device = device if device is not None else device_map[id(param)]
assert device is not None or move_device != torch.device("cpu")
old_device = param.device
param.data = param.data.to(move_device)
if param._grad is not None:
param._grad.data = param._grad.data.to(move_device)
if device is not None and id(param) not in device_map:
device_map[id(param)] = old_device
assert old_device != torch.device("cpu")
elif isinstance(param, dict):
for subparam in param.values():
move_optimizer_param(subparam, device, device_map)
def suspend_optimizer(optimizer) -> dict:
device_map = dict()
for param in optimizer.state.values():
move_optimizer_param(param, torch.device("cpu"), device_map)
return device_map
def resume_optimizer(optimizer, device_map: dict):
for param in optimizer.state.values():
move_optimizer_param(param, None, device_map)
def evaluate(model: DyntrainModel, tokenizer,
dataloader: torch.utils.data.DataLoader, globalstep: int,
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
print("*** Eval ***")
loss = torch.zeros((1), device="cuda:0")
model.model.eval()
for batch in dataloader:
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss += outputs.loss
loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}")
return loss.item()
with torch.no_grad():
loss = torch.zeros((1), device="cuda:0")
model.model.eval()
if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Eval generation: {response_decoded}")
log_writer.add_text("Text/Eval", response_decoded, globalstep)
for batch in tqdm(dataloader, desc="Doing eval"):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss += outputs.loss
loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}")
if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1,
max_new_tokens=100, min_new_tokens=100)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Eval generation: {response_decoded}")
log_writer.add_text("Text/Eval", response_decoded, globalstep)
model.model.train()
def max_vram_allocated():
max_vram_alloc = 0
for i in range(0, torch.cuda.device_count()):
max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc)
return max_vram_alloc
def min_vram_allocated():
max_vram_alloc = sys.maxsize
for i in range(0, torch.cuda.device_count()):
max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc)
return max_vram_alloc
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
log_writer = tensorboard.SummaryWriter()
log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir)
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6),
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True,
quantize=model_args.quantize)
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir,
quantize=model_args.quantize,
target_active_params=int(training_args.max_instant_params * 1e6),
train_static_params=training_args.train_non_linear_layers,
reshuffle_fraction=training_args.churn_percent / 100.0,
gradient_checkpointing=True,
trust_remote_code=True)
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
model.toDevices(devices)
model.reshuffleActive()
@ -96,34 +152,15 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
paramter_count = sum(p.numel() for p in model.model.parameters())
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m"
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m "
f"instantanous active paramters of which {static_parameter_count} are static")
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"):
print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
elif data_args.data_from_hub:
data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
else:
print("Loading dataset in txt mode")
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
train_dataloader = torch.utils.data.DataLoader(
dataset['train_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['train_dataset'] is not None else None
if training_args.do_eval:
eval_dataloader = torch.utils.data.DataLoader(
dataset['eval_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
)
train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args,
training_args.per_device_train_batch_size,
training_args.per_device_eval_batch_size,
training_args.do_train, training_args.do_eval)
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
@ -137,7 +174,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler = get_scheduler(
lr_scheduler = transformers.get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
@ -149,13 +186,11 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
global_step = 0
model.model.train()
for epoch in range(0, training_args.epochs):
model.model.train()
print("*** Train ***")
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}')
for step, batch in enumerate(train_dataloader):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps
loss.backward()
@ -166,46 +201,52 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
optimizer.step()
lr_scheduler.step()
progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB"
f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB")
model.model.zero_grad()
if global_step % 5 == 0:
print(f"Train Loss {loss.item()}")
if global_step > 0:
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
print("Reshuffleing")
lr_scheduler.optimizer = None
del optimizer
# distance, error = model.getDistanceAndErrorSample()
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
print("Reshuffleing")
lr_scheduler.optimizer = None
del optimizer
# distance, error = model.getDistanceAndErrorSample()
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
model.reshuffleActive()
model.balanceActive()
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
optimizer = get_optimizer(model.dynamicParameters(),
model.staticParameters() if training_args.train_non_linear_layers else None,
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
model.reshuffleActive()
model.balanceActive()
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
optimizer = get_optimizer(model.dynamicParameters(),
model.staticParameters() if training_args.train_non_linear_layers else None,
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
device_map = suspend_optimizer(optimizer)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
resume_optimizer(optimizer, device_map)
global_step += 1
progress_bar.update()
if global_step > 0:
if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
if training_args.flush_allocator:
gc.collect()
torch.cuda.empty_cache()
if training_args.do_eval and training_args.eval_steps == -1:
device_map = suspend_optimizer(optimizer)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
resume_optimizer(optimizer, device_map)
del optimizer
# Evaluation
if training_args.do_eval:
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)