Fix group_texts not grouping texts to a single length when the number of samples is less than the number of threads used
This commit is contained in:
@ -4,6 +4,7 @@ import typing
|
|||||||
import datasets
|
import datasets
|
||||||
import itertools
|
import itertools
|
||||||
import transformers
|
import transformers
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
@ -237,7 +238,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
|
|||||||
train_dataset_tokenized = train_dataset_tokenized.map(
|
train_dataset_tokenized = train_dataset_tokenized.map(
|
||||||
lambda example: group_texts(example, data_args.block_size),
|
lambda example: group_texts(example, data_args.block_size),
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=32,
|
num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
|
||||||
load_from_cache_file=True,
|
load_from_cache_file=True,
|
||||||
desc=f"Grouping texts in chunks of {data_args.block_size}")
|
desc=f"Grouping texts in chunks of {data_args.block_size}")
|
||||||
|
|
||||||
@ -251,10 +252,15 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
|
|||||||
eval_dataset_tokenized = eval_dataset_tokenized.map(
|
eval_dataset_tokenized = eval_dataset_tokenized.map(
|
||||||
lambda example: group_texts(example, data_args.block_size),
|
lambda example: group_texts(example, data_args.block_size),
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=32,
|
num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
|
||||||
load_from_cache_file=True,
|
load_from_cache_file=True,
|
||||||
desc=f"Grouping texts in chunks of {data_args.block_size}")
|
desc=f"Grouping texts in chunks of {data_args.block_size}")
|
||||||
|
|
||||||
|
for ids in train_dataset_tokenized['input_ids']:
|
||||||
|
assert len(ids) == data_args.block_size
|
||||||
|
for ids in eval_dataset_tokenized['input_ids']:
|
||||||
|
assert len(ids) == data_args.block_size
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
train_dataset=train_dataset_tokenized if do_train else None,
|
train_dataset=train_dataset_tokenized if do_train else None,
|
||||||
eval_dataset=eval_dataset_tokenized if do_eval else None,
|
eval_dataset=eval_dataset_tokenized if do_eval else None,
|
||||||
|
@ -101,7 +101,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
|
|
||||||
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
||||||
|
|
||||||
if data_args.dataset.endswith("json"):
|
if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"):
|
||||||
print("Loading dataset in s2s mode")
|
print("Loading dataset in s2s mode")
|
||||||
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
||||||
elif data_args.data_from_hub:
|
elif data_args.data_from_hub:
|
||||||
@ -109,6 +109,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
|||||||
else:
|
else:
|
||||||
print("Loading dataset in txt mode")
|
print("Loading dataset in txt mode")
|
||||||
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
||||||
|
|
||||||
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
|
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
dataset['train_dataset'],
|
dataset['train_dataset'],
|
||||||
|
Reference in New Issue
Block a user