fix automatic train test spliting

This commit is contained in:
uvos 2024-05-08 22:07:11 +02:00
parent 65482b55a6
commit bc5321cb33

View File

@ -135,7 +135,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split(
dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']
@ -175,7 +175,7 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split(
dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']