Improve AceStep request creation
This commit is contained in:
parent
e3fb4761b0
commit
cc72f360b9
2 changed files with 17 additions and 63 deletions
|
|
@ -119,16 +119,6 @@ bool AceStepWorker::requestGeneration(SongItem song, QString requestTemplate)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate template
|
|
||||||
QJsonParseError parseError;
|
|
||||||
QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError);
|
|
||||||
if (!templateDoc.isObject())
|
|
||||||
{
|
|
||||||
emit generationError("Invalid JSON template: " + QString(parseError.errorString()));
|
|
||||||
m_busy.store(false);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run generation in the worker thread
|
// Run generation in the worker thread
|
||||||
QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection);
|
QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection);
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -382,54 +372,26 @@ AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& tem
|
||||||
AceRequest req;
|
AceRequest req;
|
||||||
request_init(&req);
|
request_init(&req);
|
||||||
|
|
||||||
req.caption = song.caption.toStdString();
|
// Parse template first to get all default values
|
||||||
req.lyrics = song.lyrics.toStdString();
|
QJsonObject requestJson;
|
||||||
req.use_cot_caption = song.cotCaption;
|
if (!templateJson.isEmpty())
|
||||||
|
{
|
||||||
// Parse template and override defaults
|
|
||||||
QJsonParseError parseError;
|
QJsonParseError parseError;
|
||||||
QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError);
|
QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError);
|
||||||
if (templateDoc.isObject())
|
if (templateDoc.isObject())
|
||||||
{
|
requestJson = templateDoc.object();
|
||||||
QJsonObject obj = templateDoc.object();
|
else
|
||||||
if (obj.contains("inference_steps"))
|
qWarning() << "Failed to parse request template:" << parseError.errorString();
|
||||||
req.inference_steps = obj["inference_steps"].toInt(8);
|
|
||||||
if (obj.contains("shift"))
|
|
||||||
req.shift = obj["shift"].toDouble(3.0);
|
|
||||||
if (obj.contains("vocal_language"))
|
|
||||||
req.vocal_language = obj["vocal_language"].toString().toStdString();
|
|
||||||
if (obj.contains("bpm"))
|
|
||||||
req.bpm = obj["bpm"].toInt(120);
|
|
||||||
if (obj.contains("duration"))
|
|
||||||
req.duration = obj["duration"].toDouble(180.0);
|
|
||||||
if (obj.contains("keyscale"))
|
|
||||||
req.keyscale = obj["keyscale"].toString().toStdString();
|
|
||||||
if (obj.contains("lm_temperature"))
|
|
||||||
req.lm_temperature = obj["lm_temperature"].toDouble(0.85);
|
|
||||||
if (obj.contains("lm_cfg_scale"))
|
|
||||||
req.lm_cfg_scale = obj["lm_cfg_scale"].toDouble(2.0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a seed for reproducibility
|
song.store(requestJson);
|
||||||
|
QJsonDocument requestJsonDoc(requestJson);
|
||||||
|
QString requestJsonString = QString::fromUtf8(requestJsonDoc.toJson(QJsonDocument::Compact));
|
||||||
|
if (!request_parse_json(&req, requestJsonString.toUtf8()))
|
||||||
|
qWarning() << "Failed to parse merged request JSON";
|
||||||
|
|
||||||
|
if (req.seed < 0)
|
||||||
req.seed = static_cast<int64_t>(QRandomGenerator::global()->generate());
|
req.seed = static_cast<int64_t>(QRandomGenerator::global()->generate());
|
||||||
|
|
||||||
return req;
|
return req;
|
||||||
}
|
}
|
||||||
|
|
||||||
SongItem AceStepWorker::requestToSong(const AceRequest& req, const QString& json)
|
|
||||||
{
|
|
||||||
SongItem song;
|
|
||||||
song.caption = QString::fromStdString(req.caption);
|
|
||||||
song.lyrics = QString::fromStdString(req.lyrics);
|
|
||||||
song.cotCaption = req.use_cot_caption;
|
|
||||||
|
|
||||||
if (req.bpm > 0)
|
|
||||||
song.bpm = req.bpm;
|
|
||||||
if (!req.keyscale.empty())
|
|
||||||
song.key = QString::fromStdString(req.keyscale);
|
|
||||||
if (!req.vocal_language.empty())
|
|
||||||
song.vocalLanguage = QString::fromStdString(req.vocal_language);
|
|
||||||
|
|
||||||
song.json = json;
|
|
||||||
return song;
|
|
||||||
}
|
|
||||||
|
|
@ -54,25 +54,17 @@ private slots:
|
||||||
void runGeneration();
|
void runGeneration();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Check if cancellation was requested
|
|
||||||
static bool checkCancel(void* data);
|
static bool checkCancel(void* data);
|
||||||
|
|
||||||
// Load models if not already loaded
|
|
||||||
bool loadModels();
|
bool loadModels();
|
||||||
void unloadModels();
|
void unloadModels();
|
||||||
|
|
||||||
// Individual model load/unload for low VRAM mode
|
|
||||||
bool loadLm();
|
bool loadLm();
|
||||||
void unloadLm();
|
void unloadLm();
|
||||||
bool loadSynth();
|
bool loadSynth();
|
||||||
void unloadSynth();
|
void unloadSynth();
|
||||||
|
|
||||||
// Convert SongItem to AceRequest
|
|
||||||
AceRequest songToRequest(const SongItem& song, const QString& templateJson);
|
AceRequest songToRequest(const SongItem& song, const QString& templateJson);
|
||||||
|
|
||||||
// Convert AceRequest back to SongItem
|
|
||||||
SongItem requestToSong(const AceRequest& req, const QString& json);
|
|
||||||
|
|
||||||
static std::shared_ptr<QByteArray> convertToWav(const AceAudio& audio);
|
static std::shared_ptr<QByteArray> convertToWav(const AceAudio& audio);
|
||||||
|
|
||||||
// Generation state
|
// Generation state
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue