add qunatized linear, refactor model for it soon to be addition
This commit is contained in:
parent
38a7f7cfc4
commit
3fa1fc254f
103
dyntrainmodel.py
103
dyntrainmodel.py
|
@ -1,30 +1,11 @@
|
|||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
from utils import replace_module
|
||||
from modules import ConvertingLinear, Linear
|
||||
from modules import DynamicConvertingLinear, Linear
|
||||
from random import randint
|
||||
import math
|
||||
|
||||
|
||||
def find_all_linear_module_names(model) -> list[str]:
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
|
||||
module_names.add(name)
|
||||
|
||||
if 'lm_head' in module_names: # needed for 16-bit
|
||||
module_names.remove('lm_head')
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_all_outher_module_names(model) -> list[str]:
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
|
||||
module_names.add(name)
|
||||
return list(module_names)
|
||||
|
||||
|
||||
class LinearGroup:
|
||||
def __init__(self, model, group_names: list):
|
||||
self.modules = list()
|
||||
|
@ -61,7 +42,7 @@ class LinearGroup:
|
|||
|
||||
class DyntrainModel:
|
||||
def __init__(self, model_name_or_path: str, cache_dir: str,
|
||||
target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
||||
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
|
@ -69,16 +50,30 @@ class DyntrainModel:
|
|||
trust_remote_code=trust_remote_code,
|
||||
device_map=None
|
||||
)
|
||||
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
|
||||
self.linear_groups = list()
|
||||
self.target_active_params = target_active_params
|
||||
|
||||
self.reshuffle_fraction = reshuffle_fraction
|
||||
if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
|
||||
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
|
||||
self.devices = list()
|
||||
|
||||
if gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
self._prepare()
|
||||
modules = dict(self.model.named_modules())
|
||||
self.frozen_linear_groups = list()
|
||||
self.active_linear_groups = list()
|
||||
|
||||
linear_group_names = DyntrainModel._get_linear_group_names(self.model)
|
||||
for group in linear_group_names:
|
||||
for key in group:
|
||||
if DyntrainModel.isModuleIn16bitOutlist(key):
|
||||
replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
|
||||
else:
|
||||
replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
|
||||
self.frozen_linear_groups.append(LinearGroup(self.model, group))
|
||||
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
|
||||
for group in self.frozen_linear_groups:
|
||||
group.setFrozen(True)
|
||||
self.reshuffleActive()
|
||||
|
||||
def _get_nonlinear_names(layer: torch.nn.Module):
|
||||
|
@ -117,21 +112,9 @@ class DyntrainModel:
|
|||
"v_proj"})
|
||||
return key in whitelist
|
||||
|
||||
def _prepare(self) -> None:
|
||||
modules = dict(self.model.named_modules())
|
||||
linear_groups = DyntrainModel._get_linear_group_names(self.model)
|
||||
|
||||
for group in linear_groups:
|
||||
for key in group:
|
||||
if DyntrainModel.isModuleIn16bitOutlist(key):
|
||||
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
|
||||
else:
|
||||
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
|
||||
self.linear_groups.append(LinearGroup(self.model, group))
|
||||
|
||||
def dynamicParameters(self) -> list:
|
||||
parameters = list()
|
||||
for group in self.linear_groups:
|
||||
for group in self.frozen_linear_groups + self.active_linear_groups:
|
||||
parameters.extend(group.parameters())
|
||||
return parameters
|
||||
|
||||
|
@ -156,23 +139,24 @@ class DyntrainModel:
|
|||
return sum(p.numel() for p in total_params if p.requires_grad)
|
||||
|
||||
def reshuffleActive(self) -> None:
|
||||
for group in self.linear_groups:
|
||||
active_count = len(self.active_linear_groups)
|
||||
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
|
||||
group = self.active_linear_groups.pop(0)
|
||||
group.setFrozen(True)
|
||||
self.frozen_linear_groups.append(group)
|
||||
|
||||
indecies = list(range(0, len(self.linear_groups)))
|
||||
params = self.staticParameterCount()
|
||||
while params < self.target_active_params and len(indecies) > 0:
|
||||
i = randint(0, len(indecies) - 1)
|
||||
self.linear_groups[indecies[i]].setFrozen(False)
|
||||
params += self.linear_groups[indecies[i]].paramCount()
|
||||
indecies.pop(i)
|
||||
params = self.activeParameterCount()
|
||||
|
||||
if params >= self.target_active_params:
|
||||
RuntimeError("Insuficant active parameters to suffle active")
|
||||
while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
|
||||
i = randint(0, len(self.frozen_linear_groups) - 1)
|
||||
group = self.frozen_linear_groups.pop(i)
|
||||
group.setFrozen(False)
|
||||
params += group.paramCount()
|
||||
self.active_linear_groups.append(group)
|
||||
print(math.ceil(params / 1e6))
|
||||
|
||||
for group in self.linear_groups:
|
||||
if group.isFrozen():
|
||||
group.inplaceTo(dtype=torch.float16)
|
||||
else:
|
||||
group.inplaceTo(dtype=torch.float32)
|
||||
active_params = self.activeParameterCount()
|
||||
|
||||
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
|
||||
|
@ -182,9 +166,8 @@ class DyntrainModel:
|
|||
for index in range(0, len(self.devices)):
|
||||
device_groups.append(list())
|
||||
|
||||
for group in self.linear_groups:
|
||||
if not group.isFrozen():
|
||||
device_groups[self.devices.index(group.getDevice())].append(group)
|
||||
for group in self.active_linear_groups:
|
||||
device_groups[self.devices.index(group.getDevice())].append(group)
|
||||
|
||||
min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
|
||||
max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
|
||||
|
@ -207,15 +190,17 @@ class DyntrainModel:
|
|||
for key in DyntrainModel._get_nonlinear_names(self.model):
|
||||
replace_module(self.model, key, modules[key].to(devices[0]))
|
||||
|
||||
linear_groups = self.active_linear_groups + self.frozen_linear_groups
|
||||
|
||||
group_index = 0
|
||||
for device in devices[:-1]:
|
||||
params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
|
||||
params = 0
|
||||
while params_for_device > params and group_index < len(self.linear_groups):
|
||||
self.linear_groups[group_index].inplaceTo(device=device)
|
||||
params += self.linear_groups[group_index].paramCount()
|
||||
while params_for_device > params and group_index < len(linear_groups):
|
||||
linear_groups[group_index].inplaceTo(device=device)
|
||||
params += linear_groups[group_index].paramCount()
|
||||
group_index += 1
|
||||
|
||||
while group_index < len(self.linear_groups):
|
||||
self.linear_groups[group_index].inplaceTo(device=devices[-1])
|
||||
while group_index < len(linear_groups):
|
||||
linear_groups[group_index].inplaceTo(device=devices[-1])
|
||||
group_index += 1
|
||||
|
|
Loading…
Reference in a new issue