various optimizations

This commit is contained in:
2024-07-20 21:47:18 +02:00
parent 2f35689355
commit c38ac65d5b
4 changed files with 151 additions and 101 deletions

View File

@ -68,7 +68,9 @@ class LinearGroup:
class DyntrainModel: class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool, def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): target_active_params: int, train_static_params: bool,
reshuffle_fraction: float, gradient_checkpointing: bool,
trust_remote_code: bool = False):
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
@ -82,6 +84,7 @@ class DyntrainModel:
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
self.devices = list[torch.device]() self.devices = list[torch.device]()
self.inital_reshufle = True self.inital_reshufle = True
self.train_static_params = train_static_params
if gradient_checkpointing: if gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
@ -167,8 +170,14 @@ class DyntrainModel:
def staticParameterCount(self) -> int: def staticParameterCount(self) -> int:
return sum(p.numel() for p in self.staticParameters()) return sum(p.numel() for p in self.staticParameters())
def activeDynamicParameterCount(self) -> int:
return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad)
def activeParameterCount(self) -> int: def activeParameterCount(self) -> int:
total_params = self.dynamicParameters() + self.staticParameters() if self.train_static_params:
total_params = self.dynamicParameters() + self.staticParameters()
else:
total_params = self.dynamicParameters()
return sum(p.numel() for p in total_params if p.requires_grad) return sum(p.numel() for p in total_params if p.requires_grad)
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor): def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
@ -187,7 +196,7 @@ class DyntrainModel:
params = self.activeParameterCount() params = self.activeParameterCount()
if params >= self.target_active_params: if params >= self.target_active_params:
RuntimeError("Insuficant active parameters to suffle active") raise RuntimeError("Insuficant active parameters to suffle active")
while params < self.target_active_params and len(self.frozen_linear_groups) > 0: while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
i = randint(0, len(self.frozen_linear_groups) - 1) i = randint(0, len(self.frozen_linear_groups) - 1)
group = self.frozen_linear_groups.pop(i) group = self.frozen_linear_groups.pop(i)
@ -199,7 +208,7 @@ class DyntrainModel:
active_params = self.activeParameterCount() active_params = self.activeParameterCount()
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params
def activeParamtersByDevice(self) -> list[int]: def activeParamtersByDevice(self) -> list[int]:
out = [0] * len(self.devices) out = [0] * len(self.devices)
@ -213,7 +222,7 @@ class DyntrainModel:
for i, count in enumerate(active_counts): for i, count in enumerate(active_counts):
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
if i == 0: if i == 0:
memory = int(memory * 0.8) memory = int(memory * 0.5)
bits_per_param.append(count / memory) bits_per_param.append(count / memory)
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1]) max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
@ -223,7 +232,7 @@ class DyntrainModel:
if group.getDevice() is self.devices[max_index]: if group.getDevice() is self.devices[max_index]:
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
if max_index == 0: if max_index == 0:
memory = int(memory * 0.8) memory = int(memory * 0.5)
swing = group.paramCount() / memory swing = group.paramCount() / memory
if max_bits_per_param - swing > min_bits_per_param + swing: if max_bits_per_param - swing > min_bits_per_param + swing:
group.inplaceTo(device=self.devices[min_index]) group.inplaceTo(device=self.devices[min_index])

View File

@ -108,7 +108,7 @@ class DynamicConvertingLinear(Linear):
class DynamicQantizedLinear(Linear): class DynamicQantizedLinear(Linear):
def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device, def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device,
output_dtype=None, compute_dtype=None, output_device=None): output_dtype=None, compute_dtype=None, output_device=None, cold_dtype=torch.float32):
super().__init__(in_features, out_features, bias, cold_device, torch.float32) super().__init__(in_features, out_features, bias, cold_device, torch.float32)
self.active_device = active_device self.active_device = active_device
self.cold_device = cold_device self.cold_device = cold_device
@ -120,8 +120,8 @@ class DynamicQantizedLinear(Linear):
self.bias_quantized = None self.bias_quantized = None
self.bias_state = None self.bias_state = None
self.block_size = 128 self.block_size = 128
self.quant_type = 'nf4' #self.weight_start = self.weight.clone().detach()
self.weight_start = self.weight.clone().detach() self.cold_dtype = cold_dtype
@classmethod @classmethod
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"), def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
@ -131,19 +131,19 @@ class DynamicQantizedLinear(Linear):
compute_dtype=compute_dtype, output_device=output_device) compute_dtype=compute_dtype, output_device=output_device)
new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device)) new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
new_module.weight_start = new_module.weight.clone().detach() #new_module.weight_start = new_module.weight.clone().detach()
return new_module return new_module
def compress(self) -> None: def compress(self) -> None:
weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) weight = self.weight.contiguous().to(torch.float16).to(self.active_device)
self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size) self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size)
if self.bias is not None: if self.bias is not None:
bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) bias = self.bias.contiguous().to(torch.float16).to(self.active_device)
self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size) self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
frozen = self.isFrozen() frozen = self.isFrozen()
self.weight = torch.nn.Parameter(self.weight.to(self.cold_device)) self.weight = torch.nn.Parameter(self.weight.to(self.cold_dtype).to(self.cold_device))
self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None self.bias = torch.nn.Parameter(self.bias.to(self.cold_dtype).to(self.cold_device)) if self.bias is not None else None
self.setFrozen(frozen, False) self.setFrozen(frozen, False)
def decompress(self) -> None: def decompress(self) -> None:
@ -151,16 +151,16 @@ class DynamicQantizedLinear(Linear):
self.weight_state = None self.weight_state = None
self.bias_quantized = None self.bias_quantized = None
self.bias_state = None self.bias_state = None
self.weight_start = self.weight.clone().detach().to(self.cold_device) #self.weight_start = self.weight.clone().detach().to(self.cold_device)
self.weight = torch.nn.Parameter(self.weight.to(self.active_device)) self.weight = torch.nn.Parameter(self.weight.to(self.active_device).to(torch.float32))
if self.bias_quantized: if self.bias_quantized:
self.bias = torch.nn.Parameter(self.bias.to(self.active_device)) self.bias = torch.nn.Parameter(self.bias.to(self.active_device).to(torch.float32))
def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]: def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16) original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size) quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype) dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32) distance = torch.zeros((2)) #(self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
error = (dequantized_original_weight - original_weight).to(torch.float32) error = (dequantized_original_weight - original_weight).to(torch.float32)
return (distance, error) return (distance, error)

View File

@ -30,13 +30,13 @@ def smart_tokenizer_and_embedding_resize(
def get_tokenizer(model, cache_dir, model_args: ModelArguments): def get_tokenizer(model, cache_dir, model_args: ModelArguments):
print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}') tokenizer_path = model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path
print(f'Tokenizer: {tokenizer_path}')
tokenizer = transformers.AutoTokenizer.from_pretrained( tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path, tokenizer_path,
cache_dir=cache_dir, cache_dir=cache_dir,
padding_side="right", padding_side="right",
use_fast=False, use_fast=False,
eos_token="[EOS]",
tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None, tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
trust_remote_code=model_args.trust_remote_code trust_remote_code=model_args.trust_remote_code
) )

View File

@ -1,6 +1,4 @@
import transformers import transformers
from transformers import get_scheduler
import torch import torch
from torch.utils import tensorboard from torch.utils import tensorboard
import os import os
@ -8,9 +6,10 @@ import shutil
import math import math
from tqdm.auto import tqdm from tqdm.auto import tqdm
import gc import gc
import sys
from arguments import DataArguments, ModelArguments, TrainingArguments from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub from datamodules import get_data_loaders
from tokenizer import get_tokenizer from tokenizer import get_tokenizer
from dyntrainmodel import DyntrainModel from dyntrainmodel import DyntrainModel
@ -19,7 +18,16 @@ from dyntrainmodel import DyntrainModel
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0): def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else "" output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
output_dir = os.path.join(output_dir, output_chkpt_dir) output_dir = os.path.join(output_dir, output_chkpt_dir)
print(f"saveing model to {output_chkpt_dir}")
temperature = model.generation_config.temperature
top_p = model.generation_config.top_p
model.generation_config.temperature = None
model.generation_config.top_p = None
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
model.generation_config.temperature = temperature
model.generation_config.top_p = top_p
if max_checkpoints > 0: if max_checkpoints > 0:
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")] files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
@ -57,37 +65,85 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
return optimizer return optimizer
def move_optimizer_param(param, device: torch.device, device_map: dict):
if isinstance(param, torch.Tensor):
move_device = device if device is not None else device_map[id(param)]
assert device is not None or move_device != torch.device("cpu")
old_device = param.device
param.data = param.data.to(move_device)
if param._grad is not None:
param._grad.data = param._grad.data.to(move_device)
if device is not None and id(param) not in device_map:
device_map[id(param)] = old_device
assert old_device != torch.device("cpu")
elif isinstance(param, dict):
for subparam in param.values():
move_optimizer_param(subparam, device, device_map)
def suspend_optimizer(optimizer) -> dict:
device_map = dict()
for param in optimizer.state.values():
move_optimizer_param(param, torch.device("cpu"), device_map)
return device_map
def resume_optimizer(optimizer, device_map: dict):
for param in optimizer.state.values():
move_optimizer_param(param, None, device_map)
def evaluate(model: DyntrainModel, tokenizer, def evaluate(model: DyntrainModel, tokenizer,
dataloader: torch.utils.data.DataLoader, globalstep: int, dataloader: torch.utils.data.DataLoader, globalstep: int,
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None): log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
print("*** Eval ***") with torch.no_grad():
loss = torch.zeros((1), device="cuda:0") loss = torch.zeros((1), device="cuda:0")
model.model.eval() model.model.eval()
for batch in dataloader:
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss += outputs.loss
loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}")
return loss.item()
if eval_prompt is not None: for batch in tqdm(dataloader, desc="Doing eval"):
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) for key in batch:
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False) batch[key] = batch[key].to("cuda:0")
outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100) outputs = model.model(**batch)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] loss += outputs.loss
print(f"Eval generation: {response_decoded}") loss = loss / len(dataloader)
log_writer.add_text("Text/Eval", response_decoded, globalstep) log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}")
if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1,
max_new_tokens=100, min_new_tokens=100)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Eval generation: {response_decoded}")
log_writer.add_text("Text/Eval", response_decoded, globalstep)
model.model.train()
def max_vram_allocated():
max_vram_alloc = 0
for i in range(0, torch.cuda.device_count()):
max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc)
return max_vram_alloc
def min_vram_allocated():
max_vram_alloc = sys.maxsize
for i in range(0, torch.cuda.device_count()):
max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc)
return max_vram_alloc
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
log_writer = tensorboard.SummaryWriter() log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir)
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6), model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir,
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True, quantize=model_args.quantize,
quantize=model_args.quantize) target_active_params=int(training_args.max_instant_params * 1e6),
train_static_params=training_args.train_non_linear_layers,
reshuffle_fraction=training_args.churn_percent / 100.0,
gradient_checkpointing=True,
trust_remote_code=True)
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count())) devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
model.toDevices(devices) model.toDevices(devices)
model.reshuffleActive() model.reshuffleActive()
@ -96,34 +152,15 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
paramter_count = sum(p.numel() for p in model.model.parameters()) paramter_count = sum(p.numel() for p in model.model.parameters())
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0 static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m" print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m "
f"instantanous active paramters of which {static_parameter_count} are static") f"instantanous active paramters of which {static_parameter_count} are static")
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"): train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args,
print("Loading dataset in s2s mode") training_args.per_device_train_batch_size,
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) training_args.per_device_eval_batch_size,
elif data_args.data_from_hub: training_args.do_train, training_args.do_eval)
data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
else:
print("Loading dataset in txt mode")
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
train_dataloader = torch.utils.data.DataLoader(
dataset['train_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['train_dataset'] is not None else None
if training_args.do_eval:
eval_dataloader = torch.utils.data.DataLoader(
dataset['eval_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
)
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1 steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
@ -137,7 +174,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
training_args.adam_epsilon, training_args.adam_epsilon,
training_args.adam8bit) training_args.adam8bit)
lr_scheduler = get_scheduler( lr_scheduler = transformers.get_scheduler(
name=training_args.lr_scheduler_type, name=training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps, num_warmup_steps=training_args.warmup_steps,
@ -149,13 +186,11 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
global_step = 0 global_step = 0
model.model.train() model.model.train()
for epoch in range(0, training_args.epochs): for epoch in range(0, training_args.epochs):
model.model.train()
print("*** Train ***") print("*** Train ***")
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}') print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}')
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
for key in batch: for key in batch:
batch[key] = batch[key].to("cuda:0") batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch) outputs = model.model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps loss = outputs.loss / training_args.gradient_accumulation_steps
loss.backward() loss.backward()
@ -166,46 +201,52 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB"
f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB")
model.model.zero_grad() model.model.zero_grad()
if global_step % 5 == 0: if global_step > 0:
print(f"Train Loss {loss.item()}") if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
print("Reshuffleing")
lr_scheduler.optimizer = None
del optimizer
# distance, error = model.getDistanceAndErrorSample()
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0: model.reshuffleActive()
print("Reshuffleing") model.balanceActive()
lr_scheduler.optimizer = None log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
del optimizer optimizer = get_optimizer(model.dynamicParameters(),
# distance, error = model.getDistanceAndErrorSample() model.staticParameters() if training_args.train_non_linear_layers else None,
# log_writer.add_histogram("Distances/Train", distance, max_bins=50) training_args.learning_rate,
# log_writer.add_histogram("Errors/Train", error, max_bins=50) training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
model.reshuffleActive() if global_step % training_args.save_steps == 0:
model.balanceActive() save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
optimizer = get_optimizer(model.dynamicParameters(), device_map = suspend_optimizer(optimizer)
model.staticParameters() if training_args.train_non_linear_layers else None, evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
training_args.learning_rate, resume_optimizer(optimizer, device_map)
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
global_step += 1 global_step += 1
progress_bar.update() progress_bar.update()
if global_step > 0:
if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
if training_args.flush_allocator: if training_args.flush_allocator:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if training_args.do_eval and training_args.eval_steps == -1: if training_args.do_eval and training_args.eval_steps == -1:
device_map = suspend_optimizer(optimizer)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
resume_optimizer(optimizer, device_map)
del optimizer
# Evaluation
if training_args.do_eval: if training_args.do_eval:
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)