working full training
This commit is contained in:
		
							parent
							
								
									7a47fcdcc0
								
							
						
					
					
						commit
						11ea9eeaa7
					
				
					 2 changed files with 25 additions and 7 deletions
				
			
		| 
						 | 
					@ -2,11 +2,12 @@ import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConvertingLinear(torch.nn.Linear):
 | 
					class ConvertingLinear(torch.nn.Linear):
 | 
				
			||||||
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
 | 
					    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
 | 
				
			||||||
        super().__init__(in_features, out_features, bias, device, dtype)
 | 
					        super().__init__(in_features, out_features, bias, device, dtype)
 | 
				
			||||||
 | 
					        self.output_dtype = output_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, input: torch.Tensor):
 | 
					    def forward(self, input: torch.Tensor):
 | 
				
			||||||
        output_dtype = input.dtype
 | 
					        output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
 | 
				
			||||||
        output_device = input.device
 | 
					        output_device = input.device
 | 
				
			||||||
        if input.device != self.weight.device:
 | 
					        if input.device != self.weight.device:
 | 
				
			||||||
            input = input.to(self.weight.device)
 | 
					            input = input.to(self.weight.device)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,9 +7,11 @@ from torch.utils import tensorboard
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
 | 
					import collections
 | 
				
			||||||
from tqdm.auto import tqdm
 | 
					from tqdm.auto import tqdm
 | 
				
			||||||
from random import randint
 | 
					from random import randint
 | 
				
			||||||
from typing import Tuple
 | 
					import collections
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
					from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
				
			||||||
from datamodules import create_data_module_s2s, create_data_module
 | 
					from datamodules import create_data_module_s2s, create_data_module
 | 
				
			||||||
| 
						 | 
					@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
 | 
				
			||||||
        attn_implementation="flash_attention_2"
 | 
					        attn_implementation="flash_attention_2"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # for name, module in model.named_modules():
 | 
					    modules = dict(model.named_modules())
 | 
				
			||||||
    #     if 'norm' in name:
 | 
					    keys = find_all_linear_module_names(model)
 | 
				
			||||||
    #         module = module.to(torch.float32)
 | 
					    for key in keys:
 | 
				
			||||||
 | 
					        modules[key].weight = modules[key].weight.to(dtype)
 | 
				
			||||||
 | 
					        if modules[key].bias is not None:
 | 
				
			||||||
 | 
					            modules[key].bias = modules[key].bias.to(dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
 | 
				
			||||||
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
 | 
					def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
 | 
				
			||||||
                  weight_decay: float, eps: float, adam8bit: bool):
 | 
					                  weight_decay: float, eps: float, adam8bit: bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    all_keys = dynamic_module_names + static_module_names
 | 
				
			||||||
 | 
					    duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
 | 
				
			||||||
 | 
					    if len(duplicated) > 0:
 | 
				
			||||||
 | 
					        print("duplicated items:")
 | 
				
			||||||
 | 
					        for item in duplicated:
 | 
				
			||||||
 | 
					            print(item)
 | 
				
			||||||
 | 
					        raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    parameters = list()
 | 
					    parameters = list()
 | 
				
			||||||
    modules = dict(model.named_modules())
 | 
					    modules = dict(model.named_modules())
 | 
				
			||||||
    for key in dynamic_module_names:
 | 
					    for key in dynamic_module_names:
 | 
				
			||||||
        parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
 | 
					        parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					    param_ids = set([id(p['params']) for p in parameters])
 | 
				
			||||||
    for key in static_module_names:
 | 
					    for key in static_module_names:
 | 
				
			||||||
        parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
 | 
					        parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
 | 
				
			||||||
 | 
					        for p in modules[key].parameters():
 | 
				
			||||||
 | 
					            param_ids.add(id(p))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not adam8bit:
 | 
					    if not adam8bit:
 | 
				
			||||||
        optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
 | 
					        optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
 | 
				
			||||||
| 
						 | 
					@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
 | 
					def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
 | 
				
			||||||
    model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
 | 
					    model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
 | 
					    tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if data_args.dataset.endswith("json"):
 | 
					    if data_args.dataset.endswith("json"):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue