From 2f35689355984f74fff48a207d3c51a474eba6db Mon Sep 17 00:00:00 2001 From: uvos Date: Sat, 20 Jul 2024 21:46:34 +0200 Subject: [PATCH] Add chat datamodules --- arguments.py | 53 ++++++++++++++++++--- datamodules.py | 123 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 148 insertions(+), 28 deletions(-) diff --git a/arguments.py b/arguments.py index 6f645a5..042c1b3 100644 --- a/arguments.py +++ b/arguments.py @@ -1,11 +1,52 @@ from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Self +from enum import Enum + + +class DatasetType(Enum): + TEXT = 1 + S2S = 2 + HUB = 3 + CHAT = 4 + + @staticmethod + def to_string(dtype: Self) -> str: + if dtype == DatasetType.TEXT: + return "text" + elif dtype == DatasetType.S2S: + return "s2s" + elif dtype == DatasetType.HUB: + return "hub" + elif dtype == DatasetType.CHAT: + return "chat" + return "invalid" + + @staticmethod + def from_string(string: str): + if string == str(DatasetType.TEXT): + return DatasetType.TEXT + elif string == str(DatasetType.S2S): + return DatasetType.S2S + elif string == str(DatasetType.HUB): + return DatasetType.HUB + elif string == str(DatasetType.CHAT): + return DatasetType.CHAT + return None + + def __str__(self): + return DatasetType.to_string(self) @dataclass class DataArguments: dataset: str = field( - metadata={"help": "A json file (s2s) or text file with the dataset to train on"} + metadata={"help": "The dataset to train on"} + ) + dataset_type: str = field( + default="text", metadata={"help": f"The type of dataset, set to one of {[e for e in DatasetType]}"} + ) + dataset_chat_template: str | None = field( + default=None, metadata={"help": "overrides the chat template to be the one set here"} ) eval_dataset_size: int = field( default=512, metadata={"help": "Size of validation dataset."} @@ -26,10 +67,6 @@ class DataArguments: default=False, metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"} ) - block_size: int = field( - default=512, - metadata={"help": "size of the blocks the text is split into for training"}, - ) @dataclass @@ -65,8 +102,9 @@ class TrainingArguments(): ) resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'}) ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'}) - output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) + output_dir: str = field(default='./output', metadata={"help": 'The output dir for checkpoints'}) per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) + per_device_eval_batch_size: int = field(default=1, metadata={"help": 'The eval batch size per GPU. Increase for better speed.'}) gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'}) weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) @@ -82,6 +120,7 @@ class TrainingArguments(): metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'}) logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) + logging_dir: str = field(default='./log', metadata={"help": 'The output dir for logs'}) group_by_length: bool = field(default=False, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) diff --git a/datamodules.py b/datamodules.py index d0dd173..2fb4d78 100644 --- a/datamodules.py +++ b/datamodules.py @@ -7,22 +7,23 @@ import transformers import os from dataclasses import dataclass from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm -from arguments import DataArguments +from arguments import DataArguments, DatasetType IGNORE_INDEX = -100 -def group_texts(examples, block_size: int): +def group_texts(examples, source_max_len: int): # Concatenate all texts. concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. - if total_length >= block_size: - total_length = (total_length // block_size) * block_size + if total_length >= source_max_len: + total_length = (total_length // source_max_len) * source_max_len # Split by chunks of max_len. - result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result = {k: [t[i: i + source_max_len] for i in range(0, total_length, source_max_len)] for k, t in concatenated_examples.items()} result["labels"] = result["input_ids"].copy() return result @@ -199,14 +200,15 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg ) -def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict: +def create_data_module_txt(tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict: try: dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]}) except FileNotFoundError as ex: raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}") - if data_args.block_size > tokenizer.model_max_length: - raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}") + if data_args.source_max_len > tokenizer.model_max_length: + raise ValueError(f"Max source length of {data_args.source_max_len} is larger than the maximum size supported by the model: {tokenizer.model_max_length}") def add_newline_fn(example): example['text'] = example['text'] + '\n' @@ -219,9 +221,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D eval_dataset = dataset['eval'] else: print('Splitting train dataset in train and validation according to `eval_dataset_size`') - dataset = dataset['train'].train_test_split( - test_size=data_args.eval_dataset_size, shuffle=True, seed=42 - ) + dataset = dataset['train'].train_test_split(test_size=data_args.eval_dataset_size, shuffle=False) eval_dataset = dataset['test'] if 'train' in dataset: @@ -233,14 +233,14 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D lambda example: tokenizer(example['text']), batched=True, remove_columns='text', - num_proc=32, + num_proc=os.cpu_count(), load_from_cache_file=True) train_dataset_tokenized = train_dataset_tokenized.map( - lambda example: group_texts(example, data_args.block_size), + lambda example: group_texts(example, data_args.source_max_len), batched=True, - num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))), + num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))), load_from_cache_file=True, - desc=f"Grouping texts in chunks of {data_args.block_size}") + desc=f"Grouping texts in chunks of {data_args.source_max_len}") eval_dataset_tokenized = None if eval_dataset is not None: @@ -248,18 +248,18 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D lambda example: tokenizer(example['text']), batched=True, remove_columns='text', - num_proc=32) + num_proc=os.cpu_count()) eval_dataset_tokenized = eval_dataset_tokenized.map( - lambda example: group_texts(example, data_args.block_size), + lambda example: group_texts(example, data_args.source_max_len), batched=True, - num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))), + num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))), load_from_cache_file=True, - desc=f"Grouping texts in chunks of {data_args.block_size}") + desc=f"Grouping texts in chunks of {data_args.source_max_len}") for ids in train_dataset_tokenized['input_ids']: - assert len(ids) == data_args.block_size + assert len(ids) == data_args.source_max_len for ids in eval_dataset_tokenized['input_ids']: - assert len(ids) == data_args.block_size + assert len(ids) == data_args.source_max_len return dict( train_dataset=train_dataset_tokenized if do_train else None, @@ -267,3 +267,84 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D predict_dataset=eval_dataset_tokenized if do_predict else None, data_collator=transformers.default_data_collator ) + + +def create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict): + try: + dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset) + except FileNotFoundError as ex: + raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}") + + if data_args.dataset_chat_template is not None: + tokenizer.chat_template = data_args.dataset_chat_template + + target_len = data_args.source_max_len * 0.5 + grouped_chats = list() + last_len = 0 + for row in tqdm(dataset, desc="Grouping chat messages"): + content_length = len(tokenizer(row['content'])['input_ids']) + if last_len + content_length <= target_len and len(grouped_chats) > 0: + grouped_chats[-1]['chat'].append(row) + last_len += content_length + else: + last_len = 0 + grouped_chats.append({'chat': [row]}) + dataset = datasets.Dataset.from_list(grouped_chats) + dataset = dataset.map(lambda x: {"text": tokenizer.apply_chat_template(x["chat"], tokenize=False, add_generation_prompt=False)}) + dataset.remove_columns('chat') + + eval_dataset = None + if do_eval or do_predict: + print('Splitting train dataset in train and validation according to `eval_dataset_size`') + dataset_split = dataset.train_test_split(test_size=data_args.eval_dataset_size, shuffle=True) + train_dataset = dataset_split["train"] + eval_dataset = dataset_split["test"] + + data_collator = DataCollatorForCausalLMText( + tokenizer=tokenizer, + max_len=data_args.source_max_len, + ) + return dict( + train_dataset=train_dataset if do_train else None, + eval_dataset=eval_dataset, + predict_dataset=eval_dataset, + data_collator=data_collator + ) + + +def get_data_loaders(tokenizer, data_args: DataArguments, batch_size: int, eval_batch_size: int, + do_train: bool, do_eval: bool, do_predict: bool = False): + data_type = DatasetType.from_string(data_args.dataset_type) + if data_type == DatasetType.S2S: + print("Loading dataset in s2s mode") + data_module = create_data_module_s2s(tokenizer, data_args, do_train, do_eval, do_predict) + elif data_type == DatasetType.HUB: + print("Loading dataset from hub, expecting alpaca style") + data_module = create_data_module_hub(tokenizer, data_args, do_train, do_eval, do_predict) + elif data_type == DatasetType.TEXT: + print("Loading dataset in txt mode") + data_module = create_data_module_txt(tokenizer, data_args, do_train, do_eval, do_predict) + elif data_type == DatasetType.CHAT: + print("Loading dataset in chat mode") + data_module = create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict) + else: + raise RuntimeError("Unkown dataset type") + + train_dataloader = None + eval_dataloader = None + + if do_train: + train_dataloader = torch.utils.data.DataLoader( + data_module['train_dataset'], + shuffle=True, + collate_fn=data_module['data_collator'], + batch_size=batch_size + ) + if do_eval: + eval_dataloader = torch.utils.data.DataLoader( + data_module['eval_dataset'], + shuffle=True, + collate_fn=data_module['data_collator'], + batch_size=eval_batch_size + ) + return train_dataloader, eval_dataloader