add qunatized linear, refactor model for it soon to be addition
This commit is contained in:
@ -41,10 +41,6 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
|
metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
|
||||||
)
|
)
|
||||||
max_instant_params: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Maximum amount of paramters to optimize per step in millions"}
|
|
||||||
)
|
|
||||||
noresize: Optional[bool] = field(
|
noresize: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Never resize tokenizer embeddings"}
|
metadata={"help": "Never resize tokenizer embeddings"}
|
||||||
@ -93,3 +89,5 @@ class TrainingArguments():
|
|||||||
secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'})
|
secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'})
|
||||||
train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'})
|
train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'})
|
||||||
flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'})
|
flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'})
|
||||||
|
max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"})
|
||||||
|
churn_percent: int = field(default=0, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"})
|
||||||
|
101
dyntrainmodel.py
101
dyntrainmodel.py
@ -1,30 +1,11 @@
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
from utils import replace_module
|
from utils import replace_module
|
||||||
from modules import ConvertingLinear, Linear
|
from modules import DynamicConvertingLinear, Linear
|
||||||
from random import randint
|
from random import randint
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_module_names(model) -> list[str]:
|
|
||||||
module_names = set()
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
|
|
||||||
module_names.add(name)
|
|
||||||
|
|
||||||
if 'lm_head' in module_names: # needed for 16-bit
|
|
||||||
module_names.remove('lm_head')
|
|
||||||
return list(module_names)
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_outher_module_names(model) -> list[str]:
|
|
||||||
module_names = set()
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
|
|
||||||
module_names.add(name)
|
|
||||||
return list(module_names)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearGroup:
|
class LinearGroup:
|
||||||
def __init__(self, model, group_names: list):
|
def __init__(self, model, group_names: list):
|
||||||
self.modules = list()
|
self.modules = list()
|
||||||
@ -61,7 +42,7 @@ class LinearGroup:
|
|||||||
|
|
||||||
class DyntrainModel:
|
class DyntrainModel:
|
||||||
def __init__(self, model_name_or_path: str, cache_dir: str,
|
def __init__(self, model_name_or_path: str, cache_dir: str,
|
||||||
target_active_params: int, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@ -69,16 +50,30 @@ class DyntrainModel:
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
device_map=None
|
device_map=None
|
||||||
)
|
)
|
||||||
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
|
|
||||||
self.linear_groups = list()
|
|
||||||
self.target_active_params = target_active_params
|
self.target_active_params = target_active_params
|
||||||
|
self.reshuffle_fraction = reshuffle_fraction
|
||||||
|
if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
|
||||||
|
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
|
||||||
self.devices = list()
|
self.devices = list()
|
||||||
|
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
self._prepare()
|
modules = dict(self.model.named_modules())
|
||||||
|
self.frozen_linear_groups = list()
|
||||||
|
self.active_linear_groups = list()
|
||||||
|
|
||||||
|
linear_group_names = DyntrainModel._get_linear_group_names(self.model)
|
||||||
|
for group in linear_group_names:
|
||||||
|
for key in group:
|
||||||
|
if DyntrainModel.isModuleIn16bitOutlist(key):
|
||||||
|
replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
|
||||||
|
else:
|
||||||
|
replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
|
||||||
|
self.frozen_linear_groups.append(LinearGroup(self.model, group))
|
||||||
|
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
|
||||||
|
for group in self.frozen_linear_groups:
|
||||||
|
group.setFrozen(True)
|
||||||
self.reshuffleActive()
|
self.reshuffleActive()
|
||||||
|
|
||||||
def _get_nonlinear_names(layer: torch.nn.Module):
|
def _get_nonlinear_names(layer: torch.nn.Module):
|
||||||
@ -117,21 +112,9 @@ class DyntrainModel:
|
|||||||
"v_proj"})
|
"v_proj"})
|
||||||
return key in whitelist
|
return key in whitelist
|
||||||
|
|
||||||
def _prepare(self) -> None:
|
|
||||||
modules = dict(self.model.named_modules())
|
|
||||||
linear_groups = DyntrainModel._get_linear_group_names(self.model)
|
|
||||||
|
|
||||||
for group in linear_groups:
|
|
||||||
for key in group:
|
|
||||||
if DyntrainModel.isModuleIn16bitOutlist(key):
|
|
||||||
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
|
|
||||||
else:
|
|
||||||
replace_module(self.model, key, ConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
|
|
||||||
self.linear_groups.append(LinearGroup(self.model, group))
|
|
||||||
|
|
||||||
def dynamicParameters(self) -> list:
|
def dynamicParameters(self) -> list:
|
||||||
parameters = list()
|
parameters = list()
|
||||||
for group in self.linear_groups:
|
for group in self.frozen_linear_groups + self.active_linear_groups:
|
||||||
parameters.extend(group.parameters())
|
parameters.extend(group.parameters())
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
@ -156,23 +139,24 @@ class DyntrainModel:
|
|||||||
return sum(p.numel() for p in total_params if p.requires_grad)
|
return sum(p.numel() for p in total_params if p.requires_grad)
|
||||||
|
|
||||||
def reshuffleActive(self) -> None:
|
def reshuffleActive(self) -> None:
|
||||||
for group in self.linear_groups:
|
active_count = len(self.active_linear_groups)
|
||||||
|
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
|
||||||
|
group = self.active_linear_groups.pop(0)
|
||||||
group.setFrozen(True)
|
group.setFrozen(True)
|
||||||
|
self.frozen_linear_groups.append(group)
|
||||||
|
|
||||||
indecies = list(range(0, len(self.linear_groups)))
|
params = self.activeParameterCount()
|
||||||
params = self.staticParameterCount()
|
|
||||||
while params < self.target_active_params and len(indecies) > 0:
|
if params >= self.target_active_params:
|
||||||
i = randint(0, len(indecies) - 1)
|
RuntimeError("Insuficant active parameters to suffle active")
|
||||||
self.linear_groups[indecies[i]].setFrozen(False)
|
while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
|
||||||
params += self.linear_groups[indecies[i]].paramCount()
|
i = randint(0, len(self.frozen_linear_groups) - 1)
|
||||||
indecies.pop(i)
|
group = self.frozen_linear_groups.pop(i)
|
||||||
|
group.setFrozen(False)
|
||||||
|
params += group.paramCount()
|
||||||
|
self.active_linear_groups.append(group)
|
||||||
print(math.ceil(params / 1e6))
|
print(math.ceil(params / 1e6))
|
||||||
|
|
||||||
for group in self.linear_groups:
|
|
||||||
if group.isFrozen():
|
|
||||||
group.inplaceTo(dtype=torch.float16)
|
|
||||||
else:
|
|
||||||
group.inplaceTo(dtype=torch.float32)
|
|
||||||
active_params = self.activeParameterCount()
|
active_params = self.activeParameterCount()
|
||||||
|
|
||||||
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
|
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
|
||||||
@ -182,8 +166,7 @@ class DyntrainModel:
|
|||||||
for index in range(0, len(self.devices)):
|
for index in range(0, len(self.devices)):
|
||||||
device_groups.append(list())
|
device_groups.append(list())
|
||||||
|
|
||||||
for group in self.linear_groups:
|
for group in self.active_linear_groups:
|
||||||
if not group.isFrozen():
|
|
||||||
device_groups[self.devices.index(group.getDevice())].append(group)
|
device_groups[self.devices.index(group.getDevice())].append(group)
|
||||||
|
|
||||||
min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
|
min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
|
||||||
@ -207,15 +190,17 @@ class DyntrainModel:
|
|||||||
for key in DyntrainModel._get_nonlinear_names(self.model):
|
for key in DyntrainModel._get_nonlinear_names(self.model):
|
||||||
replace_module(self.model, key, modules[key].to(devices[0]))
|
replace_module(self.model, key, modules[key].to(devices[0]))
|
||||||
|
|
||||||
|
linear_groups = self.active_linear_groups + self.frozen_linear_groups
|
||||||
|
|
||||||
group_index = 0
|
group_index = 0
|
||||||
for device in devices[:-1]:
|
for device in devices[:-1]:
|
||||||
params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
|
params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
|
||||||
params = 0
|
params = 0
|
||||||
while params_for_device > params and group_index < len(self.linear_groups):
|
while params_for_device > params and group_index < len(linear_groups):
|
||||||
self.linear_groups[group_index].inplaceTo(device=device)
|
linear_groups[group_index].inplaceTo(device=device)
|
||||||
params += self.linear_groups[group_index].paramCount()
|
params += linear_groups[group_index].paramCount()
|
||||||
group_index += 1
|
group_index += 1
|
||||||
|
|
||||||
while group_index < len(self.linear_groups):
|
while group_index < len(linear_groups):
|
||||||
self.linear_groups[group_index].inplaceTo(device=devices[-1])
|
linear_groups[group_index].inplaceTo(device=devices[-1])
|
||||||
group_index += 1
|
group_index += 1
|
||||||
|
129
modules.py
129
modules.py
@ -1,4 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch.multiprocessing as multiprocessing
|
||||||
|
from typing import overload, Optional, Union
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
class Linear(torch.nn.Linear):
|
class Linear(torch.nn.Linear):
|
||||||
@ -34,12 +38,22 @@ class Linear(torch.nn.Linear):
|
|||||||
self.weight = torch.nn.Parameter(self.weight.to(device))
|
self.weight = torch.nn.Parameter(self.weight.to(device))
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = torch.nn.Parameter(self.bias.to(device))
|
self.bias = torch.nn.Parameter(self.bias.to(device))
|
||||||
self.setFrozen(frozen)
|
Linear.setFrozen(self, frozen)
|
||||||
|
|
||||||
|
def _apply(self, fn, recurse: bool = True):
|
||||||
|
if fn.__name__ == "convert":
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
return super()._apply(fn, recurse)
|
||||||
|
|
||||||
|
@wraps(torch.nn.Module.to)
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
breakpoint()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ConvertingLinear(Linear):
|
class DynamicConvertingLinear(Linear):
|
||||||
def __init__(self,
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None,
|
||||||
in_features, out_features, bias=True, device=None, dtype=None,
|
|
||||||
output_dtype=None, output_device=None):
|
output_dtype=None, output_device=None):
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
self.output_dtype = output_dtype
|
self.output_dtype = output_dtype
|
||||||
@ -58,6 +72,14 @@ class ConvertingLinear(Linear):
|
|||||||
new_module.bias = in_module.bias
|
new_module.bias = in_module.bias
|
||||||
return new_module
|
return new_module
|
||||||
|
|
||||||
|
def setFrozen(self, frozen: bool):
|
||||||
|
super().setFrozen(frozen)
|
||||||
|
|
||||||
|
if frozen:
|
||||||
|
self.inplaceTo(torch.float16)
|
||||||
|
else:
|
||||||
|
self.inplaceTo(torch.float32)
|
||||||
|
|
||||||
def setOutputDevice(self, output_device: torch.device):
|
def setOutputDevice(self, output_device: torch.device):
|
||||||
self.output_device = output_device
|
self.output_device = output_device
|
||||||
|
|
||||||
@ -69,6 +91,101 @@ class ConvertingLinear(Linear):
|
|||||||
if input.dtype != self.weight.dtype:
|
if input.dtype != self.weight.dtype:
|
||||||
input = input.to(self.weight.dtype)
|
input = input.to(self.weight.dtype)
|
||||||
output = torch.nn.Linear.forward(self, input)
|
output = torch.nn.Linear.forward(self, input)
|
||||||
if torch.isnan(output).any():
|
|
||||||
breakpoint()
|
|
||||||
return output.to(output_device).to(output_dtype)
|
return output.to(output_device).to(output_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicQantizedLinear(Linear):
|
||||||
|
def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device,
|
||||||
|
output_dtype=None, compute_dtype=None, output_device=None):
|
||||||
|
super().__init__(in_features, out_features, bias, cold_device, torch.float32)
|
||||||
|
self.active_device = active_device
|
||||||
|
self.cold_device = cold_device
|
||||||
|
self.output_device = output_device
|
||||||
|
self.output_dtype = output_dtype
|
||||||
|
self.compute_dtype = compute_dtype
|
||||||
|
self.weight_quantized = None
|
||||||
|
self.weight_state = None
|
||||||
|
self.bias_quantized = None
|
||||||
|
self.bias_state = None
|
||||||
|
self.block_size = 128
|
||||||
|
self.quant_type = 'nf4'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device,
|
||||||
|
output_dtype=None, compute_dtype=torch.float16, output_device=None):
|
||||||
|
new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None,
|
||||||
|
active_device=active_device, cold_device=cold_device, output_dtype=output_dtype,
|
||||||
|
compute_dtype=compute_dtype, output_device=output_device)
|
||||||
|
new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
|
||||||
|
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
|
||||||
|
return new_module
|
||||||
|
|
||||||
|
def quantize(self):
|
||||||
|
weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
|
||||||
|
self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size,
|
||||||
|
compress_statistics=False, quant_type=self.quant_type)
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
|
||||||
|
self.bias_quantized, self.bias_state = bnb.functional.quantize_4bit(bias, blocksize=self.block_size,
|
||||||
|
compress_statistics=False, quant_type=self.quant_type)
|
||||||
|
|
||||||
|
weight = torch.nn.Parameter(self.weight.to(self.cold_device))
|
||||||
|
bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
|
||||||
|
|
||||||
|
def dequantize(self):
|
||||||
|
if self.weight_quantized is None:
|
||||||
|
raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
|
||||||
|
dtype = self.weight.dtype
|
||||||
|
self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device))
|
||||||
|
if self.bias_quantized:
|
||||||
|
self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device))
|
||||||
|
self.weight_quantized = None
|
||||||
|
self.weight_state = None
|
||||||
|
self.bias_quantized = None
|
||||||
|
self.bias_state = None
|
||||||
|
|
||||||
|
def checkDistance(self) -> float:
|
||||||
|
if self.weight_quantized is None:
|
||||||
|
raise RuntimeError("checkDistance() called without quantized weights avialable")
|
||||||
|
original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
|
||||||
|
quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight,
|
||||||
|
blocksize=self.block_size,
|
||||||
|
compress_statistics=True,
|
||||||
|
quant_type=self.quant_type)
|
||||||
|
dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
|
||||||
|
dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
||||||
|
return (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight) / dequantized_original_weight.numel()).item()
|
||||||
|
|
||||||
|
def setOutputDevice(self, output_device: torch.device):
|
||||||
|
self.output_device = output_device
|
||||||
|
|
||||||
|
def setFrozen(self, frozen: bool) -> None:
|
||||||
|
if frozen == self.isFrozen():
|
||||||
|
return
|
||||||
|
|
||||||
|
super().setFrozen(frozen)
|
||||||
|
|
||||||
|
if frozen:
|
||||||
|
self.quantize()
|
||||||
|
else:
|
||||||
|
self.dequantize()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
output_dtype = x.dtype if self.output_dtype is None else self.output_dtype
|
||||||
|
output_device = x.device if self.output_device is None else self.output_device
|
||||||
|
|
||||||
|
if not self.isFrozen():
|
||||||
|
if x.device != self.weight.device:
|
||||||
|
x = x.to(self.weight.device)
|
||||||
|
if x.dtype != self.weight.dtype:
|
||||||
|
x = x.to(self.weight.dtype)
|
||||||
|
return super().forward(x).to(output_device).to(output_dtype)
|
||||||
|
else:
|
||||||
|
if self.weight_quantized is None:
|
||||||
|
raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
|
||||||
|
bias = None
|
||||||
|
if self.bias_quantized is not None:
|
||||||
|
bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype)
|
||||||
|
out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state)
|
||||||
|
|
||||||
|
return out.to(output_device).to(output_dtype)
|
||||||
|
@ -60,7 +60,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
secondary_device = torch.device(training_args.secondary_device)
|
secondary_device = torch.device(training_args.secondary_device)
|
||||||
log_writer = tensorboard.SummaryWriter()
|
log_writer = tensorboard.SummaryWriter()
|
||||||
|
|
||||||
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
|
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6,
|
||||||
|
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True)
|
||||||
model.toDevices([primary_device, secondary_device])
|
model.toDevices([primary_device, secondary_device])
|
||||||
model.balanceActive()
|
model.balanceActive()
|
||||||
|
|
||||||
@ -133,7 +134,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
print(loss)
|
print(loss)
|
||||||
|
|
||||||
if global_step % 10 == 0 and model_args.max_instant_params != 0:
|
if global_step % 10 == 0 and training_args.max_instant_params != 0:
|
||||||
lr_scheduler.optimizer = None
|
lr_scheduler.optimizer = None
|
||||||
del optimizer
|
del optimizer
|
||||||
model.reshuffleActive()
|
model.reshuffleActive()
|
||||||
|
19
utils.py
19
utils.py
@ -5,3 +5,22 @@ import torch
|
|||||||
def replace_module(model, key: str, module: torch.nn.Module):
|
def replace_module(model, key: str, module: torch.nn.Module):
|
||||||
parent, target, target_name = _get_submodules(model, key)
|
parent, target, target_name = _get_submodules(model, key)
|
||||||
setattr(parent, target_name, module)
|
setattr(parent, target_name, module)
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_linear_module_names(model) -> list[str]:
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
module_names.add(name)
|
||||||
|
|
||||||
|
if 'lm_head' in module_names: # needed for 16-bit
|
||||||
|
module_names.remove('lm_head')
|
||||||
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_outher_module_names(model) -> list[str]:
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not isinstance(module, torch.nn.Linear):
|
||||||
|
module_names.add(name)
|
||||||
|
return list(module_names)
|
||||||
|
Reference in New Issue
Block a user