Add chat datamodules

This commit is contained in:
uvos 2024-07-20 21:46:34 +02:00
parent 0b39ba0843
commit 2f35689355
2 changed files with 148 additions and 28 deletions

View File

@ -1,11 +1,52 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Self
from enum import Enum
class DatasetType(Enum):
TEXT = 1
S2S = 2
HUB = 3
CHAT = 4
@staticmethod
def to_string(dtype: Self) -> str:
if dtype == DatasetType.TEXT:
return "text"
elif dtype == DatasetType.S2S:
return "s2s"
elif dtype == DatasetType.HUB:
return "hub"
elif dtype == DatasetType.CHAT:
return "chat"
return "invalid"
@staticmethod
def from_string(string: str):
if string == str(DatasetType.TEXT):
return DatasetType.TEXT
elif string == str(DatasetType.S2S):
return DatasetType.S2S
elif string == str(DatasetType.HUB):
return DatasetType.HUB
elif string == str(DatasetType.CHAT):
return DatasetType.CHAT
return None
def __str__(self):
return DatasetType.to_string(self)
@dataclass
class DataArguments:
dataset: str = field(
metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
metadata={"help": "The dataset to train on"}
)
dataset_type: str = field(
default="text", metadata={"help": f"The type of dataset, set to one of {[e for e in DatasetType]}"}
)
dataset_chat_template: str | None = field(
default=None, metadata={"help": "overrides the chat template to be the one set here"}
)
eval_dataset_size: int = field(
default=512, metadata={"help": "Size of validation dataset."}
@ -26,10 +67,6 @@ class DataArguments:
default=False,
metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
)
block_size: int = field(
default=512,
metadata={"help": "size of the blocks the text is split into for training"},
)
@dataclass
@ -65,8 +102,9 @@ class TrainingArguments():
)
resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
output_dir: str = field(default='./output', metadata={"help": 'The output dir for checkpoints'})
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
per_device_eval_batch_size: int = field(default=1, metadata={"help": 'The eval batch size per GPU. Increase for better speed.'})
gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
@ -82,6 +120,7 @@ class TrainingArguments():
metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
logging_dir: str = field(default='./log', metadata={"help": 'The output dir for logs'})
group_by_length: bool = field(default=False,
metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})

View File

@ -7,22 +7,23 @@ import transformers
import os
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from arguments import DataArguments
from arguments import DataArguments, DatasetType
IGNORE_INDEX = -100
def group_texts(examples, block_size: int):
def group_texts(examples, source_max_len: int):
# Concatenate all texts.
concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
if total_length >= source_max_len:
total_length = (total_length // source_max_len) * source_max_len
# Split by chunks of max_len.
result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
result = {k: [t[i: i + source_max_len] for i in range(0, total_length, source_max_len)] for k, t in concatenated_examples.items()}
result["labels"] = result["input_ids"].copy()
return result
@ -199,14 +200,15 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg
)
def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
def create_data_module_txt(tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
try:
dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if data_args.block_size > tokenizer.model_max_length:
raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
if data_args.source_max_len > tokenizer.model_max_length:
raise ValueError(f"Max source length of {data_args.source_max_len} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
def add_newline_fn(example):
example['text'] = example['text'] + '\n'
@ -219,9 +221,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
dataset = dataset['train'].train_test_split(test_size=data_args.eval_dataset_size, shuffle=False)
eval_dataset = dataset['test']
if 'train' in dataset:
@ -233,14 +233,14 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
lambda example: tokenizer(example['text']),
batched=True,
remove_columns='text',
num_proc=32,
num_proc=os.cpu_count(),
load_from_cache_file=True)
train_dataset_tokenized = train_dataset_tokenized.map(
lambda example: group_texts(example, data_args.block_size),
lambda example: group_texts(example, data_args.source_max_len),
batched=True,
num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {data_args.block_size}")
desc=f"Grouping texts in chunks of {data_args.source_max_len}")
eval_dataset_tokenized = None
if eval_dataset is not None:
@ -248,18 +248,18 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
lambda example: tokenizer(example['text']),
batched=True,
remove_columns='text',
num_proc=32)
num_proc=os.cpu_count())
eval_dataset_tokenized = eval_dataset_tokenized.map(
lambda example: group_texts(example, data_args.block_size),
lambda example: group_texts(example, data_args.source_max_len),
batched=True,
num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {data_args.block_size}")
desc=f"Grouping texts in chunks of {data_args.source_max_len}")
for ids in train_dataset_tokenized['input_ids']:
assert len(ids) == data_args.block_size
assert len(ids) == data_args.source_max_len
for ids in eval_dataset_tokenized['input_ids']:
assert len(ids) == data_args.block_size
assert len(ids) == data_args.source_max_len
return dict(
train_dataset=train_dataset_tokenized if do_train else None,
@ -267,3 +267,84 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
predict_dataset=eval_dataset_tokenized if do_predict else None,
data_collator=transformers.default_data_collator
)
def create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict):
try:
dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if data_args.dataset_chat_template is not None:
tokenizer.chat_template = data_args.dataset_chat_template
target_len = data_args.source_max_len * 0.5
grouped_chats = list()
last_len = 0
for row in tqdm(dataset, desc="Grouping chat messages"):
content_length = len(tokenizer(row['content'])['input_ids'])
if last_len + content_length <= target_len and len(grouped_chats) > 0:
grouped_chats[-1]['chat'].append(row)
last_len += content_length
else:
last_len = 0
grouped_chats.append({'chat': [row]})
dataset = datasets.Dataset.from_list(grouped_chats)
dataset = dataset.map(lambda x: {"text": tokenizer.apply_chat_template(x["chat"], tokenize=False, add_generation_prompt=False)})
dataset.remove_columns('chat')
eval_dataset = None
if do_eval or do_predict:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset_split = dataset.train_test_split(test_size=data_args.eval_dataset_size, shuffle=True)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
data_collator = DataCollatorForCausalLMText(
tokenizer=tokenizer,
max_len=data_args.source_max_len,
)
return dict(
train_dataset=train_dataset if do_train else None,
eval_dataset=eval_dataset,
predict_dataset=eval_dataset,
data_collator=data_collator
)
def get_data_loaders(tokenizer, data_args: DataArguments, batch_size: int, eval_batch_size: int,
do_train: bool, do_eval: bool, do_predict: bool = False):
data_type = DatasetType.from_string(data_args.dataset_type)
if data_type == DatasetType.S2S:
print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, do_train, do_eval, do_predict)
elif data_type == DatasetType.HUB:
print("Loading dataset from hub, expecting alpaca style")
data_module = create_data_module_hub(tokenizer, data_args, do_train, do_eval, do_predict)
elif data_type == DatasetType.TEXT:
print("Loading dataset in txt mode")
data_module = create_data_module_txt(tokenizer, data_args, do_train, do_eval, do_predict)
elif data_type == DatasetType.CHAT:
print("Loading dataset in chat mode")
data_module = create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict)
else:
raise RuntimeError("Unkown dataset type")
train_dataloader = None
eval_dataloader = None
if do_train:
train_dataloader = torch.utils.data.DataLoader(
data_module['train_dataset'],
shuffle=True,
collate_fn=data_module['data_collator'],
batch_size=batch_size
)
if do_eval:
eval_dataloader = torch.utils.data.DataLoader(
data_module['eval_dataset'],
shuffle=True,
collate_fn=data_module['data_collator'],
batch_size=eval_batch_size
)
return train_dataloader, eval_dataloader