# QRotaryTraining - A novel method for fully training all parameters of large # language models (llms) while using less device memory than traditional methods. # Copyright (C) 2024 Carl Philipp Klemm # # This file is part of QRotaryTraining. # # QRotaryTraining is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # QRotaryTraining is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with QRotaryTraining. If not, see . from transformers import AutoModelForCausalLM import torch from utils import replace_module from modules import DynamicConvertingLinear, Linear, DynamicQantizedLinear from random import randint import math from tqdm import tqdm class LinearGroup: def __init__(self, model, group_names: list): self.modules = list() model_modules = dict(model.named_modules()) for name in group_names: self.modules.append(model_modules[name]) for module in self.modules: assert isinstance(module, Linear) def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None, output_device: torch.device | None = None) -> None: for module in self.modules: module.inplaceTo(dtype, device) self.modules[-1].setOutputDevice(output_device) def setFrozen(self, frozen: bool, convert: bool = True) -> None: for module in self.modules: module.setFrozen(frozen, convert) def isFrozen(self) -> bool: return self.modules[0].isFrozen() def parameters(self) -> list[torch.nn.Parameter]: params = list() for module in self.modules: params.extend(module.parameters()) return params def paramCount(self) -> int: return sum(p.numel() for p in self.parameters()) def getDevice(self) -> torch.device: return self.modules[0].weight.device def compress(self) -> None: for module in self.modules: module.compress() def decompress(self) -> None: for module in self.modules: module.decompress() def getDistanceAndError(self) -> tuple[float, float]: distance_accum = torch.Tensor() error_accum = torch.Tensor() for module in self.modules: distance, error = module.getDistanceAndError() distance = distance.to("cpu") error = error.to("cpu") distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel())))) error_accum = torch.cat((error_accum, error.reshape((error.numel())))) return (distance_accum, error_accum) def check(self) -> bool: for module in self.modules: if not module.check(): return False return True class DyntrainModel: def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool, target_active_params: int, train_static_params: bool, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, cache_dir=cache_dir, torch_dtype=torch.float32, trust_remote_code=trust_remote_code, device_map=None ) self.target_active_params = target_active_params self.reshuffle_fraction = reshuffle_fraction if reshuffle_fraction < 0.10 or reshuffle_fraction > 1: raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") self.devices = list[torch.device]() self.inital_reshufle = True self.train_static_params = train_static_params if gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) self.frozen_linear_groups = list[LinearGroup]() self.active_linear_groups = list[LinearGroup]() linear_group_names = DyntrainModel._getLinearGroupNames(self.model) for group in linear_group_names: for key in group: replace_module(self.model, key, self._getModule(key, quantize, torch.device("cuda:0"), torch.device("cpu"))) self.frozen_linear_groups.append(LinearGroup(self.model, group)) self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) for group in self.frozen_linear_groups: group.setFrozen(True, False) def _getModule(self, key: str, quantize: bool, active_device: torch.device, cold_device: torch.device): output_dtype = torch.float16 if DyntrainModel.isModuleIn16bitOutlist(key) else torch.float32 modules = dict(self.model.named_modules()) if quantize: return DynamicQantizedLinear.fromLinear(modules[key], active_device, cold_device, output_dtype, torch.float16) else: return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype) @staticmethod def _getNonlinearNames(layer: torch.nn.Module): names = list() modules = dict(layer.named_modules()) for key in modules.keys(): if not isinstance(modules[key], torch.nn.Linear) and len(list(modules[key].children())) == 0 or key == "lm_head": names.append(key) return names @staticmethod def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]: linear_groups = list[list[str]]() list_counter = 0 in_sequence = False modules = dict(layer.named_modules()) for key in modules.keys(): if isinstance(modules[key], torch.nn.Linear) and key != "lm_head": if not in_sequence: linear_groups.append(list()) in_sequence = True linear_groups[list_counter].append(key) elif in_sequence: in_sequence = False list_counter = list_counter + 1 return linear_groups @staticmethod def isModuleIn16bitOutlist(key: str) -> bool: key = key.split('.')[-1] whitelist = set({ "gate_proj", "up_proj", "q_proj", "k_proj", "v_proj"}) return key in whitelist def dynamicParameters(self) -> list: parameters = list() for group in self.frozen_linear_groups + self.active_linear_groups: parameters.extend(group.parameters()) return parameters def staticParameters(self) -> list: modules = dict(self.model.named_modules()) dynamic_param_ids = set([id(p) for p in self.dynamicParameters()]) parameters = list() for key in modules.keys(): for param in modules[key].parameters(): if id(param) not in dynamic_param_ids: parameters.append(param) return parameters def dynamicParameterCount(self) -> int: return sum(p.numel() for p in self.dynamicParameters()) def staticParameterCount(self) -> int: return sum(p.numel() for p in self.staticParameters()) def activeDynamicParameterCount(self) -> int: return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad) def activeParameterCount(self) -> int: if self.train_static_params: total_params = self.dynamicParameters() + self.staticParameters() else: total_params = self.dynamicParameters() return sum(p.numel() for p in total_params if p.requires_grad) def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor): index = randint(0, len(self.active_linear_groups) - 1) return self.active_linear_groups[index].getDistanceAndError() def reshuffleActive(self): active_count = len(self.active_linear_groups) index = 0 while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction): group = self.active_linear_groups.pop(index) group.setFrozen(True) self.frozen_linear_groups.append(group) assert group.check() params = self.activeParameterCount() if params >= self.target_active_params: raise RuntimeError("Insuficant active parameters to suffle active") while params < self.target_active_params and len(self.frozen_linear_groups) > 0: i = randint(0, len(self.frozen_linear_groups) - 1) group = self.frozen_linear_groups.pop(i) group.setFrozen(False) params += group.paramCount() self.active_linear_groups.append(group) assert group.check() print(math.ceil(params / 1e6)) active_params = self.activeParameterCount() assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params def activeParamtersByDevice(self) -> list[int]: out = [0] * len(self.devices) for group in self.active_linear_groups: out[self.devices.index(group.getDevice())] += group.paramCount() return out def balanceActive(self) -> None: active_counts = self.activeParamtersByDevice() bits_per_param = list() for i, count in enumerate(active_counts): memory = torch.cuda.get_device_properties(self.devices[i]).total_memory if i == 0: memory = int(memory * 0.5) bits_per_param.append(count / memory) max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1]) min_index, min_bits_per_param = min(enumerate(active_counts), key=lambda x: x[1]) for group in self.active_linear_groups: if group.getDevice() is self.devices[max_index]: memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory if max_index == 0: memory = int(memory * 0.5) swing = group.paramCount() / memory if max_bits_per_param - swing > min_bits_per_param + swing: group.inplaceTo(device=self.devices[min_index]) self.balanceActive() def toDevices(self, devices: list[torch.device]) -> None: assert len(devices) > 0 modules = dict(self.model.named_modules()) total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices) total_memory -= int(torch.cuda.get_device_properties(devices[0]).total_memory * 0.2) static_param_count = self.staticParameterCount() total_parameter_count = static_param_count + self.dynamicParameterCount() params_per_byte = total_parameter_count / float(total_memory) print(f"{math.floor(1/params_per_byte)} bytes available per parameter") self.devices = devices for key in DyntrainModel._getNonlinearNames(self.model): replace_module(self.model, key, modules[key].to(devices[0])) linear_groups = self.active_linear_groups + self.frozen_linear_groups group_index = 0 for i, device in enumerate(devices[:-1]): memory = torch.cuda.get_device_properties(device).total_memory if i == 0: memory = int(memory * 0.8) params_for_device = memory * params_per_byte params = 0 while params_for_device > params and group_index < len(linear_groups): linear_groups[group_index].inplaceTo(device=device) params += linear_groups[group_index].paramCount() group_index += 1 while group_index < len(linear_groups): linear_groups[group_index].inplaceTo(device=devices[-1]) group_index += 1 for group in tqdm(linear_groups, desc="Perpareing layers"): if group.isFrozen(): group.compress() else: group.decompress() assert group.check()