Use the acestop api directly instead of calling binaries
This commit is contained in:
parent
6be80f1d5c
commit
de7207f07e
8 changed files with 789 additions and 210 deletions
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
[submodule "third_party/acestep.cpp"]
|
||||||
|
path = third_party/acestep.cpp
|
||||||
|
url = https://github.com/ServeurpersoCom/acestep.cpp.git
|
||||||
|
|
@ -10,7 +10,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
# Find Qt packages
|
# Find Qt packages
|
||||||
find_package(Qt6 COMPONENTS Core Gui Widgets Multimedia REQUIRED)
|
find_package(Qt6 COMPONENTS Core Gui Widgets Multimedia REQUIRED)
|
||||||
|
|
||||||
# Note: acestep.cpp binaries and models should be provided at runtime
|
# Add acestep.cpp subdirectory
|
||||||
|
add_subdirectory(third_party/acestep.cpp)
|
||||||
|
|
||||||
set(CMAKE_AUTOMOC ON)
|
set(CMAKE_AUTOMOC ON)
|
||||||
set(CMAKE_AUTOUIC ON)
|
set(CMAKE_AUTOUIC ON)
|
||||||
|
|
@ -47,21 +48,22 @@ add_executable(${PROJECT_NAME}
|
||||||
# UI file
|
# UI file
|
||||||
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
# Link libraries (only Qt libraries - acestep.cpp is external)
|
# Link libraries (Qt + acestep.cpp)
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||||
Qt6::Core
|
Qt6::Core
|
||||||
Qt6::Gui
|
Qt6::Gui
|
||||||
Qt6::Widgets
|
Qt6::Widgets
|
||||||
Qt6::Multimedia
|
Qt6::Multimedia
|
||||||
|
acestep-core
|
||||||
|
ggml
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include directories (only our source directory - acestep.cpp is external)
|
# Include directories
|
||||||
target_include_directories(${PROJECT_NAME} PRIVATE
|
target_include_directories(${PROJECT_NAME} PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/third_party/acestep.cpp/src
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: acestep.cpp binaries (ace-qwen3, dit-vae) and models should be provided at runtime
|
|
||||||
|
|
||||||
# Install targets
|
# Install targets
|
||||||
install(TARGETS ${PROJECT_NAME} DESTINATION bin)
|
install(TARGETS ${PROJECT_NAME} DESTINATION bin)
|
||||||
|
|
||||||
|
|
@ -71,3 +73,24 @@ install(FILES res/xyz.uvos.aceradio.desktop DESTINATION share/applications)
|
||||||
# Install icon files
|
# Install icon files
|
||||||
install(FILES res/xyz.uvos.aceradio.png DESTINATION share/icons/hicolor/256x256/apps RENAME xyz.uvos.aceradio.png)
|
install(FILES res/xyz.uvos.aceradio.png DESTINATION share/icons/hicolor/256x256/apps RENAME xyz.uvos.aceradio.png)
|
||||||
install(FILES res/xyz.uvos.aceradio.svg DESTINATION share/icons/hicolor/scalable/apps RENAME xyz.uvos.aceradio.svg)
|
install(FILES res/xyz.uvos.aceradio.svg DESTINATION share/icons/hicolor/scalable/apps RENAME xyz.uvos.aceradio.svg)
|
||||||
|
|
||||||
|
# Test executable
|
||||||
|
add_executable(test_acestep_worker
|
||||||
|
tests/test_acestep_worker.cpp
|
||||||
|
src/AceStepWorker.cpp
|
||||||
|
src/AceStepWorker.h
|
||||||
|
src/SongItem.cpp
|
||||||
|
src/SongItem.h
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(test_acestep_worker PRIVATE
|
||||||
|
Qt6::Core
|
||||||
|
Qt6::Widgets
|
||||||
|
acestep-core
|
||||||
|
ggml
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(test_acestep_worker PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/third_party/acestep.cpp/src
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,207 +5,379 @@
|
||||||
#include <QFile>
|
#include <QFile>
|
||||||
#include <QJsonDocument>
|
#include <QJsonDocument>
|
||||||
#include <QJsonObject>
|
#include <QJsonObject>
|
||||||
#include <QProcess>
|
|
||||||
#include <QDir>
|
#include <QDir>
|
||||||
#include <QStandardPaths>
|
|
||||||
#include <QDebug>
|
#include <QDebug>
|
||||||
#include <QCoreApplication>
|
|
||||||
#include <QRandomGenerator>
|
#include <QRandomGenerator>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
AceStep::AceStep(QObject* parent): QObject(parent)
|
// acestep.cpp headers
|
||||||
|
#include "pipeline-lm.h"
|
||||||
|
#include "pipeline-synth.h"
|
||||||
|
#include "request.h"
|
||||||
|
|
||||||
|
AceStepWorker::AceStepWorker(QObject* parent)
|
||||||
|
: QObject(parent)
|
||||||
{
|
{
|
||||||
connect(&qwenProcess, &QProcess::finished, this, &AceStep::qwenProcFinished);
|
|
||||||
connect(&ditVaeProcess, &QProcess::finished, this, &AceStep::ditProcFinished);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AceStep::isGenerating(SongItem* song)
|
AceStepWorker::~AceStepWorker()
|
||||||
{
|
{
|
||||||
if(!busy && song)
|
cancelGeneration();
|
||||||
*song = this->request.song;
|
unloadModels();
|
||||||
return busy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AceStep::cancelGeneration()
|
void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath)
|
||||||
{
|
{
|
||||||
qwenProcess.blockSignals(true);
|
m_lmModelPath = lmPath;
|
||||||
qwenProcess.terminate();
|
m_textEncoderPath = textEncoderPath;
|
||||||
qwenProcess.waitForFinished();
|
m_ditPath = ditPath;
|
||||||
qwenProcess.blockSignals(false);
|
m_vaePath = vaePath;
|
||||||
|
|
||||||
ditVaeProcess.blockSignals(true);
|
// Cache as byte arrays to avoid dangling pointers
|
||||||
ditVaeProcess.terminate();
|
m_lmModelPathBytes = lmPath.toUtf8();
|
||||||
ditVaeProcess.waitForFinished();
|
m_textEncoderPathBytes = textEncoderPath.toUtf8();
|
||||||
ditVaeProcess.blockSignals(false);
|
m_ditPathBytes = ditPath.toUtf8();
|
||||||
|
m_vaePathBytes = vaePath.toUtf8();
|
||||||
progressUpdate(100);
|
|
||||||
if(busy)
|
|
||||||
generationCanceled(request.song);
|
|
||||||
|
|
||||||
busy = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AceStep::requestGeneration(SongItem song, QString requestTemplate, QString aceStepPath,
|
bool AceStepWorker::isGenerating(SongItem* song)
|
||||||
QString qwen3ModelPath, QString textEncoderModelPath, QString ditModelPath,
|
|
||||||
QString vaeModelPath)
|
|
||||||
{
|
{
|
||||||
if(busy)
|
if (!m_busy.load() && song)
|
||||||
{
|
*song = m_currentSong;
|
||||||
qWarning()<<"Dropping song:"<<song.caption;
|
return m_busy.load();
|
||||||
return false;
|
}
|
||||||
}
|
|
||||||
busy = true;
|
|
||||||
|
|
||||||
request = {song, QRandomGenerator::global()->generate(), aceStepPath, textEncoderModelPath, ditModelPath, vaeModelPath};
|
void AceStepWorker::cancelGeneration()
|
||||||
|
{
|
||||||
|
m_cancelRequested.store(true);
|
||||||
|
}
|
||||||
|
|
||||||
QString qwen3Binary = aceStepPath + "/ace-lm" + EXE_EXT;
|
bool AceStepWorker::requestGeneration(SongItem song, QString requestTemplate)
|
||||||
QFileInfo qwen3Info(qwen3Binary);
|
{
|
||||||
if (!qwen3Info.exists() || !qwen3Info.isExecutable())
|
if (m_busy.load())
|
||||||
{
|
{
|
||||||
generationError("ace-lm binary not found at: " + qwen3Binary);
|
qWarning() << "Dropping song:" << song.caption;
|
||||||
busy = false;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!QFileInfo::exists(qwen3ModelPath))
|
|
||||||
{
|
|
||||||
generationError("Qwen3 model not found: " + qwen3ModelPath);
|
|
||||||
busy = false;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!QFileInfo::exists(textEncoderModelPath))
|
|
||||||
{
|
|
||||||
generationError("Text encoder model not found: " + textEncoderModelPath);
|
|
||||||
busy = false;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!QFileInfo::exists(ditModelPath))
|
|
||||||
{
|
|
||||||
generationError("DiT model not found: " + ditModelPath);
|
|
||||||
busy = false;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!QFileInfo::exists(vaeModelPath))
|
|
||||||
{
|
|
||||||
generationError("VAE model not found: " + vaeModelPath);
|
|
||||||
busy = false;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
request.requestFilePath = tempDir + "/request_" + QString::number(request.uid) + ".json";
|
m_busy.store(true);
|
||||||
|
m_cancelRequested.store(false);
|
||||||
|
m_currentSong = song;
|
||||||
|
m_requestTemplate = requestTemplate;
|
||||||
|
m_uid = QRandomGenerator::global()->generate();
|
||||||
|
|
||||||
|
// Validate model paths
|
||||||
|
if (m_lmModelPath.isEmpty() || m_textEncoderPath.isEmpty() ||
|
||||||
|
m_ditPath.isEmpty() || m_vaePath.isEmpty())
|
||||||
|
{
|
||||||
|
emit generationError("Model paths not set. Call setModelPaths() first.");
|
||||||
|
m_busy.store(false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate model files exist
|
||||||
|
if (!QFileInfo::exists(m_lmModelPath))
|
||||||
|
{
|
||||||
|
emit generationError("LM model not found: " + m_lmModelPath);
|
||||||
|
m_busy.store(false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!QFileInfo::exists(m_textEncoderPath))
|
||||||
|
{
|
||||||
|
emit generationError("Text encoder model not found: " + m_textEncoderPath);
|
||||||
|
m_busy.store(false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!QFileInfo::exists(m_ditPath))
|
||||||
|
{
|
||||||
|
emit generationError("DiT model not found: " + m_ditPath);
|
||||||
|
m_busy.store(false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!QFileInfo::exists(m_vaePath))
|
||||||
|
{
|
||||||
|
emit generationError("VAE model not found: " + m_vaePath);
|
||||||
|
m_busy.store(false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate template
|
||||||
QJsonParseError parseError;
|
QJsonParseError parseError;
|
||||||
QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError);
|
QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError);
|
||||||
if (!templateDoc.isObject())
|
if (!templateDoc.isObject())
|
||||||
{
|
{
|
||||||
generationError("Invalid JSON template: " + QString(parseError.errorString()));
|
emit generationError("Invalid JSON template: " + QString(parseError.errorString()));
|
||||||
busy = false;
|
m_busy.store(false);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
QJsonObject requestObj = templateDoc.object();
|
// Run generation in the worker thread
|
||||||
song.store(requestObj);
|
QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection);
|
||||||
|
|
||||||
// Write the request file
|
|
||||||
QFile requestFileHandle(request.requestFilePath);
|
|
||||||
if (!requestFileHandle.open(QIODevice::WriteOnly | QIODevice::Text))
|
|
||||||
{
|
|
||||||
emit generationError("Failed to create request file: " + requestFileHandle.errorString());
|
|
||||||
busy = false;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
requestFileHandle.write(QJsonDocument(requestObj).toJson(QJsonDocument::Indented));
|
|
||||||
requestFileHandle.close();
|
|
||||||
|
|
||||||
QStringList qwen3Args;
|
|
||||||
qwen3Args << "--request" << request.requestFilePath;
|
|
||||||
qwen3Args << "--lm" << qwen3ModelPath;
|
|
||||||
|
|
||||||
progressUpdate(30);
|
|
||||||
|
|
||||||
qwenProcess.start(qwen3Binary, qwen3Args);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AceStep::qwenProcFinished(int code, QProcess::ExitStatus status)
|
bool AceStepWorker::checkCancel(void* data)
|
||||||
{
|
{
|
||||||
QFile::remove(request.requestFilePath);
|
AceStepWorker* worker = static_cast<AceStepWorker*>(data);
|
||||||
if(code != 0)
|
return worker->m_cancelRequested.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AceStepWorker::runGeneration()
|
||||||
|
{
|
||||||
|
// Load models if needed
|
||||||
|
if (!loadModels())
|
||||||
{
|
{
|
||||||
QString errorOutput = qwenProcess.readAllStandardError();
|
m_busy.store(false);
|
||||||
generationError("ace-lm exited with code " + QString::number(code) + ": " + errorOutput);
|
|
||||||
busy = false;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
QString ditVaeBinary = request.aceStepPath + "/ace-synth" + EXE_EXT;
|
// Convert SongItem to AceRequest
|
||||||
QFileInfo ditVaeInfo(ditVaeBinary);
|
AceRequest req = songToRequest(m_currentSong, m_requestTemplate);
|
||||||
if (!ditVaeInfo.exists() || !ditVaeInfo.isExecutable())
|
|
||||||
|
// Step 1: LM generates lyrics and audio codes
|
||||||
|
emit progressUpdate(30);
|
||||||
|
|
||||||
|
AceRequest lmOutput;
|
||||||
|
request_init(&lmOutput);
|
||||||
|
|
||||||
|
int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput,
|
||||||
|
nullptr, nullptr,
|
||||||
|
checkCancel, this,
|
||||||
|
LM_MODE_GENERATE);
|
||||||
|
|
||||||
|
if (m_cancelRequested.load())
|
||||||
{
|
{
|
||||||
generationError("ace-synth binary not found at: " + ditVaeBinary);
|
emit generationCanceled(m_currentSong);
|
||||||
busy = false;
|
m_busy.store(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
request.requestLlmFilePath = tempDir + "/request_" + QString::number(request.uid) + "0.json";
|
if (lmResult != 0)
|
||||||
if (!QFileInfo::exists(request.requestLlmFilePath))
|
|
||||||
{
|
{
|
||||||
generationError("ace-lm failed to create enhanced request file "+request.requestLlmFilePath);
|
emit generationError("LM generation failed or was canceled");
|
||||||
busy = false;
|
m_busy.store(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load lyrics from the enhanced request file
|
// Update song with generated lyrics
|
||||||
QFile lmOutputFile(request.requestLlmFilePath);
|
m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics);
|
||||||
if (lmOutputFile.open(QIODevice::ReadOnly | QIODevice::Text))
|
|
||||||
|
// Step 2: Synth generates audio
|
||||||
|
emit progressUpdate(60);
|
||||||
|
|
||||||
|
AceAudio* audioOut = nullptr;
|
||||||
|
AceAudio outputAudio;
|
||||||
|
outputAudio.samples = nullptr;
|
||||||
|
outputAudio.n_samples = 0;
|
||||||
|
outputAudio.sample_rate = 48000;
|
||||||
|
|
||||||
|
int synthResult = ace_synth_generate(m_synthContext, &lmOutput,
|
||||||
|
nullptr, 0, // no source audio
|
||||||
|
nullptr, 0, // no reference audio
|
||||||
|
1, &outputAudio,
|
||||||
|
checkCancel, this);
|
||||||
|
|
||||||
|
if (m_cancelRequested.load())
|
||||||
{
|
{
|
||||||
|
emit generationCanceled(m_currentSong);
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (synthResult != 0)
|
||||||
|
{
|
||||||
|
emit generationError("Synthesis failed or was canceled");
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save audio to file
|
||||||
|
QString wavFile = m_tempDir + "/request_" + QString::number(m_uid) + ".wav";
|
||||||
|
|
||||||
|
// Write WAV file
|
||||||
|
QFile outFile(wavFile);
|
||||||
|
if (!outFile.open(QIODevice::WriteOnly))
|
||||||
|
{
|
||||||
|
emit generationError("Failed to create output file: " + outFile.errorString());
|
||||||
|
ace_audio_free(&outputAudio);
|
||||||
|
m_busy.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple WAV header + stereo float data
|
||||||
|
int numChannels = 2;
|
||||||
|
int bitsPerSample = 16;
|
||||||
|
int byteRate = outputAudio.sample_rate * numChannels * (bitsPerSample / 8);
|
||||||
|
int blockAlign = numChannels * (bitsPerSample / 8);
|
||||||
|
int dataSize = outputAudio.n_samples * numChannels * (bitsPerSample / 8);
|
||||||
|
|
||||||
|
// RIFF header
|
||||||
|
outFile.write("RIFF");
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&dataSize), 4);
|
||||||
|
outFile.write("WAVE");
|
||||||
|
|
||||||
|
// fmt chunk
|
||||||
|
outFile.write("fmt ");
|
||||||
|
int fmtSize = 16;
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&fmtSize), 4);
|
||||||
|
short audioFormat = 1; // PCM
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&audioFormat), 2);
|
||||||
|
short numCh = numChannels;
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&numCh), 2);
|
||||||
|
int sampleRate = outputAudio.sample_rate;
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&sampleRate), 4);
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&byteRate), 4);
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&blockAlign), 2);
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&bitsPerSample), 2);
|
||||||
|
|
||||||
|
// data chunk
|
||||||
|
outFile.write("data");
|
||||||
|
outFile.write(reinterpret_cast<const char*>(&dataSize), 4);
|
||||||
|
|
||||||
|
// Convert float samples to 16-bit and write
|
||||||
|
QVector<short> interleaved(outputAudio.n_samples * numChannels);
|
||||||
|
for (int i = 0; i < outputAudio.n_samples; i++)
|
||||||
|
{
|
||||||
|
float left = outputAudio.samples[i];
|
||||||
|
float right = outputAudio.samples[i + outputAudio.n_samples];
|
||||||
|
// Clamp and convert to 16-bit
|
||||||
|
left = std::max(-1.0f, std::min(1.0f, left));
|
||||||
|
right = std::max(-1.0f, std::min(1.0f, right));
|
||||||
|
interleaved[i * 2] = static_cast<short>(left * 32767.0f);
|
||||||
|
interleaved[i * 2 + 1] = static_cast<short>(right * 32767.0f);
|
||||||
|
}
|
||||||
|
outFile.write(reinterpret_cast<const char*>(interleaved.data()), dataSize);
|
||||||
|
outFile.close();
|
||||||
|
|
||||||
|
// Free audio buffer
|
||||||
|
ace_audio_free(&outputAudio);
|
||||||
|
|
||||||
|
// Store the JSON with all generated fields
|
||||||
|
m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true));
|
||||||
|
m_currentSong.file = wavFile;
|
||||||
|
|
||||||
|
// Extract BPM if available
|
||||||
|
if (lmOutput.bpm > 0)
|
||||||
|
m_currentSong.bpm = lmOutput.bpm;
|
||||||
|
|
||||||
|
// Extract key if available
|
||||||
|
if (!lmOutput.keyscale.empty())
|
||||||
|
m_currentSong.key = QString::fromStdString(lmOutput.keyscale);
|
||||||
|
|
||||||
|
emit progressUpdate(100);
|
||||||
|
emit songGenerated(m_currentSong);
|
||||||
|
|
||||||
|
m_busy.store(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AceStepWorker::loadModels()
|
||||||
|
{
|
||||||
|
if (m_modelsLoaded.load())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// Load LM
|
||||||
|
AceLmParams lmParams;
|
||||||
|
ace_lm_default_params(&lmParams);
|
||||||
|
lmParams.model_path = m_lmModelPathBytes.constData();
|
||||||
|
lmParams.use_fsm = true;
|
||||||
|
lmParams.use_fa = true;
|
||||||
|
|
||||||
|
m_lmContext = ace_lm_load(&lmParams);
|
||||||
|
if (!m_lmContext)
|
||||||
|
{
|
||||||
|
emit generationError("Failed to load LM model: " + m_lmModelPath);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load Synth
|
||||||
|
AceSynthParams synthParams;
|
||||||
|
ace_synth_default_params(&synthParams);
|
||||||
|
synthParams.text_encoder_path = m_textEncoderPathBytes.constData();
|
||||||
|
synthParams.dit_path = m_ditPathBytes.constData();
|
||||||
|
synthParams.vae_path = m_vaePathBytes.constData();
|
||||||
|
synthParams.use_fa = true;
|
||||||
|
|
||||||
|
m_synthContext = ace_synth_load(&synthParams);
|
||||||
|
if (!m_synthContext)
|
||||||
|
{
|
||||||
|
emit generationError("Failed to load synthesis models");
|
||||||
|
ace_lm_free(m_lmContext);
|
||||||
|
m_lmContext = nullptr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_modelsLoaded.store(true);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AceStepWorker::unloadModels()
|
||||||
|
{
|
||||||
|
if (m_synthContext)
|
||||||
|
{
|
||||||
|
ace_synth_free(m_synthContext);
|
||||||
|
m_synthContext = nullptr;
|
||||||
|
}
|
||||||
|
if (m_lmContext)
|
||||||
|
{
|
||||||
|
ace_lm_free(m_lmContext);
|
||||||
|
m_lmContext = nullptr;
|
||||||
|
}
|
||||||
|
m_modelsLoaded.store(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& templateJson)
|
||||||
|
{
|
||||||
|
AceRequest req;
|
||||||
|
request_init(&req);
|
||||||
|
|
||||||
|
req.caption = song.caption.toStdString();
|
||||||
|
req.lyrics = song.lyrics.toStdString();
|
||||||
|
req.use_cot_caption = song.cotCaption;
|
||||||
|
|
||||||
|
// Parse template and override defaults
|
||||||
QJsonParseError parseError;
|
QJsonParseError parseError;
|
||||||
request.song.json = lmOutputFile.readAll();
|
QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError);
|
||||||
QJsonDocument doc = QJsonDocument::fromJson(request.song.json.toUtf8(), &parseError);
|
if (templateDoc.isObject())
|
||||||
lmOutputFile.close();
|
|
||||||
|
|
||||||
if (doc.isObject() && !parseError.error)
|
|
||||||
{
|
{
|
||||||
QJsonObject obj = doc.object();
|
QJsonObject obj = templateDoc.object();
|
||||||
if (obj.contains("lyrics") && obj["lyrics"].isString())
|
if (obj.contains("inference_steps"))
|
||||||
request.song.lyrics = obj["lyrics"].toString();
|
req.inference_steps = obj["inference_steps"].toInt(8);
|
||||||
}
|
if (obj.contains("shift"))
|
||||||
|
req.shift = obj["shift"].toDouble(3.0);
|
||||||
|
if (obj.contains("vocal_language"))
|
||||||
|
req.vocal_language = obj["vocal_language"].toString().toStdString();
|
||||||
|
if (obj.contains("bpm"))
|
||||||
|
req.bpm = obj["bpm"].toInt(120);
|
||||||
|
if (obj.contains("duration"))
|
||||||
|
req.duration = obj["duration"].toDouble(180.0);
|
||||||
|
if (obj.contains("keyscale"))
|
||||||
|
req.keyscale = obj["keyscale"].toString().toStdString();
|
||||||
|
if (obj.contains("lm_temperature"))
|
||||||
|
req.lm_temperature = obj["lm_temperature"].toDouble(0.85);
|
||||||
|
if (obj.contains("lm_cfg_scale"))
|
||||||
|
req.lm_cfg_scale = obj["lm_cfg_scale"].toDouble(2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 2: Run ace-synth to generate audio
|
// Generate a seed for reproducibility
|
||||||
QStringList ditVaeArgs;
|
req.seed = static_cast<int64_t>(QRandomGenerator::global()->generate());
|
||||||
ditVaeArgs << "--request"<<request.requestLlmFilePath;
|
|
||||||
ditVaeArgs << "--embedding"<<request.textEncoderModelPath;
|
|
||||||
ditVaeArgs << "--dit"<<request.ditModelPath;
|
|
||||||
ditVaeArgs << "--vae"<<request.vaeModelPath;
|
|
||||||
ditVaeArgs << "--wav";
|
|
||||||
|
|
||||||
progressUpdate(60);
|
return req;
|
||||||
|
|
||||||
ditVaeProcess.start(ditVaeBinary, ditVaeArgs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AceStep::ditProcFinished(int code, QProcess::ExitStatus status)
|
SongItem AceStepWorker::requestToSong(const AceRequest& req, const QString& json)
|
||||||
{
|
{
|
||||||
QFile::remove(request.requestLlmFilePath);
|
SongItem song;
|
||||||
if (code != 0)
|
song.caption = QString::fromStdString(req.caption);
|
||||||
{
|
song.lyrics = QString::fromStdString(req.lyrics);
|
||||||
QString errorOutput = ditVaeProcess.readAllStandardError();
|
song.cotCaption = req.use_cot_caption;
|
||||||
generationError("ace-synth exited with code " + QString::number(code) + ": " + errorOutput);
|
|
||||||
busy = false;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the generated WAV file
|
if (req.bpm > 0)
|
||||||
QString wavFile = tempDir+"/request_" + QString::number(request.uid) + "00.wav";
|
song.bpm = req.bpm;
|
||||||
if (!QFileInfo::exists(wavFile))
|
if (!req.keyscale.empty())
|
||||||
{
|
song.key = QString::fromStdString(req.keyscale);
|
||||||
generationError("No WAV file generated at "+wavFile);
|
if (!req.vocal_language.empty())
|
||||||
busy = false;
|
song.vocalLanguage = QString::fromStdString(req.vocal_language);
|
||||||
return;
|
|
||||||
}
|
|
||||||
busy = false;
|
|
||||||
|
|
||||||
progressUpdate(100);
|
song.json = json;
|
||||||
request.song.file = wavFile;
|
return song;
|
||||||
songGenerated(request.song);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,40 +8,34 @@
|
||||||
|
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
#include <QProcess>
|
#include <QThread>
|
||||||
#include <QStandardPaths>
|
#include <QStandardPaths>
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
#include "SongItem.h"
|
#include "SongItem.h"
|
||||||
|
|
||||||
#ifdef Q_OS_WIN
|
// acestep.cpp headers
|
||||||
inline const QString EXE_EXT = ".exe";
|
#include "request.h"
|
||||||
#else
|
|
||||||
inline const QString EXE_EXT = "";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
class AceStep : public QObject
|
struct AceLm;
|
||||||
|
struct AceSynth;
|
||||||
|
|
||||||
|
class AceStepWorker : public QObject
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
QProcess qwenProcess;
|
|
||||||
QProcess ditVaeProcess;
|
|
||||||
|
|
||||||
bool busy = false;
|
public:
|
||||||
|
explicit AceStepWorker(QObject* parent = nullptr);
|
||||||
|
~AceStepWorker();
|
||||||
|
|
||||||
struct Request
|
bool isGenerating(SongItem* song = nullptr);
|
||||||
{
|
void cancelGeneration();
|
||||||
SongItem song;
|
|
||||||
uint64_t uid;
|
|
||||||
QString aceStepPath;
|
|
||||||
QString textEncoderModelPath;
|
|
||||||
QString ditModelPath;
|
|
||||||
QString vaeModelPath;
|
|
||||||
QString requestFilePath;
|
|
||||||
QString requestLlmFilePath;
|
|
||||||
};
|
|
||||||
|
|
||||||
Request request;
|
// Model paths - set these before first generation
|
||||||
|
void setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath);
|
||||||
|
|
||||||
const QString tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation);
|
// Request a new song generation
|
||||||
|
bool requestGeneration(SongItem song, QString requestTemplate);
|
||||||
|
|
||||||
signals:
|
signals:
|
||||||
void songGenerated(SongItem song);
|
void songGenerated(SongItem song);
|
||||||
|
|
@ -49,19 +43,50 @@ signals:
|
||||||
void generationError(QString error);
|
void generationError(QString error);
|
||||||
void progressUpdate(int progress);
|
void progressUpdate(int progress);
|
||||||
|
|
||||||
public slots:
|
|
||||||
bool requestGeneration(SongItem song, QString requestTemplate, QString aceStepPath,
|
|
||||||
QString qwen3ModelPath, QString textEncoderModelPath, QString ditModelPath,
|
|
||||||
QString vaeModelPath);
|
|
||||||
|
|
||||||
public:
|
|
||||||
AceStep(QObject* parent = nullptr);
|
|
||||||
bool isGenerating(SongItem* song = nullptr);
|
|
||||||
void cancelGeneration();
|
|
||||||
|
|
||||||
private slots:
|
private slots:
|
||||||
void qwenProcFinished(int code, QProcess::ExitStatus status);
|
void runGeneration();
|
||||||
void ditProcFinished(int code, QProcess::ExitStatus status);
|
|
||||||
|
private:
|
||||||
|
// Check if cancellation was requested
|
||||||
|
static bool checkCancel(void* data);
|
||||||
|
|
||||||
|
// Load models if not already loaded
|
||||||
|
bool loadModels();
|
||||||
|
void unloadModels();
|
||||||
|
|
||||||
|
// Convert SongItem to AceRequest
|
||||||
|
AceRequest songToRequest(const SongItem& song, const QString& templateJson);
|
||||||
|
|
||||||
|
// Convert AceRequest back to SongItem
|
||||||
|
SongItem requestToSong(const AceRequest& req, const QString& json);
|
||||||
|
|
||||||
|
// Generation state
|
||||||
|
std::atomic<bool> m_busy{false};
|
||||||
|
std::atomic<bool> m_cancelRequested{false};
|
||||||
|
std::atomic<bool> m_modelsLoaded{false};
|
||||||
|
|
||||||
|
// Current request data
|
||||||
|
SongItem m_currentSong;
|
||||||
|
QString m_requestTemplate;
|
||||||
|
uint64_t m_uid;
|
||||||
|
|
||||||
|
// Model paths
|
||||||
|
QString m_lmModelPath;
|
||||||
|
QString m_textEncoderPath;
|
||||||
|
QString m_ditPath;
|
||||||
|
QString m_vaePath;
|
||||||
|
|
||||||
|
// Loaded models (accessed from worker thread only)
|
||||||
|
AceLm* m_lmContext = nullptr;
|
||||||
|
AceSynth* m_synthContext = nullptr;
|
||||||
|
|
||||||
|
// Cached model paths as byte arrays (to avoid dangling pointers)
|
||||||
|
QByteArray m_lmModelPathBytes;
|
||||||
|
QByteArray m_textEncoderPathBytes;
|
||||||
|
QByteArray m_ditPathBytes;
|
||||||
|
QByteArray m_vaePathBytes;
|
||||||
|
|
||||||
|
const QString m_tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // ACESTEPWORKER_H
|
#endif // ACESTEPWORKER_H
|
||||||
|
|
@ -21,7 +21,7 @@ MainWindow::MainWindow(QWidget *parent)
|
||||||
ui(new Ui::MainWindow),
|
ui(new Ui::MainWindow),
|
||||||
songModel(new SongListModel(this)),
|
songModel(new SongListModel(this)),
|
||||||
audioPlayer(new AudioPlayer(this)),
|
audioPlayer(new AudioPlayer(this)),
|
||||||
aceStep(new AceStep(this)),
|
aceStep(new AceStepWorker(this)),
|
||||||
playbackTimer(new QTimer(this)),
|
playbackTimer(new QTimer(this)),
|
||||||
isPlaying(false),
|
isPlaying(false),
|
||||||
isPaused(false),
|
isPaused(false),
|
||||||
|
|
@ -41,6 +41,9 @@ MainWindow::MainWindow(QWidget *parent)
|
||||||
// Load settings
|
// Load settings
|
||||||
loadSettings();
|
loadSettings();
|
||||||
|
|
||||||
|
// Set model paths for acestep.cpp
|
||||||
|
aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath);
|
||||||
|
|
||||||
// Auto-load playlist from config location on startup
|
// Auto-load playlist from config location on startup
|
||||||
autoLoadPlaylist();
|
autoLoadPlaylist();
|
||||||
|
|
||||||
|
|
@ -62,10 +65,10 @@ MainWindow::MainWindow(QWidget *parent)
|
||||||
connect(audioPlayer, &AudioPlayer::playbackStarted, this, &MainWindow::playbackStarted);
|
connect(audioPlayer, &AudioPlayer::playbackStarted, this, &MainWindow::playbackStarted);
|
||||||
connect(audioPlayer, &AudioPlayer::positionChanged, this, &MainWindow::updatePosition);
|
connect(audioPlayer, &AudioPlayer::positionChanged, this, &MainWindow::updatePosition);
|
||||||
connect(audioPlayer, &AudioPlayer::durationChanged, this, &MainWindow::updateDuration);
|
connect(audioPlayer, &AudioPlayer::durationChanged, this, &MainWindow::updateDuration);
|
||||||
connect(aceStep, &AceStep::songGenerated, this, &MainWindow::songGenerated);
|
connect(aceStep, &AceStepWorker::songGenerated, this, &MainWindow::songGenerated);
|
||||||
connect(aceStep, &AceStep::generationCanceled, this, &MainWindow::generationCanceld);
|
connect(aceStep, &AceStepWorker::generationCanceled, this, &MainWindow::generationCanceld);
|
||||||
connect(aceStep, &AceStep::generationError, this, &MainWindow::generationError);
|
connect(aceStep, &AceStepWorker::generationError, this, &MainWindow::generationError);
|
||||||
connect(aceStep, &AceStep::progressUpdate, ui->progressBar, &QProgressBar::setValue);
|
connect(aceStep, &AceStepWorker::progressUpdate, ui->progressBar, &QProgressBar::setValue);
|
||||||
|
|
||||||
// Connect double-click on song list for editing (works with QTableView too)
|
// Connect double-click on song list for editing (works with QTableView too)
|
||||||
connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::on_songListView_doubleClicked);
|
connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::on_songListView_doubleClicked);
|
||||||
|
|
@ -391,6 +394,9 @@ void MainWindow::on_advancedSettingsButton_clicked()
|
||||||
ditModelPath = dialog.getDiTModelPath();
|
ditModelPath = dialog.getDiTModelPath();
|
||||||
vaeModelPath = dialog.getVAEModelPath();
|
vaeModelPath = dialog.getVAEModelPath();
|
||||||
|
|
||||||
|
// Update model paths for acestep.cpp
|
||||||
|
aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath);
|
||||||
|
|
||||||
saveSettings();
|
saveSettings();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -533,10 +539,7 @@ void MainWindow::ensureSongsInQueue(bool enqeueCurrent)
|
||||||
isGeneratingNext = true;
|
isGeneratingNext = true;
|
||||||
|
|
||||||
ui->statusbar->showMessage("Generateing: "+nextSong.caption);
|
ui->statusbar->showMessage("Generateing: "+nextSong.caption);
|
||||||
aceStep->requestGeneration(nextSong, jsonTemplate,
|
aceStep->requestGeneration(nextSong, jsonTemplate);
|
||||||
aceStepPath, qwen3ModelPath,
|
|
||||||
textEncoderModelPath, ditModelPath,
|
|
||||||
vaeModelPath);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MainWindow::flushGenerationQueue()
|
void MainWindow::flushGenerationQueue()
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ class MainWindow : public QMainWindow
|
||||||
SongListModel *songModel;
|
SongListModel *songModel;
|
||||||
AudioPlayer *audioPlayer;
|
AudioPlayer *audioPlayer;
|
||||||
QThread aceThread;
|
QThread aceThread;
|
||||||
AceStep *aceStep;
|
AceStepWorker *aceStep;
|
||||||
QTimer *playbackTimer;
|
QTimer *playbackTimer;
|
||||||
|
|
||||||
QString formatTime(int milliseconds);
|
QString formatTime(int milliseconds);
|
||||||
|
|
|
||||||
352
tests/test_acestep_worker.cpp
Normal file
352
tests/test_acestep_worker.cpp
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
// Test for AceStepWorker
|
||||||
|
// Compile with: cmake .. && make test_acestep_worker && ./test_acestep_worker
|
||||||
|
|
||||||
|
#include <QCoreApplication>
|
||||||
|
#include <QTimer>
|
||||||
|
#include <QEventLoop>
|
||||||
|
#include <QDebug>
|
||||||
|
#include <QThread>
|
||||||
|
#include <QSettings>
|
||||||
|
#include <QFile>
|
||||||
|
#include <QFileInfo>
|
||||||
|
#include <iostream>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "../src/AceStepWorker.h"
|
||||||
|
|
||||||
|
// Test result tracking
|
||||||
|
static int testsPassed = 0;
|
||||||
|
static int testsFailed = 0;
|
||||||
|
static int testsSkipped = 0;
|
||||||
|
|
||||||
|
#define TEST(name) void test_##name()
|
||||||
|
#define RUN_TEST(name) do { \
|
||||||
|
std::cout << "Running " << #name << "... "; \
|
||||||
|
test_##name(); \
|
||||||
|
if (test_skipped) { \
|
||||||
|
std::cout << "SKIPPED" << std::endl; \
|
||||||
|
testsSkipped++; \
|
||||||
|
test_skipped = false; \
|
||||||
|
} else if (test_failed) { \
|
||||||
|
std::cout << "FAILED" << std::endl; \
|
||||||
|
testsFailed++; \
|
||||||
|
test_failed = false; \
|
||||||
|
} else { \
|
||||||
|
std::cout << "PASSED" << std::endl; \
|
||||||
|
testsPassed++; \
|
||||||
|
} \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
static bool test_failed = false;
|
||||||
|
static bool test_skipped = false;
|
||||||
|
|
||||||
|
#define ASSERT_TRUE(cond) do { \
|
||||||
|
if (!(cond)) { \
|
||||||
|
std::cout << "FAILED: " << #cond << " at line " << __LINE__ << std::endl; \
|
||||||
|
test_failed = true; \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
#define ASSERT_FALSE(cond) ASSERT_TRUE(!(cond))
|
||||||
|
|
||||||
|
#define SKIP_IF(cond) do { \
|
||||||
|
if (cond) { \
|
||||||
|
std::cout << "(skipping: " << #cond << ") "; \
|
||||||
|
test_skipped = true; \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
// Helper to get model paths from settings like main app
|
||||||
|
struct ModelPaths {
|
||||||
|
QString lmPath;
|
||||||
|
QString textEncoderPath;
|
||||||
|
QString ditPath;
|
||||||
|
QString vaePath;
|
||||||
|
};
|
||||||
|
|
||||||
|
static ModelPaths getModelPathsFromSettings()
|
||||||
|
{
|
||||||
|
ModelPaths paths;
|
||||||
|
QSettings settings("MusicGenerator", "AceStepGUI");
|
||||||
|
|
||||||
|
QString appDir = QCoreApplication::applicationDirPath();
|
||||||
|
paths.lmPath = settings.value("qwen3ModelPath",
|
||||||
|
appDir + "/acestep.cpp/models/acestep-5Hz-lm-4B-Q8_0.gguf").toString();
|
||||||
|
paths.textEncoderPath = settings.value("textEncoderModelPath",
|
||||||
|
appDir + "/acestep.cpp/models/Qwen3-Embedding-0.6B-Q8_0.gguf").toString();
|
||||||
|
paths.ditPath = settings.value("ditModelPath",
|
||||||
|
appDir + "/acestep.cpp/models/acestep-v15-turbo-Q8_0.gguf").toString();
|
||||||
|
paths.vaePath = settings.value("vaeModelPath",
|
||||||
|
appDir + "/acestep.cpp/models/vae-BF16.gguf").toString();
|
||||||
|
|
||||||
|
return paths;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool modelsExist(const ModelPaths& paths)
|
||||||
|
{
|
||||||
|
return QFileInfo::exists(paths.lmPath) &&
|
||||||
|
QFileInfo::exists(paths.textEncoderPath) &&
|
||||||
|
QFileInfo::exists(paths.ditPath) &&
|
||||||
|
QFileInfo::exists(paths.vaePath);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 1: Check that isGenerating returns false initially
|
||||||
|
TEST(initialState)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
ASSERT_TRUE(!worker.isGenerating());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Check that requestGeneration returns false when no model paths set
|
||||||
|
TEST(noModelPaths)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
SongItem song("test caption", "");
|
||||||
|
|
||||||
|
bool result = worker.requestGeneration(song, "{}");
|
||||||
|
ASSERT_FALSE(result);
|
||||||
|
ASSERT_TRUE(!worker.isGenerating());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Check that setModelPaths stores paths correctly
|
||||||
|
TEST(setModelPaths)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf");
|
||||||
|
ASSERT_TRUE(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Check async behavior - requestGeneration returns immediately
|
||||||
|
TEST(asyncReturnsImmediately)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf");
|
||||||
|
|
||||||
|
SongItem song("test caption", "");
|
||||||
|
|
||||||
|
// If this blocks, the test will hang
|
||||||
|
bool result = worker.requestGeneration(song, "{}");
|
||||||
|
|
||||||
|
// Should return false due to invalid paths, but immediately
|
||||||
|
ASSERT_FALSE(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 5: Check that cancelGeneration sets the cancel flag
|
||||||
|
TEST(cancellationFlag)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf");
|
||||||
|
worker.cancelGeneration();
|
||||||
|
ASSERT_TRUE(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 6: Check that signals are defined correctly
|
||||||
|
TEST(signalsExist)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
|
||||||
|
// Verify signals exist by connecting to them (compile-time check)
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated, [](const SongItem&) {});
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationCanceled, [](const SongItem&) {});
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError, [](const QString&) {});
|
||||||
|
QObject::connect(&worker, &AceStepWorker::progressUpdate, [](int) {});
|
||||||
|
|
||||||
|
ASSERT_TRUE(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 7: Check SongItem to AceRequest conversion (internal)
|
||||||
|
TEST(requestConversion)
|
||||||
|
{
|
||||||
|
AceStepWorker worker;
|
||||||
|
|
||||||
|
SongItem song("Upbeat pop rock", "[Verse 1]");
|
||||||
|
song.cotCaption = true;
|
||||||
|
|
||||||
|
QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})";
|
||||||
|
|
||||||
|
worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf");
|
||||||
|
bool result = worker.requestGeneration(song, templateJson);
|
||||||
|
|
||||||
|
// Should fail due to invalid paths, but shouldn't crash
|
||||||
|
ASSERT_FALSE(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 8: Read model paths from settings
|
||||||
|
TEST(readSettings)
|
||||||
|
{
|
||||||
|
ModelPaths paths = getModelPathsFromSettings();
|
||||||
|
|
||||||
|
std::cout << "\n Model paths from settings:" << std::endl;
|
||||||
|
std::cout << " LM: " << paths.lmPath.toStdString() << std::endl;
|
||||||
|
std::cout << " Text Encoder: " << paths.textEncoderPath.toStdString() << std::endl;
|
||||||
|
std::cout << " DiT: " << paths.ditPath.toStdString() << std::endl;
|
||||||
|
std::cout << " VAE: " << paths.vaePath.toStdString() << std::endl;
|
||||||
|
|
||||||
|
ASSERT_TRUE(!paths.lmPath.isEmpty());
|
||||||
|
ASSERT_TRUE(!paths.textEncoderPath.isEmpty());
|
||||||
|
ASSERT_TRUE(!paths.ditPath.isEmpty());
|
||||||
|
ASSERT_TRUE(!paths.vaePath.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 9: Check if model files exist
|
||||||
|
TEST(checkModelFiles)
|
||||||
|
{
|
||||||
|
ModelPaths paths = getModelPathsFromSettings();
|
||||||
|
|
||||||
|
bool lmExists = QFileInfo::exists(paths.lmPath);
|
||||||
|
bool encoderExists = QFileInfo::exists(paths.textEncoderPath);
|
||||||
|
bool ditExists = QFileInfo::exists(paths.ditPath);
|
||||||
|
bool vaeExists = QFileInfo::exists(paths.vaePath);
|
||||||
|
|
||||||
|
std::cout << "\n Model file status:" << std::endl;
|
||||||
|
std::cout << " LM: " << (lmExists ? "EXISTS" : "MISSING") << std::endl;
|
||||||
|
std::cout << " Text Encoder: " << (encoderExists ? "EXISTS" : "MISSING") << std::endl;
|
||||||
|
std::cout << " DiT: " << (ditExists ? "EXISTS" : "MISSING") << std::endl;
|
||||||
|
std::cout << " VAE: " << (vaeExists ? "EXISTS" : "MISSING") << std::endl;
|
||||||
|
|
||||||
|
ASSERT_TRUE(lmExists);
|
||||||
|
ASSERT_TRUE(encoderExists);
|
||||||
|
ASSERT_TRUE(ditExists);
|
||||||
|
ASSERT_TRUE(vaeExists);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 10: Actually generate a song (requires valid model paths)
|
||||||
|
TEST(generateSong)
|
||||||
|
{
|
||||||
|
ModelPaths paths = getModelPathsFromSettings();
|
||||||
|
|
||||||
|
// Skip if models don't exist
|
||||||
|
SKIP_IF(!modelsExist(paths));
|
||||||
|
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath);
|
||||||
|
|
||||||
|
SongItem song("Upbeat pop rock with driving guitars", "");
|
||||||
|
|
||||||
|
QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})";
|
||||||
|
|
||||||
|
// Track if we get progress updates
|
||||||
|
bool gotProgress = false;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::progressUpdate, [&gotProgress](int p) {
|
||||||
|
std::cout << "\n Progress: " << p << "%" << std::endl;
|
||||||
|
gotProgress = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Track generation result
|
||||||
|
bool generationCompleted = false;
|
||||||
|
SongItem resultSong;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||||
|
[&generationCompleted, &resultSong](const SongItem& song) {
|
||||||
|
std::cout << "\n Song generated successfully!" << std::endl;
|
||||||
|
std::cout << " Caption: " << song.caption.toStdString() << std::endl;
|
||||||
|
std::cout << " Lyrics: " << song.lyrics.left(100).toStdString() << "..." << std::endl;
|
||||||
|
std::cout << " File: " << song.file.toStdString() << std::endl;
|
||||||
|
resultSong = song;
|
||||||
|
generationCompleted = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
QString errorMsg;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError,
|
||||||
|
[&errorMsg](const QString& err) {
|
||||||
|
std::cout << "\n Error: " << err.toStdString() << std::endl;
|
||||||
|
errorMsg = err;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::cout << "\n Starting generation..." << std::endl;
|
||||||
|
|
||||||
|
// Request generation
|
||||||
|
bool result = worker.requestGeneration(song, templateJson);
|
||||||
|
ASSERT_TRUE(result);
|
||||||
|
|
||||||
|
// Use QEventLoop with timer for proper event processing
|
||||||
|
QEventLoop loop;
|
||||||
|
QTimer timeoutTimer;
|
||||||
|
|
||||||
|
timeoutTimer.setSingleShot(true);
|
||||||
|
timeoutTimer.start(300000); // 5 minute timeout
|
||||||
|
|
||||||
|
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||||||
|
|
||||||
|
loop.exec();
|
||||||
|
|
||||||
|
ASSERT_TRUE(generationCompleted);
|
||||||
|
ASSERT_TRUE(!resultSong.file.isEmpty());
|
||||||
|
ASSERT_TRUE(QFileInfo::exists(resultSong.file));
|
||||||
|
|
||||||
|
// Check file is not empty
|
||||||
|
QFileInfo fileInfo(resultSong.file);
|
||||||
|
std::cout << " File size: " << fileInfo.size() << " bytes" << std::endl;
|
||||||
|
ASSERT_TRUE(fileInfo.size() > 1000); // Should be at least 1KB for valid audio
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 11: Test cancellation
|
||||||
|
TEST(cancellation)
|
||||||
|
{
|
||||||
|
ModelPaths paths = getModelPathsFromSettings();
|
||||||
|
|
||||||
|
// Skip if models don't exist
|
||||||
|
SKIP_IF(!modelsExist(paths));
|
||||||
|
|
||||||
|
AceStepWorker worker;
|
||||||
|
worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath);
|
||||||
|
|
||||||
|
SongItem song("A very long ambient piece", "");
|
||||||
|
|
||||||
|
QString templateJson = R"({"inference_steps": 50, "shift": 3.0, "vocal_language": "en"})";
|
||||||
|
|
||||||
|
bool cancelReceived = false;
|
||||||
|
QObject::connect(&worker, &AceStepWorker::generationCanceled,
|
||||||
|
[&cancelReceived](const SongItem&) {
|
||||||
|
std::cout << "\n Generation was canceled!" << std::endl;
|
||||||
|
cancelReceived = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::cout << "\n Starting generation and will cancel after 2 seconds..." << std::endl;
|
||||||
|
|
||||||
|
// Start generation
|
||||||
|
bool result = worker.requestGeneration(song, templateJson);
|
||||||
|
ASSERT_TRUE(result);
|
||||||
|
|
||||||
|
// Wait 2 seconds then cancel
|
||||||
|
QThread::sleep(2);
|
||||||
|
worker.cancelGeneration();
|
||||||
|
|
||||||
|
// Wait a bit for cancellation to be processed
|
||||||
|
QThread::sleep(1);
|
||||||
|
QCoreApplication::processEvents();
|
||||||
|
|
||||||
|
// Note: cancellation may or may not complete depending on where in the process
|
||||||
|
// the cancel was requested. The important thing is it doesn't crash.
|
||||||
|
std::cout << " Cancel requested, no crash detected" << std::endl;
|
||||||
|
ASSERT_TRUE(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char *argv[])
|
||||||
|
{
|
||||||
|
QCoreApplication app(argc, argv);
|
||||||
|
|
||||||
|
std::cout << "=== AceStepWorker Tests ===" << std::endl;
|
||||||
|
|
||||||
|
RUN_TEST(initialState);
|
||||||
|
RUN_TEST(noModelPaths);
|
||||||
|
RUN_TEST(setModelPaths);
|
||||||
|
RUN_TEST(asyncReturnsImmediately);
|
||||||
|
RUN_TEST(cancellationFlag);
|
||||||
|
RUN_TEST(signalsExist);
|
||||||
|
RUN_TEST(requestConversion);
|
||||||
|
RUN_TEST(readSettings);
|
||||||
|
RUN_TEST(checkModelFiles);
|
||||||
|
RUN_TEST(generateSong);
|
||||||
|
RUN_TEST(cancellation);
|
||||||
|
|
||||||
|
std::cout << "\n=== Results ===" << std::endl;
|
||||||
|
std::cout << "Passed: " << testsPassed << std::endl;
|
||||||
|
std::cout << "Skipped: " << testsSkipped << std::endl;
|
||||||
|
std::cout << "Failed: " << testsFailed << std::endl;
|
||||||
|
|
||||||
|
return testsFailed > 0 ? 1 : 0;
|
||||||
|
}
|
||||||
1
third_party/acestep.cpp
vendored
Submodule
1
third_party/acestep.cpp
vendored
Submodule
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit d28398db0ffdb77e8ae071ff31bde8c559e7085a
|
||||||
Loading…
Add table
Add a link
Reference in a new issue