Refactor generation pipeline to reduce code duplication

This commit is contained in:
Carl Philipp Klemm 2026-04-15 11:52:38 +02:00
parent be21c1f2bd
commit 55be24b36f
2 changed files with 130 additions and 303 deletions

View file

@ -8,11 +8,9 @@
#include <QDir> #include <QDir>
#include <QDebug> #include <QDebug>
#include <QRandomGenerator> #include <QRandomGenerator>
#include <cstring>
// acestep.cpp headers // acestep.cpp headers
#include "pipeline-lm.h" #include "pipeline-lm.h"
#include "pipeline-synth.h"
#include "request.h" #include "request.h"
AceStepWorker::AceStepWorker(QObject* parent) AceStepWorker::AceStepWorker(QObject* parent)
@ -142,6 +140,56 @@ bool AceStepWorker::checkCancel(void* data)
return worker->m_cancelRequested.load(); return worker->m_cancelRequested.load();
} }
std::shared_ptr<QByteArray> AceStepWorker::convertToWav(const AceAudio& audio)
{
auto audioData = std::make_shared<QByteArray>();
// Simple WAV header + stereo float data
int numChannels = 2;
int bitsPerSample = 16;
int byteRate = audio.sample_rate * numChannels * (bitsPerSample / 8);
int blockAlign = numChannels * (bitsPerSample / 8);
int dataSize = audio.n_samples * numChannels * (bitsPerSample / 8);
// RIFF header
audioData->append("RIFF");
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
audioData->append("WAVE");
// fmt chunk
audioData->append("fmt ");
int fmtSize = 16;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&fmtSize), 4));
short audioFormat = 1; // PCM
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&audioFormat), 2));
short numCh = numChannels;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&numCh), 2));
int sampleRate = audio.sample_rate;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&sampleRate), 4));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&byteRate), 4));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&blockAlign), 2));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&bitsPerSample), 2));
// data chunk
audioData->append("data");
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
// Convert float samples to 16-bit and write
QVector<short> interleaved(audio.n_samples * numChannels);
for (int i = 0; i < audio.n_samples; i++)
{
float left = audio.samples[i];
float right = audio.samples[i + audio.n_samples];
// Clamp and convert to 16-bit
left = std::max(-1.0f, std::min(1.0f, left));
right = std::max(-1.0f, std::min(1.0f, right));
interleaved[i * 2] = static_cast<short>(left * 32767.0f);
interleaved[i * 2 + 1] = static_cast<short>(right * 32767.0f);
}
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(interleaved.data()), dataSize));
return audioData;
}
void AceStepWorker::runGeneration() void AceStepWorker::runGeneration()
{ {
// Convert SongItem to AceRequest // Convert SongItem to AceRequest
@ -149,11 +197,6 @@ void AceStepWorker::runGeneration()
AceRequest lmOutput; AceRequest lmOutput;
request_init(&lmOutput); request_init(&lmOutput);
if (m_lowVramMode)
{
// Low VRAM mode: load LM → run LM → unload LM → load Synth → run Synth → unload Synth
// Step 1: Load LM and generate
emit progressUpdate(10); emit progressUpdate(10);
if (!loadLm()) if (!loadLm())
@ -171,7 +214,8 @@ void AceStepWorker::runGeneration()
if (m_cancelRequested.load()) if (m_cancelRequested.load())
{ {
unloadLm(); if(m_lowVramMode)
unloadModels();
emit generationCanceled(m_currentSong); emit generationCanceled(m_currentSong);
m_busy.store(false); m_busy.store(false);
return; return;
@ -179,19 +223,18 @@ void AceStepWorker::runGeneration()
if (lmResult != 0) if (lmResult != 0)
{ {
unloadLm(); if(m_lowVramMode)
unloadModels();
emit generationError("LM generation failed or was canceled"); emit generationError("LM generation failed or was canceled");
m_busy.store(false); m_busy.store(false);
return; return;
} }
// Update song with generated lyrics
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
// Unload LM to free VRAM if(m_lowVramMode)
unloadLm(); unloadLm();
// Step 2: Load Synth and generate audio
emit progressUpdate(50); emit progressUpdate(50);
if (!loadSynth()) if (!loadSynth())
@ -213,7 +256,7 @@ void AceStepWorker::runGeneration()
1, &outputAudio, 1, &outputAudio,
checkCancel, this); checkCancel, this);
// Unload Synth to free VRAM if(m_lowVramMode)
unloadSynth(); unloadSynth();
if (m_cancelRequested.load()) if (m_cancelRequested.load())
@ -230,65 +273,15 @@ void AceStepWorker::runGeneration()
return; return;
} }
// Store audio in memory as WAV std::shared_ptr<QByteArray> audioData = convertToWav(outputAudio);
auto audioData = std::make_shared<QByteArray>();
// Simple WAV header + stereo float data
int numChannels = 2;
int bitsPerSample = 16;
int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8);
int blockAlign = numChannels * (bitsPerSample / 8);
int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8);
// RIFF header
audioData->append("RIFF");
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
audioData->append("WAVE");
// fmt chunk
audioData->append("fmt ");
int fmtSize = 16;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&fmtSize), 4));
short audioFormat = 1; // PCM
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&audioFormat), 2));
short numCh = numChannels;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&numCh), 2));
int sampleRate = outputAudio.sample_rate;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&sampleRate), 4));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&byteRate), 4));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&blockAlign), 2));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&bitsPerSample), 2));
// data chunk
audioData->append("data");
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
// Convert float samples to 16-bit and write
QVector<short> interleaved(outputAudio.n_samples * numChannels);
for (int i = 0; i < outputAudio.n_samples; i++)
{
float left = outputAudio.samples[i];
float right = outputAudio.samples[i + outputAudio.n_samples];
// Clamp and convert to 16-bit
left = std::max(-1.0f, std::min(1.0f, left));
right = std::max(-1.0f, std::min(1.0f, right));
interleaved[i * 2] = static_cast<short>(left * 32767.0f);
interleaved[i * 2 + 1] = static_cast<short>(right * 32767.0f);
}
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(interleaved.data()), dataSize));
// Free audio buffer
ace_audio_free(&outputAudio); ace_audio_free(&outputAudio);
// Store the JSON with all generated fields
m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true));
m_currentSong.audioData = audioData; m_currentSong.audioData = audioData;
// Extract BPM if available
if (lmOutput.bpm > 0) if (lmOutput.bpm > 0)
m_currentSong.bpm = lmOutput.bpm; m_currentSong.bpm = lmOutput.bpm;
// Extract key if available
if (!lmOutput.keyscale.empty()) if (!lmOutput.keyscale.empty())
m_currentSong.key = QString::fromStdString(lmOutput.keyscale); m_currentSong.key = QString::fromStdString(lmOutput.keyscale);
@ -296,181 +289,17 @@ void AceStepWorker::runGeneration()
emit songGenerated(m_currentSong); emit songGenerated(m_currentSong);
m_busy.store(false); m_busy.store(false);
}
else
{
// Normal mode: load all models at start, unload at end
// Load models if needed
if (!loadModels())
{
m_busy.store(false);
return;
}
// Step 1: LM generates lyrics and audio codes
emit progressUpdate(30);
int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput,
nullptr, nullptr,
checkCancel, this,
LM_MODE_GENERATE);
if (m_cancelRequested.load())
{
emit generationCanceled(m_currentSong);
m_busy.store(false);
return;
}
if (lmResult != 0)
{
emit generationError("LM generation failed");
unloadModels();
m_busy.store(false);
return;
}
// Update song with generated lyrics
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
// Step 2: Synth generates audio
emit progressUpdate(60);
AceAudio outputAudio;
outputAudio.samples = nullptr;
outputAudio.n_samples = 0;
outputAudio.sample_rate = 48000;
int synthResult = ace_synth_generate(m_synthContext, &lmOutput,
nullptr, 0, // no source audio
nullptr, 0, // no reference audio
1, &outputAudio,
checkCancel, this);
if (m_cancelRequested.load())
{
emit generationCanceled(m_currentSong);
unloadModels();
m_busy.store(false);
return;
}
if (synthResult != 0)
{
emit generationError("Synthesis failed or was canceled");
unloadModels();
m_busy.store(false);
return;
}
// Store audio in memory as WAV
auto audioData = std::make_shared<QByteArray>();
// Simple WAV header + stereo float data
int numChannels = 2;
int bitsPerSample = 16;
int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8);
int blockAlign = numChannels * (bitsPerSample / 8);
int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8);
// RIFF header
audioData->append("RIFF");
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
audioData->append("WAVE");
// fmt chunk
audioData->append("fmt ");
int fmtSize = 16;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&fmtSize), 4));
short audioFormat = 1; // PCM
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&audioFormat), 2));
short numCh = numChannels;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&numCh), 2));
int sampleRate = outputAudio.sample_rate;
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&sampleRate), 4));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&byteRate), 4));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&blockAlign), 2));
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&bitsPerSample), 2));
// data chunk
audioData->append("data");
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(&dataSize), 4));
// Convert float samples to 16-bit and write
QVector<short> interleaved(outputAudio.n_samples * numChannels);
for (int i = 0; i < outputAudio.n_samples; i++)
{
float left = outputAudio.samples[i];
float right = outputAudio.samples[i + outputAudio.n_samples];
// Clamp and convert to 16-bit
left = std::max(-1.0f, std::min(1.0f, left));
right = std::max(-1.0f, std::min(1.0f, right));
interleaved[i * 2] = static_cast<short>(left * 32767.0f);
interleaved[i * 2 + 1] = static_cast<short>(right * 32767.0f);
}
audioData->append(QByteArray::fromRawData(reinterpret_cast<const char*>(interleaved.data()), dataSize));
// Free audio buffer
ace_audio_free(&outputAudio);
// Store the JSON with all generated fields
m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true));
m_currentSong.audioData = audioData;
// Extract BPM if available
if (lmOutput.bpm > 0)
m_currentSong.bpm = lmOutput.bpm;
// Extract key if available
if (!lmOutput.keyscale.empty())
m_currentSong.key = QString::fromStdString(lmOutput.keyscale);
emit progressUpdate(100);
emit songGenerated(m_currentSong);
// Keep models loaded for next generation (normal mode)
m_busy.store(false);
}
} }
bool AceStepWorker::loadModels() bool AceStepWorker::loadModels()
{ {
if (m_modelsLoaded.load()) bool ret = loadSynth();
return true; if(!ret)
// Load LM
AceLmParams lmParams;
ace_lm_default_params(&lmParams);
lmParams.model_path = m_lmModelPathBytes.constData();
lmParams.use_fsm = true;
lmParams.use_fa = m_flashAttention;
m_lmContext = ace_lm_load(&lmParams);
if (!m_lmContext)
{
emit generationError("Failed to load LM model: " + m_lmModelPath);
return false; return false;
}
// Load Synth ret = loadLm();
AceSynthParams synthParams; if(!ret)
ace_synth_default_params(&synthParams);
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
synthParams.dit_path = m_ditPathBytes.constData();
synthParams.vae_path = m_vaePathBytes.constData();
synthParams.use_fa = m_flashAttention;
m_synthContext = ace_synth_load(&synthParams);
if (!m_synthContext)
{
emit generationError("Failed to load synthesis models");
ace_lm_free(m_lmContext);
m_lmContext = nullptr;
return false; return false;
}
m_modelsLoaded.store(true);
return true; return true;
} }

