Inital commit
This commit is contained in:
		
						commit
						7a47fcdcc0
					
				
					 5 changed files with 716 additions and 0 deletions
				
			
		
							
								
								
									
										95
									
								
								arguments.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								arguments.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,95 @@
 | 
				
			||||||
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class DataArguments:
 | 
				
			||||||
 | 
					    eval_dataset_size: int = field(
 | 
				
			||||||
 | 
					        default=512, metadata={"help": "Size of validation dataset."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    source_max_len: int = field(
 | 
				
			||||||
 | 
					        default=512,
 | 
				
			||||||
 | 
					        metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    train_on_source: Optional[bool] = field(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        metadata={"help": "Wether to train on the input in addition to the target text when in s2s mode."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    target_max_len: int = field(
 | 
				
			||||||
 | 
					        default=256,
 | 
				
			||||||
 | 
					        metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    dataset: str = field(
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    block_size: int = field(
 | 
				
			||||||
 | 
					        default=512,
 | 
				
			||||||
 | 
					        metadata={"help": "size of the blocks the text is split into for training"},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class ModelArguments:
 | 
				
			||||||
 | 
					    model_name_or_path: Optional[str] = field(
 | 
				
			||||||
 | 
					        default="EleutherAI/pythia-12b"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    tokenizer: Optional[str] = field(
 | 
				
			||||||
 | 
					        default=None
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    trust_remote_code: Optional[bool] = field(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    max_instant_params: int = field(
 | 
				
			||||||
 | 
					        default=0,
 | 
				
			||||||
 | 
					        metadata={"help": "Maximum amount of paramters to optimize per step in millions"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    noresize: Optional[bool] = field(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        metadata={"help": "Never resize tokenizer embeddings"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class TrainingArguments():
 | 
				
			||||||
 | 
					    cache_dir: Optional[str] = field(
 | 
				
			||||||
 | 
					        default=None
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    adam8bit: bool = field(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        metadata={"help": "Use 8-bit adam."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    report_to: str = field(
 | 
				
			||||||
 | 
					        default='none',
 | 
				
			||||||
 | 
					        metadata={"help": "To use wandb or something else for reporting."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
 | 
				
			||||||
 | 
					    ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
 | 
				
			||||||
 | 
					    output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
 | 
				
			||||||
 | 
					    per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
 | 
				
			||||||
 | 
					    gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
 | 
				
			||||||
 | 
					    epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
 | 
				
			||||||
 | 
					    weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
 | 
				
			||||||
 | 
					    learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'})
 | 
				
			||||||
 | 
					    adam_epsilon: float = field(default=1e-7, metadata={"help": 'Adam epsilon'})
 | 
				
			||||||
 | 
					    remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
 | 
				
			||||||
 | 
					    max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
 | 
				
			||||||
 | 
					    gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
 | 
				
			||||||
 | 
					    fp16: bool = field(default=False, metadata={"help": 'Train in 16 bit mixed precision'})
 | 
				
			||||||
 | 
					    do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
 | 
				
			||||||
 | 
					    do_eval: bool = field(default=False, metadata={"help": 'To eval or not to eval, that is the question?'})
 | 
				
			||||||
 | 
					    lr_scheduler_type: str = field(default='constant',
 | 
				
			||||||
 | 
					                                   metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
 | 
				
			||||||
 | 
					    warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
 | 
				
			||||||
 | 
					    logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
 | 
				
			||||||
 | 
					    group_by_length: bool = field(default=False,
 | 
				
			||||||
 | 
					                                  metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
 | 
				
			||||||
 | 
					    storage_fp16: bool = field(default=False, metadata={"help": 'Store untrained layers in 16bit'})
 | 
				
			||||||
 | 
					    save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
 | 
				
			||||||
 | 
					    max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'})
 | 
				
			||||||
 | 
					    save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
 | 
				
			||||||
 | 
					    primary_device: str = field(default="cuda:0", metadata={"help": 'The primary device to use'})
 | 
				
			||||||
 | 
					    secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'})
 | 
				
			||||||
 | 
					    train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'})
 | 
				
			||||||
 | 
					    flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'})
 | 
				
			||||||
							
								
								
									
										29
									
								
								convertinglinear.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								convertinglinear.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,29 @@
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConvertingLinear(torch.nn.Linear):
 | 
				
			||||||
 | 
					    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
 | 
				
			||||||
 | 
					        super().__init__(in_features, out_features, bias, device, dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, input: torch.Tensor):
 | 
				
			||||||
 | 
					        output_dtype = input.dtype
 | 
				
			||||||
 | 
					        output_device = input.device
 | 
				
			||||||
 | 
					        if input.device != self.weight.device:
 | 
				
			||||||
 | 
					            input = input.to(self.weight.device)
 | 
				
			||||||
 | 
					        if input.dtype != self.weight.dtype:
 | 
				
			||||||
 | 
					            input = input.to(self.weight.dtype)
 | 
				
			||||||
 | 
					        output = torch.nn.Linear.forward(self, input)
 | 
				
			||||||
 | 
					        if torch.isnan(output).any() or self.weight.dtype != torch.float32:
 | 
				
			||||||
 | 
					            breakpoint()
 | 
				
			||||||
 | 
					        return output.to(output_device).to(output_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def fromLinear(cls, in_module: torch.nn.Linear):
 | 
				
			||||||
 | 
					        new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
 | 
				
			||||||
 | 
					                                              out_features=in_module.out_features,
 | 
				
			||||||
 | 
					                                              bias=in_module.bias is not None,
 | 
				
			||||||
 | 
					                                              device=in_module.weight.device,
 | 
				
			||||||
 | 
					                                              dtype=in_module.weight.dtype)
 | 
				
			||||||
 | 
					        new_module.weight = in_module.weight
 | 
				
			||||||
 | 
					        new_module.bias = in_module.bias
 | 
				
			||||||
 | 
					        return new_module
 | 
				
			||||||
							
								
								
									
										192
									
								
								datamodules.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								datamodules.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,192 @@
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import typing
 | 
				
			||||||
 | 
					import datasets
 | 
				
			||||||
 | 
					import itertools
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from torch.nn.utils.rnn import pad_sequence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from arguments import DataArguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					IGNORE_INDEX = -100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def group_texts(examples, block_size: int):
 | 
				
			||||||
 | 
					    # Concatenate all texts.
 | 
				
			||||||
 | 
					    concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
 | 
				
			||||||
 | 
					    total_length = len(concatenated_examples[list(examples.keys())[0]])
 | 
				
			||||||
 | 
					    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
 | 
				
			||||||
 | 
					    # customize this part to your needs.
 | 
				
			||||||
 | 
					    if total_length >= block_size:
 | 
				
			||||||
 | 
					        total_length = (total_length // block_size) * block_size
 | 
				
			||||||
 | 
					    # Split by chunks of max_len.
 | 
				
			||||||
 | 
					    result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
 | 
				
			||||||
 | 
					    result["labels"] = result["input_ids"].copy()
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class DataCollatorForCausalLM(object):
 | 
				
			||||||
 | 
					    tokenizer: transformers.PreTrainedTokenizer
 | 
				
			||||||
 | 
					    source_max_len: int
 | 
				
			||||||
 | 
					    target_max_len: int
 | 
				
			||||||
 | 
					    train_on_source: bool
 | 
				
			||||||
 | 
					    predict_with_generate: bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        # Extract elements
 | 
				
			||||||
 | 
					        sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances]
 | 
				
			||||||
 | 
					        targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
 | 
				
			||||||
 | 
					        # Tokenize
 | 
				
			||||||
 | 
					        tokenized_sources_with_prompt = self.tokenizer(
 | 
				
			||||||
 | 
					            sources,
 | 
				
			||||||
 | 
					            max_length=self.source_max_len,
 | 
				
			||||||
 | 
					            truncation=True,
 | 
				
			||||||
 | 
					            add_special_tokens=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        tokenized_targets = self.tokenizer(
 | 
				
			||||||
 | 
					            targets,
 | 
				
			||||||
 | 
					            max_length=self.target_max_len,
 | 
				
			||||||
 | 
					            truncation=True,
 | 
				
			||||||
 | 
					            add_special_tokens=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Build the input and labels for causal LM
 | 
				
			||||||
 | 
					        input_ids = []
 | 
				
			||||||
 | 
					        labels = []
 | 
				
			||||||
 | 
					        for tokenized_source, tokenized_target in zip(
 | 
				
			||||||
 | 
					            tokenized_sources_with_prompt['input_ids'],
 | 
				
			||||||
 | 
					            tokenized_targets['input_ids']
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            if not self.predict_with_generate:
 | 
				
			||||||
 | 
					                input_ids.append(torch.tensor(tokenized_source + tokenized_target))
 | 
				
			||||||
 | 
					                if not self.train_on_source:
 | 
				
			||||||
 | 
					                    labels.append(
 | 
				
			||||||
 | 
					                        torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                input_ids.append(torch.tensor(tokenized_source))
 | 
				
			||||||
 | 
					        # Apply padding
 | 
				
			||||||
 | 
					        padding_value = None
 | 
				
			||||||
 | 
					        if self.tokenizer.pad_token_id is not None:
 | 
				
			||||||
 | 
					            padding_value = self.tokenizer.pad_token_id
 | 
				
			||||||
 | 
					        elif self.tokenizer.eos_token_id is not None:
 | 
				
			||||||
 | 
					            padding_value = self.tokenizer.eos_token_id
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise RuntimeError("Model dose not have a pad or eos token")
 | 
				
			||||||
 | 
					        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value)
 | 
				
			||||||
 | 
					        labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
 | 
				
			||||||
 | 
					        data_dict = {
 | 
				
			||||||
 | 
					            'input_ids': input_ids,
 | 
				
			||||||
 | 
					            'attention_mask': input_ids.ne(padding_value),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if labels is not None:
 | 
				
			||||||
 | 
					            data_dict['labels'] = labels
 | 
				
			||||||
 | 
					        return data_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
 | 
				
			||||||
 | 
					    except FileNotFoundError as ex:
 | 
				
			||||||
 | 
					        raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if do_eval or do_predict:
 | 
				
			||||||
 | 
					        if 'eval' in dataset:
 | 
				
			||||||
 | 
					            eval_dataset = dataset['eval']
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
				
			||||||
 | 
					            dataset = dataset.train_test_split(
 | 
				
			||||||
 | 
					                test_size=data_args.eval_dataset_size, shuffle=True, seed=42
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            eval_dataset = dataset['test']
 | 
				
			||||||
 | 
					            eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if 'train' in dataset:
 | 
				
			||||||
 | 
					        train_dataset = dataset['train']
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        train_dataset = dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data_collator = DataCollatorForCausalLM(
 | 
				
			||||||
 | 
					        tokenizer=tokenizer,
 | 
				
			||||||
 | 
					        source_max_len=data_args.source_max_len,
 | 
				
			||||||
 | 
					        target_max_len=data_args.target_max_len,
 | 
				
			||||||
 | 
					        train_on_source=data_args.train_on_source,
 | 
				
			||||||
 | 
					        predict_with_generate=False  # args.predict_with_generate,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return dict(
 | 
				
			||||||
 | 
					        train_dataset=train_dataset if do_train else None,
 | 
				
			||||||
 | 
					        eval_dataset=eval_dataset if do_eval else None,
 | 
				
			||||||
 | 
					        predict_dataset=eval_dataset if do_predict else None,
 | 
				
			||||||
 | 
					        data_collator=data_collator
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
 | 
				
			||||||
 | 
					    except FileNotFoundError as ex:
 | 
				
			||||||
 | 
					        raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if data_args.block_size > tokenizer.model_max_length:
 | 
				
			||||||
 | 
					        raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_newline_fn(example):
 | 
				
			||||||
 | 
					        example['text'] = example['text'] + '\n'
 | 
				
			||||||
 | 
					        return example
 | 
				
			||||||
 | 
					    dataset = dataset.map(add_newline_fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    eval_dataset = None
 | 
				
			||||||
 | 
					    if do_eval or do_predict:
 | 
				
			||||||
 | 
					        if 'eval' in dataset:
 | 
				
			||||||
 | 
					            eval_dataset = dataset['eval']
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
				
			||||||
 | 
					            dataset = dataset.train_test_split(
 | 
				
			||||||
 | 
					                test_size=data_args.eval_dataset_size, shuffle=True, seed=42
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            eval_dataset = dataset['test']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if 'train' in dataset:
 | 
				
			||||||
 | 
					        train_dataset = dataset['train']
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        train_dataset = dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataset_tokenized = train_dataset.map(
 | 
				
			||||||
 | 
					        lambda example: tokenizer(example['text']),
 | 
				
			||||||
 | 
					        batched=True,
 | 
				
			||||||
 | 
					        remove_columns='text',
 | 
				
			||||||
 | 
					        num_proc=32,
 | 
				
			||||||
 | 
					        load_from_cache_file=True)
 | 
				
			||||||
 | 
					    train_dataset_tokenized = train_dataset_tokenized.map(
 | 
				
			||||||
 | 
					        lambda example: group_texts(example, data_args.block_size),
 | 
				
			||||||
 | 
					        batched=True,
 | 
				
			||||||
 | 
					        num_proc=32,
 | 
				
			||||||
 | 
					        load_from_cache_file=True,
 | 
				
			||||||
 | 
					        desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    eval_dataset_tokenized = None
 | 
				
			||||||
 | 
					    if eval_dataset is not None:
 | 
				
			||||||
 | 
					        eval_dataset_tokenized = eval_dataset.map(
 | 
				
			||||||
 | 
					            lambda example: tokenizer(example['text']),
 | 
				
			||||||
 | 
					            batched=True,
 | 
				
			||||||
 | 
					            remove_columns='text',
 | 
				
			||||||
 | 
					            num_proc=32)
 | 
				
			||||||
 | 
					        eval_dataset_tokenized = eval_dataset_tokenized.map(
 | 
				
			||||||
 | 
					            lambda example: group_texts(example, data_args.block_size),
 | 
				
			||||||
 | 
					            batched=True,
 | 
				
			||||||
 | 
					            num_proc=32,
 | 
				
			||||||
 | 
					            load_from_cache_file=True,
 | 
				
			||||||
 | 
					            desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return dict(
 | 
				
			||||||
 | 
					        train_dataset=train_dataset_tokenized if do_train else None,
 | 
				
			||||||
 | 
					        eval_dataset=eval_dataset_tokenized if do_eval else None,
 | 
				
			||||||
 | 
					        predict_dataset=eval_dataset_tokenized if do_predict else None,
 | 
				
			||||||
 | 
					        data_collator=transformers.default_data_collator
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
							
								
								
									
										62
									
								
								tokenizer.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								tokenizer.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,62 @@
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from arguments import ModelArguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_PAD_TOKEN = "[PAD]"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def smart_tokenizer_and_embedding_resize(
 | 
				
			||||||
 | 
					    special_tokens_dict: dict,
 | 
				
			||||||
 | 
					    tokenizer: transformers.PreTrainedTokenizer,
 | 
				
			||||||
 | 
					    model: transformers.PreTrainedModel,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """Resize tokenizer and embedding.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
 | 
				
			||||||
 | 
					    model.resize_token_embeddings(len(tokenizer))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if num_new_tokens > 0:
 | 
				
			||||||
 | 
					        input_embeddings_data = model.get_input_embeddings().weight.data
 | 
				
			||||||
 | 
					        output_embeddings_data = model.get_output_embeddings().weight.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
 | 
				
			||||||
 | 
					        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
 | 
				
			||||||
 | 
					        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_tokenizer(model, cache_dir, model_args: ModelArguments):
 | 
				
			||||||
 | 
					    print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}')
 | 
				
			||||||
 | 
					    tokenizer = transformers.AutoTokenizer.from_pretrained(
 | 
				
			||||||
 | 
					        model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path,
 | 
				
			||||||
 | 
					        cache_dir=cache_dir,
 | 
				
			||||||
 | 
					        padding_side="right",
 | 
				
			||||||
 | 
					        use_fast=False,
 | 
				
			||||||
 | 
					        eos_token="[EOS]",
 | 
				
			||||||
 | 
					        tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
 | 
				
			||||||
 | 
					        trust_remote_code=model_args.trust_remote_code
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if tokenizer._pad_token is None and not model_args.noresize:
 | 
				
			||||||
 | 
					        smart_tokenizer_and_embedding_resize(
 | 
				
			||||||
 | 
					            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
 | 
				
			||||||
 | 
					            tokenizer=tokenizer,
 | 
				
			||||||
 | 
					            model=model,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if 'llama' in model_args.model_name_or_path or isinstance(tokenizer, transformers.LlamaTokenizer):
 | 
				
			||||||
 | 
					        # LLaMA tokenizer may not have correct special tokens set.
 | 
				
			||||||
 | 
					        # Check and add them if missing to prevent them from being parsed into different tokens.
 | 
				
			||||||
 | 
					        # Note that these are present in the vocabulary.
 | 
				
			||||||
 | 
					        # Note also that `model.config.pad_token_id` is 0 which corresponds to `<unk>` token.
 | 
				
			||||||
 | 
					        print('Adding special tokens.')
 | 
				
			||||||
 | 
					        tokenizer.add_special_tokens({
 | 
				
			||||||
 | 
					            "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
 | 
				
			||||||
 | 
					            "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
 | 
				
			||||||
 | 
					            "unk_token": tokenizer.convert_ids_to_tokens(
 | 
				
			||||||
 | 
					                model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					    return tokenizer
 | 
				
			||||||
							
								
								
									
										338
									
								
								train_dynamic.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										338
									
								
								train_dynamic.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,338 @@
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					from transformers import AutoModelForCausalLM, get_scheduler
 | 
				
			||||||
 | 
					from peft.utils import _get_submodules
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils import tensorboard
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import shutil
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from tqdm.auto import tqdm
 | 
				
			||||||
 | 
					from random import randint
 | 
				
			||||||
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
				
			||||||
 | 
					from datamodules import create_data_module_s2s, create_data_module
 | 
				
			||||||
 | 
					from convertinglinear import ConvertingLinear
 | 
				
			||||||
 | 
					from tokenizer import get_tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def find_all_linear_module_names(model):
 | 
				
			||||||
 | 
					    module_names = set()
 | 
				
			||||||
 | 
					    for name, module in model.named_modules():
 | 
				
			||||||
 | 
					        if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
 | 
				
			||||||
 | 
					            module_names.add(name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if 'lm_head' in module_names:  # needed for 16-bit
 | 
				
			||||||
 | 
					        module_names.remove('lm_head')
 | 
				
			||||||
 | 
					    return list(module_names)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def find_all_outher_module_names(model):
 | 
				
			||||||
 | 
					    module_names = set()
 | 
				
			||||||
 | 
					    for name, module in model.named_modules():
 | 
				
			||||||
 | 
					        if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
 | 
				
			||||||
 | 
					            module_names.add(name)
 | 
				
			||||||
 | 
					    return list(module_names)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
 | 
				
			||||||
 | 
					    dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32
 | 
				
			||||||
 | 
					    print(f'loading base model {model_args.model_name_or_path} in {dtype}...')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					        model_args.model_name_or_path,
 | 
				
			||||||
 | 
					        cache_dir=cache_dir,
 | 
				
			||||||
 | 
					        torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32,
 | 
				
			||||||
 | 
					        trust_remote_code=model_args.trust_remote_code,
 | 
				
			||||||
 | 
					        device_map=None,
 | 
				
			||||||
 | 
					        attn_implementation="flash_attention_2"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # for name, module in model.named_modules():
 | 
				
			||||||
 | 
					    #     if 'norm' in name:
 | 
				
			||||||
 | 
					    #         module = module.to(torch.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def recursive_setattr(obj, attr, value):
 | 
				
			||||||
 | 
					    attr = attr.split('.', 1)
 | 
				
			||||||
 | 
					    if len(attr) == 1:
 | 
				
			||||||
 | 
					        setattr(obj, attr[0], value)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        recursive_setattr(getattr(obj, attr[0]), attr[1], value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device):
 | 
				
			||||||
 | 
					    new_module = torch.nn.Linear(module.in_features,
 | 
				
			||||||
 | 
					                                 module.out_features,
 | 
				
			||||||
 | 
					                                 module.bias is not None,
 | 
				
			||||||
 | 
					                                 module.weight.device,
 | 
				
			||||||
 | 
					                                 dtype)
 | 
				
			||||||
 | 
					    new_module.weight = torch.nn.Parameter(module.weight.detach().clone())
 | 
				
			||||||
 | 
					    new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None
 | 
				
			||||||
 | 
					    new_module.weight.requires_grad = not frozen
 | 
				
			||||||
 | 
					    if new_module.bias is not None:
 | 
				
			||||||
 | 
					        new_module.bias.requires_grad = not frozen
 | 
				
			||||||
 | 
					    return new_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device):
 | 
				
			||||||
 | 
					    if type(module) is torch.nn.Linear:
 | 
				
			||||||
 | 
					        if frozen:
 | 
				
			||||||
 | 
					            module.weight.requires_grad = False
 | 
				
			||||||
 | 
					            if module.bias is not None:
 | 
				
			||||||
 | 
					                module.bias.requires_grad = False
 | 
				
			||||||
 | 
					            return module.to(dtype).to(device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            new_module = ConvertingLinear.fromLinear(module).to(dtype)
 | 
				
			||||||
 | 
					            new_module.weight.requires_grad = True
 | 
				
			||||||
 | 
					            if new_module.bias is not None:
 | 
				
			||||||
 | 
					                new_module.bias.requires_grad = True
 | 
				
			||||||
 | 
					            return new_module.to(device)
 | 
				
			||||||
 | 
					    elif type(module) is ConvertingLinear:
 | 
				
			||||||
 | 
					        if not frozen:
 | 
				
			||||||
 | 
					            module.weight.requires_grad = True
 | 
				
			||||||
 | 
					            if module.bias is not None:
 | 
				
			||||||
 | 
					                module.bias.requires_grad = True
 | 
				
			||||||
 | 
					            assert False
 | 
				
			||||||
 | 
					            return module.to(dtype).to(device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features,
 | 
				
			||||||
 | 
					                                                  out_features=module.out_features,
 | 
				
			||||||
 | 
					                                                  bias=module.bias is not None,
 | 
				
			||||||
 | 
					                                                  device=module.weight.device,
 | 
				
			||||||
 | 
					                                                  dtype=dtype)
 | 
				
			||||||
 | 
					            new_module.weight = torch.nn.Parameter(module.weight.to(dtype))
 | 
				
			||||||
 | 
					            new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None
 | 
				
			||||||
 | 
					            new_module.weight.requires_grad = False
 | 
				
			||||||
 | 
					            if new_module.bias is not None:
 | 
				
			||||||
 | 
					                new_module.bias.requires_grad = False
 | 
				
			||||||
 | 
					            return new_module.to(device)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device):
 | 
				
			||||||
 | 
					    modules = dict(model.named_modules())
 | 
				
			||||||
 | 
					    linear_names = find_all_linear_module_names(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for key in linear_names:
 | 
				
			||||||
 | 
					        if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad:
 | 
				
			||||||
 | 
					            parent, target, target_name = _get_submodules(model, key)
 | 
				
			||||||
 | 
					            setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device))
 | 
				
			||||||
 | 
					    modules = dict(model.named_modules())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					    if active_paramter_count > target_params:
 | 
				
			||||||
 | 
					        raise RuntimeError("Enough paramters must be available to train at least one linear layer")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while active_paramter_count < target_params and len(linear_names) > 0:
 | 
				
			||||||
 | 
					        i = randint(0, len(linear_names) - 1)
 | 
				
			||||||
 | 
					        parent, target, target_name = _get_submodules(model, linear_names[i])
 | 
				
			||||||
 | 
					        new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device)
 | 
				
			||||||
 | 
					        setattr(parent, target_name, new_module)
 | 
				
			||||||
 | 
					        active_paramter_count += modules[linear_names[i]].weight.numel()
 | 
				
			||||||
 | 
					        if modules[linear_names[i]].bias is not None:
 | 
				
			||||||
 | 
					            active_paramter_count += modules[linear_names[i]].bias.numel()
 | 
				
			||||||
 | 
					        linear_names.pop(i)
 | 
				
			||||||
 | 
					    modules = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return active_paramter_count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
 | 
				
			||||||
 | 
					    output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
 | 
				
			||||||
 | 
					    output_dir = os.path.join(output_dir, output_chkpt_dir)
 | 
				
			||||||
 | 
					    model.save_pretrained(output_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if max_checkpoints > 0:
 | 
				
			||||||
 | 
					        files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.starts_with("step_")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def extract_step(filename):
 | 
				
			||||||
 | 
					            tokens = filename.split('_')
 | 
				
			||||||
 | 
					            return int(tokens[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(files) > max_checkpoints:
 | 
				
			||||||
 | 
					            min_step = min(map(extract_step, extract_step))
 | 
				
			||||||
 | 
					            delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}")
 | 
				
			||||||
 | 
					            print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}")
 | 
				
			||||||
 | 
					            shutil.rmtree(delete_checkpoit_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
 | 
				
			||||||
 | 
					                  weight_decay: float, eps: float, adam8bit: bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parameters = list()
 | 
				
			||||||
 | 
					    modules = dict(model.named_modules())
 | 
				
			||||||
 | 
					    for key in dynamic_module_names:
 | 
				
			||||||
 | 
					        parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					    for key in static_module_names:
 | 
				
			||||||
 | 
					        parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not adam8bit:
 | 
				
			||||||
 | 
					        optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            import bitsandbytes as bnb
 | 
				
			||||||
 | 
					        except ImportError:
 | 
				
			||||||
 | 
					            raise ImportError("To use 8-bit Adam, bitsandbytes must be available")
 | 
				
			||||||
 | 
					        optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps)
 | 
				
			||||||
 | 
					    return optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def compute_dynamic_parameter_ratio(model):
 | 
				
			||||||
 | 
					    modules = dict(model.named_modules())
 | 
				
			||||||
 | 
					    active_linear_parameters = 0
 | 
				
			||||||
 | 
					    total_linear_parameters = 0
 | 
				
			||||||
 | 
					    for key in find_all_linear_module_names(model):
 | 
				
			||||||
 | 
					        active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					        total_linear_parameters += sum(p.numel() for p in modules[key].parameters())
 | 
				
			||||||
 | 
					    return math.ceil(total_linear_parameters / active_linear_parameters)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
 | 
				
			||||||
 | 
					    model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
 | 
				
			||||||
 | 
					    tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if data_args.dataset.endswith("json"):
 | 
				
			||||||
 | 
					        print("Loading dataset in s2s mode")
 | 
				
			||||||
 | 
					        data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        print("Loading dataset in txt mode")
 | 
				
			||||||
 | 
					        data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
				
			||||||
 | 
					    dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
 | 
				
			||||||
 | 
					    train_dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					        dataset['train_dataset'],
 | 
				
			||||||
 | 
					        shuffle=True,
 | 
				
			||||||
 | 
					        collate_fn=dataset['data_collator'],
 | 
				
			||||||
 | 
					        batch_size=training_args.per_device_train_batch_size
 | 
				
			||||||
 | 
					    ) if dataset['train_dataset'] is not None else None
 | 
				
			||||||
 | 
					    eval_dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					        dataset['eval_dataset'],
 | 
				
			||||||
 | 
					        shuffle=True,
 | 
				
			||||||
 | 
					        collate_fn=dataset['data_collator'],
 | 
				
			||||||
 | 
					        batch_size=training_args.per_device_train_batch_size
 | 
				
			||||||
 | 
					    ) if dataset['eval_dataset'] is not None else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model_args.max_instant_params != 0:
 | 
				
			||||||
 | 
					        print(f"Target params {model_args.max_instant_params}m")
 | 
				
			||||||
 | 
					        freeze_random_modules(model, model_args.max_instant_params * 1e6,
 | 
				
			||||||
 | 
					                              torch.float16 if training_args.storage_fp16 else torch.float32,
 | 
				
			||||||
 | 
					                              frozen_device=primary_device, active_device=secondary_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    paramter_count = sum(p.numel() for p in model.parameters())
 | 
				
			||||||
 | 
					    active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					    print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
 | 
				
			||||||
 | 
					    print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
 | 
				
			||||||
 | 
					    total_steps = steps_per_epoch * training_args.epochs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    optimizer = get_optimizer(model, find_all_linear_module_names(model),
 | 
				
			||||||
 | 
					                              find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
 | 
				
			||||||
 | 
					                              training_args.learning_rate,
 | 
				
			||||||
 | 
					                              training_args.learning_rate / dynamic_param_ratio,
 | 
				
			||||||
 | 
					                              training_args.weight_decay,
 | 
				
			||||||
 | 
					                              training_args.adam_epsilon,
 | 
				
			||||||
 | 
					                              training_args.adam8bit)
 | 
				
			||||||
 | 
					    lr_scheduler = get_scheduler(
 | 
				
			||||||
 | 
					        name=training_args.lr_scheduler_type,
 | 
				
			||||||
 | 
					        optimizer=optimizer,
 | 
				
			||||||
 | 
					        num_warmup_steps=training_args.warmup_steps,
 | 
				
			||||||
 | 
					        num_training_steps=total_steps
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return model, optimizer, lr_scheduler, train_dataloader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
 | 
				
			||||||
 | 
					    primary_device = torch.device(training_args.primary_device)
 | 
				
			||||||
 | 
					    secondary_device = torch.device(training_args.secondary_device)
 | 
				
			||||||
 | 
					    log_writer = tensorboard.SummaryWriter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
 | 
				
			||||||
 | 
					    total_steps = steps_per_epoch * training_args.epochs
 | 
				
			||||||
 | 
					    dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if training_args.do_train:
 | 
				
			||||||
 | 
					        progress_bar = tqdm(range(total_steps))
 | 
				
			||||||
 | 
					        global_step = 0
 | 
				
			||||||
 | 
					        model.train()
 | 
				
			||||||
 | 
					        for epoch in range(0, training_args.epochs):
 | 
				
			||||||
 | 
					            print("*** Train ***")
 | 
				
			||||||
 | 
					            print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
 | 
				
			||||||
 | 
					            for step, batch in enumerate(train_dataloader):
 | 
				
			||||||
 | 
					                for key in batch:
 | 
				
			||||||
 | 
					                    batch[key] = batch[key].to("cuda:0")
 | 
				
			||||||
 | 
					                outputs = model(**batch)
 | 
				
			||||||
 | 
					                loss = outputs.loss / training_args.gradient_accumulation_steps
 | 
				
			||||||
 | 
					                log_writer.add_scalar("Loss/train", loss, global_step)
 | 
				
			||||||
 | 
					                loss.backward()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
 | 
				
			||||||
 | 
					                    optimizer.step()
 | 
				
			||||||
 | 
					                    lr_scheduler.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    model.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if global_step % 10 == 0:
 | 
				
			||||||
 | 
					                        print(loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if global_step % 10 == 0 and model_args.max_instant_params != 0:
 | 
				
			||||||
 | 
					                        param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6,
 | 
				
			||||||
 | 
					                                                            torch.float16 if training_args.storage_fp16 else torch.float32,
 | 
				
			||||||
 | 
					                                                            frozen_device=primary_device,
 | 
				
			||||||
 | 
					                                                            active_device=secondary_device)
 | 
				
			||||||
 | 
					                        log_writer.add_scalar("Parameters/train", param_count, global_step)
 | 
				
			||||||
 | 
					                        optimizer = get_optimizer(model, find_all_linear_module_names(model),
 | 
				
			||||||
 | 
					                                                  find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
 | 
				
			||||||
 | 
					                                                  training_args.learning_rate,
 | 
				
			||||||
 | 
					                                                  training_args.learning_rate / dynamic_param_ratio,
 | 
				
			||||||
 | 
					                                                  training_args.weight_decay,
 | 
				
			||||||
 | 
					                                                  training_args.adam_epsilon,
 | 
				
			||||||
 | 
					                                                  training_args.adam8bit)
 | 
				
			||||||
 | 
					                        lr_scheduler.optimizer = optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    global_step += 1
 | 
				
			||||||
 | 
					                    progress_bar.update()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if global_step % training_args.save_steps == 0:
 | 
				
			||||||
 | 
					                        save_model(model, global_step, training_args.output_dir, training_args.max_checkpoints)
 | 
				
			||||||
 | 
					                if training_args.flush_allocator:
 | 
				
			||||||
 | 
					                    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Evaluation
 | 
				
			||||||
 | 
					    if training_args.do_eval:
 | 
				
			||||||
 | 
					        print("*** Evaluate ***")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    save_model(model, global_step, training_args.output_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
 | 
				
			||||||
 | 
					    model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print("Model Arguments:")
 | 
				
			||||||
 | 
					    print(model_args)
 | 
				
			||||||
 | 
					    print("\nData Arguments:")
 | 
				
			||||||
 | 
					    print(data_args)
 | 
				
			||||||
 | 
					    print("\nTraining Arguments:")
 | 
				
			||||||
 | 
					    print(training_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    transformers.utils.logging.enable_default_handler()
 | 
				
			||||||
 | 
					    transformers.utils.logging.enable_explicit_format()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train(model_args, data_args, training_args)
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue