Fix group_texts not grouping texts to a single length when the number of samples is less than the number of threads used
This commit is contained in:
@ -101,7 +101,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
|
||||
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
|
||||
|
||||
if data_args.dataset.endswith("json"):
|
||||
if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"):
|
||||
print("Loading dataset in s2s mode")
|
||||
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
||||
elif data_args.data_from_hub:
|
||||
@ -109,6 +109,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
else:
|
||||
print("Loading dataset in txt mode")
|
||||
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
|
||||
|
||||
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset['train_dataset'],
|
||||
|
Reference in New Issue
Block a user