wip refactor
This commit is contained in:
@ -1,30 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertingLinear(torch.nn.Linear):
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
|
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
|
||||||
self.output_dtype = output_dtype
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
|
||||||
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
|
||||||
output_device = input.device
|
|
||||||
if input.device != self.weight.device:
|
|
||||||
input = input.to(self.weight.device)
|
|
||||||
if input.dtype != self.weight.dtype:
|
|
||||||
input = input.to(self.weight.dtype)
|
|
||||||
output = torch.nn.Linear.forward(self, input)
|
|
||||||
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
|
||||||
breakpoint()
|
|
||||||
return output.to(output_device).to(output_dtype)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fromLinear(cls, in_module: torch.nn.Linear):
|
|
||||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
|
||||||
out_features=in_module.out_features,
|
|
||||||
bias=in_module.bias is not None,
|
|
||||||
device=in_module.weight.device,
|
|
||||||
dtype=in_module.weight.dtype)
|
|
||||||
new_module.weight = in_module.weight
|
|
||||||
new_module.bias = in_module.bias
|
|
||||||
return new_module
|
|
190
dyntrainmodel.py
Normal file
190
dyntrainmodel.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
from utils import replace_module
|
||||||
|
from modules import ConvertingLinear, Linear
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_linear_module_names(model) -> list[str]:
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
|
||||||
|
module_names.add(name)
|
||||||
|
|
||||||
|
if 'lm_head' in module_names: # needed for 16-bit
|
||||||
|
module_names.remove('lm_head')
|
||||||
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_outher_module_names(model) -> list[str]:
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
|
||||||
|
module_names.add(name)
|
||||||
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearGroup:
|
||||||
|
def __init__(self, model, group_names: list):
|
||||||
|
self.modules = list()
|
||||||
|
model_modules = dict(model.named_modules())
|
||||||
|
for name in group_names:
|
||||||
|
self.modules.append(model_modules[name])
|
||||||
|
assert isinstance(self.modules[0], ConvertingLinear)
|
||||||
|
assert isinstance(self.modules[-1], ConvertingLinear)
|
||||||
|
|
||||||
|
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
|
||||||
|
for module in self.modules:
|
||||||
|
module.inplaceTo(dtype, device)
|
||||||
|
self.modules[-1].setOutputDevice(output_device)
|
||||||
|
|
||||||
|
def setFrozen(self, frozen: bool) -> None:
|
||||||
|
for module in self.modules:
|
||||||
|
module.setFrozen(frozen)
|
||||||
|
|
||||||
|
def isFrozen(self) -> bool:
|
||||||
|
return self.modules[0].isFrozen()
|
||||||
|
|
||||||
|
def parameters(self) -> list[torch.nn.Parameter]:
|
||||||
|
params = list()
|
||||||
|
for module in self.modules:
|
||||||
|
params.extend(module.parameters())
|
||||||
|
return params
|
||||||
|
|
||||||
|
def paramCount(self) -> int:
|
||||||
|
return sum(p.numel() for p in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
class DyntrainModel:
|
||||||
|
def __init__(self, model_name_or_path: str, cache_dir: str,
|
||||||
|
target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
device_map=None,
|
||||||
|
)
|
||||||
|
self.linear_groups = list()
|
||||||
|
self.target_active_params = target_active_params
|
||||||
|
|
||||||
|
self._prepare()
|
||||||
|
self.reshuffleActive()
|
||||||
|
|
||||||
|
def _get_nonlinear_names(layer: torch.nn.Module):
|
||||||
|
names = list()
|
||||||
|
modules = dict(layer.named_modules())
|
||||||
|
|
||||||
|
for key in modules.keys():
|
||||||
|
if not isinstance(modules[key], torch.nn.Linear):
|
||||||
|
names.append(key)
|
||||||
|
return names
|
||||||
|
|
||||||
|
def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]:
|
||||||
|
linear_groups = list()
|
||||||
|
list_counter = 0
|
||||||
|
in_sequence = False
|
||||||
|
modules = dict(layer.named_modules())
|
||||||
|
|
||||||
|
for key in modules.keys():
|
||||||
|
if isinstance(modules[key], torch.nn.Linear) and key != "lm_head":
|
||||||
|
if not in_sequence:
|
||||||
|
linear_groups.append(list())
|
||||||
|
in_sequence = True
|
||||||
|
linear_groups[list_counter].append(key)
|
||||||
|
elif in_sequence:
|
||||||
|
in_sequence = False
|
||||||
|
list_counter = list_counter + 1
|
||||||
|
return linear_groups
|
||||||
|
|
||||||
|
def _prepare(self) -> None:
|
||||||
|
modules = dict(self.model.named_modules())
|
||||||
|
linear_groups = DyntrainModel._get_linear_group_names(self.model)
|
||||||
|
|
||||||
|
for group in linear_groups:
|
||||||
|
replace_module(self.model, group[0], ConvertingLinear.fromLinear(modules[group[0]].to(torch.float16), output_dtype=torch.float16))
|
||||||
|
replace_module(self.model, group[-1], ConvertingLinear.fromLinear(modules[group[-1]].to(torch.float16), output_dtype=torch.float32))
|
||||||
|
if len(group) > 2:
|
||||||
|
for index in range(1, len(group) - 1):
|
||||||
|
replace_module(self.model, group[index], Linear.fromLinear(modules[group[index]].to(torch.float16)))
|
||||||
|
self.linear_groups.append(LinearGroup(self.model, group))
|
||||||
|
|
||||||
|
def dynamicParameters(self) -> list:
|
||||||
|
parameters = list()
|
||||||
|
for group in self.linear_groups:
|
||||||
|
parameters.extend(group.parameters())
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def staticParameters(self) -> list:
|
||||||
|
modules = dict(self.model.named_modules())
|
||||||
|
dynamic_param_ids = set([id(p) for p in self.dynamicParameters()])
|
||||||
|
parameters = list()
|
||||||
|
for key in modules.keys():
|
||||||
|
for param in modules[key].parameters():
|
||||||
|
if id(param) not in dynamic_param_ids:
|
||||||
|
parameters.append(param)
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def dynamicParameterCount(self) -> int:
|
||||||
|
return sum(p.numel() for p in self.dynamicParameters())
|
||||||
|
|
||||||
|
def staticParameterCount(self) -> int:
|
||||||
|
return sum(p.numel() for p in self.staticParameters())
|
||||||
|
|
||||||
|
def activeParameterCount(self) -> int:
|
||||||
|
total_params = self.dynamicParameters() + self.staticParameters()
|
||||||
|
return sum(p.numel() for p in total_params if total_params)
|
||||||
|
|
||||||
|
def reshuffleActive(self) -> None:
|
||||||
|
for group in self.linear_groups:
|
||||||
|
group.setFrozen(True)
|
||||||
|
|
||||||
|
indecies = list(range(0, len(self.linear_groups)))
|
||||||
|
params = self.staticParameterCount()
|
||||||
|
while params < self.target_active_params and len(indecies) > 0:
|
||||||
|
i = randint(0, len(indecies) - 1)
|
||||||
|
self.linear_groups[indecies[i]].setFrozen(False)
|
||||||
|
params += self.linear_groups[indecies[i]].paramCount()
|
||||||
|
indecies.pop(i)
|
||||||
|
|
||||||
|
for group in self.linear_groups:
|
||||||
|
if group.isFrozen():
|
||||||
|
group.inplaceTo(dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
group.inplaceTo(dtype=torch.float32)
|
||||||
|
print(group.modules[0].weight.dtype)
|
||||||
|
|
||||||
|
def toDevices(self, primary_device: torch.device, secondary_devices: list[torch.device]) -> None:
|
||||||
|
modules = dict(self.model.named_modules())
|
||||||
|
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in secondary_devices)
|
||||||
|
total_memory += torch.cuda.get_device_properties(primary_device).total_memory * 0.8
|
||||||
|
static_param_count = self.staticParameterCount()
|
||||||
|
total_parameter_count = static_param_count + self.dynamicParameterCount()
|
||||||
|
params_per_byte = total_parameter_count / float(total_memory)
|
||||||
|
print(f"{1/params_per_byte} bytes available per parameter")
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
for key in DyntrainModel._get_nonlinear_names(self.model):
|
||||||
|
replace_module(self.model, key, modules[key].to(primary_device))
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
group_index = 0
|
||||||
|
params_for_primary = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte * 0.8 - static_param_count
|
||||||
|
primary_params = static_param_count
|
||||||
|
while params_for_primary > primary_params and group_index < len(self.linear_groups):
|
||||||
|
self.linear_groups[group_index].inplaceTo(device=primary_device)
|
||||||
|
primary_params += self.linear_groups[group_index].paramCount()
|
||||||
|
group_index += 1
|
||||||
|
|
||||||
|
for device in secondary_devices[:-1]:
|
||||||
|
params_for_device = torch.cuda.get_device_properties(primary_device).total_memory * params_per_byte
|
||||||
|
params = 0
|
||||||
|
while params_for_device > params and group_index < len(self.linear_groups):
|
||||||
|
self.linear_groups[group_index].inplaceTo(device=device, output_device=primary_device)
|
||||||
|
params += self.linear_groups[group_index].paramCount()
|
||||||
|
group_index += 1
|
||||||
|
|
||||||
|
while group_index < len(self.linear_groups):
|
||||||
|
self.linear_groups[group_index].inplaceTo(device=secondary_devices[-1], output_device=primary_device)
|
68
modules.py
Normal file
68
modules.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(torch.nn.Linear):
|
||||||
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fromLinear(cls, in_module: torch.nn.Linear):
|
||||||
|
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||||
|
out_features=in_module.out_features,
|
||||||
|
bias=in_module.bias is not None,
|
||||||
|
device=in_module.weight.device,
|
||||||
|
dtype=in_module.weight.dtype)
|
||||||
|
new_module.weight = in_module.weight
|
||||||
|
new_module.bias = in_module.bias
|
||||||
|
return new_module
|
||||||
|
|
||||||
|
def setFrozen(self, frozen: bool):
|
||||||
|
self.weight.requires_grad = not frozen
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.requires_grad = not frozen
|
||||||
|
|
||||||
|
def isFrozen(self) -> bool:
|
||||||
|
return not self.weight.requires_grad
|
||||||
|
|
||||||
|
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||||
|
if dtype is not None:
|
||||||
|
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
||||||
|
if device is not None:
|
||||||
|
self.weight = torch.nn.Parameter(self.weight.to(device))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertingLinear(Linear):
|
||||||
|
def __init__(self,
|
||||||
|
in_features, out_features, bias=True, device=None, dtype=None,
|
||||||
|
output_dtype=None, output_device=None):
|
||||||
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
self.output_dtype = output_dtype
|
||||||
|
self.output_device = output_device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None):
|
||||||
|
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||||
|
out_features=in_module.out_features,
|
||||||
|
bias=in_module.bias is not None,
|
||||||
|
device=in_module.weight.device,
|
||||||
|
dtype=in_module.weight.dtype)
|
||||||
|
new_module.output_dtype = output_dtype
|
||||||
|
new_module.output_device = output_device
|
||||||
|
new_module.weight = in_module.weight
|
||||||
|
new_module.bias = in_module.bias
|
||||||
|
return new_module
|
||||||
|
|
||||||
|
def setOutputDevice(self, output_device: torch.device):
|
||||||
|
self.output_device = output_device
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||||
|
output_device = input.device if self.output_device is None else self.output_device
|
||||||
|
if input.device != self.weight.device:
|
||||||
|
input = input.to(self.weight.device)
|
||||||
|
if input.dtype != self.weight.dtype:
|
||||||
|
input = input.to(self.weight.dtype)
|
||||||
|
output = torch.nn.Linear.forward(self, input)
|
||||||
|
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
||||||
|
breakpoint()
|
||||||
|
return output.to(output_device).to(output_dtype)
|
229
train_dynamic.py
229
train_dynamic.py
@ -1,156 +1,18 @@
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoModelForCausalLM, get_scheduler
|
from transformers import get_scheduler
|
||||||
from peft.utils import _get_submodules
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils import tensorboard
|
from torch.utils import tensorboard
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import math
|
import math
|
||||||
import collections
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from random import randint
|
|
||||||
import collections
|
|
||||||
|
|
||||||
|
|
||||||
from arguments import DataArguments, ModelArguments, TrainingArguments
|
from arguments import DataArguments, ModelArguments, TrainingArguments
|
||||||
from datamodules import create_data_module_s2s, create_data_module
|
from datamodules import create_data_module_s2s, create_data_module
|
||||||
from convertinglinear import ConvertingLinear
|
|
||||||
from tokenizer import get_tokenizer
|
from tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from dyntrainmodel import DyntrainModel
|
||||||
def find_all_linear_module_names(model):
|
|
||||||
module_names = set()
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
|
|
||||||
module_names.add(name)
|
|
||||||
|
|
||||||
if 'lm_head' in module_names: # needed for 16-bit
|
|
||||||
module_names.remove('lm_head')
|
|
||||||
return list(module_names)
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_outher_module_names(model):
|
|
||||||
module_names = set()
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
|
|
||||||
module_names.add(name)
|
|
||||||
return list(module_names)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
|
|
||||||
dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32
|
|
||||||
print(f'loading base model {model_args.model_name_or_path} in {dtype}...')
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
device_map=None,
|
|
||||||
attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
|
|
||||||
modules = dict(model.named_modules())
|
|
||||||
keys = find_all_linear_module_names(model)
|
|
||||||
for key in keys:
|
|
||||||
modules[key].weight = modules[key].weight.to(dtype)
|
|
||||||
if modules[key].bias is not None:
|
|
||||||
modules[key].bias = modules[key].bias.to(dtype)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def recursive_setattr(obj, attr, value):
|
|
||||||
attr = attr.split('.', 1)
|
|
||||||
if len(attr) == 1:
|
|
||||||
setattr(obj, attr[0], value)
|
|
||||||
else:
|
|
||||||
recursive_setattr(getattr(obj, attr[0]), attr[1], value)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device):
|
|
||||||
new_module = torch.nn.Linear(module.in_features,
|
|
||||||
module.out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
module.weight.device,
|
|
||||||
dtype)
|
|
||||||
new_module.weight = torch.nn.Parameter(module.weight.detach().clone())
|
|
||||||
new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None
|
|
||||||
new_module.weight.requires_grad = not frozen
|
|
||||||
if new_module.bias is not None:
|
|
||||||
new_module.bias.requires_grad = not frozen
|
|
||||||
return new_module
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device):
|
|
||||||
if type(module) is torch.nn.Linear:
|
|
||||||
if frozen:
|
|
||||||
module.weight.requires_grad = False
|
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.requires_grad = False
|
|
||||||
return module.to(dtype).to(device)
|
|
||||||
else:
|
|
||||||
new_module = ConvertingLinear.fromLinear(module).to(dtype)
|
|
||||||
new_module.weight.requires_grad = True
|
|
||||||
if new_module.bias is not None:
|
|
||||||
new_module.bias.requires_grad = True
|
|
||||||
return new_module.to(device)
|
|
||||||
elif type(module) is ConvertingLinear:
|
|
||||||
if not frozen:
|
|
||||||
module.weight.requires_grad = True
|
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.requires_grad = True
|
|
||||||
assert False
|
|
||||||
return module.to(dtype).to(device)
|
|
||||||
else:
|
|
||||||
new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features,
|
|
||||||
out_features=module.out_features,
|
|
||||||
bias=module.bias is not None,
|
|
||||||
device=module.weight.device,
|
|
||||||
dtype=dtype)
|
|
||||||
new_module.weight = torch.nn.Parameter(module.weight.to(dtype))
|
|
||||||
new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None
|
|
||||||
new_module.weight.requires_grad = False
|
|
||||||
if new_module.bias is not None:
|
|
||||||
new_module.bias.requires_grad = False
|
|
||||||
return new_module.to(device)
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device):
|
|
||||||
modules = dict(model.named_modules())
|
|
||||||
linear_names = find_all_linear_module_names(model)
|
|
||||||
|
|
||||||
for key in linear_names:
|
|
||||||
if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad:
|
|
||||||
parent, target, target_name = _get_submodules(model, key)
|
|
||||||
setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device))
|
|
||||||
modules = dict(model.named_modules())
|
|
||||||
|
|
||||||
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
if active_paramter_count > target_params:
|
|
||||||
raise RuntimeError("Enough paramters must be available to train at least one linear layer")
|
|
||||||
|
|
||||||
while active_paramter_count < target_params and len(linear_names) > 0:
|
|
||||||
i = randint(0, len(linear_names) - 1)
|
|
||||||
parent, target, target_name = _get_submodules(model, linear_names[i])
|
|
||||||
new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device)
|
|
||||||
setattr(parent, target_name, new_module)
|
|
||||||
active_paramter_count += modules[linear_names[i]].weight.numel()
|
|
||||||
if modules[linear_names[i]].bias is not None:
|
|
||||||
active_paramter_count += modules[linear_names[i]].bias.numel()
|
|
||||||
linear_names.pop(i)
|
|
||||||
modules = dict()
|
|
||||||
|
|
||||||
assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
return active_paramter_count
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
|
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
|
||||||
@ -172,26 +34,15 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
|
|||||||
shutil.rmtree(delete_checkpoit_dir)
|
shutil.rmtree(delete_checkpoit_dir)
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
|
def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters: list[torch.nn.parameter], lr: float, static_lr: float,
|
||||||
weight_decay: float, eps: float, adam8bit: bool):
|
weight_decay: float, eps: float, adam8bit: bool):
|
||||||
|
|
||||||
all_keys = dynamic_module_names + static_module_names
|
|
||||||
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
|
|
||||||
if len(duplicated) > 0:
|
|
||||||
print("duplicated items:")
|
|
||||||
for item in duplicated:
|
|
||||||
print(item)
|
|
||||||
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
|
|
||||||
|
|
||||||
parameters = list()
|
parameters = list()
|
||||||
modules = dict(model.named_modules())
|
parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
|
||||||
for key in dynamic_module_names:
|
|
||||||
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
|
|
||||||
param_ids = set([id(p['params']) for p in parameters])
|
param_ids = set([id(p['params']) for p in parameters])
|
||||||
for key in static_module_names:
|
for param in static_parameters:
|
||||||
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
|
if param.requires_grad and id(param) not in param_ids:
|
||||||
for p in modules[key].parameters():
|
parameters.append({'params': param, 'lr': static_lr})
|
||||||
param_ids.add(id(p))
|
param_ids.add(id(param))
|
||||||
|
|
||||||
if not adam8bit:
|
if not adam8bit:
|
||||||
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
||||||
@ -204,18 +55,17 @@ def get_optimizer(model, dynamic_module_names: list, static_module_names: list,
|
|||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def compute_dynamic_parameter_ratio(model):
|
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
||||||
modules = dict(model.named_modules())
|
primary_device = torch.device(training_args.primary_device)
|
||||||
active_linear_parameters = 0
|
secondary_device = torch.device(training_args.secondary_device)
|
||||||
total_linear_parameters = 0
|
log_writer = tensorboard.SummaryWriter()
|
||||||
for key in find_all_linear_module_names(model):
|
|
||||||
active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad)
|
|
||||||
total_linear_parameters += sum(p.numel() for p in modules[key].parameters())
|
|
||||||
return math.ceil(total_linear_parameters / active_linear_parameters)
|
|
||||||
|
|
||||||
|
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
|
||||||
|
model = model.toDevices(primary_device, [secondary_device])
|
||||||
|
|
||||||
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
|
paramter_count = sum(p.numel() for p in model.model.parameters())
|
||||||
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
|
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
||||||
|
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
||||||
|
|
||||||
@ -239,48 +89,24 @@ def prepare(model_args: ModelArguments, data_args: DataArguments, training_args:
|
|||||||
batch_size=training_args.per_device_train_batch_size
|
batch_size=training_args.per_device_train_batch_size
|
||||||
) if dataset['eval_dataset'] is not None else None
|
) if dataset['eval_dataset'] is not None else None
|
||||||
|
|
||||||
if model_args.max_instant_params != 0:
|
dynamic_param_ratio = (model.staticParamterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
||||||
print(f"Target params {model_args.max_instant_params}m")
|
|
||||||
freeze_random_modules(model, model_args.max_instant_params * 1e6,
|
|
||||||
torch.float16 if training_args.storage_fp16 else torch.float32,
|
|
||||||
frozen_device=primary_device, active_device=secondary_device)
|
|
||||||
|
|
||||||
paramter_count = sum(p.numel() for p in model.parameters())
|
|
||||||
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
|
|
||||||
|
|
||||||
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
|
|
||||||
print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}")
|
|
||||||
|
|
||||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
||||||
total_steps = steps_per_epoch * training_args.epochs
|
total_steps = steps_per_epoch * training_args.epochs
|
||||||
|
|
||||||
optimizer = get_optimizer(model, find_all_linear_module_names(model),
|
optimizer = get_optimizer(model.dynamicParameters(),
|
||||||
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
|
model.staticParameters(),
|
||||||
training_args.learning_rate,
|
training_args.learning_rate,
|
||||||
training_args.learning_rate / dynamic_param_ratio,
|
training_args.learning_rate / dynamic_param_ratio,
|
||||||
training_args.weight_decay,
|
training_args.weight_decay,
|
||||||
training_args.adam_epsilon,
|
training_args.adam_epsilon,
|
||||||
training_args.adam8bit)
|
training_args.adam8bit)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
name=training_args.lr_scheduler_type,
|
name=training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=training_args.warmup_steps,
|
num_warmup_steps=training_args.warmup_steps,
|
||||||
num_training_steps=total_steps
|
num_training_steps=total_steps
|
||||||
)
|
)
|
||||||
return model, optimizer, lr_scheduler, train_dataloader
|
|
||||||
|
|
||||||
|
|
||||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
|
|
||||||
primary_device = torch.device(training_args.primary_device)
|
|
||||||
secondary_device = torch.device(training_args.secondary_device)
|
|
||||||
log_writer = tensorboard.SummaryWriter()
|
|
||||||
|
|
||||||
model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device)
|
|
||||||
|
|
||||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
|
|
||||||
total_steps = steps_per_epoch * training_args.epochs
|
|
||||||
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
|
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
progress_bar = tqdm(range(total_steps))
|
progress_bar = tqdm(range(total_steps))
|
||||||
@ -307,13 +133,12 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
print(loss)
|
print(loss)
|
||||||
|
|
||||||
if global_step % 10 == 0 and model_args.max_instant_params != 0:
|
if global_step % 10 == 0 and model_args.max_instant_params != 0:
|
||||||
param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6,
|
lr_scheduler.optimizer = None
|
||||||
torch.float16 if training_args.storage_fp16 else torch.float32,
|
del optimizer
|
||||||
frozen_device=primary_device,
|
model.reshuffleActive()
|
||||||
active_device=secondary_device)
|
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
|
||||||
log_writer.add_scalar("Parameters/train", param_count, global_step)
|
optimizer = get_optimizer(model.dynamicParameters(),
|
||||||
optimizer = get_optimizer(model, find_all_linear_module_names(model),
|
model.staticParameters(),
|
||||||
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
|
|
||||||
training_args.learning_rate,
|
training_args.learning_rate,
|
||||||
training_args.learning_rate / dynamic_param_ratio,
|
training_args.learning_rate / dynamic_param_ratio,
|
||||||
training_args.weight_decay,
|
training_args.weight_decay,
|
||||||
|
Reference in New Issue
Block a user