diff --git a/src/AceStepWorker.cpp b/src/AceStepWorker.cpp index c9c9e25..02ca47f 100644 --- a/src/AceStepWorker.cpp +++ b/src/AceStepWorker.cpp @@ -28,6 +28,10 @@ AceStepWorker::~AceStepWorker() void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath) { + // Check if paths actually changed + bool pathsChanged = (m_lmModelPath != lmPath || m_textEncoderPath != textEncoderPath || + m_ditPath != ditPath || m_vaePath != vaePath); + m_lmModelPath = lmPath; m_textEncoderPath = textEncoderPath; m_ditPath = ditPath; @@ -38,6 +42,17 @@ void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QStri m_textEncoderPathBytes = textEncoderPath.toUtf8(); m_ditPathBytes = ditPath.toUtf8(); m_vaePathBytes = vaePath.toUtf8(); + + // If paths changed and models are loaded, unload them so they'll be reloaded with new paths + if (pathsChanged && m_modelsLoaded.load()) + { + unloadModels(); + } +} + +void AceStepWorker::setLowVramMode(bool enabled) +{ + m_lowVramMode = enabled; } bool AceStepWorker::isGenerating(SongItem* song) @@ -124,139 +139,295 @@ bool AceStepWorker::checkCancel(void* data) void AceStepWorker::runGeneration() { - // Load models if needed - if (!loadModels()) - { - m_busy.store(false); - return; - } - // Convert SongItem to AceRequest AceRequest req = songToRequest(m_currentSong, m_requestTemplate); - - // Step 1: LM generates lyrics and audio codes - emit progressUpdate(30); - AceRequest lmOutput; request_init(&lmOutput); - int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, - nullptr, nullptr, - checkCancel, this, - LM_MODE_GENERATE); - - if (m_cancelRequested.load()) + if (m_lowVramMode) { - emit generationCanceled(m_currentSong); + // Low VRAM mode: load LM → run LM → unload LM → load Synth → run Synth → unload Synth + + // Step 1: Load LM and generate + emit progressUpdate(10); + + if (!loadLm()) + { + m_busy.store(false); + return; + } + + emit progressUpdate(30); + + int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, + nullptr, nullptr, + checkCancel, this, + LM_MODE_GENERATE); + + if (m_cancelRequested.load()) + { + unloadLm(); + emit generationCanceled(m_currentSong); + m_busy.store(false); + return; + } + + if (lmResult != 0) + { + unloadLm(); + emit generationError("LM generation failed or was canceled"); + m_busy.store(false); + return; + } + + // Update song with generated lyrics + m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); + + // Unload LM to free VRAM + unloadLm(); + + // Step 2: Load Synth and generate audio + emit progressUpdate(50); + + if (!loadSynth()) + { + m_busy.store(false); + return; + } + + emit progressUpdate(60); + + AceAudio outputAudio; + outputAudio.samples = nullptr; + outputAudio.n_samples = 0; + outputAudio.sample_rate = 48000; + + int synthResult = ace_synth_generate(m_synthContext, &lmOutput, + nullptr, 0, // no source audio + nullptr, 0, // no reference audio + 1, &outputAudio, + checkCancel, this); + + // Unload Synth to free VRAM + unloadSynth(); + + if (m_cancelRequested.load()) + { + emit generationCanceled(m_currentSong); + m_busy.store(false); + return; + } + + if (synthResult != 0) + { + emit generationError("Synthesis failed or was canceled"); + m_busy.store(false); + return; + } + + // Store audio in memory as WAV + auto audioData = std::make_shared(); + + // Simple WAV header + stereo float data + int numChannels = 2; + int bitsPerSample = 16; + int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8); + int blockAlign = numChannels * (bitsPerSample / 8); + int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8); + + // RIFF header + audioData->append("RIFF"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + audioData->append("WAVE"); + + // fmt chunk + audioData->append("fmt "); + int fmtSize = 16; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); + short audioFormat = 1; // PCM + audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); + short numCh = numChannels; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); + int sampleRate = outputAudio.sample_rate; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); + + // data chunk + audioData->append("data"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + + // Convert float samples to 16-bit and write + QVector interleaved(outputAudio.n_samples * numChannels); + for (int i = 0; i < outputAudio.n_samples; i++) + { + float left = outputAudio.samples[i]; + float right = outputAudio.samples[i + outputAudio.n_samples]; + // Clamp and convert to 16-bit + left = std::max(-1.0f, std::min(1.0f, left)); + right = std::max(-1.0f, std::min(1.0f, right)); + interleaved[i * 2] = static_cast(left * 32767.0f); + interleaved[i * 2 + 1] = static_cast(right * 32767.0f); + } + audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); + + // Free audio buffer + ace_audio_free(&outputAudio); + + // Store the JSON with all generated fields + m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); + m_currentSong.audioData = audioData; + + // Extract BPM if available + if (lmOutput.bpm > 0) + m_currentSong.bpm = lmOutput.bpm; + + // Extract key if available + if (!lmOutput.keyscale.empty()) + m_currentSong.key = QString::fromStdString(lmOutput.keyscale); + + emit progressUpdate(100); + emit songGenerated(m_currentSong); + m_busy.store(false); - return; } - - if (lmResult != 0) + else { - emit generationError("LM generation failed or was canceled"); + // Normal mode: load all models at start, unload at end + + // Load models if needed + if (!loadModels()) + { + m_busy.store(false); + return; + } + + // Step 1: LM generates lyrics and audio codes + emit progressUpdate(30); + + int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, + nullptr, nullptr, + checkCancel, this, + LM_MODE_GENERATE); + + if (m_cancelRequested.load()) + { + emit generationCanceled(m_currentSong); + unloadModels(); + m_busy.store(false); + return; + } + + if (lmResult != 0) + { + emit generationError("LM generation failed or was canceled"); + unloadModels(); + m_busy.store(false); + return; + } + + // Update song with generated lyrics + m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); + + // Step 2: Synth generates audio + emit progressUpdate(60); + + AceAudio outputAudio; + outputAudio.samples = nullptr; + outputAudio.n_samples = 0; + outputAudio.sample_rate = 48000; + + int synthResult = ace_synth_generate(m_synthContext, &lmOutput, + nullptr, 0, // no source audio + nullptr, 0, // no reference audio + 1, &outputAudio, + checkCancel, this); + + if (m_cancelRequested.load()) + { + emit generationCanceled(m_currentSong); + unloadModels(); + m_busy.store(false); + return; + } + + if (synthResult != 0) + { + emit generationError("Synthesis failed or was canceled"); + unloadModels(); + m_busy.store(false); + return; + } + + // Store audio in memory as WAV + auto audioData = std::make_shared(); + + // Simple WAV header + stereo float data + int numChannels = 2; + int bitsPerSample = 16; + int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8); + int blockAlign = numChannels * (bitsPerSample / 8); + int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8); + + // RIFF header + audioData->append("RIFF"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + audioData->append("WAVE"); + + // fmt chunk + audioData->append("fmt "); + int fmtSize = 16; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); + short audioFormat = 1; // PCM + audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); + short numCh = numChannels; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); + int sampleRate = outputAudio.sample_rate; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); + + // data chunk + audioData->append("data"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + + // Convert float samples to 16-bit and write + QVector interleaved(outputAudio.n_samples * numChannels); + for (int i = 0; i < outputAudio.n_samples; i++) + { + float left = outputAudio.samples[i]; + float right = outputAudio.samples[i + outputAudio.n_samples]; + // Clamp and convert to 16-bit + left = std::max(-1.0f, std::min(1.0f, left)); + right = std::max(-1.0f, std::min(1.0f, right)); + interleaved[i * 2] = static_cast(left * 32767.0f); + interleaved[i * 2 + 1] = static_cast(right * 32767.0f); + } + audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); + + // Free audio buffer + ace_audio_free(&outputAudio); + + // Store the JSON with all generated fields + m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); + m_currentSong.audioData = audioData; + + // Extract BPM if available + if (lmOutput.bpm > 0) + m_currentSong.bpm = lmOutput.bpm; + + // Extract key if available + if (!lmOutput.keyscale.empty()) + m_currentSong.key = QString::fromStdString(lmOutput.keyscale); + + emit progressUpdate(100); + emit songGenerated(m_currentSong); + + // Keep models loaded for next generation (normal mode) m_busy.store(false); - return; } - - // Update song with generated lyrics - m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); - - // Step 2: Synth generates audio - emit progressUpdate(60); - - AceAudio* audioOut = nullptr; - AceAudio outputAudio; - outputAudio.samples = nullptr; - outputAudio.n_samples = 0; - outputAudio.sample_rate = 48000; - - int synthResult = ace_synth_generate(m_synthContext, &lmOutput, - nullptr, 0, // no source audio - nullptr, 0, // no reference audio - 1, &outputAudio, - checkCancel, this); - - if (m_cancelRequested.load()) - { - emit generationCanceled(m_currentSong); - m_busy.store(false); - return; - } - - if (synthResult != 0) - { - emit generationError("Synthesis failed or was canceled"); - m_busy.store(false); - return; - } - - // Store audio in memory as WAV - auto audioData = std::make_shared(); - - // Simple WAV header + stereo float data - int numChannels = 2; - int bitsPerSample = 16; - int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8); - int blockAlign = numChannels * (bitsPerSample / 8); - int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8); - - // RIFF header - audioData->append("RIFF"); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); - audioData->append("WAVE"); - - // fmt chunk - audioData->append("fmt "); - int fmtSize = 16; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); - short audioFormat = 1; // PCM - audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); - short numCh = numChannels; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); - int sampleRate = outputAudio.sample_rate; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); - - // data chunk - audioData->append("data"); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); - - // Convert float samples to 16-bit and write - QVector interleaved(outputAudio.n_samples * numChannels); - for (int i = 0; i < outputAudio.n_samples; i++) - { - float left = outputAudio.samples[i]; - float right = outputAudio.samples[i + outputAudio.n_samples]; - // Clamp and convert to 16-bit - left = std::max(-1.0f, std::min(1.0f, left)); - right = std::max(-1.0f, std::min(1.0f, right)); - interleaved[i * 2] = static_cast(left * 32767.0f); - interleaved[i * 2 + 1] = static_cast(right * 32767.0f); - } - audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); - - // Free audio buffer - ace_audio_free(&outputAudio); - - // Store the JSON with all generated fields - m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); - m_currentSong.audioData = audioData; - - // Extract BPM if available - if (lmOutput.bpm > 0) - m_currentSong.bpm = lmOutput.bpm; - - // Extract key if available - if (!lmOutput.keyscale.empty()) - m_currentSong.key = QString::fromStdString(lmOutput.keyscale); - - emit progressUpdate(100); - emit songGenerated(m_currentSong); - - m_busy.store(false); } bool AceStepWorker::loadModels() @@ -314,6 +485,65 @@ void AceStepWorker::unloadModels() m_modelsLoaded.store(false); } +bool AceStepWorker::loadLm() +{ + if (m_lmContext) + return true; + + AceLmParams lmParams; + ace_lm_default_params(&lmParams); + lmParams.model_path = m_lmModelPathBytes.constData(); + lmParams.use_fsm = true; + lmParams.use_fa = true; + + m_lmContext = ace_lm_load(&lmParams); + if (!m_lmContext) + { + emit generationError("Failed to load LM model: " + m_lmModelPath); + return false; + } + return true; +} + +void AceStepWorker::unloadLm() +{ + if (m_lmContext) + { + ace_lm_free(m_lmContext); + m_lmContext = nullptr; + } +} + +bool AceStepWorker::loadSynth() +{ + if (m_synthContext) + return true; + + AceSynthParams synthParams; + ace_synth_default_params(&synthParams); + synthParams.text_encoder_path = m_textEncoderPathBytes.constData(); + synthParams.dit_path = m_ditPathBytes.constData(); + synthParams.vae_path = m_vaePathBytes.constData(); + synthParams.use_fa = true; + + m_synthContext = ace_synth_load(&synthParams); + if (!m_synthContext) + { + emit generationError("Failed to load synthesis models"); + return false; + } + return true; +} + +void AceStepWorker::unloadSynth() +{ + if (m_synthContext) + { + ace_synth_free(m_synthContext); + m_synthContext = nullptr; + } +} + AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& templateJson) { AceRequest req; diff --git a/src/AceStepWorker.h b/src/AceStepWorker.h index e1362aa..b6498eb 100644 --- a/src/AceStepWorker.h +++ b/src/AceStepWorker.h @@ -34,6 +34,10 @@ public: // Model paths - set these before first generation void setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath); + // Low VRAM mode: unload models between phases to save VRAM + void setLowVramMode(bool enabled); + bool isLowVramMode() const { return m_lowVramMode; } + // Request a new song generation bool requestGeneration(SongItem song, QString requestTemplate); @@ -54,6 +58,12 @@ private: bool loadModels(); void unloadModels(); + // Individual model load/unload for low VRAM mode + bool loadLm(); + void unloadLm(); + bool loadSynth(); + void unloadSynth(); + // Convert SongItem to AceRequest AceRequest songToRequest(const SongItem& song, const QString& templateJson); @@ -64,6 +74,7 @@ private: std::atomic m_busy{false}; std::atomic m_cancelRequested{false}; std::atomic m_modelsLoaded{false}; + bool m_lowVramMode = false; // Current request data SongItem m_currentSong; diff --git a/src/AdvancedSettingsDialog.cpp b/src/AdvancedSettingsDialog.cpp index a34dc4e..c081a52 100644 --- a/src/AdvancedSettingsDialog.cpp +++ b/src/AdvancedSettingsDialog.cpp @@ -50,6 +50,11 @@ QString AdvancedSettingsDialog::getVAEModelPath() const return ui->vaeModelEdit->text(); } +bool AdvancedSettingsDialog::getLowVramMode() const +{ + return ui->lowVramCheckBox->isChecked(); +} + void AdvancedSettingsDialog::setJsonTemplate(const QString &templateStr) { ui->jsonTemplateEdit->setPlainText(templateStr); @@ -80,6 +85,11 @@ void AdvancedSettingsDialog::setVAEModelPath(const QString &path) ui->vaeModelEdit->setText(path); } +void AdvancedSettingsDialog::setLowVramMode(bool enabled) +{ + ui->lowVramCheckBox->setChecked(enabled); +} + void AdvancedSettingsDialog::on_aceStepBrowseButton_clicked() { QString dir = QFileDialog::getExistingDirectory(this, "Select AceStep Build Directory", ui->aceStepPathEdit->text()); diff --git a/src/AdvancedSettingsDialog.h b/src/AdvancedSettingsDialog.h index 8db247c..3388be8 100644 --- a/src/AdvancedSettingsDialog.h +++ b/src/AdvancedSettingsDialog.h @@ -29,6 +29,7 @@ public: QString getTextEncoderModelPath() const; QString getDiTModelPath() const; QString getVAEModelPath() const; + bool getLowVramMode() const; // Setters for settings void setJsonTemplate(const QString &templateStr); @@ -37,6 +38,7 @@ public: void setTextEncoderModelPath(const QString &path); void setDiTModelPath(const QString &path); void setVAEModelPath(const QString &path); + void setLowVramMode(bool enabled); private slots: void on_aceStepBrowseButton_clicked(); diff --git a/src/AdvancedSettingsDialog.ui b/src/AdvancedSettingsDialog.ui index e9fadf1..987c807 100644 --- a/src/AdvancedSettingsDialog.ui +++ b/src/AdvancedSettingsDialog.ui @@ -19,6 +19,43 @@ 0 + + + Performance + + + + + + Low VRAM Mode + + + + + + + Unload models between generation phases to save VRAM. Slower but uses less memory. + + + true + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + JSON Template diff --git a/src/MainWindow.cpp b/src/MainWindow.cpp index a8b64f7..23ede18 100644 --- a/src/MainWindow.cpp +++ b/src/MainWindow.cpp @@ -155,6 +155,10 @@ void MainWindow::loadSettings() appDir + "/acestep.cpp/models/Qwen3-Embedding-0.6B-BF16.gguf").toString(); ditModelPath = settings.value("ditModelPath", appDir + "/acestep.cpp/models/acestep-v15-turbo-Q8_0.gguf").toString(); vaeModelPath = settings.value("vaeModelPath", appDir + "/acestep.cpp/models/vae-BF16.gguf").toString(); + + // Load low VRAM mode + bool lowVram = settings.value("lowVramMode", false).toBool(); + aceStep->setLowVramMode(lowVram); } void MainWindow::saveSettings() @@ -174,6 +178,9 @@ void MainWindow::saveSettings() settings.setValue("ditModelPath", ditModelPath); settings.setValue("vaeModelPath", vaeModelPath); + // Save low VRAM mode + settings.setValue("lowVramMode", aceStep->isLowVramMode()); + settings.setValue("firstRun", false); } @@ -374,6 +381,7 @@ void MainWindow::on_advancedSettingsButton_clicked() dialog.setTextEncoderModelPath(textEncoderModelPath); dialog.setDiTModelPath(ditModelPath); dialog.setVAEModelPath(vaeModelPath); + dialog.setLowVramMode(aceStep->isLowVramMode()); if (dialog.exec() == QDialog::Accepted) { @@ -397,6 +405,9 @@ void MainWindow::on_advancedSettingsButton_clicked() // Update model paths for acestep.cpp aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath); + // Update low VRAM mode + aceStep->setLowVramMode(dialog.getLowVramMode()); + saveSettings(); } } diff --git a/tests/test_acestep_worker.cpp b/tests/test_acestep_worker.cpp index 2d32416..60a5ffa 100644 --- a/tests/test_acestep_worker.cpp +++ b/tests/test_acestep_worker.cpp @@ -324,6 +324,177 @@ TEST(cancellation) ASSERT_TRUE(true); } +// Test 12: Test low VRAM mode generation +TEST(generateSongLowVram) +{ + ModelPaths paths = getModelPathsFromSettings(); + + // Skip if models don't exist + SKIP_IF(!modelsExist(paths)); + + AceStepWorker worker; + worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath); + worker.setLowVramMode(true); + + ASSERT_TRUE(worker.isLowVramMode()); + + SongItem song("Chill electronic music", ""); + + QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})"; + + // Track generation result + bool generationCompleted = false; + SongItem resultSong; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&generationCompleted, &resultSong](const SongItem& song) { + std::cout << "\n Low VRAM mode: Song generated successfully!" << std::endl; + std::cout << " Caption: " << song.caption.toStdString() << std::endl; + if (song.audioData) { + std::cout << " Audio data size: " << song.audioData->size() << " bytes" << std::endl; + } else { + std::cout << " Audio data size: null" << std::endl; + } + resultSong = song; + generationCompleted = true; + }); + + QString errorMsg; + QObject::connect(&worker, &AceStepWorker::generationError, + [&errorMsg](const QString& err) { + std::cout << "\n Error: " << err.toStdString() << std::endl; + errorMsg = err; + }); + + std::cout << "\n Starting low VRAM mode generation..." << std::endl; + + // Request generation + bool result = worker.requestGeneration(song, templateJson); + ASSERT_TRUE(result); + + // Use QEventLoop with timer for proper event processing + QEventLoop loop; + QTimer timeoutTimer; + + timeoutTimer.setSingleShot(true); + timeoutTimer.start(300000); // 5 minute timeout + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit); + QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit); + + loop.exec(); + + ASSERT_TRUE(generationCompleted); + ASSERT_TRUE(resultSong.audioData != nullptr); + ASSERT_TRUE(!resultSong.audioData->isEmpty()); + + std::cout << " Audio data size: " << resultSong.audioData->size() << " bytes" << std::endl; + ASSERT_TRUE(resultSong.audioData->size() > 1000); +} + +// Test 13: Test normal mode keeps models loaded between generations +TEST(normalModeKeepsModelsLoaded) +{ + ModelPaths paths = getModelPathsFromSettings(); + + // Skip if models don't exist + SKIP_IF(!modelsExist(paths)); + + AceStepWorker worker; + worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath); + // Normal mode is default (lowVramMode = false) + + ASSERT_FALSE(worker.isLowVramMode()); + + QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})"; + + // Generate first song + bool firstGenerationCompleted = false; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&firstGenerationCompleted](const SongItem&) { + firstGenerationCompleted = true; + }); + + QObject::connect(&worker, &AceStepWorker::generationError, + [](const QString& err) { + std::cout << "\n Error: " << err.toStdString() << std::endl; + }); + + std::cout << "\n Generating first song (normal mode)..." << std::endl; + + SongItem song1("First song", ""); + bool result = worker.requestGeneration(song1, templateJson); + ASSERT_TRUE(result); + + QEventLoop loop; + QTimer timeoutTimer; + timeoutTimer.setSingleShot(true); + timeoutTimer.start(300000); + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit); + QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit); + + loop.exec(); + + ASSERT_TRUE(firstGenerationCompleted); + std::cout << " First generation completed, models should still be loaded" << std::endl; + + // Generate second song - in normal mode this should be faster since models are already loaded + bool secondGenerationCompleted = false; + SongItem secondResult; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&secondGenerationCompleted, &secondResult](const SongItem& song) { + secondGenerationCompleted = true; + secondResult = song; + }); + + std::cout << " Generating second song (should use cached models)..." << std::endl; + + SongItem song2("Second song", ""); + result = worker.requestGeneration(song2, templateJson); + ASSERT_TRUE(result); + + QEventLoop loop2; + QTimer timeoutTimer2; + timeoutTimer2.setSingleShot(true); + timeoutTimer2.start(300000); + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop2, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop2, &QEventLoop::quit); + QObject::connect(&timeoutTimer2, &QTimer::timeout, &loop2, &QEventLoop::quit); + + loop2.exec(); + + ASSERT_TRUE(secondGenerationCompleted); + ASSERT_TRUE(secondResult.audioData != nullptr); + ASSERT_TRUE(!secondResult.audioData->isEmpty()); + + std::cout << " Second generation completed successfully" << std::endl; + std::cout << " Audio data size: " << secondResult.audioData->size() << " bytes" << std::endl; +} + +// Test 14: Test setLowVramMode toggle +TEST(lowVramModeToggle) +{ + AceStepWorker worker; + + // Default should be false (normal mode) + ASSERT_FALSE(worker.isLowVramMode()); + + // Enable low VRAM mode + worker.setLowVramMode(true); + ASSERT_TRUE(worker.isLowVramMode()); + + // Disable low VRAM mode + worker.setLowVramMode(false); + ASSERT_FALSE(worker.isLowVramMode()); + + // Toggle again + worker.setLowVramMode(true); + ASSERT_TRUE(worker.isLowVramMode()); +} + int main(int argc, char *argv[]) { QCoreApplication app(argc, argv); @@ -341,6 +512,9 @@ int main(int argc, char *argv[]) RUN_TEST(checkModelFiles); RUN_TEST(generateSong); RUN_TEST(cancellation); + RUN_TEST(generateSongLowVram); + RUN_TEST(normalModeKeepsModelsLoaded); + RUN_TEST(lowVramModeToggle); std::cout << "\n=== Results ===" << std::endl; std::cout << "Passed: " << testsPassed << std::endl;