diff --git a/datamodules.py b/datamodules.py index d14c75b..d0dd173 100644 --- a/datamodules.py +++ b/datamodules.py @@ -4,6 +4,7 @@ import typing import datasets import itertools import transformers +import os from dataclasses import dataclass from torch.nn.utils.rnn import pad_sequence @@ -237,7 +238,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D train_dataset_tokenized = train_dataset_tokenized.map( lambda example: group_texts(example, data_args.block_size), batched=True, - num_proc=32, + num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))), load_from_cache_file=True, desc=f"Grouping texts in chunks of {data_args.block_size}") @@ -251,10 +252,15 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D eval_dataset_tokenized = eval_dataset_tokenized.map( lambda example: group_texts(example, data_args.block_size), batched=True, - num_proc=32, + num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))), load_from_cache_file=True, desc=f"Grouping texts in chunks of {data_args.block_size}") + for ids in train_dataset_tokenized['input_ids']: + assert len(ids) == data_args.block_size + for ids in eval_dataset_tokenized['input_ids']: + assert len(ids) == data_args.block_size + return dict( train_dataset=train_dataset_tokenized if do_train else None, eval_dataset=eval_dataset_tokenized if do_eval else None, diff --git a/train_dynamic.py b/train_dynamic.py index 6425de9..5dcf8cf 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -101,7 +101,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) - if data_args.dataset.endswith("json"): + if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"): print("Loading dataset in s2s mode") data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) elif data_args.data_from_hub: @@ -109,6 +109,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T else: print("Loading dataset in txt mode") data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) + dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'} train_dataloader = torch.utils.data.DataLoader( dataset['train_dataset'],