View file

@ -13,8 +13,7 @@
#include <atomic> #include <atomic>
#include "SongItem.h" #include "SongItem.h"
#include "pipeline-synth.h"
// acestep.cpp headers
#include "request.h" #include "request.h"
struct AceLm; struct AceLm;
@ -74,6 +73,8 @@ private:
// Convert AceRequest back to SongItem // Convert AceRequest back to SongItem
SongItem requestToSong(const AceRequest& req, const QString& json); SongItem requestToSong(const AceRequest& req, const QString& json);
static std::shared_ptr<QByteArray> convertToWav(const AceAudio& audio);
// Generation state // Generation state
std::atomic<bool> m_busy{false}; std::atomic<bool> m_busy{false};
std::atomic<bool> m_cancelRequested{false}; std::atomic<bool> m_cancelRequested{false};
@ -96,13 +97,10 @@ private:
AceLm* m_lmContext = nullptr; AceLm* m_lmContext = nullptr;
AceSynth* m_synthContext = nullptr; AceSynth* m_synthContext = nullptr;
// Cached model paths as byte arrays (to avoid dangling pointers)
QByteArray m_lmModelPathBytes; QByteArray m_lmModelPathBytes;
QByteArray m_textEncoderPathBytes; QByteArray m_textEncoderPathBytes;
QByteArray m_ditPathBytes; QByteArray m_ditPathBytes;
QByteArray m_vaePathBytes; QByteArray m_vaePathBytes;
const QString m_tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation);
}; };
#endif // ACESTEPWORKER_H #endif // ACESTEPWORKER_H