Use the acestop api directly instead of calling binaries
This commit is contained in:
parent
6be80f1d5c
commit
de7207f07e
8 changed files with 789 additions and 210 deletions
|
|
@ -5,207 +5,379 @@
|
|||
#include <QFile>
|
||||
#include <QJsonDocument>
|
||||
#include <QJsonObject>
|
||||
#include <QProcess>
|
||||
#include <QDir>
|
||||
#include <QStandardPaths>
|
||||
#include <QDebug>
|
||||
#include <QCoreApplication>
|
||||
#include <QRandomGenerator>
|
||||
#include <cstring>
|
||||
|
||||
AceStep::AceStep(QObject* parent): QObject(parent)
|
||||
// acestep.cpp headers
|
||||
#include "pipeline-lm.h"
|
||||
#include "pipeline-synth.h"
|
||||
#include "request.h"
|
||||
|
||||
AceStepWorker::AceStepWorker(QObject* parent)
|
||||
: QObject(parent)
|
||||
{
|
||||
connect(&qwenProcess, &QProcess::finished, this, &AceStep::qwenProcFinished);
|
||||
connect(&ditVaeProcess, &QProcess::finished, this, &AceStep::ditProcFinished);
|
||||
}
|
||||
|
||||
bool AceStep::isGenerating(SongItem* song)
|
||||
AceStepWorker::~AceStepWorker()
|
||||
{
|
||||
if(!busy && song)
|
||||
*song = this->request.song;
|
||||
return busy;
|
||||
cancelGeneration();
|
||||
unloadModels();
|
||||
}
|
||||
|
||||
void AceStep::cancelGeneration()
|
||||
void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath)
|
||||
{
|
||||
qwenProcess.blockSignals(true);
|
||||
qwenProcess.terminate();
|
||||
qwenProcess.waitForFinished();
|
||||
qwenProcess.blockSignals(false);
|
||||
|
||||
ditVaeProcess.blockSignals(true);
|
||||
ditVaeProcess.terminate();
|
||||
ditVaeProcess.waitForFinished();
|
||||
ditVaeProcess.blockSignals(false);
|
||||
|
||||
progressUpdate(100);
|
||||
if(busy)
|
||||
generationCanceled(request.song);
|
||||
|
||||
busy = false;
|
||||
m_lmModelPath = lmPath;
|
||||
m_textEncoderPath = textEncoderPath;
|
||||
m_ditPath = ditPath;
|
||||
m_vaePath = vaePath;
|
||||
|
||||
// Cache as byte arrays to avoid dangling pointers
|
||||
m_lmModelPathBytes = lmPath.toUtf8();
|
||||
m_textEncoderPathBytes = textEncoderPath.toUtf8();
|
||||
m_ditPathBytes = ditPath.toUtf8();
|
||||
m_vaePathBytes = vaePath.toUtf8();
|
||||
}
|
||||
|
||||
bool AceStep::requestGeneration(SongItem song, QString requestTemplate, QString aceStepPath,
|
||||
QString qwen3ModelPath, QString textEncoderModelPath, QString ditModelPath,
|
||||
QString vaeModelPath)
|
||||
bool AceStepWorker::isGenerating(SongItem* song)
|
||||
{
|
||||
if(busy)
|
||||
{
|
||||
qWarning()<<"Dropping song:"<<song.caption;
|
||||
return false;
|
||||
}
|
||||
busy = true;
|
||||
if (!m_busy.load() && song)
|
||||
*song = m_currentSong;
|
||||
return m_busy.load();
|
||||
}
|
||||
|
||||
request = {song, QRandomGenerator::global()->generate(), aceStepPath, textEncoderModelPath, ditModelPath, vaeModelPath};
|
||||
void AceStepWorker::cancelGeneration()
|
||||
{
|
||||
m_cancelRequested.store(true);
|
||||
}
|
||||
|
||||
QString qwen3Binary = aceStepPath + "/ace-lm" + EXE_EXT;
|
||||
QFileInfo qwen3Info(qwen3Binary);
|
||||
if (!qwen3Info.exists() || !qwen3Info.isExecutable())
|
||||
bool AceStepWorker::requestGeneration(SongItem song, QString requestTemplate)
|
||||
{
|
||||
if (m_busy.load())
|
||||
{
|
||||
generationError("ace-lm binary not found at: " + qwen3Binary);
|
||||
busy = false;
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(qwen3ModelPath))
|
||||
{
|
||||
generationError("Qwen3 model not found: " + qwen3ModelPath);
|
||||
busy = false;
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(textEncoderModelPath))
|
||||
{
|
||||
generationError("Text encoder model not found: " + textEncoderModelPath);
|
||||
busy = false;
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(ditModelPath))
|
||||
{
|
||||
generationError("DiT model not found: " + ditModelPath);
|
||||
busy = false;
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(vaeModelPath))
|
||||
{
|
||||
generationError("VAE model not found: " + vaeModelPath);
|
||||
busy = false;
|
||||
qWarning() << "Dropping song:" << song.caption;
|
||||
return false;
|
||||
}
|
||||
|
||||
request.requestFilePath = tempDir + "/request_" + QString::number(request.uid) + ".json";
|
||||
m_busy.store(true);
|
||||
m_cancelRequested.store(false);
|
||||
m_currentSong = song;
|
||||
m_requestTemplate = requestTemplate;
|
||||
m_uid = QRandomGenerator::global()->generate();
|
||||
|
||||
// Validate model paths
|
||||
if (m_lmModelPath.isEmpty() || m_textEncoderPath.isEmpty() ||
|
||||
m_ditPath.isEmpty() || m_vaePath.isEmpty())
|
||||
{
|
||||
emit generationError("Model paths not set. Call setModelPaths() first.");
|
||||
m_busy.store(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate model files exist
|
||||
if (!QFileInfo::exists(m_lmModelPath))
|
||||
{
|
||||
emit generationError("LM model not found: " + m_lmModelPath);
|
||||
m_busy.store(false);
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(m_textEncoderPath))
|
||||
{
|
||||
emit generationError("Text encoder model not found: " + m_textEncoderPath);
|
||||
m_busy.store(false);
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(m_ditPath))
|
||||
{
|
||||
emit generationError("DiT model not found: " + m_ditPath);
|
||||
m_busy.store(false);
|
||||
return false;
|
||||
}
|
||||
if (!QFileInfo::exists(m_vaePath))
|
||||
{
|
||||
emit generationError("VAE model not found: " + m_vaePath);
|
||||
m_busy.store(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate template
|
||||
QJsonParseError parseError;
|
||||
QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError);
|
||||
if (!templateDoc.isObject())
|
||||
{
|
||||
generationError("Invalid JSON template: " + QString(parseError.errorString()));
|
||||
busy = false;
|
||||
emit generationError("Invalid JSON template: " + QString(parseError.errorString()));
|
||||
m_busy.store(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
QJsonObject requestObj = templateDoc.object();
|
||||
song.store(requestObj);
|
||||
|
||||
// Write the request file
|
||||
QFile requestFileHandle(request.requestFilePath);
|
||||
if (!requestFileHandle.open(QIODevice::WriteOnly | QIODevice::Text))
|
||||
{
|
||||
emit generationError("Failed to create request file: " + requestFileHandle.errorString());
|
||||
busy = false;
|
||||
return false;
|
||||
}
|
||||
requestFileHandle.write(QJsonDocument(requestObj).toJson(QJsonDocument::Indented));
|
||||
requestFileHandle.close();
|
||||
|
||||
QStringList qwen3Args;
|
||||
qwen3Args << "--request" << request.requestFilePath;
|
||||
qwen3Args << "--lm" << qwen3ModelPath;
|
||||
|
||||
progressUpdate(30);
|
||||
|
||||
qwenProcess.start(qwen3Binary, qwen3Args);
|
||||
// Run generation in the worker thread
|
||||
QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection);
|
||||
return true;
|
||||
}
|
||||
|
||||
void AceStep::qwenProcFinished(int code, QProcess::ExitStatus status)
|
||||
bool AceStepWorker::checkCancel(void* data)
|
||||
{
|
||||
QFile::remove(request.requestFilePath);
|
||||
if(code != 0)
|
||||
{
|
||||
QString errorOutput = qwenProcess.readAllStandardError();
|
||||
generationError("ace-lm exited with code " + QString::number(code) + ": " + errorOutput);
|
||||
busy = false;
|
||||
return;
|
||||
}
|
||||
|
||||
QString ditVaeBinary = request.aceStepPath + "/ace-synth" + EXE_EXT;
|
||||
QFileInfo ditVaeInfo(ditVaeBinary);
|
||||
if (!ditVaeInfo.exists() || !ditVaeInfo.isExecutable())
|
||||
{
|
||||
generationError("ace-synth binary not found at: " + ditVaeBinary);
|
||||
busy = false;
|
||||
return;
|
||||
}
|
||||
|
||||
request.requestLlmFilePath = tempDir + "/request_" + QString::number(request.uid) + "0.json";
|
||||
if (!QFileInfo::exists(request.requestLlmFilePath))
|
||||
{
|
||||
generationError("ace-lm failed to create enhanced request file "+request.requestLlmFilePath);
|
||||
busy = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Load lyrics from the enhanced request file
|
||||
QFile lmOutputFile(request.requestLlmFilePath);
|
||||
if (lmOutputFile.open(QIODevice::ReadOnly | QIODevice::Text))
|
||||
{
|
||||
QJsonParseError parseError;
|
||||
request.song.json = lmOutputFile.readAll();
|
||||
QJsonDocument doc = QJsonDocument::fromJson(request.song.json.toUtf8(), &parseError);
|
||||
lmOutputFile.close();
|
||||
|
||||
if (doc.isObject() && !parseError.error)
|
||||
{
|
||||
QJsonObject obj = doc.object();
|
||||
if (obj.contains("lyrics") && obj["lyrics"].isString())
|
||||
request.song.lyrics = obj["lyrics"].toString();
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Run ace-synth to generate audio
|
||||
QStringList ditVaeArgs;
|
||||
ditVaeArgs << "--request"<<request.requestLlmFilePath;
|
||||
ditVaeArgs << "--embedding"<<request.textEncoderModelPath;
|
||||
ditVaeArgs << "--dit"<<request.ditModelPath;
|
||||
ditVaeArgs << "--vae"<<request.vaeModelPath;
|
||||
ditVaeArgs << "--wav";
|
||||
|
||||
progressUpdate(60);
|
||||
|
||||
ditVaeProcess.start(ditVaeBinary, ditVaeArgs);
|
||||
AceStepWorker* worker = static_cast<AceStepWorker*>(data);
|
||||
return worker->m_cancelRequested.load();
|
||||
}
|
||||
|
||||
void AceStep::ditProcFinished(int code, QProcess::ExitStatus status)
|
||||
void AceStepWorker::runGeneration()
|
||||
{
|
||||
QFile::remove(request.requestLlmFilePath);
|
||||
if (code != 0)
|
||||
// Load models if needed
|
||||
if (!loadModels())
|
||||
{
|
||||
QString errorOutput = ditVaeProcess.readAllStandardError();
|
||||
generationError("ace-synth exited with code " + QString::number(code) + ": " + errorOutput);
|
||||
busy = false;
|
||||
m_busy.store(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the generated WAV file
|
||||
QString wavFile = tempDir+"/request_" + QString::number(request.uid) + "00.wav";
|
||||
if (!QFileInfo::exists(wavFile))
|
||||
// Convert SongItem to AceRequest
|
||||
AceRequest req = songToRequest(m_currentSong, m_requestTemplate);
|
||||
|
||||
// Step 1: LM generates lyrics and audio codes
|
||||
emit progressUpdate(30);
|
||||
|
||||
AceRequest lmOutput;
|
||||
request_init(&lmOutput);
|
||||
|
||||
int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput,
|
||||
nullptr, nullptr,
|
||||
checkCancel, this,
|
||||
LM_MODE_GENERATE);
|
||||
|
||||
if (m_cancelRequested.load())
|
||||
{
|
||||
generationError("No WAV file generated at "+wavFile);
|
||||
busy = false;
|
||||
emit generationCanceled(m_currentSong);
|
||||
m_busy.store(false);
|
||||
return;
|
||||
}
|
||||
busy = false;
|
||||
|
||||
progressUpdate(100);
|
||||
request.song.file = wavFile;
|
||||
songGenerated(request.song);
|
||||
if (lmResult != 0)
|
||||
{
|
||||
emit generationError("LM generation failed or was canceled");
|
||||
m_busy.store(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Update song with generated lyrics
|
||||
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
|
||||
|
||||
// Step 2: Synth generates audio
|
||||
emit progressUpdate(60);
|
||||
|
||||
AceAudio* audioOut = nullptr;
|
||||
AceAudio outputAudio;
|
||||
outputAudio.samples = nullptr;
|
||||
outputAudio.n_samples = 0;
|
||||
outputAudio.sample_rate = 48000;
|
||||
|
||||
int synthResult = ace_synth_generate(m_synthContext, &lmOutput,
|
||||
nullptr, 0, // no source audio
|
||||
nullptr, 0, // no reference audio
|
||||
1, &outputAudio,
|
||||
checkCancel, this);
|
||||
|
||||
if (m_cancelRequested.load())
|
||||
{
|
||||
emit generationCanceled(m_currentSong);
|
||||
m_busy.store(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (synthResult != 0)
|
||||
{
|
||||
emit generationError("Synthesis failed or was canceled");
|
||||
m_busy.store(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Save audio to file
|
||||
QString wavFile = m_tempDir + "/request_" + QString::number(m_uid) + ".wav";
|
||||
|
||||
// Write WAV file
|
||||
QFile outFile(wavFile);
|
||||
if (!outFile.open(QIODevice::WriteOnly))
|
||||
{
|
||||
emit generationError("Failed to create output file: " + outFile.errorString());
|
||||
ace_audio_free(&outputAudio);
|
||||
m_busy.store(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Simple WAV header + stereo float data
|
||||
int numChannels = 2;
|
||||
int bitsPerSample = 16;
|
||||
int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8);
|
||||
int blockAlign = numChannels * (bitsPerSample / 8);
|
||||
int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8);
|
||||
|
||||
// RIFF header
|
||||
outFile.write("RIFF");
|
||||
outFile.write(reinterpret_cast<const char*>(&dataSize), 4);
|
||||
outFile.write("WAVE");
|
||||
|
||||
// fmt chunk
|
||||
outFile.write("fmt ");
|
||||
int fmtSize = 16;
|
||||
outFile.write(reinterpret_cast<const char*>(&fmtSize), 4);
|
||||
short audioFormat = 1; // PCM
|
||||
outFile.write(reinterpret_cast<const char*>(&audioFormat), 2);
|
||||
short numCh = numChannels;
|
||||
outFile.write(reinterpret_cast<const char*>(&numCh), 2);
|
||||
int sampleRate = outputAudio.sample_rate;
|
||||
outFile.write(reinterpret_cast<const char*>(&sampleRate), 4);
|
||||
outFile.write(reinterpret_cast<const char*>(&byteRate), 4);
|
||||
outFile.write(reinterpret_cast<const char*>(&blockAlign), 2);
|
||||
outFile.write(reinterpret_cast<const char*>(&bitsPerSample), 2);
|
||||
|
||||
// data chunk
|
||||
outFile.write("data");
|
||||
outFile.write(reinterpret_cast<const char*>(&dataSize), 4);
|
||||
|
||||
// Convert float samples to 16-bit and write
|
||||
QVector<short> interleaved(outputAudio.n_samples * numChannels);
|
||||
for (int i = 0; i < outputAudio.n_samples; i++)
|
||||
{
|
||||
float left = outputAudio.samples[i];
|
||||
float right = outputAudio.samples[i + outputAudio.n_samples];
|
||||
// Clamp and convert to 16-bit
|
||||
left = std::max(-1.0f, std::min(1.0f, left));
|
||||
right = std::max(-1.0f, std::min(1.0f, right));
|
||||
interleaved[i * 2] = static_cast<short>(left * 32767.0f);
|
||||
interleaved[i * 2 + 1] = static_cast<short>(right * 32767.0f);
|
||||
}
|
||||
outFile.write(reinterpret_cast<const char*>(interleaved.data()), dataSize);
|
||||
outFile.close();
|
||||
|
||||
// Free audio buffer
|
||||
ace_audio_free(&outputAudio);
|
||||
|
||||
// Store the JSON with all generated fields
|
||||
m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true));
|
||||
m_currentSong.file = wavFile;
|
||||
|
||||
// Extract BPM if available
|
||||
if (lmOutput.bpm > 0)
|
||||
m_currentSong.bpm = lmOutput.bpm;
|
||||
|
||||
// Extract key if available
|
||||
if (!lmOutput.keyscale.empty())
|
||||
m_currentSong.key = QString::fromStdString(lmOutput.keyscale);
|
||||
|
||||
emit progressUpdate(100);
|
||||
emit songGenerated(m_currentSong);
|
||||
|
||||
m_busy.store(false);
|
||||
}
|
||||
|
||||
bool AceStepWorker::loadModels()
|
||||
{
|
||||
if (m_modelsLoaded.load())
|
||||
return true;
|
||||
|
||||
// Load LM
|
||||
AceLmParams lmParams;
|
||||
ace_lm_default_params(&lmParams);
|
||||
lmParams.model_path = m_lmModelPathBytes.constData();
|
||||
lmParams.use_fsm = true;
|
||||
lmParams.use_fa = true;
|
||||
|
||||
m_lmContext = ace_lm_load(&lmParams);
|
||||
if (!m_lmContext)
|
||||
{
|
||||
emit generationError("Failed to load LM model: " + m_lmModelPath);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load Synth
|
||||
AceSynthParams synthParams;
|
||||
ace_synth_default_params(&synthParams);
|
||||
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
||||
synthParams.dit_path = m_ditPathBytes.constData();
|
||||
synthParams.vae_path = m_vaePathBytes.constData();
|
||||
synthParams.use_fa = true;
|
||||
|
||||
m_synthContext = ace_synth_load(&synthParams);
|
||||
if (!m_synthContext)
|
||||
{
|
||||
emit generationError("Failed to load synthesis models");
|
||||
ace_lm_free(m_lmContext);
|
||||
m_lmContext = nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
m_modelsLoaded.store(true);
|
||||
return true;
|
||||
}
|
||||
|
||||
void AceStepWorker::unloadModels()
|
||||
{
|
||||
if (m_synthContext)
|
||||
{
|
||||
ace_synth_free(m_synthContext);
|
||||
m_synthContext = nullptr;
|
||||
}
|
||||
if (m_lmContext)
|
||||
{
|
||||
ace_lm_free(m_lmContext);
|
||||
m_lmContext = nullptr;
|
||||
}
|
||||
m_modelsLoaded.store(false);
|
||||
}
|
||||
|
||||
AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& templateJson)
|
||||
{
|
||||
AceRequest req;
|
||||
request_init(&req);
|
||||
|
||||
req.caption = song.caption.toStdString();
|
||||
req.lyrics = song.lyrics.toStdString();
|
||||
req.use_cot_caption = song.cotCaption;
|
||||
|
||||
// Parse template and override defaults
|
||||
QJsonParseError parseError;
|
||||
QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError);
|
||||
if (templateDoc.isObject())
|
||||
{
|
||||
QJsonObject obj = templateDoc.object();
|
||||
if (obj.contains("inference_steps"))
|
||||
req.inference_steps = obj["inference_steps"].toInt(8);
|
||||
if (obj.contains("shift"))
|
||||
req.shift = obj["shift"].toDouble(3.0);
|
||||
if (obj.contains("vocal_language"))
|
||||
req.vocal_language = obj["vocal_language"].toString().toStdString();
|
||||
if (obj.contains("bpm"))
|
||||
req.bpm = obj["bpm"].toInt(120);
|
||||
if (obj.contains("duration"))
|
||||
req.duration = obj["duration"].toDouble(180.0);
|
||||
if (obj.contains("keyscale"))
|
||||
req.keyscale = obj["keyscale"].toString().toStdString();
|
||||
if (obj.contains("lm_temperature"))
|
||||
req.lm_temperature = obj["lm_temperature"].toDouble(0.85);
|
||||
if (obj.contains("lm_cfg_scale"))
|
||||
req.lm_cfg_scale = obj["lm_cfg_scale"].toDouble(2.0);
|
||||
}
|
||||
|
||||
// Generate a seed for reproducibility
|
||||
req.seed = static_cast<int64_t>(QRandomGenerator::global()->generate());
|
||||
|
||||
return req;
|
||||
}
|
||||
|
||||
SongItem AceStepWorker::requestToSong(const AceRequest& req, const QString& json)
|
||||
{
|
||||
SongItem song;
|
||||
song.caption = QString::fromStdString(req.caption);
|
||||
song.lyrics = QString::fromStdString(req.lyrics);
|
||||
song.cotCaption = req.use_cot_caption;
|
||||
|
||||
if (req.bpm > 0)
|
||||
song.bpm = req.bpm;
|
||||
if (!req.keyscale.empty())
|
||||
song.key = QString::fromStdString(req.keyscale);
|
||||
if (!req.vocal_language.empty())
|
||||
song.vocalLanguage = QString::fromStdString(req.vocal_language);
|
||||
|
||||
song.json = json;
|
||||
return song;
|
||||
}
|
||||
|
|
@ -8,40 +8,34 @@
|
|||
|
||||
#include <QObject>
|
||||
#include <QString>
|
||||
#include <QProcess>
|
||||
#include <QThread>
|
||||
#include <QStandardPaths>
|
||||
#include <atomic>
|
||||
|
||||
#include "SongItem.h"
|
||||
|
||||
#ifdef Q_OS_WIN
|
||||
inline const QString EXE_EXT = ".exe";
|
||||
#else
|
||||
inline const QString EXE_EXT = "";
|
||||
#endif
|
||||
// acestep.cpp headers
|
||||
#include "request.h"
|
||||
|
||||
class AceStep : public QObject
|
||||
struct AceLm;
|
||||
struct AceSynth;
|
||||
|
||||
class AceStepWorker : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
QProcess qwenProcess;
|
||||
QProcess ditVaeProcess;
|
||||
|
||||
bool busy = false;
|
||||
public:
|
||||
explicit AceStepWorker(QObject* parent = nullptr);
|
||||
~AceStepWorker();
|
||||
|
||||
struct Request
|
||||
{
|
||||
SongItem song;
|
||||
uint64_t uid;
|
||||
QString aceStepPath;
|
||||
QString textEncoderModelPath;
|
||||
QString ditModelPath;
|
||||
QString vaeModelPath;
|
||||
QString requestFilePath;
|
||||
QString requestLlmFilePath;
|
||||
};
|
||||
bool isGenerating(SongItem* song = nullptr);
|
||||
void cancelGeneration();
|
||||
|
||||
Request request;
|
||||
// Model paths - set these before first generation
|
||||
void setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath);
|
||||
|
||||
const QString tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation);
|
||||
// Request a new song generation
|
||||
bool requestGeneration(SongItem song, QString requestTemplate);
|
||||
|
||||
signals:
|
||||
void songGenerated(SongItem song);
|
||||
|
|
@ -49,19 +43,50 @@ signals:
|
|||
void generationError(QString error);
|
||||
void progressUpdate(int progress);
|
||||
|
||||
public slots:
|
||||
bool requestGeneration(SongItem song, QString requestTemplate, QString aceStepPath,
|
||||
QString qwen3ModelPath, QString textEncoderModelPath, QString ditModelPath,
|
||||
QString vaeModelPath);
|
||||
|
||||
public:
|
||||
AceStep(QObject* parent = nullptr);
|
||||
bool isGenerating(SongItem* song = nullptr);
|
||||
void cancelGeneration();
|
||||
|
||||
private slots:
|
||||
void qwenProcFinished(int code, QProcess::ExitStatus status);
|
||||
void ditProcFinished(int code, QProcess::ExitStatus status);
|
||||
void runGeneration();
|
||||
|
||||
private:
|
||||
// Check if cancellation was requested
|
||||
static bool checkCancel(void* data);
|
||||
|
||||
// Load models if not already loaded
|
||||
bool loadModels();
|
||||
void unloadModels();
|
||||
|
||||
// Convert SongItem to AceRequest
|
||||
AceRequest songToRequest(const SongItem& song, const QString& templateJson);
|
||||
|
||||
// Convert AceRequest back to SongItem
|
||||
SongItem requestToSong(const AceRequest& req, const QString& json);
|
||||
|
||||
// Generation state
|
||||
std::atomic<bool> m_busy{false};
|
||||
std::atomic<bool> m_cancelRequested{false};
|
||||
std::atomic<bool> m_modelsLoaded{false};
|
||||
|
||||
// Current request data
|
||||
SongItem m_currentSong;
|
||||
QString m_requestTemplate;
|
||||
uint64_t m_uid;
|
||||
|
||||
// Model paths
|
||||
QString m_lmModelPath;
|
||||
QString m_textEncoderPath;
|
||||
QString m_ditPath;
|
||||
QString m_vaePath;
|
||||
|
||||
// Loaded models (accessed from worker thread only)
|
||||
AceLm* m_lmContext = nullptr;
|
||||
AceSynth* m_synthContext = nullptr;
|
||||
|
||||
// Cached model paths as byte arrays (to avoid dangling pointers)
|
||||
QByteArray m_lmModelPathBytes;
|
||||
QByteArray m_textEncoderPathBytes;
|
||||
QByteArray m_ditPathBytes;
|
||||
QByteArray m_vaePathBytes;
|
||||
|
||||
const QString m_tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation);
|
||||
};
|
||||
|
||||
#endif // ACESTEPWORKER_H
|
||||
#endif // ACESTEPWORKER_H
|
||||
|
|
@ -21,7 +21,7 @@ MainWindow::MainWindow(QWidget *parent)
|
|||
ui(new Ui::MainWindow),
|
||||
songModel(new SongListModel(this)),
|
||||
audioPlayer(new AudioPlayer(this)),
|
||||
aceStep(new AceStep(this)),
|
||||
aceStep(new AceStepWorker(this)),
|
||||
playbackTimer(new QTimer(this)),
|
||||
isPlaying(false),
|
||||
isPaused(false),
|
||||
|
|
@ -41,6 +41,9 @@ MainWindow::MainWindow(QWidget *parent)
|
|||
// Load settings
|
||||
loadSettings();
|
||||
|
||||
// Set model paths for acestep.cpp
|
||||
aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath);
|
||||
|
||||
// Auto-load playlist from config location on startup
|
||||
autoLoadPlaylist();
|
||||
|
||||
|
|
@ -62,10 +65,10 @@ MainWindow::MainWindow(QWidget *parent)
|
|||
connect(audioPlayer, &AudioPlayer::playbackStarted, this, &MainWindow::playbackStarted);
|
||||
connect(audioPlayer, &AudioPlayer::positionChanged, this, &MainWindow::updatePosition);
|
||||
connect(audioPlayer, &AudioPlayer::durationChanged, this, &MainWindow::updateDuration);
|
||||
connect(aceStep, &AceStep::songGenerated, this, &MainWindow::songGenerated);
|
||||
connect(aceStep, &AceStep::generationCanceled, this, &MainWindow::generationCanceld);
|
||||
connect(aceStep, &AceStep::generationError, this, &MainWindow::generationError);
|
||||
connect(aceStep, &AceStep::progressUpdate, ui->progressBar, &QProgressBar::setValue);
|
||||
connect(aceStep, &AceStepWorker::songGenerated, this, &MainWindow::songGenerated);
|
||||
connect(aceStep, &AceStepWorker::generationCanceled, this, &MainWindow::generationCanceld);
|
||||
connect(aceStep, &AceStepWorker::generationError, this, &MainWindow::generationError);
|
||||
connect(aceStep, &AceStepWorker::progressUpdate, ui->progressBar, &QProgressBar::setValue);
|
||||
|
||||
// Connect double-click on song list for editing (works with QTableView too)
|
||||
connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::on_songListView_doubleClicked);
|
||||
|
|
@ -391,6 +394,9 @@ void MainWindow::on_advancedSettingsButton_clicked()
|
|||
ditModelPath = dialog.getDiTModelPath();
|
||||
vaeModelPath = dialog.getVAEModelPath();
|
||||
|
||||
// Update model paths for acestep.cpp
|
||||
aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath);
|
||||
|
||||
saveSettings();
|
||||
}
|
||||
}
|
||||
|
|
@ -533,10 +539,7 @@ void MainWindow::ensureSongsInQueue(bool enqeueCurrent)
|
|||
isGeneratingNext = true;
|
||||
|
||||
ui->statusbar->showMessage("Generateing: "+nextSong.caption);
|
||||
aceStep->requestGeneration(nextSong, jsonTemplate,
|
||||
aceStepPath, qwen3ModelPath,
|
||||
textEncoderModelPath, ditModelPath,
|
||||
vaeModelPath);
|
||||
aceStep->requestGeneration(nextSong, jsonTemplate);
|
||||
}
|
||||
|
||||
void MainWindow::flushGenerationQueue()
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class MainWindow : public QMainWindow
|
|||
SongListModel *songModel;
|
||||
AudioPlayer *audioPlayer;
|
||||
QThread aceThread;
|
||||
AceStep *aceStep;
|
||||
AceStepWorker *aceStep;
|
||||
QTimer *playbackTimer;
|
||||
|
||||
QString formatTime(int milliseconds);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue