working full training
This commit is contained in:
		
							parent
							
								
									7a47fcdcc0
								
							
						
					
					
						commit
						11ea9eeaa7
					
				
					 2 changed files with 25 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -2,11 +2,12 @@ import torch
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ConvertingLinear(torch.nn.Linear):
 | 
			
		||||
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
 | 
			
		||||
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
 | 
			
		||||
        super().__init__(in_features, out_features, bias, device, dtype)
 | 
			
		||||
        self.output_dtype = output_dtype
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor):
 | 
			
		||||
        output_dtype = input.dtype
 | 
			
		||||
        output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
 | 
			
		||||
        output_device = input.device
 | 
			
		||||
        if input.device != self.weight.device:
 | 
			
		||||
            input = input.to(self.weight.device)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,9 +7,11 @@ from torch.utils import tensorboard
 | 
			
		|||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
import math
 | 
			
		||||
import collections
 | 
			
		||||
from tqdm.auto import tqdm
 | 
			
		||||
from random import randint
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
import collections
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
			
		||||
from datamodules import create_data_module_s2s, create_data_module
 | 
			
		||||
| 
						 | 
				
			
			@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
 | 
			
		|||
        attn_implementation="flash_attention_2"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # for name, module in model.named_modules():
 | 
			
		||||
    #     if 'norm' in name:
 | 
			
		||||
    #         module = module.to(torch.float32)
 | 
			
		||||
    modules = dict(model.named_modules())
 | 
			
		||||
    keys = find_all_linear_module_names(model)
 | 
			
		||||
    for key in keys:
 | 
			
		||||
        modules[key].weight = modules[key].weight.to(dtype)
 | 
			
		||||
        if modules[key].bias is not None:
 | 
			
		||||
            modules[key].bias = modules[key].bias.to(dtype)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
 | 
			
		|||
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
 | 
			
		||||
                  weight_decay: float, eps: float, adam8bit: bool):
 | 
			
		||||
 | 
			
		||||
    all_keys = dynamic_module_names + static_module_names
 | 
			
		||||
    duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
 | 
			
		||||
    if len(duplicated) > 0:
 | 
			
		||||
        print("duplicated items:")
 | 
			
		||||
        for item in duplicated:
 | 
			
		||||
            print(item)
 | 
			
		||||
        raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
 | 
			
		||||
 | 
			
		||||
    parameters = list()
 | 
			
		||||
    modules = dict(model.named_modules())
 | 
			
		||||
    for key in dynamic_module_names:
 | 
			
		||||
        parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
 | 
			
		||||
    param_ids = set([id(p['params']) for p in parameters])
 | 
			
		||||
    for key in static_module_names:
 | 
			
		||||
        parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
 | 
			
		||||
        parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
 | 
			
		||||
        for p in modules[key].parameters():
 | 
			
		||||
            param_ids.add(id(p))
 | 
			
		||||
 | 
			
		||||
    if not adam8bit:
 | 
			
		||||
        optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
 | 
			
		||||
| 
						 | 
				
			
			@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model):
 | 
			
		|||
 | 
			
		||||
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
 | 
			
		||||
    model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
 | 
			
		||||
 | 
			
		||||
    tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
 | 
			
		||||
 | 
			
		||||
    if data_args.dataset.endswith("json"):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue