Fix mypy warnings

This commit is contained in:
Carl Philipp Klemm
2024-05-07 19:48:40 +02:00
parent a74ef976e4
commit 68f748e99e
4 changed files with 34 additions and 30 deletions

View File

@ -4,6 +4,9 @@ from typing import Optional
@dataclass @dataclass
class DataArguments: class DataArguments:
dataset: str = field(
metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
)
eval_dataset_size: int = field( eval_dataset_size: int = field(
default=512, metadata={"help": "Size of validation dataset."} default=512, metadata={"help": "Size of validation dataset."}
) )
@ -23,10 +26,6 @@ class DataArguments:
default=False, default=False,
metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"} metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
) )
dataset: str = field(
default=None,
metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
)
block_size: int = field( block_size: int = field(
default=512, default=512,
metadata={"help": "size of the blocks the text is split into for training"}, metadata={"help": "size of the blocks the text is split into for training"},
@ -35,7 +34,7 @@ class DataArguments:
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model_name_or_path: Optional[str] = field( model_name_or_path: str = field(
default="EleutherAI/pythia-12b" default="EleutherAI/pythia-12b"
) )
tokenizer: Optional[str] = field( tokenizer: Optional[str] = field(
@ -49,7 +48,7 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Never resize tokenizer embeddings"} metadata={"help": "Never resize tokenizer embeddings"}
) )
quantize: Optional[bool] = field ( quantize: bool = field(
default=False, default=False,
metadata={"help": "Quantize parameters not currently be actively trained"} metadata={"help": "Quantize parameters not currently be actively trained"}
) )

View File

@ -16,7 +16,7 @@ class LinearGroup:
for module in self.modules: for module in self.modules:
assert isinstance(module, Linear) assert isinstance(module, Linear)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None: def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None, output_device: torch.device | None = None) -> None:
for module in self.modules: for module in self.modules:
module.inplaceTo(dtype, device) module.inplaceTo(dtype, device)
self.modules[-1].setOutputDevice(output_device) self.modules[-1].setOutputDevice(output_device)
@ -67,7 +67,7 @@ class LinearGroup:
class DyntrainModel: class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str, quantize: bool, def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, model_name_or_path,
@ -80,19 +80,19 @@ class DyntrainModel:
self.reshuffle_fraction = reshuffle_fraction self.reshuffle_fraction = reshuffle_fraction
if reshuffle_fraction < 0.10 or reshuffle_fraction > 1: if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
self.devices = list() self.devices = list[torch.device]()
self.inital_reshufle = True self.inital_reshufle = True
if gradient_checkpointing: if gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
self.frozen_linear_groups = list() self.frozen_linear_groups = list[LinearGroup]()
self.active_linear_groups = list() self.active_linear_groups = list[LinearGroup]()
linear_group_names = DyntrainModel._getLinearGroupNames(self.model) linear_group_names = DyntrainModel._getLinearGroupNames(self.model)
for group in linear_group_names: for group in linear_group_names:
for key in group: for key in group:
replace_module(self.model, key, self._getModule(key, quantize, "cuda:0", "cpu")) replace_module(self.model, key, self._getModule(key, quantize, torch.device("cuda:0"), torch.device("cpu")))
self.frozen_linear_groups.append(LinearGroup(self.model, group)) self.frozen_linear_groups.append(LinearGroup(self.model, group))
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
for group in self.frozen_linear_groups: for group in self.frozen_linear_groups:
@ -106,6 +106,7 @@ class DyntrainModel:
else: else:
return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype) return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype)
@staticmethod
def _getNonlinearNames(layer: torch.nn.Module): def _getNonlinearNames(layer: torch.nn.Module):
names = list() names = list()
modules = dict(layer.named_modules()) modules = dict(layer.named_modules())
@ -115,8 +116,9 @@ class DyntrainModel:
names.append(key) names.append(key)
return names return names
@staticmethod
def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]: def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]:
linear_groups = list() linear_groups = list[list[str]]()
list_counter = 0 list_counter = 0
in_sequence = False in_sequence = False
modules = dict(layer.named_modules()) modules = dict(layer.named_modules())
@ -132,6 +134,7 @@ class DyntrainModel:
list_counter = list_counter + 1 list_counter = list_counter + 1
return linear_groups return linear_groups
@staticmethod
def isModuleIn16bitOutlist(key: str) -> bool: def isModuleIn16bitOutlist(key: str) -> bool:
key = key.split('.')[-1] key = key.split('.')[-1]
whitelist = set({ whitelist = set({
@ -210,7 +213,7 @@ class DyntrainModel:
for i, count in enumerate(active_counts): for i, count in enumerate(active_counts):
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
if i == 0: if i == 0:
memory = memory * 0.8 memory = int(memory * 0.8)
bits_per_param.append(count / memory) bits_per_param.append(count / memory)
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1]) max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
@ -220,7 +223,7 @@ class DyntrainModel:
if group.getDevice() is self.devices[max_index]: if group.getDevice() is self.devices[max_index]:
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
if max_index == 0: if max_index == 0:
memory = memory * 0.8 memory = int(memory * 0.8)
swing = group.paramCount() / memory swing = group.paramCount() / memory
if max_bits_per_param - swing > min_bits_per_param + swing: if max_bits_per_param - swing > min_bits_per_param + swing:
group.inplaceTo(device=self.devices[min_index]) group.inplaceTo(device=self.devices[min_index])
@ -230,7 +233,7 @@ class DyntrainModel:
assert len(devices) > 0 assert len(devices) > 0
modules = dict(self.model.named_modules()) modules = dict(self.model.named_modules())
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices) total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
total_memory -= torch.cuda.get_device_properties(devices[0]).total_memory * 0.2 total_memory -= int(torch.cuda.get_device_properties(devices[0]).total_memory * 0.2)
static_param_count = self.staticParameterCount() static_param_count = self.staticParameterCount()
total_parameter_count = static_param_count + self.dynamicParameterCount() total_parameter_count = static_param_count + self.dynamicParameterCount()
params_per_byte = total_parameter_count / float(total_memory) params_per_byte = total_parameter_count / float(total_memory)
@ -245,9 +248,9 @@ class DyntrainModel:
group_index = 0 group_index = 0
for i, device in enumerate(devices[:-1]): for i, device in enumerate(devices[:-1]):
memory = torch.cuda.get_device_properties(devices).total_memory memory = torch.cuda.get_device_properties(device).total_memory
if i == 0: if i == 0:
memory = memory * 0.8 memory = int(memory * 0.8)
params_for_device = memory * params_per_byte params_for_device = memory * params_per_byte
params = 0 params = 0
while params_for_device > params and group_index < len(linear_groups): while params_for_device > params and group_index < len(linear_groups):

View File

@ -39,7 +39,7 @@ class Linear(torch.nn.Linear):
def isFrozen(self) -> bool: def isFrozen(self) -> bool:
return not self.weight.requires_grad return not self.weight.requires_grad
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None):
frozen = self.isFrozen() frozen = self.isFrozen()
if dtype is not None: if dtype is not None:
self.weight = torch.nn.Parameter(self.weight.to(dtype)) self.weight = torch.nn.Parameter(self.weight.to(dtype))
@ -77,7 +77,7 @@ class DynamicConvertingLinear(Linear):
self.output_device = output_device self.output_device = output_device
@classmethod @classmethod
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None): def fromLinear(cls, in_module: torch.nn.Linear, output_dtype=torch.float32, output_device=None):
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
out_features=in_module.out_features, out_features=in_module.out_features,
bias=in_module.bias is not None, bias=in_module.bias is not None,
@ -124,7 +124,7 @@ class DynamicQantizedLinear(Linear):
self.weight_start = self.weight.clone().detach() self.weight_start = self.weight.clone().detach()
@classmethod @classmethod
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device, def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
output_dtype=None, compute_dtype=torch.float16, output_device=None): output_dtype=None, compute_dtype=torch.float16, output_device=None):
new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None,
active_device=active_device, cold_device=cold_device, output_dtype=output_dtype, active_device=active_device, cold_device=cold_device, output_dtype=output_dtype,
@ -193,7 +193,7 @@ class DynamicQantizedLinear(Linear):
return out.to(output_device).to(output_dtype) return out.to(output_device).to(output_dtype)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None):
if dtype is not None: if dtype is not None:
super().inplaceTo(dtype=dtype) super().inplaceTo(dtype=dtype)
if device is not None: if device is not None:

View File

@ -22,22 +22,22 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
if max_checkpoints > 0: if max_checkpoints > 0:
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.starts_with("step_")] files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")]
def extract_step(filename): def extract_step(filename):
tokens = filename.split('_') tokens = filename.split('_')
return int(tokens[1]) return int(tokens[1])
if len(files) > max_checkpoints: if len(files) > max_checkpoints:
min_step = min(map(extract_step, extract_step)) min_step = min(map(extract_step, files))
delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}") delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}")
print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}") print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}")
shutil.rmtree(delete_checkpoit_dir) shutil.rmtree(delete_checkpoit_dir)
def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters: list[torch.nn.parameter], lr: float, static_lr: float, def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters: list[torch.nn.Parameter] | None, lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool): weight_decay: float, eps: float, adam8bit: bool):
parameters = list() parameters = list[dict]()
parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad) parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad)
param_ids = set([id(p['params']) for p in parameters]) param_ids = set([id(p['params']) for p in parameters])
if static_parameters is not None: if static_parameters is not None:
@ -71,6 +71,7 @@ def evaluate(model: DyntrainModel, tokenizer,
loss = loss / len(dataloader) loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep) log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}") print(f"Eval Loss {loss.item()}")
return loss.item()
if eval_prompt is not None: if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
@ -84,7 +85,7 @@ def evaluate(model: DyntrainModel, tokenizer,
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
log_writer = tensorboard.SummaryWriter() log_writer = tensorboard.SummaryWriter()
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6, model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6),
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True, reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True,
quantize=model_args.quantize) quantize=model_args.quantize)
devices = list(torch.device(i) for i in range(0, torch.cuda.device_count())) devices = list(torch.device(i) for i in range(0, torch.cuda.device_count()))
@ -95,7 +96,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
paramter_count = sum(p.numel() for p in model.model.parameters()) paramter_count = sum(p.numel() for p in model.model.parameters())
active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0 static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters of which {static_parameter_count} are static") print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m"
f"instantanous active paramters of which {static_parameter_count} are static")
tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
@ -122,7 +124,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
) if dataset['eval_dataset'] is not None else None ) if dataset['eval_dataset'] is not None else None
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
total_steps = steps_per_epoch * training_args.epochs total_steps = steps_per_epoch * training_args.epochs
optimizer = get_optimizer(model.dynamicParameters(), optimizer = get_optimizer(model.dynamicParameters(),