add qunatized linear, refactor model for it soon to be addition
This commit is contained in:
@ -60,7 +60,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
secondary_device = torch.device(training_args.secondary_device)
|
||||
log_writer = tensorboard.SummaryWriter()
|
||||
|
||||
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, model_args.max_instant_params * 1e6, True, True)
|
||||
model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6,
|
||||
reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True)
|
||||
model.toDevices([primary_device, secondary_device])
|
||||
model.balanceActive()
|
||||
|
||||
@ -133,7 +134,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
if global_step % 10 == 0:
|
||||
print(loss)
|
||||
|
||||
if global_step % 10 == 0 and model_args.max_instant_params != 0:
|
||||
if global_step % 10 == 0 and training_args.max_instant_params != 0:
|
||||
lr_scheduler.optimizer = None
|
||||
del optimizer
|
||||
model.reshuffleActive()
|
||||
|
Reference in New Issue
Block a user