Add chat datamodules
This commit is contained in:
53
arguments.py
53
arguments.py
@ -1,11 +1,52 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, Self
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetType(Enum):
|
||||||
|
TEXT = 1
|
||||||
|
S2S = 2
|
||||||
|
HUB = 3
|
||||||
|
CHAT = 4
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_string(dtype: Self) -> str:
|
||||||
|
if dtype == DatasetType.TEXT:
|
||||||
|
return "text"
|
||||||
|
elif dtype == DatasetType.S2S:
|
||||||
|
return "s2s"
|
||||||
|
elif dtype == DatasetType.HUB:
|
||||||
|
return "hub"
|
||||||
|
elif dtype == DatasetType.CHAT:
|
||||||
|
return "chat"
|
||||||
|
return "invalid"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_string(string: str):
|
||||||
|
if string == str(DatasetType.TEXT):
|
||||||
|
return DatasetType.TEXT
|
||||||
|
elif string == str(DatasetType.S2S):
|
||||||
|
return DatasetType.S2S
|
||||||
|
elif string == str(DatasetType.HUB):
|
||||||
|
return DatasetType.HUB
|
||||||
|
elif string == str(DatasetType.CHAT):
|
||||||
|
return DatasetType.CHAT
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return DatasetType.to_string(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
dataset: str = field(
|
dataset: str = field(
|
||||||
metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
|
metadata={"help": "The dataset to train on"}
|
||||||
|
)
|
||||||
|
dataset_type: str = field(
|
||||||
|
default="text", metadata={"help": f"The type of dataset, set to one of {[e for e in DatasetType]}"}
|
||||||
|
)
|
||||||
|
dataset_chat_template: str | None = field(
|
||||||
|
default=None, metadata={"help": "overrides the chat template to be the one set here"}
|
||||||
)
|
)
|
||||||
eval_dataset_size: int = field(
|
eval_dataset_size: int = field(
|
||||||
default=512, metadata={"help": "Size of validation dataset."}
|
default=512, metadata={"help": "Size of validation dataset."}
|
||||||
@ -26,10 +67,6 @@ class DataArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
|
metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
|
||||||
)
|
)
|
||||||
block_size: int = field(
|
|
||||||
default=512,
|
|
||||||
metadata={"help": "size of the blocks the text is split into for training"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -65,8 +102,9 @@ class TrainingArguments():
|
|||||||
)
|
)
|
||||||
resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
|
resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
|
||||||
ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
|
ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
|
||||||
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
|
output_dir: str = field(default='./output', metadata={"help": 'The output dir for checkpoints'})
|
||||||
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
|
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
|
||||||
|
per_device_eval_batch_size: int = field(default=1, metadata={"help": 'The eval batch size per GPU. Increase for better speed.'})
|
||||||
gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
|
gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
|
||||||
epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
|
epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
|
||||||
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
|
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
|
||||||
@ -82,6 +120,7 @@ class TrainingArguments():
|
|||||||
metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
|
metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
|
||||||
warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
|
warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
|
||||||
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
|
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
|
||||||
|
logging_dir: str = field(default='./log', metadata={"help": 'The output dir for logs'})
|
||||||
group_by_length: bool = field(default=False,
|
group_by_length: bool = field(default=False,
|
||||||
metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
|
metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
|
||||||
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
|
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
|
||||||
|
123
datamodules.py
123
datamodules.py
@ -7,22 +7,23 @@ import transformers
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from arguments import DataArguments
|
from arguments import DataArguments, DatasetType
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
def group_texts(examples, block_size: int):
|
def group_texts(examples, source_max_len: int):
|
||||||
# Concatenate all texts.
|
# Concatenate all texts.
|
||||||
concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
|
concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||||
# customize this part to your needs.
|
# customize this part to your needs.
|
||||||
if total_length >= block_size:
|
if total_length >= source_max_len:
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // source_max_len) * source_max_len
|
||||||
# Split by chunks of max_len.
|
# Split by chunks of max_len.
|
||||||
result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
|
result = {k: [t[i: i + source_max_len] for i in range(0, total_length, source_max_len)] for k, t in concatenated_examples.items()}
|
||||||
result["labels"] = result["input_ids"].copy()
|
result["labels"] = result["input_ids"].copy()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -199,14 +200,15 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
|
def create_data_module_txt(tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
|
||||||
try:
|
try:
|
||||||
dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
|
dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
|
||||||
except FileNotFoundError as ex:
|
except FileNotFoundError as ex:
|
||||||
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
|
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
|
||||||
|
|
||||||
if data_args.block_size > tokenizer.model_max_length:
|
if data_args.source_max_len > tokenizer.model_max_length:
|
||||||
raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
|
raise ValueError(f"Max source length of {data_args.source_max_len} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
|
||||||
|
|
||||||
def add_newline_fn(example):
|
def add_newline_fn(example):
|
||||||
example['text'] = example['text'] + '\n'
|
example['text'] = example['text'] + '\n'
|
||||||
@ -219,9 +221,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
|
|||||||
eval_dataset = dataset['eval']
|
eval_dataset = dataset['eval']
|
||||||
else:
|
else:
|
||||||
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
|
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
|
||||||
dataset = dataset['train'].train_test_split(
|
dataset = dataset['train'].train_test_split(test_size=data_args.eval_dataset_size, shuffle=False)
|
||||||
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
|
|
||||||
)
|
|
||||||
eval_dataset = dataset['test']
|
eval_dataset = dataset['test']
|
||||||
|
|
||||||
if 'train' in dataset:
|
if 'train' in dataset:
|
||||||
@ -233,14 +233,14 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
|
|||||||
lambda example: tokenizer(example['text']),
|
lambda example: tokenizer(example['text']),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns='text',
|
remove_columns='text',
|
||||||
num_proc=32,
|
num_proc=os.cpu_count(),
|
||||||
load_from_cache_file=True)
|
load_from_cache_file=True)
|
||||||
train_dataset_tokenized = train_dataset_tokenized.map(
|
train_dataset_tokenized = train_dataset_tokenized.map(
|
||||||
lambda example: group_texts(example, data_args.block_size),
|
lambda example: group_texts(example, data_args.source_max_len),
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
|
num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
|
||||||
load_from_cache_file=True,
|
load_from_cache_file=True,
|
||||||
desc=f"Grouping texts in chunks of {data_args.block_size}")
|
desc=f"Grouping texts in chunks of {data_args.source_max_len}")
|
||||||
|
|
||||||
eval_dataset_tokenized = None
|
eval_dataset_tokenized = None
|
||||||
if eval_dataset is not None:
|
if eval_dataset is not None:
|
||||||
@ -248,18 +248,18 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
|
|||||||
lambda example: tokenizer(example['text']),
|
lambda example: tokenizer(example['text']),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns='text',
|
remove_columns='text',
|
||||||
num_proc=32)
|
num_proc=os.cpu_count())
|
||||||
eval_dataset_tokenized = eval_dataset_tokenized.map(
|
eval_dataset_tokenized = eval_dataset_tokenized.map(
|
||||||
lambda example: group_texts(example, data_args.block_size),
|
lambda example: group_texts(example, data_args.source_max_len),
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
|
num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
|
||||||
load_from_cache_file=True,
|
load_from_cache_file=True,
|
||||||
desc=f"Grouping texts in chunks of {data_args.block_size}")
|
desc=f"Grouping texts in chunks of {data_args.source_max_len}")
|
||||||
|
|
||||||
for ids in train_dataset_tokenized['input_ids']:
|
for ids in train_dataset_tokenized['input_ids']:
|
||||||
assert len(ids) == data_args.block_size
|
assert len(ids) == data_args.source_max_len
|
||||||
for ids in eval_dataset_tokenized['input_ids']:
|
for ids in eval_dataset_tokenized['input_ids']:
|
||||||
assert len(ids) == data_args.block_size
|
assert len(ids) == data_args.source_max_len
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
train_dataset=train_dataset_tokenized if do_train else None,
|
train_dataset=train_dataset_tokenized if do_train else None,
|
||||||
@ -267,3 +267,84 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
|
|||||||
predict_dataset=eval_dataset_tokenized if do_predict else None,
|
predict_dataset=eval_dataset_tokenized if do_predict else None,
|
||||||
data_collator=transformers.default_data_collator
|
data_collator=transformers.default_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict):
|
||||||
|
try:
|
||||||
|
dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
|
||||||
|
except FileNotFoundError as ex:
|
||||||
|
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
|
||||||
|
|
||||||
|
if data_args.dataset_chat_template is not None:
|
||||||
|
tokenizer.chat_template = data_args.dataset_chat_template
|
||||||
|
|
||||||
|
target_len = data_args.source_max_len * 0.5
|
||||||
|
grouped_chats = list()
|
||||||
|
last_len = 0
|
||||||
|
for row in tqdm(dataset, desc="Grouping chat messages"):
|
||||||
|
content_length = len(tokenizer(row['content'])['input_ids'])
|
||||||
|
if last_len + content_length <= target_len and len(grouped_chats) > 0:
|
||||||
|
grouped_chats[-1]['chat'].append(row)
|
||||||
|
last_len += content_length
|
||||||
|
else:
|
||||||
|
last_len = 0
|
||||||
|
grouped_chats.append({'chat': [row]})
|
||||||
|
dataset = datasets.Dataset.from_list(grouped_chats)
|
||||||
|
dataset = dataset.map(lambda x: {"text": tokenizer.apply_chat_template(x["chat"], tokenize=False, add_generation_prompt=False)})
|
||||||
|
dataset.remove_columns('chat')
|
||||||
|
|
||||||
|
eval_dataset = None
|
||||||
|
if do_eval or do_predict:
|
||||||
|
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
|
||||||
|
dataset_split = dataset.train_test_split(test_size=data_args.eval_dataset_size, shuffle=True)
|
||||||
|
train_dataset = dataset_split["train"]
|
||||||
|
eval_dataset = dataset_split["test"]
|
||||||
|
|
||||||
|
data_collator = DataCollatorForCausalLMText(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_len=data_args.source_max_len,
|
||||||
|
)
|
||||||
|
return dict(
|
||||||
|
train_dataset=train_dataset if do_train else None,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
predict_dataset=eval_dataset,
|
||||||
|
data_collator=data_collator
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_loaders(tokenizer, data_args: DataArguments, batch_size: int, eval_batch_size: int,
|
||||||
|
do_train: bool, do_eval: bool, do_predict: bool = False):
|
||||||
|
data_type = DatasetType.from_string(data_args.dataset_type)
|
||||||
|
if data_type == DatasetType.S2S:
|
||||||
|
print("Loading dataset in s2s mode")
|
||||||
|
data_module = create_data_module_s2s(tokenizer, data_args, do_train, do_eval, do_predict)
|
||||||
|
elif data_type == DatasetType.HUB:
|
||||||
|
print("Loading dataset from hub, expecting alpaca style")
|
||||||
|
data_module = create_data_module_hub(tokenizer, data_args, do_train, do_eval, do_predict)
|
||||||
|
elif data_type == DatasetType.TEXT:
|
||||||
|
print("Loading dataset in txt mode")
|
||||||
|
data_module = create_data_module_txt(tokenizer, data_args, do_train, do_eval, do_predict)
|
||||||
|
elif data_type == DatasetType.CHAT:
|
||||||
|
print("Loading dataset in chat mode")
|
||||||
|
data_module = create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unkown dataset type")
|
||||||
|
|
||||||
|
train_dataloader = None
|
||||||
|
eval_dataloader = None
|
||||||
|
|
||||||
|
if do_train:
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
data_module['train_dataset'],
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=data_module['data_collator'],
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
if do_eval:
|
||||||
|
eval_dataloader = torch.utils.data.DataLoader(
|
||||||
|
data_module['eval_dataset'],
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=data_module['data_collator'],
|
||||||
|
batch_size=eval_batch_size
|
||||||
|
)
|
||||||
|
return train_dataloader, eval_dataloader
|
||||||
|
Reference in New Issue
Block a user