wip refactor

This commit is contained in:
uvos 2024-03-13 19:45:52 +01:00
parent 11ea9eeaa7
commit 5acb6809ed
5 changed files with 292 additions and 232 deletions

View File

@ -1,30 +0,0 @@
import torch
class ConvertingLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.output_dtype = output_dtype
def forward(self, input: torch.Tensor):
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
output_device = input.device
if input.device != self.weight.device:
input = input.to(self.weight.device)
if input.dtype != self.weight.dtype:
input = input.to(self.weight.dtype)
output = torch.nn.Linear.forward(self, input)
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
breakpoint()
return output.to(output_device).to(output_dtype)
@classmethod
def fromLinear(cls, in_module: torch.nn.Linear):
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
out_features=in_module.out_features,
bias=in_module.bias is not None,
device=in_module.weight.device,
dtype=in_module.weight.dtype)
new_module.weight = in_module.weight
new_module.bias = in_module.bias
return new_module

190
dyntrainmodel.py Normal file
View File

@ -0,0 +1,190 @@
from transformers import AutoModelForCausalLM
import torch
from utils import replace_module
from modules import ConvertingLinear, Linear
from random import randint
def find_all_linear_module_names(model) -> list[str]:
module_names = set()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
module_names.add(name)
if 'lm_head' in module_names: # needed for 16-bit
module_names.remove('lm_head')
return list(module_names)
def find_all_outher_module_names(model) -> list[str]:
module_names = set()
for name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
module_names.add(name)
return list(module_names)
class LinearGroup:
def __init__(self, model, group_names: list):
self.modules = list()
model_modules = dict(model.named_modules())
for name in group_names:
self.modules.append(model_modules[name])
assert isinstance(self.modules[0], ConvertingLinear)
assert isinstance(self.modules[-1], ConvertingLinear)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
for module in self.modules:
module.inplaceTo(dtype, device)
self.modules[-1].setOutputDevice(output_device)
def setFrozen(self, frozen: bool) -> None:
for module in self.modules:
module.setFrozen(frozen)
def isFrozen(self) -> bool:
return self.modules[0].isFrozen()
def parameters(self) -> list[torch.nn.Parameter]:
params = list()
for module in self.modules:
params.extend(module.parameters())
return params
def paramCount(self) -> int:
return sum(p.numel() for p in self.parameters())
class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str,
target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False):
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
torch_dtype=torch.float32,
trust_remote_code=trust_remote_code,
device_map=None,
)
self.linear_groups = list()
self.target_active_params = target_active_params
self._prepare()
self.reshuffleActive()
def _get_nonlinear_names(layer: torch.nn.Module):
names = list()
modules = dict(layer.named_modules())
for key in modules.keys():
if not isinstance(modules[key], torch.nn.Linear):
names.append(key)
return names
def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]:
linear_groups = list()
list_counter = 0
in_sequence = False
modules = dict(layer.named_modules())
for key in modules.keys():
if isinstance(modules[key], torch.nn.Linear) and key != "lm_head":
if not in_sequence:
linear_groups.append(list())
in_sequence = True
linear_groups[list_counter].append(key)
elif in_sequence:
in_sequence = False
list_counter = list_counter + 1
return linear_groups
def _prepare(self) -> None:
modules = dict(self.model.named_modules())
linear_groups = DyntrainModel._get_linear_group_names(self.model)
for group in linear_groups:
replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16))
replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32))
if len(group) > 2:
for index in range(1, len(group) - 1):
replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16)))
self.linear_groups.append(LinearGroup(self.model, group))
def dynamicParameters(self) -> list:
parameters = list()
for group in self.linear_groups:
parameters.extend(group.parameters())
return parameters
def staticParameters(self) -> list:
modules = dict(self.model.named_modules())
dynamic_param_ids = set([id(p) for p in self.dynamicParameters()])
parameters = list()
for key in modules.keys():
for param in modules[key].parameters():
if id(param) not in dynamic_param_ids:
parameters.append(param)
return parameters
def dynamicParameterCount(self) -> int:
return sum(p.numel() for p in self.dynamicParameters())
def staticParameterCount(self) -> int:
return sum(p.numel() for p in self.staticParameters())
def activeParameterCount(self) -> int:
total_params = self.dynamicParameters() + self.staticParameters()
return sum(p.numel() for p in total_params if total_params)
def reshuffleActive(self) -> None:
for group in self.linear_groups:
group.setFrozen(True)
indecies = list(range(0, len(self.linear_groups)))
params = self.staticParameterCount()
while params < self.target_active_params and len(indecies) > 0:
i = randint(0, len(indecies) - 1)
self.linear_groups[indecies[i]].setFrozen(False)
params += self.linear_groups[indecies[i]].paramCount()
indecies.pop(i)
for group in self.linear_groups:
if group.isFrozen():
group.inplaceTo(dtype=torch.float16)
else:
group.inplaceTo(dtype=torch.float32)
print(group.modules[0].weight.dtype)
def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None:
modules = dict(self.model.named_modules())
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices)
total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8
static_param_count = self.staticParameterCount()
total_parameter_count = static_param_count + self.dynamicParameterCount()
params_per_byte = total_parameter_count / float(total_memory)
print(f"{1/params_per_byte} bytes available per parameter")
breakpoint()
for key in DyntrainModel._get_nonlinear_names(self.model):
replace_module(self.model, key, modules[key].to(primary_device))
breakpoint()
group_index = 0
params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count
primary_params = static_param_count
while params_for_primary > primary_params and group_index < len(self.linear_groups):
self.linear_groups[group_index].inplaceTo(device=primary_device)
primary_params += self.linear_groups[group_index].paramCount()
group_index += 1
for device in secondary_devices[:-1]:
params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte
params = 0
while params_for_device > params and group_index < len(self.linear_groups):
self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device)
params += self.linear_groups[group_index].paramCount()
group_index += 1
while group_index < len(self.linear_groups):
self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device)

