add support for huggingfacehub datasets and for specificying a prompt for eval

This commit is contained in:
2024-05-07 00:23:12 +02:00
parent 8abea9ef89
commit a74ef976e4
5 changed files with 183 additions and 43 deletions

View File

@ -19,6 +19,10 @@ class DataArguments:
default=256, default=256,
metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
) )
data_from_hub: Optional[bool] = field(
default=False,
metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
)
dataset: str = field( dataset: str = field(
default=None, default=None,
metadata={"help": "A json file (s2s) or text file with the dataset to train on"} metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
@ -60,10 +64,6 @@ class TrainingArguments():
default=False, default=False,
metadata={"help": "Use 8-bit adam."} metadata={"help": "Use 8-bit adam."}
) )
report_to: str = field(
default='none',
metadata={"help": "To use wandb or something else for reporting."}
)
resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'}) resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'}) ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
@ -85,7 +85,6 @@ class TrainingArguments():
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
group_by_length: bool = field(default=False, group_by_length: bool = field(default=False,
metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
storage_fp16: bool = field(default=False, metadata={"help": 'Store untrained layers in 16bit'})
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'}) max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'})
save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
@ -94,3 +93,5 @@ class TrainingArguments():
max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"}) max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"})
churn_percent: int = field(default=100, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"}) churn_percent: int = field(default=100, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"})
eval_steps: int = field(default=-1, metadata={"help": "Number of optimization steps after wich to compute the evaluation loss"}) eval_steps: int = field(default=-1, metadata={"help": "Number of optimization steps after wich to compute the evaluation loss"})
eval_prompt: str = field(default=None, metadata={"help": "A prompt to used during eval to check if the model is learning"})
reshufle_steps: int = field(default=50, metadata={"help": "Number of steps to take before changing the active parameters"})

View File

