commit 7a47fcdcc0d874540d20a2df1c5424848a4800e9 Author: uvos Date: Wed Mar 6 17:50:40 2024 +0100 Inital commit diff --git a/arguments.py b/arguments.py new file mode 100644 index 0000000..2879de1 --- /dev/null +++ b/arguments.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class DataArguments: + eval_dataset_size: int = field( + default=512, metadata={"help": "Size of validation dataset."} + ) + source_max_len: int = field( + default=512, + metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + train_on_source: Optional[bool] = field( + default=False, + metadata={"help": "Wether to train on the input in addition to the target text when in s2s mode."} + ) + target_max_len: int = field( + default=256, + metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + dataset: str = field( + default=None, + metadata={"help": "A json file (s2s) or text file with the dataset to train on"} + ) + block_size: int = field( + default=512, + metadata={"help": "size of the blocks the text is split into for training"}, + ) + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field( + default="EleutherAI/pythia-12b" + ) + tokenizer: Optional[str] = field( + default=None + ) + trust_remote_code: Optional[bool] = field( + default=False, + metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} + ) + max_instant_params: int = field( + default=0, + metadata={"help": "Maximum amount of paramters to optimize per step in millions"} + ) + noresize: Optional[bool] = field( + default=False, + metadata={"help": "Never resize tokenizer embeddings"} + ) + + +@dataclass +class TrainingArguments(): + cache_dir: Optional[str] = field( + default=None + ) + adam8bit: bool = field( + default=False, + metadata={"help": "Use 8-bit adam."} + ) + report_to: str = field( + default='none', + metadata={"help": "To use wandb or something else for reporting."} + ) + resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'}) + ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'}) + output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) + per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) + gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) + epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'}) + weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) + learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'}) + adam_epsilon: float = field(default=1e-7, metadata={"help": 'Adam epsilon'}) + remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) + max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) + gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) + fp16: bool = field(default=False, metadata={"help": 'Train in 16 bit mixed precision'}) + do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) + do_eval: bool = field(default=False, metadata={"help": 'To eval or not to eval, that is the question?'}) + lr_scheduler_type: str = field(default='constant', + metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) + warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'}) + logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) + group_by_length: bool = field(default=False, + metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) + storage_fp16: bool = field(default=False, metadata={"help": 'Store untrained layers in 16bit'}) + save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) + max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'}) + save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) + primary_device: str = field(default="cuda:0", metadata={"help": 'The primary device to use'}) + secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'}) + train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'}) + flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'}) diff --git a/convertinglinear.py b/convertinglinear.py new file mode 100644 index 0000000..b5c494f --- /dev/null +++ b/convertinglinear.py @@ -0,0 +1,29 @@ +import torch + + +class ConvertingLinear(torch.nn.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + + def forward(self, input: torch.Tensor): + output_dtype = input.dtype + output_device = input.device + if input.device != self.weight.device: + input = input.to(self.weight.device) + if input.dtype != self.weight.dtype: + input = input.to(self.weight.dtype) + output = torch.nn.Linear.forward(self, input) + if torch.isnan(output).any() or self.weight.dtype != torch.float32: + breakpoint() + return output.to(output_device).to(output_dtype) + + @classmethod + def fromLinear(cls, in_module: torch.nn.Linear): + new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, + out_features=in_module.out_features, + bias=in_module.bias is not None, + device=in_module.weight.device, + dtype=in_module.weight.dtype) + new_module.weight = in_module.weight + new_module.bias = in_module.bias + return new_module diff --git a/datamodules.py b/datamodules.py new file mode 100644 index 0000000..15eaf54 --- /dev/null +++ b/datamodules.py @@ -0,0 +1,192 @@ +import copy +import torch +import typing +import datasets +import itertools +import transformers +from dataclasses import dataclass +from torch.nn.utils.rnn import pad_sequence + +from arguments import DataArguments + +IGNORE_INDEX = -100 + + +def group_texts(examples, block_size: int): + # Concatenate all texts. + concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result["labels"] = result["input_ids"].copy() + return result + + +@dataclass +class DataCollatorForCausalLM(object): + tokenizer: transformers.PreTrainedTokenizer + source_max_len: int + target_max_len: int + train_on_source: bool + predict_with_generate: bool + + def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]: + # Extract elements + sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances] + targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances] + # Tokenize + tokenized_sources_with_prompt = self.tokenizer( + sources, + max_length=self.source_max_len, + truncation=True, + add_special_tokens=False, + ) + tokenized_targets = self.tokenizer( + targets, + max_length=self.target_max_len, + truncation=True, + add_special_tokens=False, + ) + # Build the input and labels for causal LM + input_ids = [] + labels = [] + for tokenized_source, tokenized_target in zip( + tokenized_sources_with_prompt['input_ids'], + tokenized_targets['input_ids'] + ): + if not self.predict_with_generate: + input_ids.append(torch.tensor(tokenized_source + tokenized_target)) + if not self.train_on_source: + labels.append( + torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) + ) + else: + labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) + else: + input_ids.append(torch.tensor(tokenized_source)) + # Apply padding + padding_value = None + if self.tokenizer.pad_token_id is not None: + padding_value = self.tokenizer.pad_token_id + elif self.tokenizer.eos_token_id is not None: + padding_value = self.tokenizer.eos_token_id + else: + raise RuntimeError("Model dose not have a pad or eos token") + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value) + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None + data_dict = { + 'input_ids': input_ids, + 'attention_mask': input_ids.ne(padding_value), + } + if labels is not None: + data_dict['labels'] = labels + return data_dict + + +def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict: + try: + dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset) + except FileNotFoundError as ex: + raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}") + + if do_eval or do_predict: + if 'eval' in dataset: + eval_dataset = dataset['eval'] + else: + print('Splitting train dataset in train and validation according to `eval_dataset_size`') + dataset = dataset.train_test_split( + test_size=data_args.eval_dataset_size, shuffle=True, seed=42 + ) + eval_dataset = dataset['test'] + eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) + + if 'train' in dataset: + train_dataset = dataset['train'] + else: + train_dataset = dataset + + train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) + + data_collator = DataCollatorForCausalLM( + tokenizer=tokenizer, + source_max_len=data_args.source_max_len, + target_max_len=data_args.target_max_len, + train_on_source=data_args.train_on_source, + predict_with_generate=False # args.predict_with_generate, + ) + + return dict( + train_dataset=train_dataset if do_train else None, + eval_dataset=eval_dataset if do_eval else None, + predict_dataset=eval_dataset if do_predict else None, + data_collator=data_collator + ) + + +def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict: + try: + dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]}) + except FileNotFoundError as ex: + raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}") + + if data_args.block_size > tokenizer.model_max_length: + raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}") + + def add_newline_fn(example): + example['text'] = example['text'] + '\n' + return example + dataset = dataset.map(add_newline_fn) + + eval_dataset = None + if do_eval or do_predict: + if 'eval' in dataset: + eval_dataset = dataset['eval'] + else: + print('Splitting train dataset in train and validation according to `eval_dataset_size`') + dataset = dataset.train_test_split( + test_size=data_args.eval_dataset_size, shuffle=True, seed=42 + ) + eval_dataset = dataset['test'] + + if 'train' in dataset: + train_dataset = dataset['train'] + else: + train_dataset = dataset + + train_dataset_tokenized = train_dataset.map( + lambda example: tokenizer(example['text']), + batched=True, + remove_columns='text', + num_proc=32, + load_from_cache_file=True) + train_dataset_tokenized = train_dataset_tokenized.map( + lambda example: group_texts(example, data_args.block_size), + batched=True, + num_proc=32, + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {data_args.block_size}") + + eval_dataset_tokenized = None + if eval_dataset is not None: + eval_dataset_tokenized = eval_dataset.map( + lambda example: tokenizer(example['text']), + batched=True, + remove_columns='text', + num_proc=32) + eval_dataset_tokenized = eval_dataset_tokenized.map( + lambda example: group_texts(example, data_args.block_size), + batched=True, + num_proc=32, + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {data_args.block_size}") + + return dict( + train_dataset=train_dataset_tokenized if do_train else None, + eval_dataset=eval_dataset_tokenized if do_eval else None, + predict_dataset=eval_dataset_tokenized if do_predict else None, + data_collator=transformers.default_data_collator + ) diff --git a/tokenizer.py b/tokenizer.py new file mode 100644 index 0000000..c16f3df --- /dev/null +++ b/tokenizer.py @@ -0,0 +1,62 @@ +import transformers + +from arguments import ModelArguments + + +DEFAULT_PAD_TOKEN = "[PAD]" + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings_data = model.get_input_embeddings().weight.data + output_embeddings_data = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings_data[-num_new_tokens:] = input_embeddings_avg + output_embeddings_data[-num_new_tokens:] = output_embeddings_avg + + +def get_tokenizer(model, cache_dir, model_args: ModelArguments): + print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}') + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path, + cache_dir=cache_dir, + padding_side="right", + use_fast=False, + eos_token="[EOS]", + tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None, + trust_remote_code=model_args.trust_remote_code + ) + if tokenizer._pad_token is None and not model_args.noresize: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + if 'llama' in model_args.model_name_or_path or isinstance(tokenizer, transformers.LlamaTokenizer): + # LLaMA tokenizer may not have correct special tokens set. + # Check and add them if missing to prevent them from being parsed into different tokens. + # Note that these are present in the vocabulary. + # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. + print('Adding special tokens.') + tokenizer.add_special_tokens({ + "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), + "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), + "unk_token": tokenizer.convert_ids_to_tokens( + model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id + ), + }) + return tokenizer diff --git a/train_dynamic.py b/train_dynamic.py new file mode 100644 index 0000000..0c92f39 --- /dev/null +++ b/train_dynamic.py @@ -0,0 +1,338 @@ +import transformers +from transformers import AutoModelForCausalLM, get_scheduler +from peft.utils import _get_submodules + +import torch +from torch.utils import tensorboard +import os +import shutil +import math +from tqdm.auto import tqdm +from random import randint +from typing import Tuple + +from arguments import DataArguments, ModelArguments, TrainingArguments +from datamodules import create_data_module_s2s, create_data_module +from convertinglinear import ConvertingLinear +from tokenizer import get_tokenizer + + +def find_all_linear_module_names(model): + module_names = set() + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear): + module_names.add(name) + + if 'lm_head' in module_names: # needed for 16-bit + module_names.remove('lm_head') + return list(module_names) + + +def find_all_outher_module_names(model): + module_names = set() + for name, module in model.named_modules(): + if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)): + module_names.add(name) + return list(module_names) + + +def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing): + dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32 + print(f'loading base model {model_args.model_name_or_path} in {dtype}...') + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=cache_dir, + torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32, + trust_remote_code=model_args.trust_remote_code, + device_map=None, + attn_implementation="flash_attention_2" + ) + + # for name, module in model.named_modules(): + # if 'norm' in name: + # module = module.to(torch.float32) + + return model + + +@torch.no_grad() +def recursive_setattr(obj, attr, value): + attr = attr.split('.', 1) + if len(attr) == 1: + setattr(obj, attr[0], value) + else: + recursive_setattr(getattr(obj, attr[0]), attr[1], value) + + +@torch.no_grad() +def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device): + new_module = torch.nn.Linear(module.in_features, + module.out_features, + module.bias is not None, + module.weight.device, + dtype) + new_module.weight = torch.nn.Parameter(module.weight.detach().clone()) + new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None + new_module.weight.requires_grad = not frozen + if new_module.bias is not None: + new_module.bias.requires_grad = not frozen + return new_module + + +@torch.no_grad() +def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device): + if type(module) is torch.nn.Linear: + if frozen: + module.weight.requires_grad = False + if module.bias is not None: + module.bias.requires_grad = False + return module.to(dtype).to(device) + else: + new_module = ConvertingLinear.fromLinear(module).to(dtype) + new_module.weight.requires_grad = True + if new_module.bias is not None: + new_module.bias.requires_grad = True + return new_module.to(device) + elif type(module) is ConvertingLinear: + if not frozen: + module.weight.requires_grad = True + if module.bias is not None: + module.bias.requires_grad = True + assert False + return module.to(dtype).to(device) + else: + new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + device=module.weight.device, + dtype=dtype) + new_module.weight = torch.nn.Parameter(module.weight.to(dtype)) + new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None + new_module.weight.requires_grad = False + if new_module.bias is not None: + new_module.bias.requires_grad = False + return new_module.to(device) + else: + assert False + + +@torch.no_grad() +def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device): + modules = dict(model.named_modules()) + linear_names = find_all_linear_module_names(model) + + for key in linear_names: + if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad: + parent, target, target_name = _get_submodules(model, key) + setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device)) + modules = dict(model.named_modules()) + + active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad) + if active_paramter_count > target_params: + raise RuntimeError("Enough paramters must be available to train at least one linear layer") + + while active_paramter_count < target_params and len(linear_names) > 0: + i = randint(0, len(linear_names) - 1) + parent, target, target_name = _get_submodules(model, linear_names[i]) + new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device) + setattr(parent, target_name, new_module) + active_paramter_count += modules[linear_names[i]].weight.numel() + if modules[linear_names[i]].bias is not None: + active_paramter_count += modules[linear_names[i]].bias.numel() + linear_names.pop(i) + modules = dict() + + assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad) + + return active_paramter_count + + +def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0): + output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else "" + output_dir = os.path.join(output_dir, output_chkpt_dir) + model.save_pretrained(output_dir) + + if max_checkpoints > 0: + files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.starts_with("step_")] + + def extract_step(filename): + tokens = filename.split('_') + return int(tokens[1]) + + if len(files) > max_checkpoints: + min_step = min(map(extract_step, extract_step)) + delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}") + print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}") + shutil.rmtree(delete_checkpoit_dir) + + +def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float, + weight_decay: float, eps: float, adam8bit: bool): + + parameters = list() + modules = dict(model.named_modules()) + for key in dynamic_module_names: + parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad) + for key in static_module_names: + parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad) + + if not adam8bit: + optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon) + else: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, bitsandbytes must be available") + optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps) + return optimizer + + +def compute_dynamic_parameter_ratio(model): + modules = dict(model.named_modules()) + active_linear_parameters = 0 + total_linear_parameters = 0 + for key in find_all_linear_module_names(model): + active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad) + total_linear_parameters += sum(p.numel() for p in modules[key].parameters()) + return math.ceil(total_linear_parameters / active_linear_parameters) + + +def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple: + model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device) + tokenizer = get_tokenizer(model, training_args.cache_dir, model_args) + + if data_args.dataset.endswith("json"): + print("Loading dataset in s2s mode") + data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) + else: + print("Loading dataset in txt mode") + data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) + dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'} + train_dataloader = torch.utils.data.DataLoader( + dataset['train_dataset'], + shuffle=True, + collate_fn=dataset['data_collator'], + batch_size=training_args.per_device_train_batch_size + ) if dataset['train_dataset'] is not None else None + eval_dataloader = torch.utils.data.DataLoader( + dataset['eval_dataset'], + shuffle=True, + collate_fn=dataset['data_collator'], + batch_size=training_args.per_device_train_batch_size + ) if dataset['eval_dataset'] is not None else None + + if model_args.max_instant_params != 0: + print(f"Target params {model_args.max_instant_params}m") + freeze_random_modules(model, model_args.max_instant_params * 1e6, + torch.float16 if training_args.storage_fp16 else torch.float32, + frozen_device=primary_device, active_device=secondary_device) + + paramter_count = sum(p.numel() for p in model.parameters()) + active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters") + + dynamic_param_ratio = compute_dynamic_parameter_ratio(model) + print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}") + + steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) + total_steps = steps_per_epoch * training_args.epochs + + optimizer = get_optimizer(model, find_all_linear_module_names(model), + find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(), + training_args.learning_rate, + training_args.learning_rate / dynamic_param_ratio, + training_args.weight_decay, + training_args.adam_epsilon, + training_args.adam8bit) + lr_scheduler = get_scheduler( + name=training_args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=total_steps + ) + return model, optimizer, lr_scheduler, train_dataloader + + +def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): + primary_device = torch.device(training_args.primary_device) + secondary_device = torch.device(training_args.secondary_device) + log_writer = tensorboard.SummaryWriter() + + model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device) + + steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) + total_steps = steps_per_epoch * training_args.epochs + dynamic_param_ratio = compute_dynamic_parameter_ratio(model) + + if training_args.do_train: + progress_bar = tqdm(range(total_steps)) + global_step = 0 + model.train() + for epoch in range(0, training_args.epochs): + print("*** Train ***") + print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}') + for step, batch in enumerate(train_dataloader): + for key in batch: + batch[key] = batch[key].to("cuda:0") + outputs = model(**batch) + loss = outputs.loss / training_args.gradient_accumulation_steps + log_writer.add_scalar("Loss/train", loss, global_step) + loss.backward() + + if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader): + optimizer.step() + lr_scheduler.step() + + model.zero_grad() + + if global_step % 10 == 0: + print(loss) + + if global_step % 10 == 0 and model_args.max_instant_params != 0: + param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6, + torch.float16 if training_args.storage_fp16 else torch.float32, + frozen_device=primary_device, + active_device=secondary_device) + log_writer.add_scalar("Parameters/train", param_count, global_step) + optimizer = get_optimizer(model, find_all_linear_module_names(model), + find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(), + training_args.learning_rate, + training_args.learning_rate / dynamic_param_ratio, + training_args.weight_decay, + training_args.adam_epsilon, + training_args.adam8bit) + lr_scheduler.optimizer = optimizer + + global_step += 1 + progress_bar.update() + + if global_step % training_args.save_steps == 0: + save_model(model, global_step, training_args.output_dir, training_args.max_checkpoints) + if training_args.flush_allocator: + torch.cuda.empty_cache() + + # Evaluation + if training_args.do_eval: + print("*** Evaluate ***") + + save_model(model, global_step, training_args.output_dir) + + return + + +if __name__ == "__main__": + hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True) + + print("Model Arguments:") + print(model_args) + print("\nData Arguments:") + print(data_args) + print("\nTraining Arguments:") + print(training_args) + + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + train(model_args, data_args, training_args)