add qunatized linear, refactor model for it soon to be addition

This commit is contained in:
2024-03-23 21:38:27 +01:00
parent 38a7f7cfc4
commit 3fa1fc254f
5 changed files with 191 additions and 71 deletions

View File

@ -60,7 +60,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
secondary_device = torch.device(training_args.secondary_device)
log_writer = tensorboard.SummaryWriter()
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6,
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True)
model.toDevices([primary_device, secondary_device])
model.balanceActive()
@ -133,7 +134,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if global_step % 10 == 0:
print(loss)
if global_step % 10 == 0 and model_args.max_instant_params != 0:
if global_step % 10 == 0 and training_args.max_instant_params != 0:
lr_scheduler.optimizer = None
del optimizer
model.reshuffleActive()