QRotaryTraining/dyntrainmodel.py
2024-07-20 22:57:16 +02:00

300 lines
12 KiB
Python

# QRotaryTraining - A novel method for fully training all parameters of large
# language models (llms) while using less device memory than traditional methods.
# Copyright (C) 2024 Carl Philipp Klemm
#
# This file is part of QRotaryTraining.
#
# QRotaryTraining is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# QRotaryTraining is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with QRotaryTraining. If not, see <http://www.gnu.org/licenses/>.
from transformers import AutoModelForCausalLM
import torch
from utils import replace_module
from modules import DynamicConvertingLinear, Linear, DynamicQantizedLinear
from random import randint
import math
from tqdm import tqdm
class LinearGroup:
def __init__(self, model, group_names: list):
self.modules = list()
model_modules = dict(model.named_modules())
for name in group_names:
self.modules.append(model_modules[name])
for module in self.modules:
assert isinstance(module, Linear)
def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None, output_device: torch.device | None = None) -> None:
for module in self.modules:
module.inplaceTo(dtype, device)
self.modules[-1].setOutputDevice(output_device)
def setFrozen(self, frozen: bool, convert: bool = True) -> None:
for module in self.modules:
module.setFrozen(frozen, convert)
def isFrozen(self) -> bool:
return self.modules[0].isFrozen()
def parameters(self) -> list[torch.nn.Parameter]:
params = list()
for module in self.modules:
params.extend(module.parameters())
return params
def paramCount(self) -> int:
return sum(p.numel() for p in self.parameters())
def getDevice(self) -> torch.device:
return self.modules[0].weight.device
def compress(self) -> None:
for module in self.modules:
module.compress()
def decompress(self) -> None:
for module in self.modules:
module.decompress()
def getDistanceAndError(self) -> tuple[float, float]:
distance_accum = torch.Tensor()
error_accum = torch.Tensor()
for module in self.modules:
distance, error = module.getDistanceAndError()
distance = distance.to("cpu")
error = error.to("cpu")
distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel()))))
error_accum = torch.cat((error_accum, error.reshape((error.numel()))))
return (distance_accum, error_accum)
def check(self) -> bool:
for module in self.modules:
if not module.check():
return False
return True
class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
target_active_params: int, train_static_params: bool,
reshuffle_fraction: float, gradient_checkpointing: bool,
trust_remote_code: bool = False):
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
torch_dtype=torch.float32,
trust_remote_code=trust_remote_code,
device_map=None
)
self.target_active_params = target_active_params
self.reshuffle_fraction = reshuffle_fraction
if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
self.devices = list[torch.device]()
self.inital_reshufle = True
self.train_static_params = train_static_params
if gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
self.frozen_linear_groups = list[LinearGroup]()
self.active_linear_groups = list[LinearGroup]()
linear_group_names = DyntrainModel._getLinearGroupNames(self.model)
for group in linear_group_names:
for key in group:
replace_module(self.model, key, self._getModule(key, quantize, torch.device("cuda:0"), torch.device("cpu")))
self.frozen_linear_groups.append(LinearGroup(self.model, group))
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
for group in self.frozen_linear_groups:
group.setFrozen(True, False)
def _getModule(self, key: str, quantize: bool, active_device: torch.device, cold_device: torch.device):
output_dtype = torch.float16 if DyntrainModel.isModuleIn16bitOutlist(key) else torch.float32
modules = dict(self.model.named_modules())
if quantize:
return DynamicQantizedLinear.fromLinear(modules[key], active_device, cold_device, output_dtype, torch.float16)
else:
return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype)
@staticmethod
def _getNonlinearNames(layer: torch.nn.Module):
names = list()
modules = dict(layer.named_modules())
for key in modules.keys():
if not isinstance(modules[key], torch.nn.Linear) and len(list(modules[key].children())) == 0 or key == "lm_head":
names.append(key)
return names
@staticmethod
def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]:
linear_groups = list[list[str]]()
list_counter = 0
in_sequence = False
modules = dict(layer.named_modules())
for key in modules.keys():
if isinstance(modules[key], torch.nn.Linear) and key != "lm_head":
if not in_sequence:
linear_groups.append(list())
in_sequence = True
linear_groups[list_counter].append(key)
elif in_sequence:
in_sequence = False
list_counter = list_counter + 1
return linear_groups
@staticmethod
def isModuleIn16bitOutlist(key: str) -> bool:
key = key.split('.')[-1]
whitelist = set({
"gate_proj",
"up_proj",
"q_proj",
"k_proj",
"v_proj"})
return key in whitelist
def dynamicParameters(self) -> list:
parameters = list()
for group in self.frozen_linear_groups + self.active_linear_groups:
parameters.extend(group.parameters())
return parameters
def staticParameters(self) -> list:
modules = dict(self.model.named_modules())
dynamic_param_ids = set([id(p) for p in self.dynamicParameters()])
parameters = list()
for key in modules.keys():
for param in modules[key].parameters():
if id(param) not in dynamic_param_ids:
parameters.append(param)
return parameters
def dynamicParameterCount(self) -> int:
return sum(p.numel() for p in self.dynamicParameters())
def staticParameterCount(self) -> int:
return sum(p.numel() for p in self.staticParameters())
def activeDynamicParameterCount(self) -> int:
return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad)
def activeParameterCount(self) -> int:
if self.train_static_params:
total_params = self.dynamicParameters() + self.staticParameters()
else:
total_params = self.dynamicParameters()
return sum(p.numel() for p in total_params if p.requires_grad)
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
index = randint(0, len(self.active_linear_groups) - 1)
return self.active_linear_groups[index].getDistanceAndError()
def reshuffleActive(self):
active_count = len(self.active_linear_groups)
index = 0
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
group = self.active_linear_groups.pop(index)
group.setFrozen(True)
self.frozen_linear_groups.append(group)
assert group.check()
params = self.activeParameterCount()
if params >= self.target_active_params:
raise RuntimeError("Insuficant active parameters to suffle active")
while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
i = randint(0, len(self.frozen_linear_groups) - 1)
group = self.frozen_linear_groups.pop(i)
group.setFrozen(False)
params += group.paramCount()
self.active_linear_groups.append(group)
assert group.check()
print(math.ceil(params / 1e6))
active_params = self.activeParameterCount()
assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params
def activeParamtersByDevice(self) -> list[int]:
out = [0] * len(self.devices)
for group in self.active_linear_groups:
out[self.devices.index(group.getDevice())] += group.paramCount()
return out
def balanceActive(self) -> None:
active_counts = self.activeParamtersByDevice()
bits_per_param = list()
for i, count in enumerate(active_counts):
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
if i == 0:
memory = int(memory * 0.5)
bits_per_param.append(count / memory)
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
min_index, min_bits_per_param = min(enumerate(active_counts), key=lambda x: x[1])
for group in self.active_linear_groups:
if group.getDevice() is self.devices[max_index]:
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
if max_index == 0:
memory = int(memory * 0.5)
swing = group.paramCount() / memory
if max_bits_per_param - swing > min_bits_per_param + swing:
group.inplaceTo(device=self.devices[min_index])
self.balanceActive()
def toDevices(self, devices: list[torch.device]) -> None:
assert len(devices) > 0
modules = dict(self.model.named_modules())
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
total_memory -= int(torch.cuda.get_device_properties(devices[0]).total_memory * 0.2)
static_param_count = self.staticParameterCount()
total_parameter_count = static_param_count + self.dynamicParameterCount()
params_per_byte = total_parameter_count / float(total_memory)
print(f"{math.floor(1/params_per_byte)} bytes available per parameter")
self.devices = devices
for key in DyntrainModel._getNonlinearNames(self.model):
replace_module(self.model, key, modules[key].to(devices[0]))
linear_groups = self.active_linear_groups + self.frozen_linear_groups
group_index = 0
for i, device in enumerate(devices[:-1]):
memory = torch.cuda.get_device_properties(device).total_memory
if i == 0:
memory = int(memory * 0.8)
params_for_device = memory * params_per_byte
params = 0
while params_for_device > params and group_index < len(linear_groups):
linear_groups[group_index].inplaceTo(device=device)
params += linear_groups[group_index].paramCount()
group_index += 1
while group_index < len(linear_groups):
linear_groups[group_index].inplaceTo(device=devices[-1])
group_index += 1
for group in tqdm(linear_groups, desc="Perpareing layers"):
if group.isFrozen():
group.compress()
else:
group.decompress()
assert group.check()