Make fattn optional
This commit is contained in:
parent
216e59c105
commit
275c1a1852
6 changed files with 57 additions and 4 deletions
|
|
@ -55,6 +55,11 @@ void AceStepWorker::setLowVramMode(bool enabled)
|
||||||
m_lowVramMode = enabled;
|
m_lowVramMode = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AceStepWorker::setFlashAttention(bool enabled)
|
||||||
|
{
|
||||||
|
m_flashAttention = enabled;
|
||||||
|
}
|
||||||
|
|
||||||
bool AceStepWorker::isGenerating(SongItem* song)
|
bool AceStepWorker::isGenerating(SongItem* song)
|
||||||
{
|
{
|
||||||
if (!m_busy.load() && song)
|
if (!m_busy.load() && song)
|
||||||
|
|
@ -440,7 +445,7 @@ bool AceStepWorker::loadModels()
|
||||||
ace_lm_default_params(&lmParams);
|
ace_lm_default_params(&lmParams);
|
||||||
lmParams.model_path = m_lmModelPathBytes.constData();
|
lmParams.model_path = m_lmModelPathBytes.constData();
|
||||||
lmParams.use_fsm = true;
|
lmParams.use_fsm = true;
|
||||||
lmParams.use_fa = true;
|
lmParams.use_fa = m_flashAttention;
|
||||||
|
|
||||||
m_lmContext = ace_lm_load(&lmParams);
|
m_lmContext = ace_lm_load(&lmParams);
|
||||||
if (!m_lmContext)
|
if (!m_lmContext)
|
||||||
|
|
@ -455,7 +460,7 @@ bool AceStepWorker::loadModels()
|
||||||
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
||||||
synthParams.dit_path = m_ditPathBytes.constData();
|
synthParams.dit_path = m_ditPathBytes.constData();
|
||||||
synthParams.vae_path = m_vaePathBytes.constData();
|
synthParams.vae_path = m_vaePathBytes.constData();
|
||||||
synthParams.use_fa = true;
|
synthParams.use_fa = m_flashAttention;
|
||||||
|
|
||||||
m_synthContext = ace_synth_load(&synthParams);
|
m_synthContext = ace_synth_load(&synthParams);
|
||||||
if (!m_synthContext)
|
if (!m_synthContext)
|
||||||
|
|
@ -494,7 +499,7 @@ bool AceStepWorker::loadLm()
|
||||||
ace_lm_default_params(&lmParams);
|
ace_lm_default_params(&lmParams);
|
||||||
lmParams.model_path = m_lmModelPathBytes.constData();
|
lmParams.model_path = m_lmModelPathBytes.constData();
|
||||||
lmParams.use_fsm = true;
|
lmParams.use_fsm = true;
|
||||||
lmParams.use_fa = true;
|
lmParams.use_fa = m_flashAttention;
|
||||||
|
|
||||||
m_lmContext = ace_lm_load(&lmParams);
|
m_lmContext = ace_lm_load(&lmParams);
|
||||||
if (!m_lmContext)
|
if (!m_lmContext)
|
||||||
|
|
@ -524,7 +529,7 @@ bool AceStepWorker::loadSynth()
|
||||||
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
||||||
synthParams.dit_path = m_ditPathBytes.constData();
|
synthParams.dit_path = m_ditPathBytes.constData();
|
||||||
synthParams.vae_path = m_vaePathBytes.constData();
|
synthParams.vae_path = m_vaePathBytes.constData();
|
||||||
synthParams.use_fa = true;
|
synthParams.use_fa = m_flashAttention;
|
||||||
|
|
||||||
m_synthContext = ace_synth_load(&synthParams);
|
m_synthContext = ace_synth_load(&synthParams);
|
||||||
if (!m_synthContext)
|
if (!m_synthContext)
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,10 @@ public:
|
||||||
void setLowVramMode(bool enabled);
|
void setLowVramMode(bool enabled);
|
||||||
bool isLowVramMode() const { return m_lowVramMode; }
|
bool isLowVramMode() const { return m_lowVramMode; }
|
||||||
|
|
||||||
|
// Flash attention mode
|
||||||
|
void setFlashAttention(bool enabled);
|
||||||
|
bool isFlashAttention() const { return m_flashAttention; }
|
||||||
|
|
||||||
// Request a new song generation
|
// Request a new song generation
|
||||||
bool requestGeneration(SongItem song, QString requestTemplate);
|
bool requestGeneration(SongItem song, QString requestTemplate);
|
||||||
|
|
||||||
|
|
@ -75,6 +79,7 @@ private:
|
||||||
std::atomic<bool> m_cancelRequested{false};
|
std::atomic<bool> m_cancelRequested{false};
|
||||||
std::atomic<bool> m_modelsLoaded{false};
|
std::atomic<bool> m_modelsLoaded{false};
|
||||||
bool m_lowVramMode = false;
|
bool m_lowVramMode = false;
|
||||||
|
bool m_flashAttention = true;
|
||||||
|
|
||||||
// Current request data
|
// Current request data
|
||||||
SongItem m_currentSong;
|
SongItem m_currentSong;
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,11 @@ bool AdvancedSettingsDialog::getLowVramMode() const
|
||||||
return ui->lowVramCheckBox->isChecked();
|
return ui->lowVramCheckBox->isChecked();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AdvancedSettingsDialog::getFlashAttention() const
|
||||||
|
{
|
||||||
|
return ui->flashAttentionCheckBox->isChecked();
|
||||||
|
}
|
||||||
|
|
||||||
void AdvancedSettingsDialog::setJsonTemplate(const QString &templateStr)
|
void AdvancedSettingsDialog::setJsonTemplate(const QString &templateStr)
|
||||||
{
|
{
|
||||||
ui->jsonTemplateEdit->setPlainText(templateStr);
|
ui->jsonTemplateEdit->setPlainText(templateStr);
|
||||||
|
|
@ -90,6 +95,11 @@ void AdvancedSettingsDialog::setLowVramMode(bool enabled)
|
||||||
ui->lowVramCheckBox->setChecked(enabled);
|
ui->lowVramCheckBox->setChecked(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AdvancedSettingsDialog::setFlashAttention(bool enabled)
|
||||||
|
{
|
||||||
|
ui->flashAttentionCheckBox->setChecked(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void AdvancedSettingsDialog::on_aceStepBrowseButton_clicked()
|
void AdvancedSettingsDialog::on_aceStepBrowseButton_clicked()
|
||||||
{
|
{
|
||||||
QString dir = QFileDialog::getExistingDirectory(this, "Select AceStep Build Directory", ui->aceStepPathEdit->text());
|
QString dir = QFileDialog::getExistingDirectory(this, "Select AceStep Build Directory", ui->aceStepPathEdit->text());
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ public:
|
||||||
QString getDiTModelPath() const;
|
QString getDiTModelPath() const;
|
||||||
QString getVAEModelPath() const;
|
QString getVAEModelPath() const;
|
||||||
bool getLowVramMode() const;
|
bool getLowVramMode() const;
|
||||||
|
bool getFlashAttention() const;
|
||||||
|
|
||||||
// Setters for settings
|
// Setters for settings
|
||||||
void setJsonTemplate(const QString &templateStr);
|
void setJsonTemplate(const QString &templateStr);
|
||||||
|
|
@ -39,6 +40,7 @@ public:
|
||||||
void setDiTModelPath(const QString &path);
|
void setDiTModelPath(const QString &path);
|
||||||
void setVAEModelPath(const QString &path);
|
void setVAEModelPath(const QString &path);
|
||||||
void setLowVramMode(bool enabled);
|
void setLowVramMode(bool enabled);
|
||||||
|
void setFlashAttention(bool enabled);
|
||||||
|
|
||||||
private slots:
|
private slots:
|
||||||
void on_aceStepBrowseButton_clicked();
|
void on_aceStepBrowseButton_clicked();
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,26 @@
|
||||||
</widget>
|
</widget>
|
||||||
</item>
|
</item>
|
||||||
<item>
|
<item>
|
||||||
|
<widget class="QCheckBox" name="flashAttentionCheckBox">
|
||||||
|
<property name="text">
|
||||||
|
<string>Flash Attention</string>
|
||||||
|
</property>
|
||||||
|
<property name="checked">
|
||||||
|
<bool>true</bool>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
|
<widget class="QLabel" name="flashAttentionLabel">
|
||||||
|
<property name="text">
|
||||||
|
<string>Use flash attention for faster generation. Disable if experiencing issues.</string>
|
||||||
|
</property>
|
||||||
|
<property name="wordWrap">
|
||||||
|
<bool>true</bool>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
<spacer name="verticalSpacer">
|
<spacer name="verticalSpacer">
|
||||||
<property name="orientation">
|
<property name="orientation">
|
||||||
<enum>Qt::Vertical</enum>
|
<enum>Qt::Vertical</enum>
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,10 @@ void MainWindow::loadSettings()
|
||||||
// Load low VRAM mode
|
// Load low VRAM mode
|
||||||
bool lowVram = settings.value("lowVramMode", false).toBool();
|
bool lowVram = settings.value("lowVramMode", false).toBool();
|
||||||
aceStep->setLowVramMode(lowVram);
|
aceStep->setLowVramMode(lowVram);
|
||||||
|
|
||||||
|
// Load flash attention setting
|
||||||
|
bool flashAttention = settings.value("flashAttention", true).toBool();
|
||||||
|
aceStep->setFlashAttention(flashAttention);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MainWindow::saveSettings()
|
void MainWindow::saveSettings()
|
||||||
|
|
@ -181,6 +185,9 @@ void MainWindow::saveSettings()
|
||||||
// Save low VRAM mode
|
// Save low VRAM mode
|
||||||
settings.setValue("lowVramMode", aceStep->isLowVramMode());
|
settings.setValue("lowVramMode", aceStep->isLowVramMode());
|
||||||
|
|
||||||
|
// Save flash attention setting
|
||||||
|
settings.setValue("flashAttention", aceStep->isFlashAttention());
|
||||||
|
|
||||||
settings.setValue("firstRun", false);
|
settings.setValue("firstRun", false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -382,6 +389,7 @@ void MainWindow::on_advancedSettingsButton_clicked()
|
||||||
dialog.setDiTModelPath(ditModelPath);
|
dialog.setDiTModelPath(ditModelPath);
|
||||||
dialog.setVAEModelPath(vaeModelPath);
|
dialog.setVAEModelPath(vaeModelPath);
|
||||||
dialog.setLowVramMode(aceStep->isLowVramMode());
|
dialog.setLowVramMode(aceStep->isLowVramMode());
|
||||||
|
dialog.setFlashAttention(aceStep->isFlashAttention());
|
||||||
|
|
||||||
if (dialog.exec() == QDialog::Accepted)
|
if (dialog.exec() == QDialog::Accepted)
|
||||||
{
|
{
|
||||||
|
|
@ -408,6 +416,9 @@ void MainWindow::on_advancedSettingsButton_clicked()
|
||||||
// Update low VRAM mode
|
// Update low VRAM mode
|
||||||
aceStep->setLowVramMode(dialog.getLowVramMode());
|
aceStep->setLowVramMode(dialog.getLowVramMode());
|
||||||
|
|
||||||
|
// Update flash attention setting
|
||||||
|
aceStep->setFlashAttention(dialog.getFlashAttention());
|
||||||
|
|
||||||
saveSettings();
|
saveSettings();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue