Improve AceStep request creation

This commit is contained in:
Carl Philipp Klemm 2026-04-15 12:11:38 +02:00
parent e3fb4761b0
commit cc72f360b9
2 changed files with 17 additions and 63 deletions

View file

@ -119,16 +119,6 @@ bool AceStepWorker::requestGeneration(SongItem song, QString requestTemplate)
return false; return false;
} }
// Validate template
QJsonParseError parseError;
QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError);
if (!templateDoc.isObject())
{
emit generationError("Invalid JSON template: " + QString(parseError.errorString()));
m_busy.store(false);
return false;
}
// Run generation in the worker thread // Run generation in the worker thread
QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection); QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection);
return true; return true;
@ -382,54 +372,26 @@ AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& tem
AceRequest req; AceRequest req;
request_init(&req); request_init(&req);
req.caption = song.caption.toStdString(); // Parse template first to get all default values
req.lyrics = song.lyrics.toStdString(); QJsonObject requestJson;
req.use_cot_caption = song.cotCaption; if (!templateJson.isEmpty())
// Parse template and override defaults
QJsonParseError parseError;
QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError);
if (templateDoc.isObject())
{ {
QJsonObject obj = templateDoc.object(); QJsonParseError parseError;
if (obj.contains("inference_steps")) QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError);
req.inference_steps = obj["inference_steps"].toInt(8); if (templateDoc.isObject())
if (obj.contains("shift")) requestJson = templateDoc.object();
req.shift = obj["shift"].toDouble(3.0); else
if (obj.contains("vocal_language")) qWarning() << "Failed to parse request template:" << parseError.errorString();
req.vocal_language = obj["vocal_language"].toString().toStdString();
if (obj.contains("bpm"))
req.bpm = obj["bpm"].toInt(120);
if (obj.contains("duration"))
req.duration = obj["duration"].toDouble(180.0);
if (obj.contains("keyscale"))
req.keyscale = obj["keyscale"].toString().toStdString();
if (obj.contains("lm_temperature"))
req.lm_temperature = obj["lm_temperature"].toDouble(0.85);
if (obj.contains("lm_cfg_scale"))
req.lm_cfg_scale = obj["lm_cfg_scale"].toDouble(2.0);
} }
// Generate a seed for reproducibility song.store(requestJson);
req.seed = static_cast<int64_t>(QRandomGenerator::global()->generate()); QJsonDocument requestJsonDoc(requestJson);
QString requestJsonString = QString::fromUtf8(requestJsonDoc.toJson(QJsonDocument::Compact));
if (!request_parse_json(&req, requestJsonString.toUtf8()))
qWarning() << "Failed to parse merged request JSON";
if (req.seed < 0)
req.seed = static_cast<int64_t>(QRandomGenerator::global()->generate());
return req; return req;
} }
SongItem AceStepWorker::requestToSong(const AceRequest& req, const QString& json)
{
SongItem song;
song.caption = QString::fromStdString(req.caption);
song.lyrics = QString::fromStdString(req.lyrics);
song.cotCaption = req.use_cot_caption;
if (req.bpm > 0)
song.bpm = req.bpm;
if (!req.keyscale.empty())
song.key = QString::fromStdString(req.keyscale);
if (!req.vocal_language.empty())
song.vocalLanguage = QString::fromStdString(req.vocal_language);
song.json = json;
return song;
}

View file

@ -54,25 +54,17 @@ private slots:
void runGeneration(); void runGeneration();
private: private:
// Check if cancellation was requested
static bool checkCancel(void* data); static bool checkCancel(void* data);
// Load models if not already loaded
bool loadModels(); bool loadModels();
void unloadModels(); void unloadModels();
// Individual model load/unload for low VRAM mode
bool loadLm(); bool loadLm();
void unloadLm(); void unloadLm();
bool loadSynth(); bool loadSynth();
void unloadSynth(); void unloadSynth();
// Convert SongItem to AceRequest
AceRequest songToRequest(const SongItem& song, const QString& templateJson); AceRequest songToRequest(const SongItem& song, const QString& templateJson);
// Convert AceRequest back to SongItem
SongItem requestToSong(const AceRequest& req, const QString& json);
static std::shared_ptr<QByteArray> convertToWav(const AceAudio& audio); static std::shared_ptr<QByteArray> convertToWav(const AceAudio& audio);
// Generation state // Generation state