Inactive parameter quanitzation support
This commit is contained in:
		
							parent
							
								
									3fa1fc254f
								
							
						
					
					
						commit
						c33964371c
					
				
					 4 changed files with 161 additions and 78 deletions
				
			
		
							
								
								
									
										11
									
								
								arguments.py
									
										
									
									
									
								
							
							
						
						
									
										11
									
								
								arguments.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -45,6 +45,10 @@ class ModelArguments:
 | 
			
		|||
        default=False,
 | 
			
		||||
        metadata={"help": "Never resize tokenizer embeddings"}
 | 
			
		||||
    )
 | 
			
		||||
    quantize: Optional[bool] = field (
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Quantize parameters not currently be actively trained"}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
| 
						 | 
				
			
			@ -85,9 +89,8 @@ class TrainingArguments():
 | 
			
		|||
    save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
 | 
			
		||||
    max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'})
 | 
			
		||||
    save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
 | 
			
		||||
    primary_device: str = field(default="cuda:0", metadata={"help": 'The primary device to use'})
 | 
			
		||||
    secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'})
 | 
			
		||||
    train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'})
 | 
			
		||||
    train_non_linear_layers: Optional[bool] = field(default=False, metadata={"help": 'train non linear layers'})
 | 
			
		||||
    flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'})
 | 
			
		||||
    max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"})
 | 
			
		||||
    churn_percent: int = field(default=0, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"})
 | 
			
		||||
    churn_percent: int = field(default=100, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"})
 | 
			
		||||
    eval_steps: int = field(default=-1, metadata={"help": "Number of optimization steps after wich to compute the evaluation loss"})
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										103
									
								
								dyntrainmodel.py
									
										
									
									
									
								
							
							
						
						
									
										103
									
								
								dyntrainmodel.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -1,9 +1,10 @@
 | 
			
		|||
from transformers import AutoModelForCausalLM
 | 
			
		||||
import torch
 | 
			
		||||
from utils import replace_module
 | 
			
		||||
from modules import DynamicConvertingLinear, Linear
 | 
			
		||||
from modules import DynamicConvertingLinear, Linear, DynamicQantizedLinear
 | 
			
		||||
from random import randint
 | 
			
		||||
import math
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearGroup:
 | 
			
		||||
| 
						 | 
				
			
			@ -20,9 +21,9 @@ class LinearGroup:
 | 
			
		|||
            module.inplaceTo(dtype, device)
 | 
			
		||||
        self.modules[-1].setOutputDevice(output_device)
 | 
			
		||||
 | 
			
		||||
    def setFrozen(self, frozen: bool) -> None:
 | 
			
		||||
    def setFrozen(self, frozen: bool, convert: bool = True) -> None:
 | 
			
		||||
        for module in self.modules:
 | 
			
		||||
            module.setFrozen(frozen)
 | 
			
		||||
            module.setFrozen(frozen, convert)
 | 
			
		||||
 | 
			
		||||
    def isFrozen(self) -> bool:
 | 
			
		||||
        return self.modules[0].isFrozen()
 | 
			
		||||
| 
						 | 
				
			
			@ -39,9 +40,26 @@ class LinearGroup:
 | 
			
		|||
    def getDevice(self) -> torch.device:
 | 
			
		||||
        return self.modules[0].weight.device
 | 
			
		||||
 | 
			
		||||
    def compress(self) -> None:
 | 
			
		||||
        for module in self.modules:
 | 
			
		||||
            module.compress()
 | 
			
		||||
 | 
			
		||||
    def decompress(self) -> None:
 | 
			
		||||
        for module in self.modules:
 | 
			
		||||
            module.decompress()
 | 
			
		||||
 | 
			
		||||
    def checkDistance(self) -> tuple[float, float]:
 | 
			
		||||
        distance_accum = 0.0
 | 
			
		||||
        error_accum = 0.0
 | 
			
		||||
        for module in self.modules:
 | 
			
		||||
            distance, error = module.checkDistance()
 | 
			
		||||
            distance_accum += distance**2
 | 
			
		||||
            error_accum += error**2
 | 
			
		||||
        return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DyntrainModel:
 | 
			
		||||
    def __init__(self, model_name_or_path: str, cache_dir: str,
 | 
			
		||||
    def __init__(self, model_name_or_path: str, cache_dir: str, quantize: bool,
 | 
			
		||||
                 target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
 | 
			
		||||
        self.model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
            model_name_or_path,
 | 
			
		||||
| 
						 | 
				
			
			@ -55,28 +73,32 @@ class DyntrainModel:
 | 
			
		|||
        if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
 | 
			
		||||
            raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
 | 
			
		||||
        self.devices = list()
 | 
			
		||||
        self.inital_reshufle = True
 | 
			
		||||
 | 
			
		||||
        if gradient_checkpointing:
 | 
			
		||||
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
 | 
			
		||||
 | 
			
		||||
        modules = dict(self.model.named_modules())
 | 
			
		||||
        self.frozen_linear_groups = list()
 | 
			
		||||
        self.active_linear_groups = list()
 | 
			
		||||
 | 
			
		||||
        linear_group_names = DyntrainModel._get_linear_group_names(self.model)
 | 
			
		||||
        linear_group_names = DyntrainModel._getLinearGroupNames(self.model)
 | 
			
		||||
        for group in linear_group_names:
 | 
			
		||||
            for key in group:
 | 
			
		||||
                if DyntrainModel.isModuleIn16bitOutlist(key):
 | 
			
		||||
                    replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float16))
 | 
			
		||||
                else:
 | 
			
		||||
                    replace_module(self.model, key, DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=torch.float32))
 | 
			
		||||
                replace_module(self.model, key, self._getModule(key, quantize, "cuda:0", "cpu"))
 | 
			
		||||
            self.frozen_linear_groups.append(LinearGroup(self.model, group))
 | 
			
		||||
        self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
 | 
			
		||||
        for group in self.frozen_linear_groups:
 | 
			
		||||
            group.setFrozen(True)
 | 
			
		||||
        self.reshuffleActive()
 | 
			
		||||
            group.setFrozen(True, False)
 | 
			
		||||
 | 
			
		||||
    def _get_nonlinear_names(layer: torch.nn.Module):
 | 
			
		||||
    def _getModule(self, key: str, quantize: bool, active_device: torch.device, cold_device: torch.device):
 | 
			
		||||
        output_dtype = torch.float16 if DyntrainModel.isModuleIn16bitOutlist(key) else torch.float32
 | 
			
		||||
        modules = dict(self.model.named_modules())
 | 
			
		||||
        if quantize:
 | 
			
		||||
            return DynamicQantizedLinear.fromLinear(modules[key], active_device, cold_device, output_dtype, torch.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype)
 | 
			
		||||
 | 
			
		||||
    def _getNonlinearNames(layer: torch.nn.Module):
 | 
			
		||||
        names = list()
 | 
			
		||||
        modules = dict(layer.named_modules())
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +107,7 @@ class DyntrainModel:
 | 
			
		|||
                names.append(key)
 | 
			
		||||
        return names
 | 
			
		||||
 | 
			
		||||
    def _get_linear_group_names(layer: torch.nn.Module) -> list[list[str]]:
 | 
			
		||||
    def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]:
 | 
			
		||||
        linear_groups = list()
 | 
			
		||||
        list_counter = 0
 | 
			
		||||
        in_sequence = False
 | 
			
		||||
| 
						 | 
				
			
			@ -140,8 +162,11 @@ class DyntrainModel:
 | 
			
		|||
 | 
			
		||||
    def reshuffleActive(self) -> None:
 | 
			
		||||
        active_count = len(self.active_linear_groups)
 | 
			
		||||
        index = 0
 | 
			
		||||
        while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
 | 
			
		||||
            group = self.active_linear_groups.pop(0)
 | 
			
		||||
            distance, error = self.active_linear_groups[index].checkDistance()
 | 
			
		||||
            print(f"linear group has moved {distance} with an error of {error}")
 | 
			
		||||
            group = self.active_linear_groups.pop(index)
 | 
			
		||||
            group.setFrozen(True)
 | 
			
		||||
            self.frozen_linear_groups.append(group)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -161,25 +186,39 @@ class DyntrainModel:
 | 
			
		|||
 | 
			
		||||
        assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
 | 
			
		||||
 | 
			
		||||
    def activeParamtersByDevice(self) -> list[int]:
 | 
			
		||||
        out = [0] * len(self.devices)
 | 
			
		||||
        for group in self.active_linear_groups:
 | 
			
		||||
            out[self.devices.index(group.getDevice())] += group.paramCount()
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def balanceActive(self) -> None:
 | 
			
		||||
        device_groups = list()
 | 
			
		||||
        for index in range(0, len(self.devices)):
 | 
			
		||||
            device_groups.append(list())
 | 
			
		||||
        active_counts = self.activeParamtersByDevice()
 | 
			
		||||
        bits_per_param = list()
 | 
			
		||||
        for i, count in enumerate(active_counts):
 | 
			
		||||
            memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
 | 
			
		||||
            if i == 0:
 | 
			
		||||
                memory = memory * 0.8
 | 
			
		||||
            bits_per_param.append(count / memory)
 | 
			
		||||
 | 
			
		||||
        max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
 | 
			
		||||
        min_index, min_bits_per_param = min(enumerate(active_counts), key=lambda x: x[1])
 | 
			
		||||
 | 
			
		||||
        for group in self.active_linear_groups:
 | 
			
		||||
            device_groups[self.devices.index(group.getDevice())].append(group)
 | 
			
		||||
 | 
			
		||||
        min_index, min_count = min(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
 | 
			
		||||
        max_index, max_count = max(enumerate(len(grouplist) for grouplist in device_groups), key=lambda x: x[1])
 | 
			
		||||
 | 
			
		||||
        if max_count - 2 > min_count:
 | 
			
		||||
            device_groups[max_index][0].inplaceTo(device=self.devices[min_index])
 | 
			
		||||
            self.balanceActive()
 | 
			
		||||
            if group.getDevice() is self.devices[max_index]:
 | 
			
		||||
                memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
 | 
			
		||||
                if max_index == 0:
 | 
			
		||||
                    memory = memory * 0.8
 | 
			
		||||
                swing = group.paramCount() / memory
 | 
			
		||||
                if max_bits_per_param - swing > min_bits_per_param + swing:
 | 
			
		||||
                    group.inplaceTo(device=self.devices[min_index])
 | 
			
		||||
                    self.balanceActive()
 | 
			
		||||
 | 
			
		||||
    def toDevices(self, devices: list[torch.device]) -> None:
 | 
			
		||||
        assert len(devices) > 0
 | 
			
		||||
        modules = dict(self.model.named_modules())
 | 
			
		||||
        total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
 | 
			
		||||
        total_memory -= torch.cuda.get_device_properties(devices[0]).total_memory * 0.2
 | 
			
		||||
        static_param_count = self.staticParameterCount()
 | 
			
		||||
        total_parameter_count = static_param_count + self.dynamicParameterCount()
 | 
			
		||||
        params_per_byte = total_parameter_count / float(total_memory)
 | 
			
		||||
| 
						 | 
				
			
			@ -187,14 +226,17 @@ class DyntrainModel:
 | 
			
		|||
 | 
			
		||||
        self.devices = devices
 | 
			
		||||
 | 
			
		||||
        for key in DyntrainModel._get_nonlinear_names(self.model):
 | 
			
		||||
        for key in DyntrainModel._getNonlinearNames(self.model):
 | 
			
		||||
            replace_module(self.model, key, modules[key].to(devices[0]))
 | 
			
		||||
 | 
			
		||||
        linear_groups = self.active_linear_groups + self.frozen_linear_groups
 | 
			
		||||
 | 
			
		||||
        group_index = 0
 | 
			
		||||
        for device in devices[:-1]:
 | 
			
		||||
            params_for_device = torch.cuda.get_device_properties(devices).total_memory * params_per_byte
 | 
			
		||||
        for i, device in enumerate(devices[:-1]):
 | 
			
		||||
            memory = torch.cuda.get_device_properties(devices).total_memory
 | 
			
		||||
            if i == 0:
 | 
			
		||||
                memory = memory * 0.8
 | 
			
		||||
            params_for_device = memory * params_per_byte
 | 
			
		||||
            params = 0
 | 
			
		||||
            while params_for_device > params and group_index < len(linear_groups):
 | 
			
		||||
                linear_groups[group_index].inplaceTo(device=device)
 | 
			
		||||
| 
						 | 
				
			
			@ -204,3 +246,6 @@ class DyntrainModel:
 | 
			
		|||
        while group_index < len(linear_groups):
 | 
			
		||||
            linear_groups[group_index].inplaceTo(device=devices[-1])
 | 
			
		||||
            group_index += 1
 | 
			
		||||
 | 
			
		||||
        for group in tqdm(linear_groups, desc="Perpareing layers"):
 | 
			
		||||
            group.compress()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										73
									
								
								modules.py
									
										
									
									
									
								
							
							
						
						
									
										73
									
								
								modules.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -20,10 +20,23 @@ class Linear(torch.nn.Linear):
 | 
			
		|||
        new_module.bias = in_module.bias
 | 
			
		||||
        return new_module
 | 
			
		||||
 | 
			
		||||
    def setFrozen(self, frozen: bool):
 | 
			
		||||
    def compress(self) -> None:
 | 
			
		||||
        self.inplaceTo(torch.float16)
 | 
			
		||||
 | 
			
		||||
    def decompress(self) -> None:
 | 
			
		||||
        self.inplaceTo(torch.float32)
 | 
			
		||||
 | 
			
		||||
    def setFrozen(self, frozen: bool, convert: bool = True):
 | 
			
		||||
        self.weight.requires_grad = not frozen
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            self.bias.requires_grad = not frozen
 | 
			
		||||
        if convert:
 | 
			
		||||
            if frozen:
 | 
			
		||||
                breakpoint()
 | 
			
		||||
                self.compress()
 | 
			
		||||
            else:
 | 
			
		||||
                self.decompress()
 | 
			
		||||
                self.weightStart = torch.Tensor(self.weight).clone().detach()
 | 
			
		||||
 | 
			
		||||
    def isFrozen(self) -> bool:
 | 
			
		||||
        return not self.weight.requires_grad
 | 
			
		||||
| 
						 | 
				
			
			@ -38,7 +51,7 @@ class Linear(torch.nn.Linear):
 | 
			
		|||
            self.weight = torch.nn.Parameter(self.weight.to(device))
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                self.bias = torch.nn.Parameter(self.bias.to(device))
 | 
			
		||||
        Linear.setFrozen(self, frozen)
 | 
			
		||||
        Linear.setFrozen(self, frozen, False)
 | 
			
		||||
 | 
			
		||||
    def _apply(self, fn, recurse: bool = True):
 | 
			
		||||
        if fn.__name__ == "convert":
 | 
			
		||||
| 
						 | 
				
			
			@ -72,17 +85,12 @@ class DynamicConvertingLinear(Linear):
 | 
			
		|||
        new_module.bias = in_module.bias
 | 
			
		||||
        return new_module
 | 
			
		||||
 | 
			
		||||
    def setFrozen(self, frozen: bool):
 | 
			
		||||
        super().setFrozen(frozen)
 | 
			
		||||
 | 
			
		||||
        if frozen:
 | 
			
		||||
            self.inplaceTo(torch.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            self.inplaceTo(torch.float32)
 | 
			
		||||
 | 
			
		||||
    def setOutputDevice(self, output_device: torch.device):
 | 
			
		||||
        self.output_device = output_device
 | 
			
		||||
 | 
			
		||||
    def checkDistance(self) -> tuple[float, float]:
 | 
			
		||||
        return (10.0, 0.0)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor):
 | 
			
		||||
        output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
 | 
			
		||||
        output_device = input.device if self.output_device is None else self.output_device
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +128,7 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
        new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
 | 
			
		||||
        return new_module
 | 
			
		||||
 | 
			
		||||
    def quantize(self):
 | 
			
		||||
    def compress(self) -> None:
 | 
			
		||||
        weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
 | 
			
		||||
        self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size,
 | 
			
		||||
                                                                                compress_statistics=False, quant_type=self.quant_type)
 | 
			
		||||
| 
						 | 
				
			
			@ -132,19 +140,15 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
        weight = torch.nn.Parameter(self.weight.to(self.cold_device))
 | 
			
		||||
        bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
 | 
			
		||||
 | 
			
		||||
    def dequantize(self):
 | 
			
		||||
    def decompress(self) -> None:
 | 
			
		||||
        if self.weight_quantized is None:
 | 
			
		||||
            raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
 | 
			
		||||
            raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable")
 | 
			
		||||
        dtype = self.weight.dtype
 | 
			
		||||
        self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device))
 | 
			
		||||
        if self.bias_quantized:
 | 
			
		||||
            self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device))
 | 
			
		||||
        self.weight_quantized = None
 | 
			
		||||
        self.weight_state = None
 | 
			
		||||
        self.bias_quantized = None
 | 
			
		||||
        self.bias_state = None
 | 
			
		||||
 | 
			
		||||
    def checkDistance(self) -> float:
 | 
			
		||||
    def checkDistance(self) -> tuple[float, float]:
 | 
			
		||||
        if self.weight_quantized is None:
 | 
			
		||||
            raise RuntimeError("checkDistance() called without quantized weights avialable")
 | 
			
		||||
        original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
 | 
			
		||||
