Inactive parameter quanitzation support

This commit is contained in:
2024-04-07 19:15:42 +02:00
parent 3fa1fc254f
commit c33964371c
4 changed files with 161 additions and 78 deletions

View File

@ -1,9 +1,10 @@
from transformers import AutoModelForCausalLM
import torch
from utils import replace_module
from modules import DynamicConvertingLinear, Linear
from modules import DynamicConvertingLinear, Linear, DynamicQantizedLinear
from random import randint
import math
from tqdm import tqdm
class LinearGroup:
@ -20,9 +21,9 @@ class LinearGroup:
module.inplaceTo(dtype, device)
self.modules[-1].setOutputDevice(output_device)
def setFrozen(self, frozen: bool) -> None:
def setFrozen(self, frozen: bool, convert: bool = True) -> None:
for module in self.modules:
module.setFrozen(frozen)
module.setFrozen(frozen, convert)
def isFrozen(self) -> bool:
return self.modules[0].isFrozen()
@ -39,9 +40,26 @@ class LinearGroup:
def getDevice(self) -> torch.device:
return self.modules[0].weight.device
def compress(self) -> None:
for module in self.modules:
module.compress()
def decompress(self) -> None:
for module in self.modules:
module.decompress()
def checkDistance(self) -> tuple[float, float]:
distance_accum = 0.0
error_accum = 0.0
for module in self.modules:
distance, error = module.checkDistance()
distance_accum += distance**2
error_accum += error**2
return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules)))
class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str,
def __init__(self, model_name_or_path: str, cache_dir: str, quantize: bool,
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
@ -55,28 +73,32 @@ class DyntrainModel:
if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
self.devices = list()
self.inital_reshufle = True
if gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
modules = dict(self.model.named_modules())
self.frozen_linear_groups = list()
self.active_linear_groups = list()
linear_group_names = DyntrainModel._get_linear_group_names(self.model)
linear_group_names = DyntrainModel._getLinearGroupNames(self.model)
for group in linear_group_names:
for key in group:
if DyntrainModel.isModuleIn16bitOutlist(key):
replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
else:
replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
replace_module(self.model, key, self._getModule(key, quantize, "cuda:0", "cpu"))
self.frozen_linear_groups.append(LinearGroup(self.model, group))
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
for group in self.frozen_linear_groups:
group.setFrozen(True)
self.reshuffleActive()
group.setFrozen(True, False)
def _get_nonlinear_names(layer: torch.nn.Module):
def _getModule(self, key: str, quantize: bool, active_device: torch.device, cold_device: torch.device):
output_dtype = torch.float16 if DyntrainModel.isModuleIn16bitOutlist(key) else torch.float32
modules = dict(self.model.named_modules())
if quantize:
return DynamicQantizedLinear.fromLinear(modules[key], active_device, cold_device, output_dtype, torch.float16)
else:
return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype)
def _getNonlinearNames(layer: torch.nn.Module):
names = list()
modules = dict(layer.named_modules())
@ -85,7 +107,7 @@ class DyntrainModel:
names.append(key)
return names
def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]:
def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]:
linear_groups = list()
list_counter = 0
in_sequence = False
@ -140,8 +162,11 @@ class DyntrainModel:
def reshuffleActive(self) -> None:
active_count = len(self.active_linear_groups)
index = 0
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
group = self.active_linear_groups.pop(0)
distance, error = self.active_linear_groups[index].checkDistance()
print(f"linear group has moved {distance} with an error of {error}")
group = self.active_linear_groups.pop(index)
group.setFrozen(True)
self.frozen_linear_groups.append(group)
@ -161,25 +186,39 @@ class DyntrainModel:
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
def activeParamtersByDevice(self) -> list[int]:
out = [0] * len(self.devices)
for group in self.active_linear_groups:
out[self.devices.index(group.getDevice())] += group.paramCount()
return out
def balanceActive(self) -> None:
device_groups = list()
for index in range(0, len(self.devices)):
device_groups.append(list())
active_counts = self.activeParamtersByDevice()
bits_per_param = list()
for i, count in enumerate(active_counts):
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
if i == 0:
memory = memory * 0.8
bits_per_param.append(count / memory)
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
min_index, min_bits_per_param = min(enumerate(active_counts), key=lambda x: x[1])
for group in self.active_linear_groups:
device_groups[self.devices.index(group.getDevice())].append(group)
min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
if max_count - 2 > min_count:
device_groups[max_index][0].inplaceTo(device=self.devices[min_index])
self.balanceActive()
if group.getDevice() is self.devices[max_index]:
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
if max_index == 0:
memory = memory * 0.8
swing = group.paramCount() / memory
if max_bits_per_param - swing > min_bits_per_param + swing:
group.inplaceTo(device=self.devices[min_index])
self.balanceActive()
def toDevices(self, devices: list[torch.device]) -> None:
assert len(devices) > 0
modules = dict(self.model.named_modules())
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
total_memory -= torch.cuda.get_device_properties(devices[0]).total_memory * 0.2
static_param_count = self.staticParameterCount()
total_parameter_count = static_param_count + self.dynamicParameterCount()
params_per_byte = total_parameter_count / float(total_memory)
@ -187,14 +226,17 @@ class DyntrainModel:
self.devices = devices
for key in DyntrainModel._get_nonlinear_names(self.model):
for key in DyntrainModel._getNonlinearNames(self.model):
replace_module(self.model, key, modules[key].to(devices[0]))
linear_groups = self.active_linear_groups + self.frozen_linear_groups
group_index = 0
for device in devices[:-1]:
params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
for i, device in enumerate(devices[:-1]):
memory = torch.cuda.get_device_properties(devices).total_memory
if i == 0:
memory = memory * 0.8
params_for_device = memory * params_per_byte
params = 0
while params_for_device > params and group_index < len(linear_groups):
linear_groups[group_index].inplaceTo(device=device)
@ -204,3 +246,6 @@ class DyntrainModel:
while group_index < len(linear_groups):
linear_groups[group_index].inplaceTo(device=devices[-1])
group_index += 1
for group in tqdm(linear_groups, desc="Perpareing layers"):
group.compress()