wip refactor
This commit is contained in:
parent
11ea9eeaa7
commit
5acb6809ed
@ -1,30 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ConvertingLinear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||
output_device = input.device
|
||||
if input.device != self.weight.device:
|
||||
input = input.to(self.weight.device)
|
||||
if input.dtype != self.weight.dtype:
|
||||
input = input.to(self.weight.dtype)
|
||||
output = torch.nn.Linear.forward(self, input)
|
||||
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
||||
breakpoint()
|
||||
return output.to(output_device).to(output_dtype)
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear):
|
||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||
out_features=in_module.out_features,
|
||||
bias=in_module.bias is not None,
|
||||
device=in_module.weight.device,
|
||||
dtype=in_module.weight.dtype)
|
||||
new_module.weight = in_module.weight
|
||||
new_module.bias = in_module.bias
|
||||
return new_module
|
190
dyntrainmodel.py
Normal file
190
dyntrainmodel.py
Normal file
@ -0,0 +1,190 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
from utils import replace_module
|
||||
from modules import ConvertingLinear, Linear
|
||||
from random import randint
|
||||
|
||||
|
||||
def find_all_linear_module_names(model) -> list[str]:
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
|
||||
module_names.add(name)
|
||||
|
||||
if 'lm_head' in module_names: # needed for 16-bit
|
||||
module_names.remove('lm_head')
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_all_outher_module_names(model) -> list[str]:
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
|
||||
module_names.add(name)
|
||||
return list(module_names)
|
||||
|
||||
|
||||
class LinearGroup:
|
||||
def __init__(self, model, group_names: list):
|
||||
self.modules = list()
|
||||
model_modules = dict(model.named_modules())
|
||||
for name in group_names:
|
||||
self.modules.append(model_modules[name])
|
||||
assert isinstance(self.modules[0], ConvertingLinear)
|
||||
assert isinstance(self.modules[-1], ConvertingLinear)
|
||||
|
||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
|
||||
for module in self.modules:
|
||||
module.inplaceTo(dtype, device)
|
||||
self.modules[-1].setOutputDevice(output_device)
|
||||
|
||||
def setFrozen(self, frozen: bool) -> None:
|
||||
for module in self.modules:
|
||||
module.setFrozen(frozen)
|
||||
|
||||
def isFrozen(self) -> bool:
|
||||
return self.modules[0].isFrozen()
|
||||
|
||||
def parameters(self) -> list[torch.nn.Parameter]:
|
||||
params = list()
|
||||
for module in self.modules:
|
||||
params.extend(module.parameters())
|
||||
return params
|
||||
|
||||
def paramCount(self) -> int:
|
||||
return sum(p.numel() for p in self.parameters())
|
||||
|
||||
|
||||
class DyntrainModel:
|
||||
def __init__(self, model_name_or_path: str, cache_dir: str,
|
||||
target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map=None,
|
||||
)
|
||||
self.linear_groups = list()
|
||||
self.target_active_params = target_active_params
|
||||
|
||||
self._prepare()
|
||||
self.reshuffleActive()
|
||||
|
||||
def _get_nonlinear_names(layer: torch.nn.Module):
|
||||
names = list()
|
||||
modules = dict(layer.named_modules())
|
||||
|
||||
for key in modules.keys():
|
||||
if not isinstance(modules[key], torch.nn.Linear):
|
||||
names.append(key)
|
||||
return names
|
||||
|
||||
def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]:
|
||||
linear_groups = list()
|
||||
list_counter = 0
|
||||
in_sequence = False
|
||||
modules = dict(layer.named_modules())
|
||||
|
||||
for key in modules.keys():
|
||||
if isinstance(modules[key], torch.nn.Linear) and key != "lm_head":
|
||||
if not in_sequence:
|
||||
linear_groups.append(list())
|
||||
in_sequence = True
|
||||
linear_groups[list_counter].append(key)
|
||||
elif in_sequence:
|
||||
in_sequence = False
|
||||
list_counter = list_counter + 1
|
||||
return linear_groups
|
||||
|
||||
def _prepare(self) -> None:
|
||||
modules = dict(self.model.named_modules())
|
||||
linear_groups = DyntrainModel._get_linear_group_names(self.model)
|
||||
|
||||
for group in linear_groups:
|
||||
replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16))
|
||||
replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32))
|
||||
if len(group) > 2:
|
||||
for index in range(1, len(group) - 1):
|
||||
replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16)))
|
||||
self.linear_groups.append(LinearGroup(self.model, group))
|
||||
|
||||
def dynamicParameters(self) -> list:
|
||||
parameters = list()
|
||||
for group in self.linear_groups:
|
||||
parameters.extend(group.parameters())
|
||||
return parameters
|
||||
|
||||
def staticParameters(self) -> list:
|
||||
modules = dict(self.model.named_modules())
|
||||
dynamic_param_ids = set([id(p) for p in self.dynamicParameters()])
|
||||
parameters = list()
|
||||
for key in modules.keys():
|
||||
for param in modules[key].parameters():
|
||||
if id(param) not in dynamic_param_ids:
|
||||
parameters.append(param)
|
||||
return parameters
|
||||
|
||||
def dynamicParameterCount(self) -> int:
|
||||
return sum(p.numel() for p in self.dynamicParameters())
|
||||
|
||||
def staticParameterCount(self) -> int:
|
||||
return sum(p.numel() for p in self.staticParameters())
|
||||
|
||||
def activeParameterCount(self) -> int:
|
||||
total_params = self.dynamicParameters() + self.staticParameters()
|
||||
return sum(p.numel() for p in total_params if total_params)
|
||||
|
||||
def reshuffleActive(self) -> None:
|
||||
for group in self.linear_groups:
|
||||
group.setFrozen(True)
|
||||
|
||||
indecies = list(range(0, len(self.linear_groups)))
|
||||
params = self.staticParameterCount()
|
||||
while params < self.target_active_params and len(indecies) > 0:
|
||||
i = randint(0, len(indecies) - 1)
|
||||
self.linear_groups[indecies[i]].setFrozen(False)
|
||||
params += self.linear_groups[indecies[i]].paramCount()
|
||||
indecies.pop(i)
|
||||
|
||||
for group in self.linear_groups:
|
||||
if group.isFrozen():
|
||||
group.inplaceTo(dtype=torch.float16)
|
||||
else:
|
||||
group.inplaceTo(dtype=torch.float32)
|
||||
print(group.modules[0].weight.dtype)
|
||||
|
||||
def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None:
|
||||
modules = dict(self.model.named_modules())
|
||||
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices)
|
||||
total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8
|
||||
static_param_count = self.staticParameterCount()
|
||||
total_parameter_count = static_param_count + self.dynamicParameterCount()
|
||||
params_per_byte = total_parameter_count / float(total_memory)
|
||||
print(f"{1/params_per_byte} bytes available per parameter")
|
||||
|
||||
breakpoint()
|
||||
|
||||
for key in DyntrainModel._get_nonlinear_names(self.model):
|
||||
replace_module(self.model, key, modules[key].to(primary_device))
|
||||
|
||||
breakpoint()
|
||||
|
||||
group_index = 0
|
||||
params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count
|
||||
primary_params = static_param_count
|
||||
while params_for_primary > primary_params and group_index < len(self.linear_groups):
|
||||
self.linear_groups[group_index].inplaceTo(device=primary_device)
|
||||
primary_params += self.linear_groups[group_index].paramCount()
|
||||
group_index += 1
|
||||
|
||||
for device in secondary_devices[:-1]:
|
||||
params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte
|
||||
params = 0
|
||||
while params_for_device > params and group_index < len(self.linear_groups):
|
||||
self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device)
|
||||
params += self.linear_groups[group_index].paramCount()
|
||||
group_index += 1
|
||||
|
||||
while group_index < len(self.linear_groups):
|
||||
self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device)
|
68
modules.py
Normal file
68
modules.py
Normal file
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear):
|
||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||
out_features=in_module.out_features,
|
||||
bias=in_module.bias is not None,
|
||||
device=in_module.weight.device,
|
||||
dtype=in_module.weight.dtype)
|
||||
new_module.weight = in_module.weight
|
||||
new_module.bias = in_module.bias
|
||||
return new_module
|
||||
|
||||
def setFrozen(self, frozen: bool):
|
||||
self.weight.requires_grad = not frozen
|
||||
if self.bias is not None:
|
||||
self.bias.requires_grad = not frozen
|
||||
|
||||
def isFrozen(self) -> bool:
|
||||
return not self.weight.requires_grad
|
||||
|
||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||
if dtype is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
||||
if device is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(device))
|
||||
|
||||
|
||||
class ConvertingLinear(Linear):
|
||||
def __init__(self,
|
||||
in_features, out_features, bias=True, device=None, dtype=None,
|
||||
output_dtype=None, output_device=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.output_dtype = output_dtype
|
||||
self.output_device = output_device
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None):
|
||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||
out_features=in_module.out_features,
|
||||
bias=in_module.bias is not None,
|
||||
device=in_module.weight.device,
|
||||
dtype=in_module.weight.dtype)
|
||||
new_module.output_dtype = output_dtype
|
||||
new_module.output_device = output_device
|
||||
new_module.weight = in_module.weight
|
||||
new_module.bias = in_module.bias
|
||||
return new_module
|
||||
|
||||
def setOutputDevice(self, output_device: torch.device):
|
||||
self.output_device = output_device
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||
output_device = input.device if self.output_device is None else self.output_device
|
||||
if input.device != self.weight.device:
|
||||
input = input.to(self.weight.device)
|
||||
if input.dtype != self.weight.dtype:
|
||||
input = input.to(self.weight.dtype)
|
||||
output = torch.nn.Linear.forward(self, input)
|
||||
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
||||
breakpoint()
|
||||
return output.to(output_device).to(output_dtype)
|
229
train_dynamic.py
229
train_dynamic.py
@ -1,156 +1,18 @@
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, get_scheduler
|
||||
from peft.utils import _get_submodules
|
||||
from transformers import get_scheduler
|
||||
|
||||
import torch
|
||||
from torch.utils import tensorboard
|
||||
import os
|
||||
import shutil
|
||||
import math
|
||||
import collections
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
import collections
|
||||
|
||||
|
||||
from arguments import DataArguments, ModelArguments, TrainingArguments
|
||||
from datamodules import create_data_module_s2s, create_data_module
|
||||
from convertinglinear import ConvertingLinear
|
||||
from tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def find_all_linear_module_names(model):
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
|
||||
module_names.add(name)
|
||||
|
||||
if 'lm_head' in module_names: # needed for 16-bit
|
||||
module_names.remove('lm_head')
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_all_outher_module_names(model):
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
|
||||
module_names.add(name)
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
|
||||
dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32
|
||||
print(f'loading base model {model_args.model_name_or_path} in {dtype}...')
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
device_map=None,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
modules = dict(model.named_modules())
|
||||
keys = find_all_linear_module_names(model)
|
||||
for key in keys:
|
||||
modules[key].weight = modules[key].weight.to(dtype)
|
||||
if modules[key].bias is not None:
|
||||
modules[key].bias = modules[key].bias.to(dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def recursive_setattr(obj, attr, value):
|
||||
attr = attr.split('.', 1)
|
||||
if len(attr) == 1:
|
||||
setattr(obj, attr[0], value)
|
||||
else:
|
||||
recursive_setattr(getattr(obj, attr[0]), attr[1], value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device):
|
||||
new_module = torch.nn.Linear(module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
module.weight.device,
|
||||
dtype)
|
||||
new_module.weight = torch.nn.Parameter(module.weight.detach().clone())
|
||||
new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None
|
||||
new_module.weight.requires_grad = not frozen
|
||||
if new_module.bias is not None:
|
||||
new_module.bias.requires_grad = not frozen
|
||||
return new_module
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device):
|
||||
if type(module) is torch.nn.Linear:
|
||||
if frozen:
|
||||
module.weight.requires_grad = False
|
||||
if module.bias is not None:
|
||||
module.bias.requires_grad = False
|
||||
return module.to(dtype).to(device)
|
||||
else:
|
||||
new_module = ConvertingLinear.fromLinear(module).to(dtype)
|
||||
new_module.weight.requires_grad = True
|
||||
if new_module.bias is not None:
|
||||
new_module.bias.requires_grad = True
|
||||
return new_module.to(device)
|
||||
elif type(module) is ConvertingLinear:
|
||||
if not frozen:
|
||||
module.weight.requires_grad = True
|
||||
if module.bias is not None:
|
||||
module.bias.requires_grad = True
|
||||
assert False
|
||||
return module.to(dtype).to(device)
|
||||
else:
|
||||
new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=dtype)
|
||||
new_module.weight = torch.nn.Parameter(module.weight.to(dtype))
|
||||
new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None
|
||||
new_module.weight.requires_grad = False
|
||||
if new_module.bias is not None:
|
||||
new_module.bias.requires_grad = False
|
||||
return new_module.to(device)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device):
|
||||
modules = dict(model.named_modules())
|
||||
linear_names = find_all_linear_module_names(model)
|
||||
|
||||
for key in linear_names:
|
||||
if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad:
|
||||
parent, target, target_name = _get_submodules(model, key)
|
||||
setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device))
|
||||
modules = dict(model.named_modules())
|
||||
|
||||
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
if active_paramter_count > target_params:
|
||||
raise RuntimeError("Enough paramters must be available to train at least one linear layer")
|
||||
|
||||
while active_paramter_count < target_params and len(linear_names) > 0:
|
||||
i = randint(0, len(linear_names) - 1)
|
||||
parent, target, target_name = _get_submodules(model, linear_names[i])
|
||||
new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device)
|
||||
setattr(parent, target_name, new_module)
|
||||
active_paramter_count += modules[linear_names[i]].weight.numel()
|
||||
if modules[linear_names[i]].bias is not None:
|
||||
active_paramter_count += modules[linear_names[i]].bias.numel()
|
||||
linear_names.pop(i)
|
||||
modules = dict()
|
||||
|
||||
assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
return active_paramter_count
|
||||
from dyntrainmodel import DyntrainModel
|
||||
|
||||
|
||||
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
|
||||
@ -172,26 +34,15 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
|
||||
shutil.rmtree(delete_checkpoit_dir)
|
||||
|
||||
|
||||
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
|
||||
def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters: list[torch.nn.parameter], lr: float, static_lr: float,
|
||||
weight_decay: float, eps: float, adam8bit: bool):
|
||||
|
||||
all_keys = dynamic_module_names + static_module_names
|
||||
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
|
||||
if len(duplicated) > 0:
|
||||
print("duplicated items:")
|
||||
for item in duplicated:
|
||||
print(item)
|
||||
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
|
||||
|
||||
parameters = list()
|
||||
modules = dict(model.named_modules())
|
||||
for key in dynamic_module_names:
|
||||
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
|
||||
parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
|
||||
param_ids = set([id(p['params']) for p in parameters])
|
||||
for key in static_module_names:
|
||||
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
|
||||
for p in modules[key].parameters():
|
||||
param_ids.add(id(p))
|
||||
for param in static_parameters:
|
||||
if param.requires_grad and id(param) not in param_ids:
|
||||
parameters.append({'params': param, 'lr': static_lr})
|
||||
param_ids.add(id(param))
|
||||
|
||||
if not adam8bit:
|
||||
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
||||
@ -204,18 +55,17 @@ def get_optimizer(model, dynamic_module_names: list, static_module_names: list,
|
||||
return optimizer
|
||||
|
||||
|
||||
def compute_dynamic_parameter_ratio(model):
|
||||
modules = dict(model.named_modules())
|
||||
active_linear_parameters = 0
|
||||
total_linear_parameters = 0
|
||||
for key in find_all_linear_module_names(model):
|
||||
active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad)
|
||||
total_linear_parameters += sum(p.numel() for p in modules[key].parameters())
|
||||
return math.ceil(total_linear_parameters / active_linear_parameters)
|
||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
||||
primary_device = torch.device(training_args.primary_device)
|
||||
secondary_device = torch.device(training_args.secondary_device)
|
||||
log_writer = tensorboard.SummaryWriter()
|
||||
|
||||
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
|
||||
model = model.toDevices(primary_device, [secondary_device])
|
||||
|
||||
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
|
||||
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
|
||||
paramter_count = sum(p.numel() for p in model.model.parameters())
|
||||
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
||||
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
|
||||
|
||||
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
||||
|
||||
@ -239,48 +89,24 @@ def prepare(model_args: ModelArguments, data_args: DataArguments, training_args:
|
||||
batch_size=training_args.per_device_train_batch_size
|
||||
) if dataset['eval_dataset'] is not None else None
|
||||
|
||||
if model_args.max_instant_params != 0:
|
||||
print(f"Target params {model_args.max_instant_params}m")
|
||||
freeze_random_modules(model, model_args.max_instant_params * 1e6,
|
||||
torch.float16 if training_args.storage_fp16 else torch.float32,
|
||||
frozen_device=primary_device, active_device=secondary_device)
|
||||
|
||||
paramter_count = sum(p.numel() for p in model.parameters())
|
||||
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
|
||||
|
||||
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
|
||||
print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}")
|
||||
|
||||
dynamic_param_ratio = (model.staticParamterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
||||
total_steps = steps_per_epoch * training_args.epochs
|
||||
|
||||
optimizer = get_optimizer(model, find_all_linear_module_names(model),
|
||||
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
|
||||
optimizer = get_optimizer(model.dynamicParameters(),
|
||||
model.staticParameters(),
|
||||
training_args.learning_rate,
|
||||
training_args.learning_rate / dynamic_param_ratio,
|
||||
training_args.weight_decay,
|
||||
training_args.adam_epsilon,
|
||||
training_args.adam8bit)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=training_args.warmup_steps,
|
||||
num_training_steps=total_steps
|
||||
)
|
||||
return model, optimizer, lr_scheduler, train_dataloader
|
||||
|
||||
|
||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
||||
primary_device = torch.device(training_args.primary_device)
|
||||
secondary_device = torch.device(training_args.secondary_device)
|
||||
log_writer = tensorboard.SummaryWriter()
|
||||
|
||||
model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device)
|
||||
|
||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
||||
total_steps = steps_per_epoch * training_args.epochs
|
||||
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
|
||||
|
||||
if training_args.do_train:
|
||||
progress_bar = tqdm(range(total_steps))
|
||||
@ -307,13 +133,12 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
print(loss)
|
||||
|
||||
if global_step % 10 == 0 and model_args.max_instant_params != 0:
|
||||
param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6,
|
||||
torch.float16 if training_args.storage_fp16 else torch.float32,
|
||||
frozen_device=primary_device,
|
||||
active_device=secondary_device)
|
||||
log_writer.add_scalar("Parameters/train", param_count, global_step)
|
||||
optimizer = get_optimizer(model, find_all_linear_module_names(model),
|
||||
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
|
||||
lr_scheduler.optimizer = None
|
||||
del optimizer
|
||||
model.reshuffleActive()
|
||||
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
||||
optimizer = get_optimizer(model.dynamicParameters(),
|
||||
model.staticParameters(),
|
||||
training_args.learning_rate,
|
||||
training_args.learning_rate / dynamic_param_ratio,
|
||||
training_args.weight_decay,
|
||||
|
Loading…
x
Reference in New Issue
Block a user