Add low vram mode (unloads models)
This commit is contained in:
parent
14dec9f335
commit
216e59c105
7 changed files with 597 additions and 122 deletions
|
|
@ -28,6 +28,10 @@ AceStepWorker::~AceStepWorker()
|
||||||
|
|
||||||
void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath)
|
void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath)
|
||||||
{
|
{
|
||||||
|
// Check if paths actually changed
|
||||||
|
bool pathsChanged = (m_lmModelPath != lmPath || m_textEncoderPath != textEncoderPath ||
|
||||||
|
m_ditPath != ditPath || m_vaePath != vaePath);
|
||||||
|
|
||||||
m_lmModelPath = lmPath;
|
m_lmModelPath = lmPath;
|
||||||
m_textEncoderPath = textEncoderPath;
|
m_textEncoderPath = textEncoderPath;
|
||||||
m_ditPath = ditPath;
|
m_ditPath = ditPath;
|
||||||
|
|
@ -38,6 +42,17 @@ void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QStri
|
||||||
m_textEncoderPathBytes = textEncoderPath.toUtf8();
|
m_textEncoderPathBytes = textEncoderPath.toUtf8();
|
||||||
m_ditPathBytes = ditPath.toUtf8();
|
m_ditPathBytes = ditPath.toUtf8();
|
||||||
m_vaePathBytes = vaePath.toUtf8();
|
m_vaePathBytes = vaePath.toUtf8();
|
||||||
|
|
||||||
|
// If paths changed and models are loaded, unload them so they'll be reloaded with new paths
|
||||||
|
if (pathsChanged && m_modelsLoaded.load())
|
||||||
|
{
|
||||||
|
unloadModels();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AceStepWorker::setLowVramMode(bool enabled)
|
||||||
|
{
|
||||||
|
m_lowVramMode = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AceStepWorker::isGenerating(SongItem* song)
|
bool AceStepWorker::isGenerating(SongItem* song)
|
||||||
|
|
@ -124,22 +139,26 @@ bool AceStepWorker::checkCancel(void* data)
|
||||||
|
|
||||||
void AceStepWorker::runGeneration()
|
void AceStepWorker::runGeneration()
|
||||||
{
|
{
|
||||||
// Load models if needed
|
// Convert SongItem to AceRequest
|
||||||
if (!loadModels())
|
AceRequest req = songToRequest(m_currentSong, m_requestTemplate);
|
||||||
|
AceRequest lmOutput;
|
||||||
|
request_init(&lmOutput);
|
||||||
|
|
||||||
|
if (m_lowVramMode)
|
||||||
|
{
|
||||||
|
// Low VRAM mode: load LM → run LM → unload LM → load Synth → run Synth → unload Synth
|
||||||
|
|
||||||
|
// Step 1: Load LM and generate
|
||||||
|
emit progressUpdate(10);
|
||||||
|
|
||||||
|
if (!loadLm())
|
||||||
{
|
{
|
||||||
m_busy.store(false);
|
m_busy.store(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert SongItem to AceRequest
|
|
||||||
AceRequest req = songToRequest(m_currentSong, m_requestTemplate);
|
|
||||||
|
|
||||||
// Step 1: LM generates lyrics and audio codes
|
|
||||||
emit progressUpdate(30);
|
emit progressUpdate(30);
|
||||||
|
|
||||||
AceRequest lmOutput;
|
|
||||||
request_init(&lmOutput);
|
|
||||||
|
|
||||||
int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput,
|
int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput,
|
||||||
nullptr, nullptr,
|
nullptr, nullptr,
|
||||||
checkCancel, this,
|
checkCancel, this,
|
||||||
|
|
@ -147,6 +166,7 @@ void AceStepWorker::runGeneration()
|
||||||
|
|
||||||
if (m_cancelRequested.load())
|
if (m_cancelRequested.load())
|
||||||
{
|
{
|
||||||
|
unloadLm();
|
||||||
emit generationCanceled(m_currentSong);
|
emit generationCanceled(m_currentSong);
|
||||||
m_busy.store(false);
|
m_busy.store(false);
|
||||||
return;
|
return;
|
||||||
|
|
@ -154,6 +174,7 @@ void AceStepWorker::runGeneration()
|
||||||
|
|
||||||
if (lmResult != 0)
|
if (lmResult != 0)
|
||||||
{
|
{
|
||||||
|
unloadLm();
|
||||||
emit generationError("LM generation failed or was canceled");
|
emit generationError("LM generation failed or was canceled");
|
||||||
m_busy.store(false);
|
m_busy.store(false);
|
||||||
return;
|
return;
|
||||||
|
|
@ -162,10 +183,20 @@ void AceStepWorker::runGeneration()
|
||||||
// Update song with generated lyrics
|
// Update song with generated lyrics
|
||||||
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
|
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
|
||||||
|
|
||||||
// Step 2: Synth generates audio
|
// Unload LM to free VRAM
|
||||||
|
unloadLm();
|
||||||
|
|
||||||
|
// Step 2: Load Synth and generate audio
|
||||||
|
emit progressUpdate(50);
|
||||||
|
|
||||||
|
if (!loadSynth())
|
||||||
|
{
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
emit progressUpdate(60);
|
emit progressUpdate(60);
|
||||||
|
|
||||||
AceAudio* audioOut = nullptr;
|
|
||||||
AceAudio outputAudio;
|
AceAudio outputAudio;
|
||||||
outputAudio.samples = nullptr;
|
outputAudio.samples = nullptr;
|
||||||
outputAudio.n_samples = 0;
|
outputAudio.n_samples = 0;
|
||||||
|
|
@ -177,6 +208,9 @@ void AceStepWorker::runGeneration()
|
||||||
1, &outputAudio,
|
1, &outputAudio,
|
||||||
checkCancel, this);
|
checkCancel, this);
|
||||||
|
|
||||||
|
// Unload Synth to free VRAM
|
||||||
|
unloadSynth();
|
||||||
|
|
||||||
if (m_cancelRequested.load())
|
if (m_cancelRequested.load())
|
||||||
{
|
{
|
||||||
emit generationCanceled(m_currentSong);
|
emit generationCanceled(m_currentSong);
|
||||||
|
|
@ -257,6 +291,143 @@ void AceStepWorker::runGeneration()
|
||||||
emit songGenerated(m_currentSong);
|
emit songGenerated(m_currentSong);
|
||||||
|
|
||||||
m_busy.store(false);
|
m_busy.store(false);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Normal mode: load all models at start, unload at end
|
||||||
|
|
||||||
|
// Load models if needed
|
||||||
|
if (!loadModels())
|
||||||
|
{
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: LM generates lyrics and audio codes
|
||||||
|
emit progressUpdate(30);
|
||||||
|
|
||||||
|
int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput,
|
||||||
|
nullptr, nullptr,
|
||||||
|
checkCancel, this,
|
||||||
|
LM_MODE_GENERATE);
|
||||||
|
|
||||||
|
if (m_cancelRequested.load())
|
||||||
|
{
|
||||||
|
emit generationCanceled(m_currentSong);
|
||||||
|
unloadModels();
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lmResult != 0)
|
||||||
|
{
|
||||||
|
emit generationError("LM generation failed or was canceled");
|
||||||
|
unloadModels();
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update song with generated lyrics
|
||||||
|
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
|
||||||
|
|
||||||
|
// Step 2: Synth generates audio
|
||||||
|
emit progressUpdate(60);
|
||||||
|
|
||||||
|
AceAudio outputAudio;
|
||||||
|
outputAudio.samples = nullptr;
|
||||||
|
outputAudio.n_samples = 0;
|
||||||
|
outputAudio.sample_rate = 48000;
|
||||||
|
|
||||||
|
int synthResult = ace_synth_generate(m_synthContext, &lmOutput,
|
||||||
|
nullptr, 0, // no source audio
|
||||||
|
nullptr, 0, // no reference audio
|
||||||
|
1, &outputAudio,
|
||||||
|
checkCancel, this);
|
||||||
|
|
||||||
|
if (m_cancelRequested.load())
|
||||||
|
{
|
||||||
|
emit generationCanceled(m_currentSong);
|
||||||
|
unloadModels();
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (synthResult != 0)
|
||||||
|
{
|
||||||
|
emit generationError("Synthesis failed or was canceled");
|
||||||
|
unloadModels();
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store audio in memory as WAV
|
||||||
|
auto audioData = std::make_shared<QByteArray>();
|
||||||
|
|
||||||
|
// Simple WAV header + stereo float data
|
||||||
|
int numChannels = 2;
|
||||||
|
int bitsPerSample = 16;
|
||||||
|
int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8);
|
||||||
|
int blockAlign = numChannels * (bitsPerSample / 8);
|
||||||
|
int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8);
|
||||||
|
|
||||||
|
// RIFF header
|
||||||
|
audioData->append("RIFF");
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
|
||||||
|
audioData->append("WAVE");
|
||||||
|
|
||||||
|
// fmt chunk
|
||||||
|
audioData->append("fmt ");
|
||||||
|
int fmtSize = 16;
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&fmtSize), 4));
|
||||||
|
short audioFormat = 1; // PCM
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&audioFormat), 2));
|
||||||
|
short numCh = numChannels;
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&numCh), 2));
|
||||||
|
int sampleRate = outputAudio.sample_rate;
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&sampleRate), 4));
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&byteRate), 4));
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&blockAlign), 2));
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&bitsPerSample), 2));
|
||||||
|
|
||||||
|
// data chunk
|
||||||
|
audioData->append("data");
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
|
||||||
|
|
||||||
|
// Convert float samples to 16-bit and write
|
||||||
|
QVector<short> interleaved(outputAudio.n_samples * numChannels);
|
||||||
|
for (int i = 0; i < outputAudio.n_samples; i++)
|
||||||
|
{
|
||||||
|
float left = outputAudio.samples[i];
|
||||||
|
float right = outputAudio.samples[i + outputAudio.n_samples];
|
||||||
|
// Clamp and convert to 16-bit
|
||||||
|
left = std::max(-1.0f, std::min(1.0f, left));
|
||||||
|
right = std::max(-1.0f, std::min(1.0f, right));
|
||||||
|
interleaved[i * 2] = static_cast<short>(left * 32767.0f);
|
||||||
|
interleaved[i * 2 + 1] = static_cast<short>(right * 32767.0f);
|
||||||
|
}
|
||||||
|
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(interleaved.data()), dataSize));
|
||||||
|
|
||||||
|
// Free audio buffer
|
||||||
|
ace_audio_free(&outputAudio);
|
||||||
|
|
||||||
|
// Store the JSON with all generated fields
|
||||||
|
m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true));
|
||||||
|
m_currentSong.audioData = audioData;
|
||||||
|
|
||||||
|
// Extract BPM if available
|
||||||
|
if (lmOutput.bpm > 0)
|
||||||
|
m_currentSong.bpm = lmOutput.bpm;
|
||||||
|
|
||||||
|
// Extract key if available
|
||||||
|
if (!lmOutput.keyscale.empty())
|
||||||
|
m_currentSong.key = QString::fromStdString(lmOutput.keyscale);
|
||||||
|
|
||||||
|
emit progressUpdate(100);
|
||||||
|
emit songGenerated(m_currentSong);
|
||||||
|
|
||||||
|
// Keep models loaded for next generation (normal mode)
|
||||||
|
m_busy.store(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AceStepWorker::loadModels()
|
bool AceStepWorker::loadModels()
|
||||||
|
|
@ -314,6 +485,65 @@ void AceStepWorker::unloadModels()
|
||||||
m_modelsLoaded.store(false);
|
m_modelsLoaded.store(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AceStepWorker::loadLm()
|
||||||
|
{
|
||||||
|
if (m_lmContext)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
AceLmParams lmParams;
|
||||||
|
ace_lm_default_params(&lmParams);
|
||||||
|
lmParams.model_path = m_lmModelPathBytes.constData();
|
||||||
|
lmParams.use_fsm = true;
|
||||||
|
lmParams.use_fa = true;
|
||||||
|
|
||||||
|
m_lmContext = ace_lm_load(&lmParams);
|
||||||
|
if (!m_lmContext)
|
||||||
|
{
|
||||||
|
emit generationError("Failed to load LM model: " + m_lmModelPath);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AceStepWorker::unloadLm()
|
||||||
|
{
|
||||||
|
if (m_lmContext)
|
||||||
|
{
|
||||||
|
ace_lm_free(m_lmContext);
|
||||||
|
m_lmContext = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AceStepWorker::loadSynth()
|
||||||
|
{
|
||||||
|
if (m_synthContext)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
AceSynthParams synthParams;
|
||||||
|
ace_synth_default_params(&synthParams);
|
||||||
|
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
||||||
|
synthParams.dit_path = m_ditPathBytes.constData();
|
||||||
|
synthParams.vae_path = m_vaePathBytes.constData();
|
||||||
|
synthParams.use_fa = true;
|
||||||
|
|
||||||
|
m_synthContext = ace_synth_load(&synthParams);
|
||||||
|
if (!m_synthContext)
|
||||||
|
{
|
||||||
|
emit generationError("Failed to load synthesis models");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AceStepWorker::unloadSynth()
|
||||||
|
{
|
||||||
|
if (m_synthContext)
|
||||||
|
{
|
||||||
|
ace_synth_free(m_synthContext);
|
||||||
|
m_synthContext = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& templateJson)
|
AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& templateJson)
|
||||||
{
|
{
|
||||||
AceRequest req;
|
AceRequest req;
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,10 @@ public:
|
||||||
// Model paths - set these before first generation
|
// Model paths - set these before first generation
|
||||||
void setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath);
|
void setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath);
|
||||||
|
|
||||||
|
// Low VRAM mode: unload models between phases to save VRAM
|
||||||
|
void setLowVramMode(bool enabled);
|
||||||
|
bool isLowVramMode() const { return m_lowVramMode; }
|
||||||
|
|
||||||
// Request a new song generation
|
// Request a new song generation
|
||||||
bool requestGeneration(SongItem song, QString requestTemplate);
|
bool requestGeneration(SongItem song, QString requestTemplate);
|
||||||
|
|
||||||
|
|
@ -54,6 +58,12 @@ private:
|
||||||
bool loadModels();
|
bool loadModels();
|
||||||
void unloadModels();
|
void unloadModels();
|
||||||
|
|
||||||
|
// Individual model load/unload for low VRAM mode
|
||||||
|
bool loadLm();
|
||||||
|
void unloadLm();
|
||||||
|
bool loadSynth();
|
||||||
|
void unloadSynth();
|
||||||
|
|
||||||
// Convert SongItem to AceRequest
|
// Convert SongItem to AceRequest
|
||||||
AceRequest songToRequest(const SongItem& song, const QString& templateJson);
|
AceRequest songToRequest(const SongItem& song, const QString& templateJson);
|
||||||
|
|
||||||
|
|
@ -64,6 +74,7 @@ private:
|
||||||
std::atomic<bool> m_busy{false};
|
std::atomic<bool> m_busy{false};
|
||||||
std::atomic<bool> m_cancelRequested{false};
|
std::atomic<bool> m_cancelRequested{false};
|
||||||
std::atomic<bool> m_modelsLoaded{false};
|
std::atomic<bool> m_modelsLoaded{false};
|
||||||
|
bool m_lowVramMode = false;
|
||||||
|
|
||||||
// Current request data
|
// Current request data
|
||||||
SongItem m_currentSong;
|
SongItem m_currentSong;
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,11 @@ QString AdvancedSettingsDialog::getVAEModelPath() const
|
||||||
return ui->vaeModelEdit->text();
|
return ui->vaeModelEdit->text();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AdvancedSettingsDialog::getLowVramMode() const
|
||||||
|
{
|
||||||
|
return ui->lowVramCheckBox->isChecked();
|
||||||
|
}
|
||||||
|
|
||||||
void AdvancedSettingsDialog::setJsonTemplate(const QString &templateStr)
|
void AdvancedSettingsDialog::setJsonTemplate(const QString &templateStr)
|
||||||
{
|
{
|
||||||
ui->jsonTemplateEdit->setPlainText(templateStr);
|
ui->jsonTemplateEdit->setPlainText(templateStr);
|
||||||
|
|
@ -80,6 +85,11 @@ void AdvancedSettingsDialog::setVAEModelPath(const QString &path)
|
||||||
ui->vaeModelEdit->setText(path);
|
ui->vaeModelEdit->setText(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AdvancedSettingsDialog::setLowVramMode(bool enabled)
|
||||||
|
{
|
||||||
|
ui->lowVramCheckBox->setChecked(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void AdvancedSettingsDialog::on_aceStepBrowseButton_clicked()
|
void AdvancedSettingsDialog::on_aceStepBrowseButton_clicked()
|
||||||
{
|
{
|
||||||
QString dir = QFileDialog::getExistingDirectory(this, "Select AceStep Build Directory", ui->aceStepPathEdit->text());
|
QString dir = QFileDialog::getExistingDirectory(this, "Select AceStep Build Directory", ui->aceStepPathEdit->text());
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ public:
|
||||||
QString getTextEncoderModelPath() const;
|
QString getTextEncoderModelPath() const;
|
||||||
QString getDiTModelPath() const;
|
QString getDiTModelPath() const;
|
||||||
QString getVAEModelPath() const;
|
QString getVAEModelPath() const;
|
||||||
|
bool getLowVramMode() const;
|
||||||
|
|
||||||
// Setters for settings
|
// Setters for settings
|
||||||
void setJsonTemplate(const QString &templateStr);
|
void setJsonTemplate(const QString &templateStr);
|
||||||
|
|
@ -37,6 +38,7 @@ public:
|
||||||
void setTextEncoderModelPath(const QString &path);
|
void setTextEncoderModelPath(const QString &path);
|
||||||
void setDiTModelPath(const QString &path);
|
void setDiTModelPath(const QString &path);
|
||||||
void setVAEModelPath(const QString &path);
|
void setVAEModelPath(const QString &path);
|
||||||
|
void setLowVramMode(bool enabled);
|
||||||
|
|
||||||
private slots:
|
private slots:
|
||||||
void on_aceStepBrowseButton_clicked();
|
void on_aceStepBrowseButton_clicked();
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,43 @@
|
||||||
<property name="currentIndex">
|
<property name="currentIndex">
|
||||||
<number>0</number>
|
<number>0</number>
|
||||||
</property>
|
</property>
|
||||||
|
<widget class="QWidget" name="performanceTab">
|
||||||
|
<attribute name="title">
|
||||||
|
<string>Performance</string>
|
||||||
|
</attribute>
|
||||||
|
<layout class="QVBoxLayout" name="performanceLayout">
|
||||||
|
<item>
|
||||||
|
<widget class="QCheckBox" name="lowVramCheckBox">
|
||||||
|
<property name="text">
|
||||||
|
<string>Low VRAM Mode</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
|
<widget class="QLabel" name="lowVramLabel">
|
||||||
|
<property name="text">
|
||||||
|
<string>Unload models between generation phases to save VRAM. Slower but uses less memory.</string>
|
||||||
|
</property>
|
||||||
|
<property name="wordWrap">
|
||||||
|
<bool>true</bool>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
<item>
|
||||||
|
<spacer name="verticalSpacer">
|
||||||
|
<property name="orientation">
|
||||||
|
<enum>Qt::Vertical</enum>
|
||||||
|
</property>
|
||||||
|
<property name="sizeHint" stdset="0">
|
||||||
|
<size>
|
||||||
|
<width>20</width>
|
||||||
|
<height>40</height>
|
||||||
|
</size>
|
||||||
|
</property>
|
||||||
|
</spacer>
|
||||||
|
</item>
|
||||||
|
</layout>
|
||||||
|
</widget>
|
||||||
<widget class="QWidget" name="jsonTab">
|
<widget class="QWidget" name="jsonTab">
|
||||||
<attribute name="title">
|
<attribute name="title">
|
||||||
<string>JSON Template</string>
|
<string>JSON Template</string>
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,10 @@ void MainWindow::loadSettings()
|
||||||
appDir + "/acestep.cpp/models/Qwen3-Embedding-0.6B-BF16.gguf").toString();
|
appDir + "/acestep.cpp/models/Qwen3-Embedding-0.6B-BF16.gguf").toString();
|
||||||
ditModelPath = settings.value("ditModelPath", appDir + "/acestep.cpp/models/acestep-v15-turbo-Q8_0.gguf").toString();
|
ditModelPath = settings.value("ditModelPath", appDir + "/acestep.cpp/models/acestep-v15-turbo-Q8_0.gguf").toString();
|
||||||
vaeModelPath = settings.value("vaeModelPath", appDir + "/acestep.cpp/models/vae-BF16.gguf").toString();
|
vaeModelPath = settings.value("vaeModelPath", appDir + "/acestep.cpp/models/vae-BF16.gguf").toString();
|
||||||
|
|
||||||
|
// Load low VRAM mode
|
||||||
|
bool lowVram = settings.value("lowVramMode", false).toBool();
|
||||||
|
aceStep->setLowVramMode(lowVram);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MainWindow::saveSettings()
|
void MainWindow::saveSettings()
|
||||||
|
|
@ -174,6 +178,9 @@ void MainWindow::saveSettings()
|
||||||
settings.setValue("ditModelPath", ditModelPath);
|
settings.setValue("ditModelPath", ditModelPath);
|
||||||
settings.setValue("vaeModelPath", vaeModelPath);
|
settings.setValue("vaeModelPath", vaeModelPath);
|
||||||
|
|
||||||
|
// Save low VRAM mode
|
||||||
|
settings.setValue("lowVramMode", aceStep->isLowVramMode());
|
||||||
|
|
||||||
settings.setValue("firstRun", false);
|
settings.setValue("firstRun", false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -374,6 +381,7 @@ void MainWindow::on_advancedSettingsButton_clicked()
|
||||||
dialog.setTextEncoderModelPath(textEncoderModelPath);
|
dialog.setTextEncoderModelPath(textEncoderModelPath);
|
||||||
dialog.setDiTModelPath(ditModelPath);
|
dialog.setDiTModelPath(ditModelPath);
|
||||||
dialog.setVAEModelPath(vaeModelPath);
|
dialog.setVAEModelPath(vaeModelPath);
|
||||||
|
dialog.setLowVramMode(aceStep->isLowVramMode());
|
||||||
|
|
||||||
if (dialog.exec() == QDialog::Accepted)
|
if (dialog.exec() == QDialog::Accepted)
|
||||||
{
|
{
|
||||||
|
|
@ -397,6 +405,9 @@ void MainWindow::on_advancedSettingsButton_clicked()
|
||||||
// Update model paths for acestep.cpp
|
// Update model paths for acestep.cpp
|
||||||
aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath);
|
aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath);
|
||||||
|
|
||||||
|
// Update low VRAM mode
|
||||||
|
aceStep->setLowVramMode(dialog.getLowVramMode());
|
||||||
|
|
||||||
saveSettings();
|
saveSettings();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -324,6 +324,177 @@ TEST(cancellation)
|
||||||
ASSERT_TRUE(true);
|
ASSERT_TRUE(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test 12: Test low VRAM mode generation
|
||||||
|
TEST(generateSongLowVram)
|
||||||
|
{
|
||||||
|
ModelPaths paths = getModelPathsFromSettings();
|
||||||
|
|
||||||
|
// Skip if models don't exist
|
||||||
|
SKIP_IF(!modelsExist(paths));
|
||||||
|
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath);
|
||||||
|
worker.setLowVramMode(true);
|
||||||
|
|
||||||
|
ASSERT_TRUE(worker.isLowVramMode());
|
||||||
|
|
||||||
|
SongItem song("Chill electronic music", "");
|
||||||
|
|
||||||
|
QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})";
|
||||||
|
|
||||||
|
// Track generation result
|
||||||
|
bool generationCompleted = false;
|
||||||
|
SongItem resultSong;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||||
|
[&generationCompleted, &resultSong](const SongItem& song) {
|
||||||
|
std::cout << "\n Low VRAM mode: Song generated successfully!" << std::endl;
|
||||||
|
std::cout << " Caption: " << song.caption.toStdString() << std::endl;
|
||||||
|
if (song.audioData) {
|
||||||
|
std::cout << " Audio data size: " << song.audioData->size() << " bytes" << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " Audio data size: null" << std::endl;
|
||||||
|
}
|
||||||
|
resultSong = song;
|
||||||
|
generationCompleted = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
QString errorMsg;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError,
|
||||||
|
[&errorMsg](const QString& err) {
|
||||||
|
std::cout << "\n Error: " << err.toStdString() << std::endl;
|
||||||
|
errorMsg = err;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::cout << "\n Starting low VRAM mode generation..." << std::endl;
|
||||||
|
|
||||||
|
// Request generation
|
||||||
|
bool result = worker.requestGeneration(song, templateJson);
|
||||||
|
ASSERT_TRUE(result);
|
||||||
|
|
||||||
|
// Use QEventLoop with timer for proper event processing
|
||||||
|
QEventLoop loop;
|
||||||
|
QTimer timeoutTimer;
|
||||||
|
|
||||||
|
timeoutTimer.setSingleShot(true);
|
||||||
|
timeoutTimer.start(300000); // 5 minute timeout
|
||||||
|
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||||||
|
|
||||||
|
loop.exec();
|
||||||
|
|
||||||
|
ASSERT_TRUE(generationCompleted);
|
||||||
|
ASSERT_TRUE(resultSong.audioData != nullptr);
|
||||||
|
ASSERT_TRUE(!resultSong.audioData->isEmpty());
|
||||||
|
|
||||||
|
std::cout << " Audio data size: " << resultSong.audioData->size() << " bytes" << std::endl;
|
||||||
|
ASSERT_TRUE(resultSong.audioData->size() > 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 13: Test normal mode keeps models loaded between generations
|
||||||
|
TEST(normalModeKeepsModelsLoaded)
|
||||||
|
{
|
||||||
|
ModelPaths paths = getModelPathsFromSettings();
|
||||||
|
|
||||||
|
// Skip if models don't exist
|
||||||
|
SKIP_IF(!modelsExist(paths));
|
||||||
|
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath);
|
||||||
|
// Normal mode is default (lowVramMode = false)
|
||||||
|
|
||||||
|
ASSERT_FALSE(worker.isLowVramMode());
|
||||||
|
|
||||||
|
QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})";
|
||||||
|
|
||||||
|
// Generate first song
|
||||||
|
bool firstGenerationCompleted = false;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||||
|
[&firstGenerationCompleted](const SongItem&) {
|
||||||
|
firstGenerationCompleted = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError,
|
||||||
|
[](const QString& err) {
|
||||||
|
std::cout << "\n Error: " << err.toStdString() << std::endl;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::cout << "\n Generating first song (normal mode)..." << std::endl;
|
||||||
|
|
||||||
|
SongItem song1("First song", "");
|
||||||
|
bool result = worker.requestGeneration(song1, templateJson);
|
||||||
|
ASSERT_TRUE(result);
|
||||||
|
|
||||||
|
QEventLoop loop;
|
||||||
|
QTimer timeoutTimer;
|
||||||
|
timeoutTimer.setSingleShot(true);
|
||||||
|
timeoutTimer.start(300000);
|
||||||
|
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||||||
|
|
||||||
|
loop.exec();
|
||||||
|
|
||||||
|
ASSERT_TRUE(firstGenerationCompleted);
|
||||||
|
std::cout << " First generation completed, models should still be loaded" << std::endl;
|
||||||
|
|
||||||
|
// Generate second song - in normal mode this should be faster since models are already loaded
|
||||||
|
bool secondGenerationCompleted = false;
|
||||||
|
SongItem secondResult;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||||
|
[&secondGenerationCompleted, &secondResult](const SongItem& song) {
|
||||||
|
secondGenerationCompleted = true;
|
||||||
|
secondResult = song;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::cout << " Generating second song (should use cached models)..." << std::endl;
|
||||||
|
|
||||||
|
SongItem song2("Second song", "");
|
||||||
|
result = worker.requestGeneration(song2, templateJson);
|
||||||
|
ASSERT_TRUE(result);
|
||||||
|
|
||||||
|
QEventLoop loop2;
|
||||||
|
QTimer timeoutTimer2;
|
||||||
|
timeoutTimer2.setSingleShot(true);
|
||||||
|
timeoutTimer2.start(300000);
|
||||||
|
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop2, &QEventLoop::quit);
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError, &loop2, &QEventLoop::quit);
|
||||||
|
QObject::connect(&timeoutTimer2, &QTimer::timeout, &loop2, &QEventLoop::quit);
|
||||||
|
|
||||||
|
loop2.exec();
|
||||||
|
|
||||||
|
ASSERT_TRUE(secondGenerationCompleted);
|
||||||
|
ASSERT_TRUE(secondResult.audioData != nullptr);
|
||||||
|
ASSERT_TRUE(!secondResult.audioData->isEmpty());
|
||||||
|
|
||||||
|
std::cout << " Second generation completed successfully" << std::endl;
|
||||||
|
std::cout << " Audio data size: " << secondResult.audioData->size() << " bytes" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 14: Test setLowVramMode toggle
|
||||||
|
TEST(lowVramModeToggle)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
|
||||||
|
// Default should be false (normal mode)
|
||||||
|
ASSERT_FALSE(worker.isLowVramMode());
|
||||||
|
|
||||||
|
// Enable low VRAM mode
|
||||||
|
worker.setLowVramMode(true);
|
||||||
|
ASSERT_TRUE(worker.isLowVramMode());
|
||||||
|
|
||||||
|
// Disable low VRAM mode
|
||||||
|
worker.setLowVramMode(false);
|
||||||
|
ASSERT_FALSE(worker.isLowVramMode());
|
||||||
|
|
||||||
|
// Toggle again
|
||||||
|
worker.setLowVramMode(true);
|
||||||
|
ASSERT_TRUE(worker.isLowVramMode());
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
int main(int argc, char *argv[])
|
||||||
{
|
{
|
||||||
QCoreApplication app(argc, argv);
|
QCoreApplication app(argc, argv);
|
||||||
|
|
@ -341,6 +512,9 @@ int main(int argc, char *argv[])
|
||||||
RUN_TEST(checkModelFiles);
|
RUN_TEST(checkModelFiles);
|
||||||
RUN_TEST(generateSong);
|
RUN_TEST(generateSong);
|
||||||
RUN_TEST(cancellation);
|
RUN_TEST(cancellation);
|
||||||
|
RUN_TEST(generateSongLowVram);
|
||||||
|
RUN_TEST(normalModeKeepsModelsLoaded);
|
||||||
|
RUN_TEST(lowVramModeToggle);
|
||||||
|
|
||||||
std::cout << "\n=== Results ===" << std::endl;
|
std::cout << "\n=== Results ===" << std::endl;
|
||||||
std::cout << "Passed: " << testsPassed << std::endl;
|
std::cout << "Passed: " << testsPassed << std::endl;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue