From cc72f360b9b91877e4f783cdbe59b1b271da85c8 Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm Date: Wed, 15 Apr 2026 12:11:38 +0200 Subject: [PATCH] Improve AceStep request creation --- src/AceStepWorker.cpp | 72 ++++++++++--------------------------------- src/AceStepWorker.h | 8 ----- 2 files changed, 17 insertions(+), 63 deletions(-) diff --git a/src/AceStepWorker.cpp b/src/AceStepWorker.cpp index 731d021..8804613 100644 --- a/src/AceStepWorker.cpp +++ b/src/AceStepWorker.cpp @@ -119,16 +119,6 @@ bool AceStepWorker::requestGeneration(SongItem song, QString requestTemplate) return false; } - // Validate template - QJsonParseError parseError; - QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError); - if (!templateDoc.isObject()) - { - emit generationError("Invalid JSON template: " + QString(parseError.errorString())); - m_busy.store(false); - return false; - } - // Run generation in the worker thread QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection); return true; @@ -382,54 +372,26 @@ AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& tem AceRequest req; request_init(&req); - req.caption = song.caption.toStdString(); - req.lyrics = song.lyrics.toStdString(); - req.use_cot_caption = song.cotCaption; - - // Parse template and override defaults - QJsonParseError parseError; - QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError); - if (templateDoc.isObject()) + // Parse template first to get all default values + QJsonObject requestJson; + if (!templateJson.isEmpty()) { - QJsonObject obj = templateDoc.object(); - if (obj.contains("inference_steps")) - req.inference_steps = obj["inference_steps"].toInt(8); - if (obj.contains("shift")) - req.shift = obj["shift"].toDouble(3.0); - if (obj.contains("vocal_language")) - req.vocal_language = obj["vocal_language"].toString().toStdString(); - if (obj.contains("bpm")) - req.bpm = obj["bpm"].toInt(120); - if (obj.contains("duration")) - req.duration = obj["duration"].toDouble(180.0); - if (obj.contains("keyscale")) - req.keyscale = obj["keyscale"].toString().toStdString(); - if (obj.contains("lm_temperature")) - req.lm_temperature = obj["lm_temperature"].toDouble(0.85); - if (obj.contains("lm_cfg_scale")) - req.lm_cfg_scale = obj["lm_cfg_scale"].toDouble(2.0); + QJsonParseError parseError; + QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError); + if (templateDoc.isObject()) + requestJson = templateDoc.object(); + else + qWarning() << "Failed to parse request template:" << parseError.errorString(); } - // Generate a seed for reproducibility - req.seed = static_cast(QRandomGenerator::global()->generate()); + song.store(requestJson); + QJsonDocument requestJsonDoc(requestJson); + QString requestJsonString = QString::fromUtf8(requestJsonDoc.toJson(QJsonDocument::Compact)); + if (!request_parse_json(&req, requestJsonString.toUtf8())) + qWarning() << "Failed to parse merged request JSON"; + + if (req.seed < 0) + req.seed = static_cast(QRandomGenerator::global()->generate()); return req; -} - -SongItem AceStepWorker::requestToSong(const AceRequest& req, const QString& json) -{ - SongItem song; - song.caption = QString::fromStdString(req.caption); - song.lyrics = QString::fromStdString(req.lyrics); - song.cotCaption = req.use_cot_caption; - - if (req.bpm > 0) - song.bpm = req.bpm; - if (!req.keyscale.empty()) - song.key = QString::fromStdString(req.keyscale); - if (!req.vocal_language.empty()) - song.vocalLanguage = QString::fromStdString(req.vocal_language); - - song.json = json; - return song; } \ No newline at end of file diff --git a/src/AceStepWorker.h b/src/AceStepWorker.h index ec8c8a6..56d95b1 100644 --- a/src/AceStepWorker.h +++ b/src/AceStepWorker.h @@ -54,25 +54,17 @@ private slots: void runGeneration(); private: - // Check if cancellation was requested static bool checkCancel(void* data); - // Load models if not already loaded bool loadModels(); void unloadModels(); - - // Individual model load/unload for low VRAM mode bool loadLm(); void unloadLm(); bool loadSynth(); void unloadSynth(); - // Convert SongItem to AceRequest AceRequest songToRequest(const SongItem& song, const QString& templateJson); - // Convert AceRequest back to SongItem - SongItem requestToSong(const AceRequest& req, const QString& json); - static std::shared_ptr convertToWav(const AceAudio& audio); // Generation state