68
modules.py Normal file
View File

@ -0,0 +1,68 @@
import torch
class Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
@classmethod
def fromLinear(cls, in_module: torch.nn.Linear):
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
out_features=in_module.out_features,
bias=in_module.bias is not None,
device=in_module.weight.device,
dtype=in_module.weight.dtype)
new_module.weight = in_module.weight
new_module.bias = in_module.bias
return new_module
def setFrozen(self, frozen: bool):
self.weight.requires_grad = not frozen
if self.bias is not None:
self.bias.requires_grad = not frozen
def isFrozen(self) -> bool:
return not self.weight.requires_grad
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
if dtype is not None:
self.weight = torch.nn.Parameter(self.weight.to(dtype))
if device is not None:
self.weight = torch.nn.Parameter(self.weight.to(device))
class ConvertingLinear(Linear):
def __init__(self,
in_features, out_features, bias=True, device=None, dtype=None,
output_dtype=None, output_device=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.output_dtype = output_dtype
self.output_device = output_device
@classmethod
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None):
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
out_features=in_module.out_features,
bias=in_module.bias is not None,
device=in_module.weight.device,
dtype=in_module.weight.dtype)
new_module.output_dtype = output_dtype
new_module.output_device = output_device
new_module.weight = in_module.weight
new_module.bias = in_module.bias
return new_module
def setOutputDevice(self, output_device: torch.device):
self.output_device = output_device
def forward(self, input: torch.Tensor):
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
output_device = input.device if self.output_device is None else self.output_device
if input.device != self.weight.device:
input = input.to(self.weight.device)
if input.dtype != self.weight.dtype:
input = input.to(self.weight.dtype)
output = torch.nn.Linear.forward(self, input)
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
breakpoint()
return output.to(output_device).to(output_dtype)

View File