@ -27,7 +27,44 @@ def group_texts(examples, block_size: int):
@dataclass @dataclass
class DataCollatorForCausalLM(object): class DataCollatorForCausalLMText(object):
tokenizer: transformers.PreTrainedTokenizer
max_len: int
def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]:
# Extract elements
examples = [f"{self.tokenizer.bos_token}{example['text']}{self.tokenizer.eos_token}" for example in instances]
# Tokenize
tokenized_examples = self.tokenizer(
examples,
max_length=self.max_len,
truncation=True,
add_special_tokens=False,
)
# Build the input and labels for causal LM
input_ids = []
for tokenized_example in tokenized_examples['input_ids']:
input_ids.append(torch.tensor(tokenized_example))
# Apply padding
padding_value = None
if self.tokenizer.pad_token_id is not None:
padding_value = self.tokenizer.pad_token_id
elif self.tokenizer.eos_token_id is not None:
padding_value = self.tokenizer.eos_token_id
else:
raise RuntimeError("Model dose not have a pad or eos token")
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value)
data_dict = {
'input_ids': input_ids,
'attention_mask': input_ids.ne(padding_value),
'labels': input_ids
}
return data_dict
@dataclass
class DataCollatorForCausalLMs2s(object):
tokenizer: transformers.PreTrainedTokenizer tokenizer: transformers.PreTrainedTokenizer
source_max_len: int source_max_len: int
target_max_len: int target_max_len: int
@ -111,7 +148,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
data_collator = DataCollatorForCausalLM( data_collator = DataCollatorForCausalLMs2s(
tokenizer=tokenizer, tokenizer=tokenizer,
source_max_len=data_args.source_max_len, source_max_len=data_args.source_max_len,
target_max_len=data_args.target_max_len, target_max_len=data_args.target_max_len,
@ -127,6 +164,40 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
) )
def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict:
try:
dataset = datasets.load_dataset(data_args.dataset)
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if do_eval or do_predict:
if 'eval' in dataset:
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']
if 'train' in dataset:
train_dataset = dataset['train']
else:
train_dataset = dataset
data_collator = DataCollatorForCausalLMText(
tokenizer=tokenizer,
max_len=data_args.source_max_len,
)
return dict(
train_dataset=train_dataset if do_train else None,
eval_dataset=eval_dataset if do_eval else None,
predict_dataset=eval_dataset if do_predict else None,
data_collator=data_collator
)
def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict: def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
try: try:
dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]}) dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
@ -147,7 +218,8 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
eval_dataset = dataset['eval'] eval_dataset = dataset['eval']
else: else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`') print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split( breakpoint()
dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42 test_size=data_args.eval_dataset_size, shuffle=True, seed=42
) )
eval_dataset = dataset['test'] eval_dataset = dataset['test']

View File

@ -48,14 +48,22 @@ class LinearGroup:
for module in self.modules: for module in self.modules:
module.decompress() module.decompress()
def checkDistance(self) -> tuple[float, float]: def getDistanceAndError(self) -> tuple[float, float]:
distance_accum = 0.0 distance_accum = torch.Tensor()
error_accum = 0.0 error_accum = torch.Tensor()
for module in self.modules: for module in self.modules:
distance, error = module.checkDistance() distance, error = module.getDistanceAndError()
distance_accum += distance**2 distance = distance.to("cpu")
error_accum += error**2 error = error.to("cpu")
return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules))) distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel()))))
error_accum = torch.cat((error_accum, error.reshape((error.numel()))))
return (distance_accum, error_accum)
def check(self) -> bool:
for module in self.modules:
if not module.check():
return False
return True
class DyntrainModel: class DyntrainModel:
@ -160,15 +168,18 @@ class DyntrainModel:
total_params = self.dynamicParameters() + self.staticParameters() total_params = self.dynamicParameters() + self.staticParameters()
return sum(p.numel() for p in total_params if p.requires_grad) return sum(p.numel() for p in total_params if p.requires_grad)
def reshuffleActive(self) -> None: def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
index = randint(0, len(self.active_linear_groups) - 1)
return self.active_linear_groups[index].getDistanceAndError()
def reshuffleActive(self):
active_count = len(self.active_linear_groups) active_count = len(self.active_linear_groups)
index = 0 index = 0
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction): while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
distance, error = self.active_linear_groups[index].checkDistance()
print(f"linear group has moved {distance} with an error of {error}")
group = self.active_linear_groups.pop(index) group = self.active_linear_groups.pop(index)
group.setFrozen(True) group.setFrozen(True)
self.frozen_linear_groups.append(group) self.frozen_linear_groups.append(group)
assert group.check()
params = self.activeParameterCount() params = self.activeParameterCount()
@ -180,6 +191,7 @@ class DyntrainModel:
group.setFrozen(False) group.setFrozen(False)
params += group.paramCount() params += group.paramCount()
self.active_linear_groups.append(group) self.active_linear_groups.append(group)
assert group.check()
print(math.ceil(params / 1e6)) print(math.ceil(params / 1e6))
active_params = self.activeParameterCount() active_params = self.activeParameterCount()
@ -248,4 +260,8 @@ class DyntrainModel:
group_index += 1 group_index += 1
for group in tqdm(linear_groups, desc="Perpareing layers"): for group in tqdm(linear_groups, desc="Perpareing layers"):
if group.isFrozen():
group.compress() group.compress()
else:
group.decompress()
assert group.check()

View File

@ -35,7 +35,6 @@ class Linear(torch.nn.Linear):
self.compress() self.compress()
else: else:
self.decompress() self.decompress()
self.weightStart = torch.Tensor(self.weight).clone().detach()
def isFrozen(self) -> bool: def isFrozen(self) -> bool:
return not self.weight.requires_grad return not self.weight.requires_grad
@ -60,9 +59,15 @@ class Linear(torch.nn.Linear):
@wraps(torch.nn.Module.to) @wraps(torch.nn.Module.to)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
breakpoint()
return self return self
def check(self) -> bool:
if self.isFrozen() and self.weight.dtype != torch.float16:
return False
elif not self.isFrozen() and self.weight.dtype != torch.float32:
return False
return True
class DynamicConvertingLinear(Linear): class DynamicConvertingLinear(Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, def __init__(self, in_features, out_features, bias=True, device=None, dtype=None,
@ -116,6 +121,7 @@ class DynamicQantizedLinear(Linear):
self.bias_state = None self.bias_state = None
self.block_size = 128 self.block_size = 128
self.quant_type = 'nf4' self.quant_type = 'nf4'
self.weight_start = self.weight.clone().detach()
@classmethod @classmethod
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device, def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device,
@ -125,6 +131,7 @@ class DynamicQantizedLinear(Linear):
compute_dtype=compute_dtype, output_device=output_device) compute_dtype=compute_dtype, output_device=output_device)
new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device)) new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
new_module.weight_start = new_module.weight.clone().detach()
return new_module return new_module
def compress(self) -> None: def compress(self) -> None:
@ -134,26 +141,27 @@ class DynamicQantizedLinear(Linear):
bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size) self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
weight = torch.nn.Parameter(self.weight.to(self.cold_device)) frozen = self.isFrozen()
bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None self.weight = torch.nn.Parameter(self.weight.to(self.cold_device))
self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
self.setFrozen(frozen, False)
def decompress(self) -> None: def decompress(self) -> None:
if self.weight_quantized is None: self.weight_quantized = None
raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable") self.weight_state = None
dtype = self.weight.dtype self.bias_quantized = None
self.weight = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) self.bias_state = None
self.weight_start = self.weight.clone().detach().to(self.cold_device)
self.weight = torch.nn.Parameter(self.weight.to(self.active_device))
if self.bias_quantized: if self.bias_quantized:
self.bias = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) self.bias = torch.nn.Parameter(self.bias.to(self.active_device))
def checkDistance(self) -> tuple[float, float]: def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
if self.weight_quantized is None:
raise RuntimeError("checkDistance() called without quantized weights avialable")
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16) original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size) quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype) dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype) distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item() error = (dequantized_original_weight - original_weight).to(torch.float32)
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
return (distance, error) return (distance, error)
def setOutputDevice(self, output_device: torch.device): def setOutputDevice(self, output_device: torch.device):
@ -200,3 +208,24 @@ class DynamicQantizedLinear(Linear):
if not frozen: if not frozen:
super().inplaceTo(device=device) super().inplaceTo(device=device)
self.setFrozen(frozen, False) self.setFrozen(frozen, False)
def check(self) -> bool:
if self.isFrozen():
if torch.device(self.weight.device) != torch.device(self.cold_device):
breakpoint()
print("Frozen but not cold")
return False
if self.weight_quantized is None:
breakpoint()
print("Frozen but not quanted")
return False
else:
if torch.device(self.weight.device) != torch.device(self.active_device):
breakpoint()
print("Active but not warm")
return False
if self.weight_quantized is not None:
breakpoint()
print("Active but still quantized")
return False
return True

View File

@ -7,9 +7,10 @@ import os
import shutil import shutil
import math import math
from tqdm.auto import tqdm from tqdm.auto import tqdm
import gc
from arguments import DataArguments, ModelArguments, TrainingArguments from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub
from tokenizer import get_tokenizer from tokenizer import get_tokenizer
from dyntrainmodel import DyntrainModel from dyntrainmodel import DyntrainModel
@ -56,7 +57,9 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters
return optimizer return optimizer
def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> float: def evaluate(model: DyntrainModel, tokenizer,
dataloader: torch.utils.data.DataLoader, globalstep: int,
log_writer: tensorboard.SummaryWriter, eval_prompt: str = None):
print("*** Eval ***") print("*** Eval ***")
loss = torch.zeros((1), device="cuda:0") loss = torch.zeros((1), device="cuda:0")
model.model.eval() model.model.eval()
@ -66,8 +69,17 @@ def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> f
outputs = model.model(**batch) outputs = model.model(**batch)
loss += outputs.loss loss += outputs.loss
loss = loss / len(dataloader) loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}") print(f"Eval Loss {loss.item()}")
if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Eval generation: response_decoded")
log_writer.add_text("Text/Eval", response_decoded, globalstep)
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
log_writer = tensorboard.SummaryWriter() log_writer = tensorboard.SummaryWriter()
@ -90,6 +102,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if data_args.dataset.endswith("json"): if data_args.dataset.endswith("json"):
print("Loading dataset in s2s mode") print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
elif data_args.data_from_hub:
data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
else: else:
print("Loading dataset in txt mode") print("Loading dataset in txt mode")
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False) data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
@ -137,12 +151,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
for key in batch: for key in batch:
batch[key] = batch[key].to("cuda:0") batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch) outputs = model.model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps loss = outputs.loss / training_args.gradient_accumulation_steps
log_writer.add_scalar("Loss/train", loss, global_step)
loss.backward() loss.backward()
if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader): if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
if global_step % training_args.logging_steps == 0:
log_writer.add_scalar("Loss/train", loss, global_step)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
@ -151,9 +167,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if global_step % 5 == 0: if global_step % 5 == 0:
print(f"Train Loss {loss.item()}") print(f"Train Loss {loss.item()}")
if global_step % 50 == 0 and training_args.max_instant_params != 0: if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
print("Reshuffleing")
lr_scheduler.optimizer = None lr_scheduler.optimizer = None
del optimizer del optimizer
# distance, error = model.getDistanceAndErrorSample()
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
model.reshuffleActive() model.reshuffleActive()
model.balanceActive() model.balanceActive()
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step) log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
@ -173,15 +194,16 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if global_step % training_args.save_steps == 0: if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0: if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
evaluate(model, eval_dataloader) evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
if training_args.flush_allocator: if training_args.flush_allocator:
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if training_args.do_eval and training_args.eval_steps == -1: if training_args.do_eval and training_args.eval_steps == -1:
evaluate(model, eval_dataloader) evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
evaluate(model, eval_dataloader) evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
save_model(model.model, global_step, training_args.output_dir) save_model(model.model, global_step, training_args.output_dir)