From 55be24b36fb3c19b3d07648e40a37bf92d527022 Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm Date: Wed, 15 Apr 2026 11:52:38 +0200 Subject: [PATCH] Refactor generation pipeline to reduce code duplication --- src/AceStepWorker.cpp | 425 +++++++++++++----------------------------- src/AceStepWorker.h | 8 +- 2 files changed, 130 insertions(+), 303 deletions(-) diff --git a/src/AceStepWorker.cpp b/src/AceStepWorker.cpp index b401de8..731d021 100644 --- a/src/AceStepWorker.cpp +++ b/src/AceStepWorker.cpp @@ -8,11 +8,9 @@ #include #include #include -#include // acestep.cpp headers #include "pipeline-lm.h" -#include "pipeline-synth.h" #include "request.h" AceStepWorker::AceStepWorker(QObject* parent) @@ -142,6 +140,56 @@ bool AceStepWorker::checkCancel(void* data) return worker->m_cancelRequested.load(); } +std::shared_ptr AceStepWorker::convertToWav(const AceAudio& audio) +{ + auto audioData = std::make_shared(); + + // Simple WAV header + stereo float data + int numChannels = 2; + int bitsPerSample = 16; + int byteRate = audio.sample_rate * numChannels * (bitsPerSample / 8); + int blockAlign = numChannels * (bitsPerSample / 8); + int dataSize = audio.n_samples * numChannels * (bitsPerSample / 8); + + // RIFF header + audioData->append("RIFF"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + audioData->append("WAVE"); + + // fmt chunk + audioData->append("fmt "); + int fmtSize = 16; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); + short audioFormat = 1; // PCM + audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); + short numCh = numChannels; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); + int sampleRate = audio.sample_rate; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); + + // data chunk + audioData->append("data"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + + // Convert float samples to 16-bit and write + QVector interleaved(audio.n_samples * numChannels); + for (int i = 0; i < audio.n_samples; i++) + { + float left = audio.samples[i]; + float right = audio.samples[i + audio.n_samples]; + // Clamp and convert to 16-bit + left = std::max(-1.0f, std::min(1.0f, left)); + right = std::max(-1.0f, std::min(1.0f, right)); + interleaved[i * 2] = static_cast(left * 32767.0f); + interleaved[i * 2 + 1] = static_cast(right * 32767.0f); + } + audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); + return audioData; +} + void AceStepWorker::runGeneration() { // Convert SongItem to AceRequest @@ -149,328 +197,109 @@ void AceStepWorker::runGeneration() AceRequest lmOutput; request_init(&lmOutput); - if (m_lowVramMode) + emit progressUpdate(10); + + if (!loadLm()) { - // Low VRAM mode: load LM → run LM → unload LM → load Synth → run Synth → unload Synth + m_busy.store(false); + return; + } - // Step 1: Load LM and generate - emit progressUpdate(10); + emit progressUpdate(30); - if (!loadLm()) - { - m_busy.store(false); - return; - } + int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, + nullptr, nullptr, + checkCancel, this, + LM_MODE_GENERATE); - emit progressUpdate(30); + if (m_cancelRequested.load()) + { + if(m_lowVramMode) + unloadModels(); + emit generationCanceled(m_currentSong); + m_busy.store(false); + return; + } - int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, - nullptr, nullptr, - checkCancel, this, - LM_MODE_GENERATE); + if (lmResult != 0) + { + if(m_lowVramMode) + unloadModels(); + emit generationError("LM generation failed or was canceled"); + m_busy.store(false); + return; + } - if (m_cancelRequested.load()) - { - unloadLm(); - emit generationCanceled(m_currentSong); - m_busy.store(false); - return; - } + m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); - if (lmResult != 0) - { - unloadLm(); - emit generationError("LM generation failed or was canceled"); - m_busy.store(false); - return; - } - - // Update song with generated lyrics - m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); - - // Unload LM to free VRAM + if(m_lowVramMode) unloadLm(); - // Step 2: Load Synth and generate audio - emit progressUpdate(50); + emit progressUpdate(50); - if (!loadSynth()) - { - m_busy.store(false); - return; - } + if (!loadSynth()) + { + m_busy.store(false); + return; + } - emit progressUpdate(60); + emit progressUpdate(60); - AceAudio outputAudio; - outputAudio.samples = nullptr; - outputAudio.n_samples = 0; - outputAudio.sample_rate = 48000; + AceAudio outputAudio; + outputAudio.samples = nullptr; + outputAudio.n_samples = 0; + outputAudio.sample_rate = 48000; - int synthResult = ace_synth_generate(m_synthContext, &lmOutput, - nullptr, 0, // no source audio - nullptr, 0, // no reference audio - 1, &outputAudio, - checkCancel, this); + int synthResult = ace_synth_generate(m_synthContext, &lmOutput, + nullptr, 0, // no source audio + nullptr, 0, // no reference audio + 1, &outputAudio, + checkCancel, this); - // Unload Synth to free VRAM + if(m_lowVramMode) unloadSynth(); - if (m_cancelRequested.load()) - { - emit generationCanceled(m_currentSong); - m_busy.store(false); - return; - } - - if (synthResult != 0) - { - emit generationError("Synthesis failed or was canceled"); - m_busy.store(false); - return; - } - - // Store audio in memory as WAV - auto audioData = std::make_shared(); - - // Simple WAV header + stereo float data - int numChannels = 2; - int bitsPerSample = 16; - int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8); - int blockAlign = numChannels * (bitsPerSample / 8); - int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8); - - // RIFF header - audioData->append("RIFF"); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); - audioData->append("WAVE"); - - // fmt chunk - audioData->append("fmt "); - int fmtSize = 16; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); - short audioFormat = 1; // PCM - audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); - short numCh = numChannels; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); - int sampleRate = outputAudio.sample_rate; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); - - // data chunk - audioData->append("data"); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); - - // Convert float samples to 16-bit and write - QVector interleaved(outputAudio.n_samples * numChannels); - for (int i = 0; i < outputAudio.n_samples; i++) - { - float left = outputAudio.samples[i]; - float right = outputAudio.samples[i + outputAudio.n_samples]; - // Clamp and convert to 16-bit - left = std::max(-1.0f, std::min(1.0f, left)); - right = std::max(-1.0f, std::min(1.0f, right)); - interleaved[i * 2] = static_cast(left * 32767.0f); - interleaved[i * 2 + 1] = static_cast(right * 32767.0f); - } - audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); - - // Free audio buffer - ace_audio_free(&outputAudio); - - // Store the JSON with all generated fields - m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); - m_currentSong.audioData = audioData; - - // Extract BPM if available - if (lmOutput.bpm > 0) - m_currentSong.bpm = lmOutput.bpm; - - // Extract key if available - if (!lmOutput.keyscale.empty()) - m_currentSong.key = QString::fromStdString(lmOutput.keyscale); - - emit progressUpdate(100); - emit songGenerated(m_currentSong); - - m_busy.store(false); - } - else + if (m_cancelRequested.load()) { - // Normal mode: load all models at start, unload at end - - // Load models if needed - if (!loadModels()) - { - m_busy.store(false); - return; - } - - // Step 1: LM generates lyrics and audio codes - emit progressUpdate(30); - - int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, - nullptr, nullptr, - checkCancel, this, - LM_MODE_GENERATE); - - if (m_cancelRequested.load()) - { - emit generationCanceled(m_currentSong); - m_busy.store(false); - return; - } - - if (lmResult != 0) - { - emit generationError("LM generation failed"); - unloadModels(); - m_busy.store(false); - return; - } - - // Update song with generated lyrics - m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); - - // Step 2: Synth generates audio - emit progressUpdate(60); - - AceAudio outputAudio; - outputAudio.samples = nullptr; - outputAudio.n_samples = 0; - outputAudio.sample_rate = 48000; - - int synthResult = ace_synth_generate(m_synthContext, &lmOutput, - nullptr, 0, // no source audio - nullptr, 0, // no reference audio - 1, &outputAudio, - checkCancel, this); - - if (m_cancelRequested.load()) - { - emit generationCanceled(m_currentSong); - unloadModels(); - m_busy.store(false); - return; - } - - if (synthResult != 0) - { - emit generationError("Synthesis failed or was canceled"); - unloadModels(); - m_busy.store(false); - return; - } - - // Store audio in memory as WAV - auto audioData = std::make_shared(); - - // Simple WAV header + stereo float data - int numChannels = 2; - int bitsPerSample = 16; - int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8); - int blockAlign = numChannels * (bitsPerSample / 8); - int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8); - - // RIFF header - audioData->append("RIFF"); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); - audioData->append("WAVE"); - - // fmt chunk - audioData->append("fmt "); - int fmtSize = 16; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); - short audioFormat = 1; // PCM - audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); - short numCh = numChannels; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); - int sampleRate = outputAudio.sample_rate; - audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); - - // data chunk - audioData->append("data"); - audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); - - // Convert float samples to 16-bit and write - QVector interleaved(outputAudio.n_samples * numChannels); - for (int i = 0; i < outputAudio.n_samples; i++) - { - float left = outputAudio.samples[i]; - float right = outputAudio.samples[i + outputAudio.n_samples]; - // Clamp and convert to 16-bit - left = std::max(-1.0f, std::min(1.0f, left)); - right = std::max(-1.0f, std::min(1.0f, right)); - interleaved[i * 2] = static_cast(left * 32767.0f); - interleaved[i * 2 + 1] = static_cast(right * 32767.0f); - } - audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); - - // Free audio buffer - ace_audio_free(&outputAudio); - - // Store the JSON with all generated fields - m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); - m_currentSong.audioData = audioData; - - // Extract BPM if available - if (lmOutput.bpm > 0) - m_currentSong.bpm = lmOutput.bpm; - - // Extract key if available - if (!lmOutput.keyscale.empty()) - m_currentSong.key = QString::fromStdString(lmOutput.keyscale); - - emit progressUpdate(100); - emit songGenerated(m_currentSong); - - // Keep models loaded for next generation (normal mode) + emit generationCanceled(m_currentSong); m_busy.store(false); + return; } + + if (synthResult != 0) + { + emit generationError("Synthesis failed or was canceled"); + m_busy.store(false); + return; + } + + std::shared_ptr audioData = convertToWav(outputAudio); + ace_audio_free(&outputAudio); + + m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); + m_currentSong.audioData = audioData; + + if (lmOutput.bpm > 0) + m_currentSong.bpm = lmOutput.bpm; + + if (!lmOutput.keyscale.empty()) + m_currentSong.key = QString::fromStdString(lmOutput.keyscale); + + emit progressUpdate(100); + emit songGenerated(m_currentSong); + + m_busy.store(false); } bool AceStepWorker::loadModels() { - if (m_modelsLoaded.load()) - return true; - - // Load LM - AceLmParams lmParams; - ace_lm_default_params(&lmParams); - lmParams.model_path = m_lmModelPathBytes.constData(); - lmParams.use_fsm = true; - lmParams.use_fa = m_flashAttention; - - m_lmContext = ace_lm_load(&lmParams); - if (!m_lmContext) - { - emit generationError("Failed to load LM model: " + m_lmModelPath); + bool ret = loadSynth(); + if(!ret) return false; - } - // Load Synth - AceSynthParams synthParams; - ace_synth_default_params(&synthParams); - synthParams.text_encoder_path = m_textEncoderPathBytes.constData(); - synthParams.dit_path = m_ditPathBytes.constData(); - synthParams.vae_path = m_vaePathBytes.constData(); - synthParams.use_fa = m_flashAttention; - - m_synthContext = ace_synth_load(&synthParams); - if (!m_synthContext) - { - emit generationError("Failed to load synthesis models"); - ace_lm_free(m_lmContext); - m_lmContext = nullptr; + ret = loadLm(); + if(!ret) return false; - } - - m_modelsLoaded.store(true); return true; } diff --git a/src/AceStepWorker.h b/src/AceStepWorker.h index ba856fa..ec8c8a6 100644 --- a/src/AceStepWorker.h +++ b/src/AceStepWorker.h @@ -13,8 +13,7 @@ #include #include "SongItem.h" - -// acestep.cpp headers +#include "pipeline-synth.h" #include "request.h" struct AceLm; @@ -74,6 +73,8 @@ private: // Convert AceRequest back to SongItem SongItem requestToSong(const AceRequest& req, const QString& json); + static std::shared_ptr convertToWav(const AceAudio& audio); + // Generation state std::atomic m_busy{false}; std::atomic m_cancelRequested{false}; @@ -96,13 +97,10 @@ private: AceLm* m_lmContext = nullptr; AceSynth* m_synthContext = nullptr; - // Cached model paths as byte arrays (to avoid dangling pointers) QByteArray m_lmModelPathBytes; QByteArray m_textEncoderPathBytes; QByteArray m_ditPathBytes; QByteArray m_vaePathBytes; - - const QString m_tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation); }; #endif // ACESTEPWORKER_H \ No newline at end of file