various optimizations
This commit is contained in:
@ -68,7 +68,9 @@ class LinearGroup:
|
|||||||
|
|
||||||
class DyntrainModel:
|
class DyntrainModel:
|
||||||
def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
|
def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
|
||||||
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
target_active_params: int, train_static_params: bool,
|
||||||
|
reshuffle_fraction: float, gradient_checkpointing: bool,
|
||||||
|
trust_remote_code: bool = False):
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@ -82,6 +84,7 @@ class DyntrainModel:
|
|||||||
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
|
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
|
||||||
self.devices = list[torch.device]()
|
self.devices = list[torch.device]()
|
||||||
self.inital_reshufle = True
|
self.inital_reshufle = True
|
||||||
|
self.train_static_params = train_static_params
|
||||||
|
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
@ -167,8 +170,14 @@ class DyntrainModel:
|
|||||||
def staticParameterCount(self) -> int:
|
def staticParameterCount(self) -> int:
|
||||||
return sum(p.numel() for p in self.staticParameters())
|
return sum(p.numel() for p in self.staticParameters())
|
||||||
|
|
||||||
|
def activeDynamicParameterCount(self) -> int:
|
||||||
|
return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad)
|
||||||
|
|
||||||
def activeParameterCount(self) -> int:
|
def activeParameterCount(self) -> int:
|
||||||
total_params = self.dynamicParameters() + self.staticParameters()
|
if self.train_static_params:
|
||||||
|
total_params = self.dynamicParameters() + self.staticParameters()
|
||||||
|
else:
|
||||||
|
total_params = self.dynamicParameters()
|
||||||
return sum(p.numel() for p in total_params if p.requires_grad)
|
return sum(p.numel() for p in total_params if p.requires_grad)
|
||||||
|
|
||||||
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
|
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
|
||||||
@ -187,7 +196,7 @@ class DyntrainModel:
|
|||||||
params = self.activeParameterCount()
|
params = self.activeParameterCount()
|
||||||
|
|
||||||
if params >= self.target_active_params:
|
if params >= self.target_active_params:
|
||||||
RuntimeError("Insuficant active parameters to suffle active")
|
raise RuntimeError("Insuficant active parameters to suffle active")
|
||||||
while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
|
while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
|
||||||
i = randint(0, len(self.frozen_linear_groups) - 1)
|
i = randint(0, len(self.frozen_linear_groups) - 1)
|
||||||
group = self.frozen_linear_groups.pop(i)
|
group = self.frozen_linear_groups.pop(i)
|
||||||
@ -199,7 +208,7 @@ class DyntrainModel:
|
|||||||
|
|
||||||
active_params = self.activeParameterCount()
|
active_params = self.activeParameterCount()
|
||||||
|
|
||||||
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
|
assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params
|
||||||
|
|
||||||
def activeParamtersByDevice(self) -> list[int]:
|
def activeParamtersByDevice(self) -> list[int]:
|
||||||
out = [0] * len(self.devices)
|
out = [0] * len(self.devices)
|
||||||
@ -213,7 +222,7 @@ class DyntrainModel:
|
|||||||
for i, count in enumerate(active_counts):
|
for i, count in enumerate(active_counts):
|
||||||
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
|
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
|
||||||
if i == 0:
|
if i == 0:
|
||||||
memory = int(memory * 0.8)
|
memory = int(memory * 0.5)
|
||||||
bits_per_param.append(count / memory)
|
bits_per_param.append(count / memory)
|
||||||
|
|
||||||
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
|
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
|
||||||
@ -223,7 +232,7 @@ class DyntrainModel:
|
|||||||
if group.getDevice() is self.devices[max_index]:
|
if group.getDevice() is self.devices[max_index]:
|
||||||
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
|
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
|
||||||
if max_index == 0:
|
if max_index == 0:
|
||||||
memory = int(memory * 0.8)
|
memory = int(memory * 0.5)
|
||||||
swing = group.paramCount() / memory
|
swing = group.paramCount() / memory
|
||||||
if max_bits_per_param - swing > min_bits_per_param + swing:
|
if max_bits_per_param - swing > min_bits_per_param + swing:
|
||||||
group.inplaceTo(device=self.devices[min_index])
|
group.inplaceTo(device=self.devices[min_index])
|
||||||
|
24
modules.py
24
modules.py
@ -108,7 +108,7 @@ class DynamicConvertingLinear(Linear):
|
|||||||
|
|
||||||
class DynamicQantizedLinear(Linear):
|
class DynamicQantizedLinear(Linear):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device,
|
def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device,
|
||||||
output_dtype=None, compute_dtype=None, output_device=None):
|
output_dtype=None, compute_dtype=None, output_device=None, cold_dtype=torch.float32):
|
||||||
super().__init__(in_features, out_features, bias, cold_device, torch.float32)
|
super().__init__(in_features, out_features, bias, cold_device, torch.float32)
|
||||||
self.active_device = active_device
|
self.active_device = active_device
|
||||||
self.cold_device = cold_device
|
self.cold_device = cold_device
|
||||||
@ -120,8 +120,8 @@ class DynamicQantizedLinear(Linear):
|
|||||||
self.bias_quantized = None
|
self.bias_quantized = None
|
||||||
self.bias_state = None
|
self.bias_state = None
|
||||||
self.block_size = 128
|
self.block_size = 128
|
||||||
self.quant_type = 'nf4'
|
#self.weight_start = self.weight.clone().detach()
|
||||||
self.weight_start = self.weight.clone().detach()
|
self.cold_dtype = cold_dtype
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
|
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
|
||||||
@ -131,19 +131,19 @@ class DynamicQantizedLinear(Linear):
|
|||||||
compute_dtype=compute_dtype, output_device=output_device)
|
compute_dtype=compute_dtype, output_device=output_device)
|
||||||
new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
|
new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
|
||||||
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
|
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
|
||||||
new_module.weight_start = new_module.weight.clone().detach()
|
#new_module.weight_start = new_module.weight.clone().detach()
|
||||||
return new_module
|
return new_module
|
||||||
|
|
||||||
def compress(self) -> None:
|
def compress(self) -> None:
|
||||||
weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
|
weight = self.weight.contiguous().to(torch.float16).to(self.active_device)
|
||||||
self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size)
|
self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
|
bias = self.bias.contiguous().to(torch.float16).to(self.active_device)
|
||||||
self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
|
self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
|
||||||
|
|
||||||
frozen = self.isFrozen()
|
frozen = self.isFrozen()
|
||||||
self.weight = torch.nn.Parameter(self.weight.to(self.cold_device))
|
self.weight = torch.nn.Parameter(self.weight.to(self.cold_dtype).to(self.cold_device))
|
||||||
self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
|
self.bias = torch.nn.Parameter(self.bias.to(self.cold_dtype).to(self.cold_device)) if self.bias is not None else None
|
||||||
self.setFrozen(frozen, False)
|
self.setFrozen(frozen, False)
|
||||||
|
|
||||||
def decompress(self) -> None:
|
def decompress(self) -> None:
|
||||||
@ -151,16 +151,16 @@ class DynamicQantizedLinear(Linear):
|
|||||||
self.weight_state = None
|
self.weight_state = None
|
||||||
self.bias_quantized = None
|
self.bias_quantized = None
|
||||||
self.bias_state = None
|
self.bias_state = None
|
||||||
self.weight_start = self.weight.clone().detach().to(self.cold_device)
|
#self.weight_start = self.weight.clone().detach().to(self.cold_device)
|
||||||
self.weight = torch.nn.Parameter(self.weight.to(self.active_device))
|
self.weight = torch.nn.Parameter(self.weight.to(self.active_device).to(torch.float32))
|
||||||
if self.bias_quantized:
|
if self.bias_quantized:
|
||||||
self.bias = torch.nn.Parameter(self.bias.to(self.active_device))
|
self.bias = torch.nn.Parameter(self.bias.to(self.active_device).to(torch.float32))
|
||||||
|
|
||||||
def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
|
def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
|
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
|
||||||
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
|
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
|
||||||
dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
|
dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
|
||||||
distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
|
distance = torch.zeros((2)) #(self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
|
||||||
error = (dequantized_original_weight - original_weight).to(torch.float32)
|
error = (dequantized_original_weight - original_weight).to(torch.float32)
|
||||||
return (distance, error)
|
return (distance, error)
|
||||||
|
|
||||||
|
@ -30,13 +30,13 @@ def smart_tokenizer_and_embedding_resize(
|
|||||||
|
|
||||||
|
|
||||||
def get_tokenizer(model, cache_dir, model_args: ModelArguments):
|
def get_tokenizer(model, cache_dir, model_args: ModelArguments):
|
||||||
print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}')
|
tokenizer_path = model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path
|
||||||
|
print(f'Tokenizer: {tokenizer_path}')
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||||
model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path,
|
tokenizer_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
use_fast=False,
|
use_fast=False,
|
||||||
eos_token="[EOS]",
|
|
||||||
tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
|
tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
|
||||||
trust_remote_code=model_args.trust_remote_code
|
trust_remote_code=model_args.trust_remote_code
|
||||||
)
|
)
|
||||||
|
201
train_dynamic.py
201
train_dynamic.py
@ -1,6 +1,4 @@
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers import get_scheduler
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils import tensorboard
|
from torch.utils import tensorboard
|
||||||
import os
|
import os
|
||||||
@ -8,9 +6,10 @@ import shutil
|
|||||||
import math
|
import math
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import gc
|
import gc
|
||||||
|
import sys
|
||||||
|
|
||||||
from arguments import DataArguments, ModelArguments, TrainingArguments
|
from arguments import DataArguments, ModelArguments, TrainingArguments
|
||||||
from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub
|
from datamodules import get_data_loaders
|
||||||
from tokenizer import get_tokenizer
|
from tokenizer import get_tokenizer
|
||||||
|
|
||||||
from dyntrainmodel import DyntrainModel
|
from dyntrainmodel import DyntrainModel
|
||||||
@ -19,7 +18,16 @@ from dyntrainmodel import DyntrainModel
|
|||||||
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
|
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
|
||||||
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
|
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
|
||||||
output_dir = os.path.join(output_dir, output_chkpt_dir)
|
output_dir = os.path.join(output_dir, output_chkpt_dir)
|
||||||
|
|
||||||
|
print(f"saveing model to {output_chkpt_dir}")
|
||||||
|
|
||||||
|
temperature = model.generation_config.temperature
|
||||||
|
top_p = model.generation_config.top_p
|
||||||
|
model.generation_config.temperature = None
|
||||||
|
model.generation_config.top_p = None
|
||||||
model.save_pretrained(output_dir)
|
model.save_pretrained(output_dir)
|
||||||
|
model.generation_config.temperature = temperature
|
||||||
|
model.generation_config.top_p = top_p
|
||||||
|
|
||||||
if max_checkpoints > 0:
|
if max_checkpoints > 0:
|
||||||
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
|
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
|
||||||
@ -57,37 +65,85 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
|
|||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def move_optimizer_param(param, device: torch.device, device_map: dict):
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
move_device = device if device is not None else device_map[id(param)]
|
||||||
|
assert device is not None or move_device != torch.device("cpu")
|
||||||
|
old_device = param.device
|
||||||
|
param.data = param.data.to(move_device)
|
||||||
|
if param._grad is not None:
|
||||||
|
param._grad.data = param._grad.data.to(move_device)
|
||||||
|
if device is not None and id(param) not in device_map:
|
||||||
|
device_map[id(param)] = old_device
|
||||||
|
assert old_device != torch.device("cpu")
|
||||||
|
elif isinstance(param, dict):
|
||||||
|
for subparam in param.values():
|
||||||
|
move_optimizer_param(subparam, device, device_map)
|
||||||
|
|
||||||
|
|
||||||
|
def suspend_optimizer(optimizer) -> dict:
|
||||||
|
device_map = dict()
|
||||||
|
for param in optimizer.state.values():
|
||||||
|
move_optimizer_param(param, torch.device("cpu"), device_map)
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
|
||||||
|
def resume_optimizer(optimizer, device_map: dict):
|
||||||
|
for param in optimizer.state.values():
|
||||||
|
move_optimizer_param(param, None, device_map)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model: DyntrainModel, tokenizer,
|
def evaluate(model: DyntrainModel, tokenizer,
|
||||||
dataloader: torch.utils.data.DataLoader, globalstep: int,
|
dataloader: torch.utils.data.DataLoader, globalstep: int,
|
||||||
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
|
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
|
||||||
print("*** Eval ***")
|
with torch.no_grad():
|
||||||
loss = torch.zeros((1), device="cuda:0")
|
loss = torch.zeros((1), device="cuda:0")
|
||||||
model.model.eval()
|
model.model.eval()
|
||||||
for batch in dataloader:
|
|
||||||
for key in batch:
|
|
||||||
batch[key] = batch[key].to("cuda:0")
|
|
||||||
outputs = model.model(**batch)
|
|
||||||
loss += outputs.loss
|
|
||||||
loss = loss / len(dataloader)
|
|
||||||
log_writer.add_scalar("Loss/Eval", loss, globalstep)
|
|
||||||
print(f"Eval Loss {loss.item()}")
|
|
||||||
return loss.item()
|
|
||||||
|
|
||||||
if eval_prompt is not None:
|
for batch in tqdm(dataloader, desc="Doing eval"):
|
||||||
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
|
for key in batch:
|
||||||
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
|
batch[key] = batch[key].to("cuda:0")
|
||||||
outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100)
|
outputs = model.model(**batch)
|
||||||
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
loss += outputs.loss
|
||||||
print(f"Eval generation: {response_decoded}")
|
loss = loss / len(dataloader)
|
||||||
log_writer.add_text("Text/Eval", response_decoded, globalstep)
|
log_writer.add_scalar("Loss/Eval", loss, globalstep)
|
||||||
|
print(f"Eval Loss {loss.item()}")
|
||||||
|
|
||||||
|
if eval_prompt is not None:
|
||||||
|
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
|
||||||
|
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
|
||||||
|
outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1,
|
||||||
|
max_new_tokens=100, min_new_tokens=100)
|
||||||
|
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||||
|
print(f"Eval generation: {response_decoded}")
|
||||||
|
log_writer.add_text("Text/Eval", response_decoded, globalstep)
|
||||||
|
model.model.train()
|
||||||
|
|
||||||
|
|
||||||
|
def max_vram_allocated():
|
||||||
|
max_vram_alloc = 0
|
||||||
|
for i in range(0, torch.cuda.device_count()):
|
||||||
|
max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc)
|
||||||
|
return max_vram_alloc
|
||||||
|
|
||||||
|
|
||||||
|
def min_vram_allocated():
|
||||||
|
max_vram_alloc = sys.maxsize
|
||||||
|
for i in range(0, torch.cuda.device_count()):
|
||||||
|
max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc)
|
||||||
|
return max_vram_alloc
|
||||||
|
|
||||||
|
|
||||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
||||||
log_writer = tensorboard.SummaryWriter()
|
log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir)
|
||||||
|
|
||||||
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6),
|
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir,
|
||||||
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True,
|
quantize=model_args.quantize,
|
||||||
quantize=model_args.quantize)
|
target_active_params=int(training_args.max_instant_params * 1e6),
|
||||||
|
train_static_params=training_args.train_non_linear_layers,
|
||||||
|
reshuffle_fraction=training_args.churn_percent / 100.0,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
trust_remote_code=True)
|
||||||
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
|
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
|
||||||
model.toDevices(devices)
|
model.toDevices(devices)
|
||||||
model.reshuffleActive()
|
model.reshuffleActive()
|
||||||
@ -96,34 +152,15 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
paramter_count = sum(p.numel() for p in model.model.parameters())
|
paramter_count = sum(p.numel() for p in model.model.parameters())
|
||||||
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
||||||
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
|
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
|
||||||
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m"
|
print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m "
|
||||||
f"instantanous active paramters of which {static_parameter_count} are static")
|
f"instantanous active paramters of which {static_parameter_count} are static")
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
||||||
|
|
||||||
if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"):
|
train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args,
|
||||||
print("Loading dataset in s2s mode")
|
training_args.per_device_train_batch_size,
|
||||||
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
training_args.per_device_eval_batch_size,
|
||||||
elif data_args.data_from_hub:
|
training_args.do_train, training_args.do_eval)
|
||||||
data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
|
||||||
else:
|
|
||||||
print("Loading dataset in txt mode")
|
|
||||||
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
|
||||||
|
|
||||||
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset['train_dataset'],
|
|
||||||
shuffle=True,
|
|
||||||
collate_fn=dataset['data_collator'],
|
|
||||||
batch_size=training_args.per_device_train_batch_size
|
|
||||||
) if dataset['train_dataset'] is not None else None
|
|
||||||
if training_args.do_eval:
|
|
||||||
eval_dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset['eval_dataset'],
|
|
||||||
shuffle=True,
|
|
||||||
collate_fn=dataset['data_collator'],
|
|
||||||
batch_size=training_args.per_device_train_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
||||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
|
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
|
||||||
@ -137,7 +174,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
training_args.adam_epsilon,
|
training_args.adam_epsilon,
|
||||||
training_args.adam8bit)
|
training_args.adam8bit)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = transformers.get_scheduler(
|
||||||
name=training_args.lr_scheduler_type,
|
name=training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=training_args.warmup_steps,
|
num_warmup_steps=training_args.warmup_steps,
|
||||||
@ -149,13 +186,11 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
global_step = 0
|
global_step = 0
|
||||||
model.model.train()
|
model.model.train()
|
||||||
for epoch in range(0, training_args.epochs):
|
for epoch in range(0, training_args.epochs):
|
||||||
model.model.train()
|
|
||||||
print("*** Train ***")
|
print("*** Train ***")
|
||||||
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
|
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}')
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to("cuda:0")
|
batch[key] = batch[key].to("cuda:0")
|
||||||
|
|
||||||
outputs = model.model(**batch)
|
outputs = model.model(**batch)
|
||||||
loss = outputs.loss / training_args.gradient_accumulation_steps
|
loss = outputs.loss / training_args.gradient_accumulation_steps
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -166,46 +201,52 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB"
|
||||||
|
f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB")
|
||||||
|
|
||||||
model.model.zero_grad()
|
model.model.zero_grad()
|
||||||
|
|
||||||
if global_step % 5 == 0:
|
if global_step > 0:
|
||||||
print(f"Train Loss {loss.item()}")
|
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
|
||||||
|
print("Reshuffleing")
|
||||||
|
lr_scheduler.optimizer = None
|
||||||
|
del optimizer
|
||||||
|
# distance, error = model.getDistanceAndErrorSample()
|
||||||
|
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
|
||||||
|
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
|
||||||
|
|
||||||
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
|
model.reshuffleActive()
|
||||||
print("Reshuffleing")
|
model.balanceActive()
|
||||||
lr_scheduler.optimizer = None
|
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
||||||
del optimizer
|
optimizer = get_optimizer(model.dynamicParameters(),
|
||||||
# distance, error = model.getDistanceAndErrorSample()
|
model.staticParameters() if training_args.train_non_linear_layers else None,
|
||||||
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
|
training_args.learning_rate,
|
||||||
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
|
training_args.learning_rate / dynamic_param_ratio,
|
||||||
|
training_args.weight_decay,
|
||||||
|
training_args.adam_epsilon,
|
||||||
|
training_args.adam8bit)
|
||||||
|
lr_scheduler.optimizer = optimizer
|
||||||
|
|
||||||
model.reshuffleActive()
|
if global_step % training_args.save_steps == 0:
|
||||||
model.balanceActive()
|
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
|
||||||
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
|
||||||
optimizer = get_optimizer(model.dynamicParameters(),
|
device_map = suspend_optimizer(optimizer)
|
||||||
model.staticParameters() if training_args.train_non_linear_layers else None,
|
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||||
training_args.learning_rate,
|
resume_optimizer(optimizer, device_map)
|
||||||
training_args.learning_rate / dynamic_param_ratio,
|
|
||||||
training_args.weight_decay,
|
|
||||||
training_args.adam_epsilon,
|
|
||||||
training_args.adam8bit)
|
|
||||||
lr_scheduler.optimizer = optimizer
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
|
||||||
if global_step > 0:
|
|
||||||
if global_step % training_args.save_steps == 0:
|
|
||||||
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
|
|
||||||
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
|
|
||||||
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
|
||||||
if training_args.flush_allocator:
|
if training_args.flush_allocator:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if training_args.do_eval and training_args.eval_steps == -1:
|
if training_args.do_eval and training_args.eval_steps == -1:
|
||||||
|
device_map = suspend_optimizer(optimizer)
|
||||||
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||||
|
resume_optimizer(optimizer, device_map)
|
||||||
|
|
||||||
|
del optimizer
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user