| 
						 | 
				
			
			@ -154,22 +158,13 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
                                                                                           quant_type=self.quant_type)
 | 
			
		||||
        dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
 | 
			
		||||
        dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype)
 | 
			
		||||
        return (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight) / dequantized_original_weight.numel()).item()
 | 
			
		||||
        distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
			
		||||
        error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
			
		||||
        return (distance, error)
 | 
			
		||||
 | 
			
		||||
    def setOutputDevice(self, output_device: torch.device):
 | 
			
		||||
        self.output_device = output_device
 | 
			
		||||
 | 
			
		||||
    def setFrozen(self, frozen: bool) -> None:
 | 
			
		||||
        if frozen == self.isFrozen():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        super().setFrozen(frozen)
 | 
			
		||||
 | 
			
		||||
        if frozen:
 | 
			
		||||
            self.quantize()
 | 
			
		||||
        else:
 | 
			
		||||
            self.dequantize()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        output_dtype = x.dtype if self.output_dtype is None else self.output_dtype
 | 
			
		||||
        output_device = x.device if self.output_device is None else self.output_device
 | 
			
		||||
| 
						 | 
				
			
			@ -183,9 +178,27 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
        else:
 | 
			
		||||
            if self.weight_quantized is None:
 | 
			
		||||
                raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
 | 
			
		||||
            if x.device != self.weight_quantized.device:
 | 
			
		||||
                x = x.to(self.weight_quantized.device)
 | 
			
		||||
            bias = None
 | 
			
		||||
            if self.bias_quantized is not None:
 | 
			
		||||
                bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype)
 | 
			
		||||
            out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state)
 | 
			
		||||
 | 
			
		||||
            return out.to(output_device).to(output_dtype)
 | 
			
		||||
 | 
			
		||||
    def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
 | 
			
		||||
        if dtype is not None:
 | 
			
		||||
            super().inplaceTo(dtype=dtype)
 | 
			
		||||
        if device is not None:
 | 
			
		||||
            frozen = self.isFrozen()
 | 
			
		||||
            self.active_device = device
 | 
			
		||||
            if self.weight_quantized is not None:
 | 
			
		||||
                self.weight_quantized = self.weight_quantized.to(device)
 | 
			
		||||
                self.weight_state = self.weight_state.to(device)
 | 
			
		||||
                if self.bias_quantized is not None:
 | 
			
		||||
                    self.bias_quantized = self.bias_quantized.to(device)
 | 
			
		||||
                    self.bias_state = self.bias_state.to(device)
 | 
			
		||||
            if not frozen:
 | 
			
		||||
                super().inplaceTo(device=device)
 | 
			
		||||
            self.setFrozen(frozen, False)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,10 +39,11 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters
 | 
			
		|||
    parameters = list()
 | 
			
		||||
    parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
 | 
			
		||||
    param_ids = set([id(p['params']) for p in parameters])
 | 
			
		||||
    for param in static_parameters:
 | 
			
		||||
        if param.requires_grad and id(param) not in param_ids:
 | 
			
		||||
            parameters.append({'params': param, 'lr': static_lr})
 | 
			
		||||
            param_ids.add(id(param))
 | 
			
		||||
    if static_parameters is not None:
 | 
			
		||||
        for param in static_parameters:
 | 
			
		||||
            if param.requires_grad and id(param) not in param_ids:
 | 
			
		||||
                parameters.append({'params': param, 'lr': static_lr})
 | 
			
		||||
                param_ids.add(id(param))
 | 
			
		||||
 | 
			
		||||
    if not adam8bit:
 | 
			
		||||
        optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
 | 
			
		||||
| 
						 | 
				
			
			@ -55,19 +56,34 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters
 | 
			
		|||
    return optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> float:
 | 
			
		||||
    print("*** Eval ***")
 | 
			
		||||
    loss = torch.zeros((1), device="cuda:0")
 | 
			
		||||
    model.model.eval()
 | 
			
		||||
    for batch in dataloader:
 | 
			
		||||
        for key in batch:
 | 
			
		||||
            batch[key] = batch[key].to("cuda:0")
 | 
			
		||||
        outputs = model.model(**batch)
 | 
			
		||||
        loss += outputs.loss
 | 
			
		||||
    loss = loss / len(dataloader)
 | 
			
		||||
    print(f"Eval Loss {loss.item()}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
 | 
			
		||||
    primary_device = torch.device(training_args.primary_device)
 | 
			
		||||
    secondary_device = torch.device(training_args.secondary_device)
 | 
			
		||||
    log_writer = tensorboard.SummaryWriter()
 | 
			
		||||
 | 
			
		||||
    model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6,
 | 
			
		||||
                          reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True)
 | 
			
		||||
    model.toDevices([primary_device, secondary_device])
 | 
			
		||||
                          reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True,
 | 
			
		||||
                          quantize=model_args.quantize)
 | 
			
		||||
    devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
 | 
			
		||||
    model.toDevices(devices)
 | 
			
		||||
    model.reshuffleActive()
 | 
			
		||||
    model.balanceActive()
 | 
			
		||||
 | 
			
		||||
    paramter_count = sum(p.numel() for p in model.model.parameters())
 | 
			
		||||
    active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
 | 
			
		||||
    print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
 | 
			
		||||
    static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
 | 
			
		||||
    print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters of which {static_parameter_count} are static")
 | 
			
		||||
 | 
			
		||||
    tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -96,7 +112,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
    total_steps = steps_per_epoch * training_args.epochs
 | 
			
		||||
 | 
			
		||||
    optimizer = get_optimizer(model.dynamicParameters(),
 | 
			
		||||
                              model.staticParameters(),
 | 
			
		||||
                              model.staticParameters() if training_args.train_non_linear_layers else None,
 | 
			
		||||
                              training_args.learning_rate,
 | 
			
		||||
                              training_args.learning_rate / dynamic_param_ratio,
 | 
			
		||||
                              training_args.weight_decay,
 | 
			
		||||
| 
						 | 
				
			
			@ -115,6 +131,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
        global_step = 0
 | 
			
		||||
        model.model.train()
 | 
			
		||||
        for epoch in range(0, training_args.epochs):
 | 
			
		||||
            model.model.train()
 | 
			
		||||
            print("*** Train ***")
 | 
			
		||||
            print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
 | 
			
		||||
            for step, batch in enumerate(train_dataloader):
 | 
			
		||||
| 
						 | 
				
			
			@ -131,17 +148,17 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
 | 
			
		||||
                    model.model.zero_grad()
 | 
			
		||||
 | 
			
		||||
                    if global_step % 10 == 0:
 | 
			
		||||
                        print(loss)
 | 
			
		||||
                    if global_step % 5 == 0:
 | 
			
		||||
                        print(f"Train Loss {loss.item()}")
 | 
			
		||||
 | 
			
		||||
                    if global_step % 10 == 0 and training_args.max_instant_params != 0:
 | 
			
		||||
                    if global_step % 50 == 0 and training_args.max_instant_params != 0:
 | 
			
		||||
                        lr_scheduler.optimizer = None
 | 
			
		||||
                        del optimizer
 | 
			
		||||
                        model.reshuffleActive()
 | 
			
		||||
                        model.balanceActive()
 | 
			
		||||
                        log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
 | 
			
		||||
                        optimizer = get_optimizer(model.dynamicParameters(),
 | 
			
		||||
                                                  model.staticParameters(),
 | 
			
		||||
                                                  model.staticParameters() if training_args.train_non_linear_layers else None,
 | 
			
		||||
                                                  training_args.learning_rate,
 | 
			
		||||
                                                  training_args.learning_rate / dynamic_param_ratio,
 | 
			
		||||
                                                  training_args.weight_decay,
 | 
			
		||||
| 
						 | 
				
			
			@ -152,14 +169,19 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
                    global_step += 1
 | 
			
		||||
                    progress_bar.update()
 | 
			
		||||
 | 
			
		||||
                if global_step > 0:
 | 
			
		||||
                    if global_step % training_args.save_steps == 0:
 | 
			
		||||
                        save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
 | 
			
		||||
                    if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
 | 
			
		||||
                        evaluate(model, eval_dataloader)
 | 
			
		||||
                if training_args.flush_allocator:
 | 
			
		||||
                    torch.cuda.empty_cache()
 | 
			
		||||
            if training_args.do_eval and training_args.eval_steps == -1:
 | 
			
		||||
                evaluate(model, eval_dataloader)
 | 
			
		||||
 | 
			
		||||
    # Evaluation
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
        print("*** Evaluate ***")
 | 
			
		||||
        evaluate(model, eval_dataloader)
 | 
			
		||||
 | 
			
		||||
    save_model(model.model, global_step, training_args.output_dir)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue