QRotaryTraining/datamodules.py
2024-07-20 22:57:16 +02:00

371 lines
15 KiB
Python

# QRotaryTraining - A novel method for fully training all parameters of large
# language models (llms) while using less device memory than traditional methods.
# Copyright (C) 2024 Carl Philipp Klemm
#
# This file is part of QRotaryTraining.
#
# QRotaryTraining is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# QRotaryTraining is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with QRotaryTraining. If not, see <http://www.gnu.org/licenses/>.
import copy
import torch
import typing
import datasets
import itertools
import transformers
import os
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from arguments import DataArguments, DatasetType
IGNORE_INDEX = -100
def group_texts(examples, source_max_len: int):
# Concatenate all texts.
concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= source_max_len:
total_length = (total_length // source_max_len) * source_max_len
# Split by chunks of max_len.
result = {k: [t[i: i + source_max_len] for i in range(0, total_length, source_max_len)] for k, t in concatenated_examples.items()}
result["labels"] = result["input_ids"].copy()
return result
@dataclass
class DataCollatorForCausalLMText(object):
tokenizer: transformers.PreTrainedTokenizer
max_len: int
def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]:
# Extract elements
examples = [f"{self.tokenizer.bos_token}{example['text']}{self.tokenizer.eos_token}" for example in instances]
# Tokenize
tokenized_examples = self.tokenizer(
examples,
max_length=self.max_len,
truncation=True,
add_special_tokens=False,
)
# Build the input and labels for causal LM
input_ids = []
for tokenized_example in tokenized_examples['input_ids']:
input_ids.append(torch.tensor(tokenized_example))
# Apply padding
padding_value = None
if self.tokenizer.pad_token_id is not None:
padding_value = self.tokenizer.pad_token_id
elif self.tokenizer.eos_token_id is not None:
padding_value = self.tokenizer.eos_token_id
else:
raise RuntimeError("Model dose not have a pad or eos token")
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value)
data_dict = {
'input_ids': input_ids,
'attention_mask': input_ids.ne(padding_value),
'labels': input_ids
}
return data_dict
@dataclass
class DataCollatorForCausalLMs2s(object):
tokenizer: transformers.PreTrainedTokenizer
source_max_len: int
target_max_len: int
train_on_source: bool
predict_with_generate: bool
def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]:
# Extract elements
sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances]
targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
# Tokenize
tokenized_sources_with_prompt = self.tokenizer(
sources,
max_length=self.source_max_len,
truncation=True,
add_special_tokens=False,
)
tokenized_targets = self.tokenizer(
targets,
max_length=self.target_max_len,
truncation=True,
add_special_tokens=False,
)
# Build the input and labels for causal LM
input_ids = []
labels = []
for tokenized_source, tokenized_target in zip(
tokenized_sources_with_prompt['input_ids'],
tokenized_targets['input_ids']
):
if not self.predict_with_generate:
input_ids.append(torch.tensor(tokenized_source + tokenized_target))
if not self.train_on_source:
labels.append(
torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
)
else:
labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
else:
input_ids.append(torch.tensor(tokenized_source))
# Apply padding
padding_value = None
if self.tokenizer.pad_token_id is not None:
padding_value = self.tokenizer.pad_token_id
elif self.tokenizer.eos_token_id is not None:
padding_value = self.tokenizer.eos_token_id
else:
raise RuntimeError("Model dose not have a pad or eos token")
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value)
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
data_dict = {
'input_ids': input_ids,
'attention_mask': input_ids.ne(padding_value),
}
if labels is not None:
data_dict['labels'] = labels
return data_dict
def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict:
try:
dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if do_eval or do_predict:
if 'eval' in dataset:
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']
eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
if 'train' in dataset:
train_dataset = dataset['train']
else:
train_dataset = dataset
train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
data_collator = DataCollatorForCausalLMs2s(
tokenizer=tokenizer,
source_max_len=data_args.source_max_len,
target_max_len=data_args.target_max_len,
train_on_source=data_args.train_on_source,
predict_with_generate=False # args.predict_with_generate,
)
return dict(
train_dataset=train_dataset if do_train else None,
eval_dataset=eval_dataset if do_eval else None,
predict_dataset=eval_dataset if do_predict else None,
data_collator=data_collator
)
def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict:
try:
dataset = datasets.load_dataset(data_args.dataset)
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if do_eval or do_predict:
if 'eval' in dataset:
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']
if 'train' in dataset:
train_dataset = dataset['train']
else:
train_dataset = dataset
data_collator = DataCollatorForCausalLMText(
tokenizer=tokenizer,
max_len=data_args.source_max_len,
)
return dict(
train_dataset=train_dataset if do_train else None,
eval_dataset=eval_dataset if do_eval else None,
predict_dataset=eval_dataset if do_predict else None,
data_collator=data_collator
)
def create_data_module_txt(tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
try:
dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if data_args.source_max_len > tokenizer.model_max_length:
raise ValueError(f"Max source length of {data_args.source_max_len} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
def add_newline_fn(example):
example['text'] = example['text'] + '\n'
return example
dataset = dataset.map(add_newline_fn)
eval_dataset = None
if do_eval or do_predict:
if 'eval' in dataset:
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset['train'].train_test_split(test_size=data_args.eval_dataset_size, shuffle=False)
eval_dataset = dataset['test']
if 'train' in dataset:
train_dataset = dataset['train']
else:
train_dataset = dataset
train_dataset_tokenized = train_dataset.map(
lambda example: tokenizer(example['text']),
batched=True,
remove_columns='text',
num_proc=os.cpu_count(),
load_from_cache_file=True)
train_dataset_tokenized = train_dataset_tokenized.map(
lambda example: group_texts(example, data_args.source_max_len),
batched=True,
num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {data_args.source_max_len}")
eval_dataset_tokenized = None
if eval_dataset is not None:
eval_dataset_tokenized = eval_dataset.map(
lambda example: tokenizer(example['text']),
batched=True,
remove_columns='text',
num_proc=os.cpu_count())
eval_dataset_tokenized = eval_dataset_tokenized.map(
lambda example: group_texts(example, data_args.source_max_len),
batched=True,
num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {data_args.source_max_len}")
for ids in train_dataset_tokenized['input_ids']:
assert len(ids) == data_args.source_max_len
for ids in eval_dataset_tokenized['input_ids']:
assert len(ids) == data_args.source_max_len
return dict(
train_dataset=train_dataset_tokenized if do_train else None,
eval_dataset=eval_dataset_tokenized if do_eval else None,
predict_dataset=eval_dataset_tokenized if do_predict else None,
data_collator=transformers.default_data_collator
)
def create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict):
try:
dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if data_args.dataset_chat_template is not None:
tokenizer.chat_template = data_args.dataset_chat_template
target_len = data_args.source_max_len * 0.5
grouped_chats = list()
last_len = 0
for row in tqdm(dataset, desc="Grouping chat messages"):
content_length = len(tokenizer(row['content'])['input_ids'])
if last_len + content_length <= target_len and len(grouped_chats) > 0:
grouped_chats[-1]['chat'].append(row)
last_len += content_length
else:
last_len = 0
grouped_chats.append({'chat': [row]})
dataset = datasets.Dataset.from_list(grouped_chats)
dataset = dataset.map(lambda x: {"text": tokenizer.apply_chat_template(x["chat"], tokenize=False, add_generation_prompt=False)})
dataset.remove_columns('chat')
eval_dataset = None
if do_eval or do_predict:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset_split = dataset.train_test_split(test_size=data_args.eval_dataset_size, shuffle=True)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
data_collator = DataCollatorForCausalLMText(
tokenizer=tokenizer,
max_len=data_args.source_max_len,
)
return dict(
train_dataset=train_dataset if do_train else None,
eval_dataset=eval_dataset,
predict_dataset=eval_dataset,
data_collator=data_collator
)
def get_data_loaders(tokenizer, data_args: DataArguments, batch_size: int, eval_batch_size: int,
do_train: bool, do_eval: bool, do_predict: bool = False):
data_type = DatasetType.from_string(data_args.dataset_type)
if data_type == DatasetType.S2S:
print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, do_train, do_eval, do_predict)
elif data_type == DatasetType.HUB:
print("Loading dataset from hub, expecting alpaca style")
data_module = create_data_module_hub(tokenizer, data_args, do_train, do_eval, do_predict)
elif data_type == DatasetType.TEXT:
print("Loading dataset in txt mode")
data_module = create_data_module_txt(tokenizer, data_args, do_train, do_eval, do_predict)
elif data_type == DatasetType.CHAT:
print("Loading dataset in chat mode")
data_module = create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict)
else:
raise RuntimeError("Unkown dataset type")
train_dataloader = None
eval_dataloader = None
if do_train:
train_dataloader = torch.utils.data.DataLoader(
data_module['train_dataset'],
shuffle=True,
collate_fn=data_module['data_collator'],
batch_size=batch_size
)
if do_eval:
eval_dataloader = torch.utils.data.DataLoader(
data_module['eval_dataset'],
shuffle=True,
collate_fn=data_module['data_collator'],
batch_size=eval_batch_size
)
return train_dataloader, eval_dataloader