Add chat datamodules
This commit is contained in:
53
arguments.py
53
arguments.py
@ -1,11 +1,52 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
TEXT = 1
|
||||
S2S = 2
|
||||
HUB = 3
|
||||
CHAT = 4
|
||||
|
||||
@staticmethod
|
||||
def to_string(dtype: Self) -> str:
|
||||
if dtype == DatasetType.TEXT:
|
||||
return "text"
|
||||
elif dtype == DatasetType.S2S:
|
||||
return "s2s"
|
||||
elif dtype == DatasetType.HUB:
|
||||
return "hub"
|
||||
elif dtype == DatasetType.CHAT:
|
||||
return "chat"
|
||||
return "invalid"
|
||||
|
||||
@staticmethod
|
||||
def from_string(string: str):
|
||||
if string == str(DatasetType.TEXT):
|
||||
return DatasetType.TEXT
|
||||
elif string == str(DatasetType.S2S):
|
||||
return DatasetType.S2S
|
||||
elif string == str(DatasetType.HUB):
|
||||
return DatasetType.HUB
|
||||
elif string == str(DatasetType.CHAT):
|
||||
return DatasetType.CHAT
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return DatasetType.to_string(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: str = field(
|
||||
metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
|
||||
metadata={"help": "The dataset to train on"}
|
||||
)
|
||||
dataset_type: str = field(
|
||||
default="text", metadata={"help": f"The type of dataset, set to one of {[e for e in DatasetType]}"}
|
||||
)
|
||||
dataset_chat_template: str | None = field(
|
||||
default=None, metadata={"help": "overrides the chat template to be the one set here"}
|
||||
)
|
||||
eval_dataset_size: int = field(
|
||||
default=512, metadata={"help": "Size of validation dataset."}
|
||||
@ -26,10 +67,6 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
|
||||
)
|
||||
block_size: int = field(
|
||||
default=512,
|
||||
metadata={"help": "size of the blocks the text is split into for training"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -65,8 +102,9 @@ class TrainingArguments():
|
||||
)
|
||||
resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
|
||||
ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
|
||||
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
|
||||
output_dir: str = field(default='./output', metadata={"help": 'The output dir for checkpoints'})
|
||||
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
|
||||
per_device_eval_batch_size: int = field(default=1, metadata={"help": 'The eval batch size per GPU. Increase for better speed.'})
|
||||
gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
|
||||
epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
|
||||
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
|
||||
@ -82,6 +120,7 @@ class TrainingArguments():
|
||||
metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
|
||||
warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
|
||||
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
|
||||
logging_dir: str = field(default='./log', metadata={"help": 'The output dir for logs'})
|
||||
group_by_length: bool = field(default=False,
|
||||
metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
|
||||
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
|
||||
|
Reference in New Issue
Block a user