@ -1,156 +1,18 @@
import transformers
from transformers import AutoModelForCausalLM, get_scheduler
from peft.utils import _get_submodules
from transformers import get_scheduler
import torch
from torch.utils import tensorboard
import os
import shutil
import math
import collections
from tqdm.auto import tqdm
from random import randint
import collections
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module
from convertinglinear import ConvertingLinear
from tokenizer import get_tokenizer
def find_all_linear_module_names(model):
module_names = set()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
module_names.add(name)
if 'lm_head' in module_names: # needed for 16-bit
module_names.remove('lm_head')
return list(module_names)
def find_all_outher_module_names(model):
module_names = set()
for name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
module_names.add(name)
return list(module_names)
def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32
print(f'loading base model {model_args.model_name_or_path} in {dtype}...')
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=cache_dir,
torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32,
trust_remote_code=model_args.trust_remote_code,
device_map=None,
attn_implementation="flash_attention_2"
)
modules = dict(model.named_modules())
keys = find_all_linear_module_names(model)
for key in keys:
modules[key].weight = modules[key].weight.to(dtype)
if modules[key].bias is not None:
modules[key].bias = modules[key].bias.to(dtype)
return model
@torch.no_grad()
def recursive_setattr(obj, attr, value):
attr = attr.split('.', 1)
if len(attr) == 1:
setattr(obj, attr[0], value)
else:
recursive_setattr(getattr(obj, attr[0]), attr[1], value)
@torch.no_grad()
def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device):
new_module = torch.nn.Linear(module.in_features,
module.out_features,
module.bias is not None,
module.weight.device,
dtype)
new_module.weight = torch.nn.Parameter(module.weight.detach().clone())
new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None
new_module.weight.requires_grad = not frozen
if new_module.bias is not None:
new_module.bias.requires_grad = not frozen
return new_module
@torch.no_grad()
def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device):
if type(module) is torch.nn.Linear:
if frozen:
module.weight.requires_grad = False
if module.bias is not None:
module.bias.requires_grad = False
return module.to(dtype).to(device)
else:
new_module = ConvertingLinear.fromLinear(module).to(dtype)
new_module.weight.requires_grad = True
if new_module.bias is not None:
new_module.bias.requires_grad = True
return new_module.to(device)
elif type(module) is ConvertingLinear:
if not frozen:
module.weight.requires_grad = True
if module.bias is not None:
module.bias.requires_grad = True
assert False
return module.to(dtype).to(device)
else:
new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=dtype)
new_module.weight = torch.nn.Parameter(module.weight.to(dtype))
new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None
new_module.weight.requires_grad = False
if new_module.bias is not None:
new_module.bias.requires_grad = False
return new_module.to(device)
else:
assert False
@torch.no_grad()
def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device):
modules = dict(model.named_modules())
linear_names = find_all_linear_module_names(model)
for key in linear_names:
if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad:
parent, target, target_name = _get_submodules(model, key)
setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device))
modules = dict(model.named_modules())
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
if active_paramter_count > target_params:
raise RuntimeError("Enough paramters must be available to train at least one linear layer")
while active_paramter_count < target_params and len(linear_names) > 0:
i = randint(0, len(linear_names) - 1)
parent, target, target_name = _get_submodules(model, linear_names[i])
new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device)
setattr(parent, target_name, new_module)
active_paramter_count += modules[linear_names[i]].weight.numel()
if modules[linear_names[i]].bias is not None:
active_paramter_count += modules[linear_names[i]].bias.numel()
linear_names.pop(i)
modules = dict()
assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad)
return active_paramter_count
from dyntrainmodel import DyntrainModel
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
@ -172,26 +34,15 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
shutil.rmtree(delete_checkpoit_dir)
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters: list[torch.nn.parameter], lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool):
all_keys = dynamic_module_names + static_module_names
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
if len(duplicated) > 0:
print("duplicated items:")
for item in duplicated:
print(item)
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
parameters = list()
modules = dict(model.named_modules())
for key in dynamic_module_names:
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
param_ids = set([id(p['params']) for p in parameters])
for key in static_module_names:
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
for p in modules[key].parameters():
param_ids.add(id(p))
for param in static_parameters:
if param.requires_grad and id(param) not in param_ids:
parameters.append({'params': param, 'lr': static_lr})
param_ids.add(id(param))
if not adam8bit:
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
@ -204,18 +55,17 @@ def get_optimizer(model, dynamic_module_names: list, static_module_names: list,
return optimizer
def compute_dynamic_parameter_ratio(model):
modules = dict(model.named_modules())
active_linear_parameters = 0
total_linear_parameters = 0
for key in find_all_linear_module_names(model):
active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad)
total_linear_parameters += sum(p.numel() for p in modules[key].parameters())
return math.ceil(total_linear_parameters / active_linear_parameters)
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
primary_device = torch.device(training_args.primary_device)
secondary_device = torch.device(training_args.secondary_device)
log_writer = tensorboard.SummaryWriter()
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
model = model.toDevices(primary_device, [secondary_device])
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
paramter_count = sum(p.numel() for p in model.model.parameters())
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
@ -239,48 +89,24 @@ def prepare(model_args: ModelArguments, data_args: DataArguments, training_args:
batch_size=training_args.per_device_train_batch_size
) if dataset['eval_dataset'] is not None else None
if model_args.max_instant_params != 0:
print(f"Target params {model_args.max_instant_params}m")
freeze_random_modules(model, model_args.max_instant_params * 1e6,
torch.float16 if training_args.storage_fp16 else torch.float32,
frozen_device=primary_device, active_device=secondary_device)
paramter_count = sum(p.numel() for p in model.parameters())
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}")
dynamic_param_ratio = (model.staticParamterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
total_steps = steps_per_epoch * training_args.epochs
optimizer = get_optimizer(model, find_all_linear_module_names(model),
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
optimizer = get_optimizer(model.dynamicParameters(),
model.staticParameters(),
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=total_steps
)
return model, optimizer, lr_scheduler, train_dataloader
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
primary_device = torch.device(training_args.primary_device)
secondary_device = torch.device(training_args.secondary_device)
log_writer = tensorboard.SummaryWriter()
model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device)
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
total_steps = steps_per_epoch * training_args.epochs
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
if training_args.do_train:
progress_bar = tqdm(range(total_steps))
@ -307,13 +133,12 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
print(loss)
if global_step % 10 == 0 and model_args.max_instant_params != 0:
param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6,
torch.float16 if training_args.storage_fp16 else torch.float32,
frozen_device=primary_device,
active_device=secondary_device)
log_writer.add_scalar("Parameters/train", param_count, global_step)
optimizer = get_optimizer(model, find_all_linear_module_names(model),
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
lr_scheduler.optimizer = None
del optimizer
model.reshuffleActive()
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
optimizer = get_optimizer(model.dynamicParameters(),
model.staticParameters(),
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,

7
utils.py Normal file
View File

@ -0,0 +1,7 @@
from peft.utils import _get_submodules
import torch
def replace_module(model, key: str, module: torch.nn.Module):
parent, target, target_name = _get_submodules(model, key)
setattr(parent, target_name, module)