diff --git a/arguments.py b/arguments.py index 410f772..9c35cca 100644 --- a/arguments.py +++ b/arguments.py @@ -19,6 +19,10 @@ class DataArguments: default=256, metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, ) + data_from_hub: Optional[bool] = field( + default=False, + metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"} + ) dataset: str = field( default=None, metadata={"help": "A json file (s2s) or text file with the dataset to train on"} @@ -60,10 +64,6 @@ class TrainingArguments(): default=False, metadata={"help": "Use 8-bit adam."} ) - report_to: str = field( - default='none', - metadata={"help": "To use wandb or something else for reporting."} - ) resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'}) ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'}) output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) @@ -85,7 +85,6 @@ class TrainingArguments(): logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) group_by_length: bool = field(default=False, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) - storage_fp16: bool = field(default=False, metadata={"help": 'Store untrained layers in 16bit'}) save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'}) save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) @@ -94,3 +93,5 @@ class TrainingArguments(): max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"}) churn_percent: int = field(default=100, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"}) eval_steps: int = field(default=-1, metadata={"help": "Number of optimization steps after wich to compute the evaluation loss"}) + eval_prompt: str = field(default=None, metadata={"help": "A prompt to used during eval to check if the model is learning"}) + reshufle_steps: int = field(default=50, metadata={"help": "Number of steps to take before changing the active parameters"}) diff --git a/datamodules.py b/datamodules.py index 15eaf54..0e36a6d 100644 --- a/datamodules.py +++ b/datamodules.py @@ -27,7 +27,44 @@ def group_texts(examples, block_size: int): @dataclass -class DataCollatorForCausalLM(object): +class DataCollatorForCausalLMText(object): + tokenizer: transformers.PreTrainedTokenizer + max_len: int + + def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]: + # Extract elements + examples = [f"{self.tokenizer.bos_token}{example['text']}{self.tokenizer.eos_token}" for example in instances] + # Tokenize + tokenized_examples = self.tokenizer( + examples, + max_length=self.max_len, + truncation=True, + add_special_tokens=False, + ) + # Build the input and labels for causal LM + input_ids = [] + for tokenized_example in tokenized_examples['input_ids']: + input_ids.append(torch.tensor(tokenized_example)) + # Apply padding + padding_value = None + if self.tokenizer.pad_token_id is not None: + padding_value = self.tokenizer.pad_token_id + elif self.tokenizer.eos_token_id is not None: + padding_value = self.tokenizer.eos_token_id + else: + raise RuntimeError("Model dose not have a pad or eos token") + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value) + + data_dict = { + 'input_ids': input_ids, + 'attention_mask': input_ids.ne(padding_value), + 'labels': input_ids + } + return data_dict + + +@dataclass +class DataCollatorForCausalLMs2s(object): tokenizer: transformers.PreTrainedTokenizer source_max_len: int target_max_len: int @@ -102,7 +139,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg test_size=data_args.eval_dataset_size, shuffle=True, seed=42 ) eval_dataset = dataset['test'] - eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) + eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) if 'train' in dataset: train_dataset = dataset['train'] @@ -111,7 +148,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) - data_collator = DataCollatorForCausalLM( + data_collator = DataCollatorForCausalLMs2s( tokenizer=tokenizer, source_max_len=data_args.source_max_len, target_max_len=data_args.target_max_len, @@ -127,6 +164,40 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg ) +def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict: + try: + dataset = datasets.load_dataset(data_args.dataset) + except FileNotFoundError as ex: + raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}") + + if do_eval or do_predict: + if 'eval' in dataset: + eval_dataset = dataset['eval'] + else: + print('Splitting train dataset in train and validation according to `eval_dataset_size`') + dataset = dataset.train_test_split( + test_size=data_args.eval_dataset_size, shuffle=True, seed=42 + ) + eval_dataset = dataset['test'] + + if 'train' in dataset: + train_dataset = dataset['train'] + else: + train_dataset = dataset + + data_collator = DataCollatorForCausalLMText( + tokenizer=tokenizer, + max_len=data_args.source_max_len, + ) + + return dict( + train_dataset=train_dataset if do_train else None, + eval_dataset=eval_dataset if do_eval else None, + predict_dataset=eval_dataset if do_predict else None, + data_collator=data_collator + ) + + def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict: try: dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]}) @@ -147,7 +218,8 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D eval_dataset = dataset['eval'] else: print('Splitting train dataset in train and validation according to `eval_dataset_size`') - dataset = dataset.train_test_split( + breakpoint() + dataset = dataset['train'].train_test_split( test_size=data_args.eval_dataset_size, shuffle=True, seed=42 ) eval_dataset = dataset['test'] diff --git a/dyntrainmodel.py b/dyntrainmodel.py index c436983..3fd6cb7 100644 --- a/dyntrainmodel.py +++ b/dyntrainmodel.py @@ -48,14 +48,22 @@ class LinearGroup: for module in self.modules: module.decompress() - def checkDistance(self) -> tuple[float, float]: - distance_accum = 0.0 - error_accum = 0.0 + def getDistanceAndError(self) -> tuple[float, float]: + distance_accum = torch.Tensor() + error_accum = torch.Tensor() for module in self.modules: - distance, error = module.checkDistance() - distance_accum += distance**2 - error_accum += error**2 - return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules))) + distance, error = module.getDistanceAndError() + distance = distance.to("cpu") + error = error.to("cpu") + distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel())))) + error_accum = torch.cat((error_accum, error.reshape((error.numel())))) + return (distance_accum, error_accum) + + def check(self) -> bool: + for module in self.modules: + if not module.check(): + return False + return True class DyntrainModel: @@ -160,15 +168,18 @@ class DyntrainModel: total_params = self.dynamicParameters() + self.staticParameters() return sum(p.numel() for p in total_params if p.requires_grad) - def reshuffleActive(self) -> None: + def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor): + index = randint(0, len(self.active_linear_groups) - 1) + return self.active_linear_groups[index].getDistanceAndError() + + def reshuffleActive(self): active_count = len(self.active_linear_groups) index = 0 while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction): - distance, error = self.active_linear_groups[index].checkDistance() - print(f"linear group has moved {distance} with an error of {error}") group = self.active_linear_groups.pop(index) group.setFrozen(True) self.frozen_linear_groups.append(group) + assert group.check() params = self.activeParameterCount() @@ -180,6 +191,7 @@ class DyntrainModel: group.setFrozen(False) params += group.paramCount() self.active_linear_groups.append(group) + assert group.check() print(math.ceil(params / 1e6)) active_params = self.activeParameterCount() @@ -248,4 +260,8 @@ class DyntrainModel: group_index += 1 for group in tqdm(linear_groups, desc="Perpareing layers"): - group.compress() + if group.isFrozen(): + group.compress() + else: + group.decompress() + assert group.check() diff --git a/modules.py b/modules.py index 6e34858..e8a4eae 100644 --- a/modules.py +++ b/modules.py @@ -35,7 +35,6 @@ class Linear(torch.nn.Linear): self.compress() else: self.decompress() - self.weightStart = torch.Tensor(self.weight).clone().detach() def isFrozen(self) -> bool: return not self.weight.requires_grad @@ -60,9 +59,15 @@ class Linear(torch.nn.Linear): @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): - breakpoint() return self + def check(self) -> bool: + if self.isFrozen() and self.weight.dtype != torch.float16: + return False + elif not self.isFrozen() and self.weight.dtype != torch.float32: + return False + return True + class DynamicConvertingLinear(Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, @@ -116,6 +121,7 @@ class DynamicQantizedLinear(Linear): self.bias_state = None self.block_size = 128 self.quant_type = 'nf4' + self.weight_start = self.weight.clone().detach() @classmethod def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device, @@ -125,6 +131,7 @@ class DynamicQantizedLinear(Linear): compute_dtype=compute_dtype, output_device=output_device) new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device)) new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None + new_module.weight_start = new_module.weight.clone().detach() return new_module def compress(self) -> None: @@ -134,26 +141,27 @@ class DynamicQantizedLinear(Linear): bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size) - weight = torch.nn.Parameter(self.weight.to(self.cold_device)) - bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None + frozen = self.isFrozen() + self.weight = torch.nn.Parameter(self.weight.to(self.cold_device)) + self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None + self.setFrozen(frozen, False) def decompress(self) -> None: - if self.weight_quantized is None: - raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable") - dtype = self.weight.dtype - self.weight = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) + self.weight_quantized = None + self.weight_state = None + self.bias_quantized = None + self.bias_state = None + self.weight_start = self.weight.clone().detach().to(self.cold_device) + self.weight = torch.nn.Parameter(self.weight.to(self.active_device)) if self.bias_quantized: - self.bias = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) + self.bias = torch.nn.Parameter(self.bias.to(self.active_device)) - def checkDistance(self) -> tuple[float, float]: - if self.weight_quantized is None: - raise RuntimeError("checkDistance() called without quantized weights avialable") + def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]: original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16) quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size) - dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype) - dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype) - distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item() - error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item() + dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype) + distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32) + error = (dequantized_original_weight - original_weight).to(torch.float32) return (distance, error) def setOutputDevice(self, output_device: torch.device): @@ -200,3 +208,24 @@ class DynamicQantizedLinear(Linear): if not frozen: super().inplaceTo(device=device) self.setFrozen(frozen, False) + + def check(self) -> bool: + if self.isFrozen(): + if torch.device(self.weight.device) != torch.device(self.cold_device): + breakpoint() + print("Frozen but not cold") + return False + if self.weight_quantized is None: + breakpoint() + print("Frozen but not quanted") + return False + else: + if torch.device(self.weight.device) != torch.device(self.active_device): + breakpoint() + print("Active but not warm") + return False + if self.weight_quantized is not None: + breakpoint() + print("Active but still quantized") + return False + return True diff --git a/train_dynamic.py b/train_dynamic.py index 941c3ca..c04e787 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -7,9 +7,10 @@ import os import shutil import math from tqdm.auto import tqdm +import gc from arguments import DataArguments, ModelArguments, TrainingArguments -from datamodules import create_data_module_s2s, create_data_module +from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub from tokenizer import get_tokenizer from dyntrainmodel import DyntrainModel @@ -56,7 +57,9 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters return optimizer -def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> float: +def evaluate(model: DyntrainModel, tokenizer, + dataloader: torch.utils.data.DataLoader, globalstep: int, + log_writer: tensorboard.SummaryWriter, eval_prompt: str = None): print("*** Eval ***") loss = torch.zeros((1), device="cuda:0") model.model.eval() @@ -66,8 +69,17 @@ def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> f outputs = model.model(**batch) loss += outputs.loss loss = loss / len(dataloader) + log_writer.add_scalar("Loss/Eval", loss, globalstep) print(f"Eval Loss {loss.item()}") + if eval_prompt is not None: + input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) + attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False) + outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100) + response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + print(f"Eval generation: response_decoded") + log_writer.add_text("Text/Eval", response_decoded, globalstep) + def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): log_writer = tensorboard.SummaryWriter() @@ -90,6 +102,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if data_args.dataset.endswith("json"): print("Loading dataset in s2s mode") data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) + elif data_args.data_from_hub: + data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) else: print("Loading dataset in txt mode") data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) @@ -137,12 +151,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T for step, batch in enumerate(train_dataloader): for key in batch: batch[key] = batch[key].to("cuda:0") + outputs = model.model(**batch) loss = outputs.loss / training_args.gradient_accumulation_steps - log_writer.add_scalar("Loss/train", loss, global_step) loss.backward() if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader): + if global_step % training_args.logging_steps == 0: + log_writer.add_scalar("Loss/train", loss, global_step) optimizer.step() lr_scheduler.step() @@ -151,9 +167,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if global_step % 5 == 0: print(f"Train Loss {loss.item()}") - if global_step % 50 == 0 and training_args.max_instant_params != 0: + if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0: + print("Reshuffleing") lr_scheduler.optimizer = None del optimizer + # distance, error = model.getDistanceAndErrorSample() + # log_writer.add_histogram("Distances/Train", distance, max_bins=50) + # log_writer.add_histogram("Errors/Train", error, max_bins=50) + model.reshuffleActive() model.balanceActive() log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) @@ -173,15 +194,16 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if global_step % training_args.save_steps == 0: save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0: - evaluate(model, eval_dataloader) + evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt) if training_args.flush_allocator: + gc.collect() torch.cuda.empty_cache() if training_args.do_eval and training_args.eval_steps == -1: - evaluate(model, eval_dataloader) + evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt) # Evaluation if training_args.do_eval: - evaluate(model, eval_dataloader) + evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt) save_model(model.model, global_step, training_args.output_dir)