various optimizations
This commit is contained in:
		
							parent
							
								
									2f35689355
								
							
						
					
					
						commit
						c38ac65d5b
					
				
					 4 changed files with 151 additions and 101 deletions
				
			
		| 
						 | 
					@ -68,7 +68,9 @@ class LinearGroup:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DyntrainModel:
 | 
					class DyntrainModel:
 | 
				
			||||||
    def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
 | 
					    def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
 | 
				
			||||||
                 target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
 | 
					                 target_active_params: int, train_static_params: bool,
 | 
				
			||||||
 | 
					                 reshuffle_fraction: float, gradient_checkpointing: bool,
 | 
				
			||||||
 | 
					                 trust_remote_code: bool = False):
 | 
				
			||||||
        self.model = AutoModelForCausalLM.from_pretrained(
 | 
					        self.model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
            model_name_or_path,
 | 
					            model_name_or_path,
 | 
				
			||||||
            cache_dir=cache_dir,
 | 
					            cache_dir=cache_dir,
 | 
				
			||||||
| 
						 | 
					@ -82,6 +84,7 @@ class DyntrainModel:
 | 
				
			||||||
            raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
 | 
					            raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
 | 
				
			||||||
        self.devices = list[torch.device]()
 | 
					        self.devices = list[torch.device]()
 | 
				
			||||||
        self.inital_reshufle = True
 | 
					        self.inital_reshufle = True
 | 
				
			||||||
 | 
					        self.train_static_params = train_static_params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if gradient_checkpointing:
 | 
					        if gradient_checkpointing:
 | 
				
			||||||
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
 | 
					            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
 | 
				
			||||||
| 
						 | 
					@ -167,8 +170,14 @@ class DyntrainModel:
 | 
				
			||||||
    def staticParameterCount(self) -> int:
 | 
					    def staticParameterCount(self) -> int:
 | 
				
			||||||
        return sum(p.numel() for p in self.staticParameters())
 | 
					        return sum(p.numel() for p in self.staticParameters())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def activeDynamicParameterCount(self) -> int:
 | 
				
			||||||
 | 
					        return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def activeParameterCount(self) -> int:
 | 
					    def activeParameterCount(self) -> int:
 | 
				
			||||||
 | 
					        if self.train_static_params:
 | 
				
			||||||
            total_params = self.dynamicParameters() + self.staticParameters()
 | 
					            total_params = self.dynamicParameters() + self.staticParameters()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            total_params = self.dynamicParameters()
 | 
				
			||||||
        return sum(p.numel() for p in total_params if p.requires_grad)
 | 
					        return sum(p.numel() for p in total_params if p.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
 | 
					    def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
 | 
				
			||||||
| 
						 | 
					@ -187,7 +196,7 @@ class DyntrainModel:
 | 
				
			||||||
        params = self.activeParameterCount()
 | 
					        params = self.activeParameterCount()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if params >= self.target_active_params:
 | 
					        if params >= self.target_active_params:
 | 
				
			||||||
            RuntimeError("Insuficant active parameters to suffle active")
 | 
					            raise RuntimeError("Insuficant active parameters to suffle active")
 | 
				
			||||||
        while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
 | 
					        while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
 | 
				
			||||||
            i = randint(0, len(self.frozen_linear_groups) - 1)
 | 
					            i = randint(0, len(self.frozen_linear_groups) - 1)
 | 
				
			||||||
            group = self.frozen_linear_groups.pop(i)
 | 
					            group = self.frozen_linear_groups.pop(i)
 | 
				
			||||||
| 
						 | 
					@ -199,7 +208,7 @@ class DyntrainModel:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        active_params = self.activeParameterCount()
 | 
					        active_params = self.activeParameterCount()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
 | 
					        assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def activeParamtersByDevice(self) -> list[int]:
 | 
					    def activeParamtersByDevice(self) -> list[int]:
 | 
				
			||||||
        out = [0] * len(self.devices)
 | 
					        out = [0] * len(self.devices)
 | 
				
			||||||
| 
						 | 
					@ -213,7 +222,7 @@ class DyntrainModel:
 | 
				
			||||||
        for i, count in enumerate(active_counts):
 | 
					        for i, count in enumerate(active_counts):
 | 
				
			||||||
            memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
 | 
					            memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
 | 
				
			||||||
            if i == 0:
 | 
					            if i == 0:
 | 
				
			||||||
                memory = int(memory * 0.8)
 | 
					                memory = int(memory * 0.5)
 | 
				
			||||||
            bits_per_param.append(count / memory)
 | 
					            bits_per_param.append(count / memory)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
 | 
					        max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
 | 
				
			||||||
| 
						 | 
					@ -223,7 +232,7 @@ class DyntrainModel:
 | 
				
			||||||
            if group.getDevice() is self.devices[max_index]:
 | 
					            if group.getDevice() is self.devices[max_index]:
 | 
				
			||||||
                memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
 | 
					                memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
 | 
				
			||||||
                if max_index == 0:
 | 
					                if max_index == 0:
 | 
				
			||||||
                    memory = int(memory * 0.8)
 | 
					                    memory = int(memory * 0.5)
 | 
				
			||||||
                swing = group.paramCount() / memory
 | 
					                swing = group.paramCount() / memory
 | 
				
			||||||
                if max_bits_per_param - swing > min_bits_per_param + swing:
 | 
					                if max_bits_per_param - swing > min_bits_per_param + swing:
 | 
				
			||||||
                    group.inplaceTo(device=self.devices[min_index])
 | 
					                    group.inplaceTo(device=self.devices[min_index])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										24
									
								
								modules.py
									
										
									
									
									
								
							
							
						
						
									
										24
									
								
								modules.py
									
										
									
									
									
								
							| 
						 | 
					@ -108,7 +108,7 @@ class DynamicConvertingLinear(Linear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DynamicQantizedLinear(Linear):
 | 
					class DynamicQantizedLinear(Linear):
 | 
				
			||||||
    def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device,
 | 
					    def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device,
 | 
				
			||||||
                 output_dtype=None, compute_dtype=None, output_device=None):
 | 
					                 output_dtype=None, compute_dtype=None, output_device=None, cold_dtype=torch.float32):
 | 
				
			||||||
        super().__init__(in_features, out_features, bias, cold_device, torch.float32)
 | 
					        super().__init__(in_features, out_features, bias, cold_device, torch.float32)
 | 
				
			||||||
        self.active_device = active_device
 | 
					        self.active_device = active_device
 | 
				
			||||||
        self.cold_device = cold_device
 | 
					        self.cold_device = cold_device
 | 
				
			||||||
| 
						 | 
					@ -120,8 +120,8 @@ class DynamicQantizedLinear(Linear):
 | 
				
			||||||
        self.bias_quantized = None
 | 
					        self.bias_quantized = None
 | 
				
			||||||
        self.bias_state = None
 | 
					        self.bias_state = None
 | 
				
			||||||
        self.block_size = 128
 | 
					        self.block_size = 128
 | 
				
			||||||
        self.quant_type = 'nf4'
 | 
					        #self.weight_start = self.weight.clone().detach()
 | 
				
			||||||
        self.weight_start = self.weight.clone().detach()
 | 
					        self.cold_dtype = cold_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
 | 
					    def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
 | 
				
			||||||
| 
						 | 
					@ -131,19 +131,19 @@ class DynamicQantizedLinear(Linear):
 | 
				
			||||||
                         compute_dtype=compute_dtype, output_device=output_device)
 | 
					                         compute_dtype=compute_dtype, output_device=output_device)
 | 
				
			||||||
        new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
 | 
					        new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
 | 
				
			||||||
        new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
 | 
					        new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
 | 
				
			||||||
        new_module.weight_start = new_module.weight.clone().detach()
 | 
					        #new_module.weight_start = new_module.weight.clone().detach()
 | 
				
			||||||
        return new_module
 | 
					        return new_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compress(self) -> None:
 | 
					    def compress(self) -> None:
 | 
				
			||||||
        weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
 | 
					        weight = self.weight.contiguous().to(torch.float16).to(self.active_device)
 | 
				
			||||||
        self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size)
 | 
					        self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size)
 | 
				
			||||||
        if self.bias is not None:
 | 
					        if self.bias is not None:
 | 
				
			||||||
            bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
 | 
					            bias = self.bias.contiguous().to(torch.float16).to(self.active_device)
 | 
				
			||||||
            self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
 | 
					            self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        frozen = self.isFrozen()
 | 
					        frozen = self.isFrozen()
 | 
				
			||||||
        self.weight = torch.nn.Parameter(self.weight.to(self.cold_device))
 | 
					        self.weight = torch.nn.Parameter(self.weight.to(self.cold_dtype).to(self.cold_device))
 | 
				
			||||||
        self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
 | 
					        self.bias = torch.nn.Parameter(self.bias.to(self.cold_dtype).to(self.cold_device)) if self.bias is not None else None
 | 
				
			||||||
        self.setFrozen(frozen, False)
 | 
					        self.setFrozen(frozen, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def decompress(self) -> None:
 | 
					    def decompress(self) -> None:
 | 
				
			||||||
| 
						 | 
					@ -151,16 +151,16 @@ class DynamicQantizedLinear(Linear):
 | 
				
			||||||
        self.weight_state = None
 | 
					        self.weight_state = None
 | 
				
			||||||
        self.bias_quantized = None
 | 
					        self.bias_quantized = None
 | 
				
			||||||
        self.bias_state = None
 | 
					        self.bias_state = None
 | 
				
			||||||
        self.weight_start = self.weight.clone().detach().to(self.cold_device)
 | 
					        #self.weight_start = self.weight.clone().detach().to(self.cold_device)
 | 
				
			||||||
        self.weight = torch.nn.Parameter(self.weight.to(self.active_device))
 | 
					        self.weight = torch.nn.Parameter(self.weight.to(self.active_device).to(torch.float32))
 | 
				
			||||||
        if self.bias_quantized:
 | 
					        if self.bias_quantized:
 | 
				
			||||||
            self.bias = torch.nn.Parameter(self.bias.to(self.active_device))
 | 
					            self.bias = torch.nn.Parameter(self.bias.to(self.active_device).to(torch.float32))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
 | 
					    def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
        original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
 | 
					        original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
 | 
				
			||||||
        quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
 | 
					        quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
 | 
				
			||||||
        dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
 | 
					        dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
 | 
				
			||||||
        distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
 | 
					        distance = torch.zeros((2)) #(self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
 | 
				
			||||||
        error = (dequantized_original_weight - original_weight).to(torch.float32)
 | 
					        error = (dequantized_original_weight - original_weight).to(torch.float32)
 | 
				
			||||||
        return (distance, error)
 | 
					        return (distance, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,13 +30,13 @@ def smart_tokenizer_and_embedding_resize(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_tokenizer(model, cache_dir, model_args: ModelArguments):
 | 
					def get_tokenizer(model, cache_dir, model_args: ModelArguments):
 | 
				
			||||||
    print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}')
 | 
					    tokenizer_path = model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path
 | 
				
			||||||
 | 
					    print(f'Tokenizer: {tokenizer_path}')
 | 
				
			||||||
    tokenizer = transformers.AutoTokenizer.from_pretrained(
 | 
					    tokenizer = transformers.AutoTokenizer.from_pretrained(
 | 
				
			||||||
        model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path,
 | 
					        tokenizer_path,
 | 
				
			||||||
        cache_dir=cache_dir,
 | 
					        cache_dir=cache_dir,
 | 
				
			||||||
        padding_side="right",
 | 
					        padding_side="right",
 | 
				
			||||||
        use_fast=False,
 | 
					        use_fast=False,
 | 
				
			||||||
        eos_token="[EOS]",
 | 
					 | 
				
			||||||
        tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
 | 
					        tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
 | 
				
			||||||
        trust_remote_code=model_args.trust_remote_code
 | 
					        trust_remote_code=model_args.trust_remote_code
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										133
									
								
								train_dynamic.py
									
										
									
									
									
								
							
							
						
						
									
										133
									
								
								train_dynamic.py
									
										
									
									
									
								
							| 
						 | 
					@ -1,6 +1,4 @@
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
from transformers import get_scheduler
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.utils import tensorboard
 | 
					from torch.utils import tensorboard
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
| 
						 | 
					@ -8,9 +6,10 @@ import shutil
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from tqdm.auto import tqdm
 | 
					from tqdm.auto import tqdm
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
					from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
				
			||||||
from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub
 | 
					from datamodules import get_data_loaders
 | 
				
			||||||
from tokenizer import get_tokenizer
 | 
					from tokenizer import get_tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from dyntrainmodel import DyntrainModel
 | 
					from dyntrainmodel import DyntrainModel
 | 
				
			||||||
| 
						 | 
					@ -19,7 +18,16 @@ from dyntrainmodel import DyntrainModel
 | 
				
			||||||
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
 | 
					def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
 | 
				
			||||||
    output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
 | 
					    output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
 | 
				
			||||||
    output_dir = os.path.join(output_dir, output_chkpt_dir)
 | 
					    output_dir = os.path.join(output_dir, output_chkpt_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(f"saveing model to {output_chkpt_dir}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    temperature = model.generation_config.temperature
 | 
				
			||||||
 | 
					    top_p = model.generation_config.top_p
 | 
				
			||||||
 | 
					    model.generation_config.temperature = None
 | 
				
			||||||
 | 
					    model.generation_config.top_p = None
 | 
				
			||||||
    model.save_pretrained(output_dir)
 | 
					    model.save_pretrained(output_dir)
 | 
				
			||||||
 | 
					    model.generation_config.temperature = temperature
 | 
				
			||||||
 | 
					    model.generation_config.top_p = top_p
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if max_checkpoints > 0:
 | 
					    if max_checkpoints > 0:
 | 
				
			||||||
        files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
 | 
					        files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
 | 
				
			||||||
| 
						 | 
					@ -57,13 +65,42 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
 | 
				
			||||||
    return optimizer
 | 
					    return optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def move_optimizer_param(param, device: torch.device, device_map: dict):
 | 
				
			||||||
 | 
					    if isinstance(param, torch.Tensor):
 | 
				
			||||||
 | 
					        move_device = device if device is not None else device_map[id(param)]
 | 
				
			||||||
 | 
					        assert device is not None or move_device != torch.device("cpu")
 | 
				
			||||||
 | 
					        old_device = param.device
 | 
				
			||||||
 | 
					        param.data = param.data.to(move_device)
 | 
				
			||||||
 | 
					        if param._grad is not None:
 | 
				
			||||||
 | 
					            param._grad.data = param._grad.data.to(move_device)
 | 
				
			||||||
 | 
					        if device is not None and id(param) not in device_map:
 | 
				
			||||||
 | 
					            device_map[id(param)] = old_device
 | 
				
			||||||
 | 
					            assert old_device != torch.device("cpu")
 | 
				
			||||||
 | 
					    elif isinstance(param, dict):
 | 
				
			||||||
 | 
					        for subparam in param.values():
 | 
				
			||||||
 | 
					            move_optimizer_param(subparam, device, device_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def suspend_optimizer(optimizer) -> dict:
 | 
				
			||||||
 | 
					    device_map = dict()
 | 
				
			||||||
 | 
					    for param in optimizer.state.values():
 | 
				
			||||||
 | 
					        move_optimizer_param(param, torch.device("cpu"), device_map)
 | 
				
			||||||
 | 
					    return device_map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def resume_optimizer(optimizer, device_map: dict):
 | 
				
			||||||
 | 
					    for param in optimizer.state.values():
 | 
				
			||||||
 | 
					        move_optimizer_param(param, None, device_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def evaluate(model: DyntrainModel, tokenizer,
 | 
					def evaluate(model: DyntrainModel, tokenizer,
 | 
				
			||||||
             dataloader: torch.utils.data.DataLoader, globalstep: int,
 | 
					             dataloader: torch.utils.data.DataLoader, globalstep: int,
 | 
				
			||||||
             log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
 | 
					             log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
 | 
				
			||||||
    print("*** Eval ***")
 | 
					    with torch.no_grad():
 | 
				
			||||||
        loss = torch.zeros((1), device="cuda:0")
 | 
					        loss = torch.zeros((1), device="cuda:0")
 | 
				
			||||||
        model.model.eval()
 | 
					        model.model.eval()
 | 
				
			||||||
    for batch in dataloader:
 | 
					
 | 
				
			||||||
 | 
					        for batch in tqdm(dataloader, desc="Doing eval"):
 | 
				
			||||||
            for key in batch:
 | 
					            for key in batch:
 | 
				
			||||||
                batch[key] = batch[key].to("cuda:0")
 | 
					                batch[key] = batch[key].to("cuda:0")
 | 
				
			||||||
            outputs = model.model(**batch)
 | 
					            outputs = model.model(**batch)
 | 
				
			||||||
| 
						 | 
					@ -71,23 +108,42 @@ def evaluate(model: DyntrainModel, tokenizer,
 | 
				
			||||||
        loss = loss / len(dataloader)
 | 
					        loss = loss / len(dataloader)
 | 
				
			||||||
        log_writer.add_scalar("Loss/Eval", loss, globalstep)
 | 
					        log_writer.add_scalar("Loss/Eval", loss, globalstep)
 | 
				
			||||||
        print(f"Eval Loss {loss.item()}")
 | 
					        print(f"Eval Loss {loss.item()}")
 | 
				
			||||||
    return loss.item()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if eval_prompt is not None:
 | 
					        if eval_prompt is not None:
 | 
				
			||||||
            input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
 | 
					            input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
 | 
				
			||||||
            attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
 | 
					            attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
 | 
				
			||||||
        outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100)
 | 
					            outputs = model.model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1,
 | 
				
			||||||
 | 
					                                           max_new_tokens=100, min_new_tokens=100)
 | 
				
			||||||
            response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 | 
					            response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 | 
				
			||||||
            print(f"Eval generation: {response_decoded}")
 | 
					            print(f"Eval generation: {response_decoded}")
 | 
				
			||||||
            log_writer.add_text("Text/Eval", response_decoded, globalstep)
 | 
					            log_writer.add_text("Text/Eval", response_decoded, globalstep)
 | 
				
			||||||
 | 
					        model.model.train()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def max_vram_allocated():
 | 
				
			||||||
 | 
					    max_vram_alloc = 0
 | 
				
			||||||
 | 
					    for i in range(0, torch.cuda.device_count()):
 | 
				
			||||||
 | 
					        max_vram_alloc = max(torch.cuda.memory_allocated(i), max_vram_alloc)
 | 
				
			||||||
 | 
					    return max_vram_alloc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def min_vram_allocated():
 | 
				
			||||||
 | 
					    max_vram_alloc = sys.maxsize
 | 
				
			||||||
 | 
					    for i in range(0, torch.cuda.device_count()):
 | 
				
			||||||
 | 
					        max_vram_alloc = min(torch.cuda.memory_allocated(i), max_vram_alloc)
 | 
				
			||||||
 | 
					    return max_vram_alloc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
 | 
					def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
 | 
				
			||||||
    log_writer = tensorboard.SummaryWriter()
 | 
					    log_writer = tensorboard.SummaryWriter(log_dir=training_args.logging_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6),
 | 
					    model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir,
 | 
				
			||||||
                          reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True,
 | 
					                          quantize=model_args.quantize,
 | 
				
			||||||
                          quantize=model_args.quantize)
 | 
					                          target_active_params=int(training_args.max_instant_params * 1e6),
 | 
				
			||||||
 | 
					                          train_static_params=training_args.train_non_linear_layers,
 | 
				
			||||||
 | 
					                          reshuffle_fraction=training_args.churn_percent / 100.0,
 | 
				
			||||||
 | 
					                          gradient_checkpointing=True,
 | 
				
			||||||
 | 
					                          trust_remote_code=True)
 | 
				
			||||||
    devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
 | 
					    devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
 | 
				
			||||||
    model.toDevices(devices)
 | 
					    model.toDevices(devices)
 | 
				
			||||||
    model.reshuffleActive()
 | 
					    model.reshuffleActive()
 | 
				
			||||||
| 
						 | 
					@ -101,29 +157,10 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
 | 
					    tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"):
 | 
					    train_dataloader, eval_dataloader = get_data_loaders(tokenizer, data_args,
 | 
				
			||||||
        print("Loading dataset in s2s mode")
 | 
					                                                         training_args.per_device_train_batch_size,
 | 
				
			||||||
        data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
					                                                         training_args.per_device_eval_batch_size,
 | 
				
			||||||
    elif data_args.data_from_hub:
 | 
					                                                         training_args.do_train, training_args.do_eval)
 | 
				
			||||||
        data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        print("Loading dataset in txt mode")
 | 
					 | 
				
			||||||
        data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
 | 
					 | 
				
			||||||
    train_dataloader = torch.utils.data.DataLoader(
 | 
					 | 
				
			||||||
        dataset['train_dataset'],
 | 
					 | 
				
			||||||
        shuffle=True,
 | 
					 | 
				
			||||||
        collate_fn=dataset['data_collator'],
 | 
					 | 
				
			||||||
        batch_size=training_args.per_device_train_batch_size
 | 
					 | 
				
			||||||
    ) if dataset['train_dataset'] is not None else None
 | 
					 | 
				
			||||||
    if training_args.do_eval:
 | 
					 | 
				
			||||||
        eval_dataloader = torch.utils.data.DataLoader(
 | 
					 | 
				
			||||||
            dataset['eval_dataset'],
 | 
					 | 
				
			||||||
            shuffle=True,
 | 
					 | 
				
			||||||
            collate_fn=dataset['data_collator'],
 | 
					 | 
				
			||||||
            batch_size=training_args.per_device_train_batch_size
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
 | 
					    dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
 | 
				
			||||||
    steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
 | 
					    steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
 | 
				
			||||||
| 
						 | 
					@ -137,7 +174,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
                              training_args.adam_epsilon,
 | 
					                              training_args.adam_epsilon,
 | 
				
			||||||
                              training_args.adam8bit)
 | 
					                              training_args.adam8bit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lr_scheduler = get_scheduler(
 | 
					    lr_scheduler = transformers.get_scheduler(
 | 
				
			||||||
        name=training_args.lr_scheduler_type,
 | 
					        name=training_args.lr_scheduler_type,
 | 
				
			||||||
        optimizer=optimizer,
 | 
					        optimizer=optimizer,
 | 
				
			||||||
        num_warmup_steps=training_args.warmup_steps,
 | 
					        num_warmup_steps=training_args.warmup_steps,
 | 
				
			||||||
| 
						 | 
					@ -149,13 +186,11 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
        global_step = 0
 | 
					        global_step = 0
 | 
				
			||||||
        model.model.train()
 | 
					        model.model.train()
 | 
				
			||||||
        for epoch in range(0, training_args.epochs):
 | 
					        for epoch in range(0, training_args.epochs):
 | 
				
			||||||
            model.model.train()
 | 
					 | 
				
			||||||
            print("*** Train ***")
 | 
					            print("*** Train ***")
 | 
				
			||||||
            print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
 | 
					            print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0**3):.2f}')
 | 
				
			||||||
            for step, batch in enumerate(train_dataloader):
 | 
					            for step, batch in enumerate(train_dataloader):
 | 
				
			||||||
                for key in batch:
 | 
					                for key in batch:
 | 
				
			||||||
                    batch[key] = batch[key].to("cuda:0")
 | 
					                    batch[key] = batch[key].to("cuda:0")
 | 
				
			||||||
 | 
					 | 
				
			||||||
                outputs = model.model(**batch)
 | 
					                outputs = model.model(**batch)
 | 
				
			||||||
                loss = outputs.loss / training_args.gradient_accumulation_steps
 | 
					                loss = outputs.loss / training_args.gradient_accumulation_steps
 | 
				
			||||||
                loss.backward()
 | 
					                loss.backward()
 | 
				
			||||||
| 
						 | 
					@ -166,11 +201,12 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
                    optimizer.step()
 | 
					                    optimizer.step()
 | 
				
			||||||
                    lr_scheduler.step()
 | 
					                    lr_scheduler.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    progress_bar.set_postfix_str(f"Loss: {loss.item():.2f} Max: {max_vram_allocated()/(1024.0**3):.2f}GB"
 | 
				
			||||||
 | 
					                                                 f" Min: {min_vram_allocated()/(1024.0**3):.2f}GB")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    model.model.zero_grad()
 | 
					                    model.model.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if global_step % 5 == 0:
 | 
					                    if global_step > 0:
 | 
				
			||||||
                        print(f"Train Loss {loss.item()}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
 | 
					                        if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
 | 
				
			||||||
                            print("Reshuffleing")
 | 
					                            print("Reshuffleing")
 | 
				
			||||||
                            lr_scheduler.optimizer = None
 | 
					                            lr_scheduler.optimizer = None
 | 
				
			||||||
| 
						 | 
					@ -191,21 +227,26 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
                                                      training_args.adam8bit)
 | 
					                                                      training_args.adam8bit)
 | 
				
			||||||
                            lr_scheduler.optimizer = optimizer
 | 
					                            lr_scheduler.optimizer = optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    global_step += 1
 | 
					 | 
				
			||||||
                    progress_bar.update()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if global_step > 0:
 | 
					 | 
				
			||||||
                        if global_step % training_args.save_steps == 0:
 | 
					                        if global_step % training_args.save_steps == 0:
 | 
				
			||||||
                            save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
 | 
					                            save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
 | 
				
			||||||
                        if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
 | 
					                        if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
 | 
				
			||||||
 | 
					                            device_map = suspend_optimizer(optimizer)
 | 
				
			||||||
                            evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
					                            evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
				
			||||||
 | 
					                            resume_optimizer(optimizer, device_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    global_step += 1
 | 
				
			||||||
 | 
					                    progress_bar.update()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if training_args.flush_allocator:
 | 
					                if training_args.flush_allocator:
 | 
				
			||||||
                    gc.collect()
 | 
					                    gc.collect()
 | 
				
			||||||
                    torch.cuda.empty_cache()
 | 
					                    torch.cuda.empty_cache()
 | 
				
			||||||
            if training_args.do_eval and training_args.eval_steps == -1:
 | 
					            if training_args.do_eval and training_args.eval_steps == -1:
 | 
				
			||||||
 | 
					                device_map = suspend_optimizer(optimizer)
 | 
				
			||||||
                evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
					                evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
				
			||||||
 | 
					                resume_optimizer(optimizer, device_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    del optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Evaluation
 | 
					 | 
				
			||||||
    if training_args.do_eval:
 | 
					    if training_args.do_eval:
 | 
				
			||||||
        evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
					        evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue