diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..f58ab56 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/acestep.cpp"] + path = third_party/acestep.cpp + url = https://github.com/ServeurpersoCom/acestep.cpp.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 191c922..b1f0110 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # Find Qt packages find_package(Qt6 COMPONENTS Core Gui Widgets Multimedia REQUIRED) -# Note: acestep.cpp binaries and models should be provided at runtime +# Add acestep.cpp subdirectory +add_subdirectory(third_party/acestep.cpp) set(CMAKE_AUTOMOC ON) set(CMAKE_AUTOUIC ON) @@ -47,21 +48,22 @@ add_executable(${PROJECT_NAME} # UI file target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -# Link libraries (only Qt libraries - acestep.cpp is external) +# Link libraries (Qt + acestep.cpp) target_link_libraries(${PROJECT_NAME} PRIVATE Qt6::Core Qt6::Gui Qt6::Widgets Qt6::Multimedia + acestep-core + ggml ) -# Include directories (only our source directory - acestep.cpp is external) +# Include directories target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/acestep.cpp/src ) -# Note: acestep.cpp binaries (ace-qwen3, dit-vae) and models should be provided at runtime - # Install targets install(TARGETS ${PROJECT_NAME} DESTINATION bin) @@ -71,3 +73,24 @@ install(FILES res/xyz.uvos.aceradio.desktop DESTINATION share/applications) # Install icon files install(FILES res/xyz.uvos.aceradio.png DESTINATION share/icons/hicolor/256x256/apps RENAME xyz.uvos.aceradio.png) install(FILES res/xyz.uvos.aceradio.svg DESTINATION share/icons/hicolor/scalable/apps RENAME xyz.uvos.aceradio.svg) + +# Test executable +add_executable(test_acestep_worker + tests/test_acestep_worker.cpp + src/AceStepWorker.cpp + src/AceStepWorker.h + src/SongItem.cpp + src/SongItem.h +) + +target_link_libraries(test_acestep_worker PRIVATE + Qt6::Core + Qt6::Widgets + acestep-core + ggml +) + +target_include_directories(test_acestep_worker PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/acestep.cpp/src +) diff --git a/src/AceStepWorker.cpp b/src/AceStepWorker.cpp index 50d940c..8804613 100644 --- a/src/AceStepWorker.cpp +++ b/src/AceStepWorker.cpp @@ -5,207 +5,393 @@ #include #include #include -#include #include -#include #include -#include #include -AceStep::AceStep(QObject* parent): QObject(parent) +// acestep.cpp headers +#include "pipeline-lm.h" +#include "request.h" + +AceStepWorker::AceStepWorker(QObject* parent) + : QObject(parent) { - connect(&qwenProcess, &QProcess::finished, this, &AceStep::qwenProcFinished); - connect(&ditVaeProcess, &QProcess::finished, this, &AceStep::ditProcFinished); } -bool AceStep::isGenerating(SongItem* song) +AceStepWorker::~AceStepWorker() { - if(!busy && song) - *song = this->request.song; - return busy; + cancelGeneration(); + unloadModels(); } -void AceStep::cancelGeneration() +void AceStepWorker::setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath) { - qwenProcess.blockSignals(true); - qwenProcess.terminate(); - qwenProcess.waitForFinished(); - qwenProcess.blockSignals(false); - - ditVaeProcess.blockSignals(true); - ditVaeProcess.terminate(); - ditVaeProcess.waitForFinished(); - ditVaeProcess.blockSignals(false); - - progressUpdate(100); - if(busy) - generationCanceled(request.song); - - busy = false; + // Check if paths actually changed + bool pathsChanged = (m_lmModelPath != lmPath || m_textEncoderPath != textEncoderPath || + m_ditPath != ditPath || m_vaePath != vaePath); + + m_lmModelPath = lmPath; + m_textEncoderPath = textEncoderPath; + m_ditPath = ditPath; + m_vaePath = vaePath; + + // Cache as byte arrays to avoid dangling pointers + m_lmModelPathBytes = lmPath.toUtf8(); + m_textEncoderPathBytes = textEncoderPath.toUtf8(); + m_ditPathBytes = ditPath.toUtf8(); + m_vaePathBytes = vaePath.toUtf8(); + + // If paths changed and models are loaded, unload them so they'll be reloaded with new paths + if (pathsChanged && m_modelsLoaded.load()) + { + unloadModels(); + } } -bool AceStep::requestGeneration(SongItem song, QString requestTemplate, QString aceStepPath, - QString qwen3ModelPath, QString textEncoderModelPath, QString ditModelPath, - QString vaeModelPath) +void AceStepWorker::setLowVramMode(bool enabled) { - if(busy) - { - qWarning()<<"Dropping song:"<generate(), aceStepPath, textEncoderModelPath, ditModelPath, vaeModelPath}; +void AceStepWorker::setFlashAttention(bool enabled) +{ + m_flashAttention = enabled; +} - QString qwen3Binary = aceStepPath + "/ace-lm" + EXE_EXT; - QFileInfo qwen3Info(qwen3Binary); - if (!qwen3Info.exists() || !qwen3Info.isExecutable()) +bool AceStepWorker::isGenerating(SongItem* song) +{ + if (!m_busy.load() && song) + *song = m_currentSong; + return m_busy.load(); +} + +void AceStepWorker::cancelGeneration() +{ + m_cancelRequested.store(true); +} + +bool AceStepWorker::requestGeneration(SongItem song, QString requestTemplate) +{ + if (m_busy.load()) { - generationError("ace-lm binary not found at: " + qwen3Binary); - busy = false; - return false; - } - if (!QFileInfo::exists(qwen3ModelPath)) - { - generationError("Qwen3 model not found: " + qwen3ModelPath); - busy = false; - return false; - } - if (!QFileInfo::exists(textEncoderModelPath)) - { - generationError("Text encoder model not found: " + textEncoderModelPath); - busy = false; - return false; - } - if (!QFileInfo::exists(ditModelPath)) - { - generationError("DiT model not found: " + ditModelPath); - busy = false; - return false; - } - if (!QFileInfo::exists(vaeModelPath)) - { - generationError("VAE model not found: " + vaeModelPath); - busy = false; + qWarning() << "Dropping song:" << song.caption; return false; } - request.requestFilePath = tempDir + "/request_" + QString::number(request.uid) + ".json"; + m_busy.store(true); + m_cancelRequested.store(false); + m_currentSong = song; + m_requestTemplate = requestTemplate; + m_uid = QRandomGenerator::global()->generate(); - QJsonParseError parseError; - QJsonDocument templateDoc = QJsonDocument::fromJson(requestTemplate.toUtf8(), &parseError); - if (!templateDoc.isObject()) + // Validate model paths + if (m_lmModelPath.isEmpty() || m_textEncoderPath.isEmpty() || + m_ditPath.isEmpty() || m_vaePath.isEmpty()) { - generationError("Invalid JSON template: " + QString(parseError.errorString())); - busy = false; + emit generationError("Model paths not set. Call setModelPaths() first."); + m_busy.store(false); return false; } - QJsonObject requestObj = templateDoc.object(); - song.store(requestObj); - - // Write the request file - QFile requestFileHandle(request.requestFilePath); - if (!requestFileHandle.open(QIODevice::WriteOnly | QIODevice::Text)) + // Validate model files exist + if (!QFileInfo::exists(m_lmModelPath)) { - emit generationError("Failed to create request file: " + requestFileHandle.errorString()); - busy = false; + emit generationError("LM model not found: " + m_lmModelPath); + m_busy.store(false); + return false; + } + if (!QFileInfo::exists(m_textEncoderPath)) + { + emit generationError("Text encoder model not found: " + m_textEncoderPath); + m_busy.store(false); + return false; + } + if (!QFileInfo::exists(m_ditPath)) + { + emit generationError("DiT model not found: " + m_ditPath); + m_busy.store(false); + return false; + } + if (!QFileInfo::exists(m_vaePath)) + { + emit generationError("VAE model not found: " + m_vaePath); + m_busy.store(false); return false; } - requestFileHandle.write(QJsonDocument(requestObj).toJson(QJsonDocument::Indented)); - requestFileHandle.close(); - QStringList qwen3Args; - qwen3Args << "--request" << request.requestFilePath; - qwen3Args << "--lm" << qwen3ModelPath; - - progressUpdate(30); - - qwenProcess.start(qwen3Binary, qwen3Args); + // Run generation in the worker thread + QMetaObject::invokeMethod(this, &AceStepWorker::runGeneration, Qt::QueuedConnection); return true; } -void AceStep::qwenProcFinished(int code, QProcess::ExitStatus status) +bool AceStepWorker::checkCancel(void* data) { - QFile::remove(request.requestFilePath); - if(code != 0) + AceStepWorker* worker = static_cast(data); + return worker->m_cancelRequested.load(); +} + +std::shared_ptr AceStepWorker::convertToWav(const AceAudio& audio) +{ + auto audioData = std::make_shared(); + + // Simple WAV header + stereo float data + int numChannels = 2; + int bitsPerSample = 16; + int byteRate = audio.sample_rate * numChannels * (bitsPerSample / 8); + int blockAlign = numChannels * (bitsPerSample / 8); + int dataSize = audio.n_samples * numChannels * (bitsPerSample / 8); + + // RIFF header + audioData->append("RIFF"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + audioData->append("WAVE"); + + // fmt chunk + audioData->append("fmt "); + int fmtSize = 16; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&fmtSize), 4)); + short audioFormat = 1; // PCM + audioData->append(QByteArray::fromRawData(reinterpret_cast(&audioFormat), 2)); + short numCh = numChannels; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&numCh), 2)); + int sampleRate = audio.sample_rate; + audioData->append(QByteArray::fromRawData(reinterpret_cast(&sampleRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&byteRate), 4)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&blockAlign), 2)); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&bitsPerSample), 2)); + + // data chunk + audioData->append("data"); + audioData->append(QByteArray::fromRawData(reinterpret_cast(&dataSize), 4)); + + // Convert float samples to 16-bit and write + QVector interleaved(audio.n_samples * numChannels); + for (int i = 0; i < audio.n_samples; i++) { - QString errorOutput = qwenProcess.readAllStandardError(); - generationError("ace-lm exited with code " + QString::number(code) + ": " + errorOutput); - busy = false; + float left = audio.samples[i]; + float right = audio.samples[i + audio.n_samples]; + // Clamp and convert to 16-bit + left = std::max(-1.0f, std::min(1.0f, left)); + right = std::max(-1.0f, std::min(1.0f, right)); + interleaved[i * 2] = static_cast(left * 32767.0f); + interleaved[i * 2 + 1] = static_cast(right * 32767.0f); + } + audioData->append(QByteArray::fromRawData(reinterpret_cast(interleaved.data()), dataSize)); + return audioData; +} + +void AceStepWorker::runGeneration() +{ + // Convert SongItem to AceRequest + AceRequest req = songToRequest(m_currentSong, m_requestTemplate); + AceRequest lmOutput; + request_init(&lmOutput); + + emit progressUpdate(10); + + if (!loadLm()) + { + m_busy.store(false); return; } - QString ditVaeBinary = request.aceStepPath + "/ace-synth" + EXE_EXT; - QFileInfo ditVaeInfo(ditVaeBinary); - if (!ditVaeInfo.exists() || !ditVaeInfo.isExecutable()) + emit progressUpdate(30); + + int lmResult = ace_lm_generate(m_lmContext, &req, 1, &lmOutput, + nullptr, nullptr, + checkCancel, this, + LM_MODE_GENERATE); + + if (m_cancelRequested.load()) { - generationError("ace-synth binary not found at: " + ditVaeBinary); - busy = false; + if(m_lowVramMode) + unloadModels(); + emit generationCanceled(m_currentSong); + m_busy.store(false); return; } - request.requestLlmFilePath = tempDir + "/request_" + QString::number(request.uid) + "0.json"; - if (!QFileInfo::exists(request.requestLlmFilePath)) + if (lmResult != 0) { - generationError("ace-lm failed to create enhanced request file "+request.requestLlmFilePath); - busy = false; + if(m_lowVramMode) + unloadModels(); + emit generationError("LM generation failed or was canceled"); + m_busy.store(false); return; } - // Load lyrics from the enhanced request file - QFile lmOutputFile(request.requestLlmFilePath); - if (lmOutputFile.open(QIODevice::ReadOnly | QIODevice::Text)) + m_currentSong.lyrics = QString::fromStdString(lmOutput.lyrics); + + if(m_lowVramMode) + unloadLm(); + + emit progressUpdate(50); + + if (!loadSynth()) + { + m_busy.store(false); + return; + } + + emit progressUpdate(60); + + AceAudio outputAudio; + outputAudio.samples = nullptr; + outputAudio.n_samples = 0; + outputAudio.sample_rate = 48000; + + int synthResult = ace_synth_generate(m_synthContext, &lmOutput, + nullptr, 0, // no source audio + nullptr, 0, // no reference audio + 1, &outputAudio, + checkCancel, this); + + if(m_lowVramMode) + unloadSynth(); + + if (m_cancelRequested.load()) + { + emit generationCanceled(m_currentSong); + m_busy.store(false); + return; + } + + if (synthResult != 0) + { + emit generationError("Synthesis failed or was canceled"); + m_busy.store(false); + return; + } + + std::shared_ptr audioData = convertToWav(outputAudio); + ace_audio_free(&outputAudio); + + m_currentSong.json = QString::fromStdString(request_to_json(&lmOutput, true)); + m_currentSong.audioData = audioData; + + if (lmOutput.bpm > 0) + m_currentSong.bpm = lmOutput.bpm; + + if (!lmOutput.keyscale.empty()) + m_currentSong.key = QString::fromStdString(lmOutput.keyscale); + + emit progressUpdate(100); + emit songGenerated(m_currentSong); + + m_busy.store(false); +} + +bool AceStepWorker::loadModels() +{ + bool ret = loadSynth(); + if(!ret) + return false; + + ret = loadLm(); + if(!ret) + return false; + return true; +} + +void AceStepWorker::unloadModels() +{ + if (m_synthContext) + { + ace_synth_free(m_synthContext); + m_synthContext = nullptr; + } + if (m_lmContext) + { + ace_lm_free(m_lmContext); + m_lmContext = nullptr; + } + m_modelsLoaded.store(false); +} + +bool AceStepWorker::loadLm() +{ + if (m_lmContext) + return true; + + AceLmParams lmParams; + ace_lm_default_params(&lmParams); + lmParams.model_path = m_lmModelPathBytes.constData(); + lmParams.use_fsm = true; + lmParams.use_fa = m_flashAttention; + + m_lmContext = ace_lm_load(&lmParams); + if (!m_lmContext) + { + emit generationError("Failed to load LM model: " + m_lmModelPath); + return false; + } + return true; +} + +void AceStepWorker::unloadLm() +{ + if (m_lmContext) + { + ace_lm_free(m_lmContext); + m_lmContext = nullptr; + } +} + +bool AceStepWorker::loadSynth() +{ + if (m_synthContext) + return true; + + AceSynthParams synthParams; + ace_synth_default_params(&synthParams); + synthParams.text_encoder_path = m_textEncoderPathBytes.constData(); + synthParams.dit_path = m_ditPathBytes.constData(); + synthParams.vae_path = m_vaePathBytes.constData(); + synthParams.use_fa = m_flashAttention; + + m_synthContext = ace_synth_load(&synthParams); + if (!m_synthContext) + { + emit generationError("Failed to load synthesis models"); + return false; + } + return true; +} + +void AceStepWorker::unloadSynth() +{ + if (m_synthContext) + { + ace_synth_free(m_synthContext); + m_synthContext = nullptr; + } +} + +AceRequest AceStepWorker::songToRequest(const SongItem& song, const QString& templateJson) +{ + AceRequest req; + request_init(&req); + + // Parse template first to get all default values + QJsonObject requestJson; + if (!templateJson.isEmpty()) { QJsonParseError parseError; - request.song.json = lmOutputFile.readAll(); - QJsonDocument doc = QJsonDocument::fromJson(request.song.json.toUtf8(), &parseError); - lmOutputFile.close(); - - if (doc.isObject() && !parseError.error) - { - QJsonObject obj = doc.object(); - if (obj.contains("lyrics") && obj["lyrics"].isString()) - request.song.lyrics = obj["lyrics"].toString(); - } + QJsonDocument templateDoc = QJsonDocument::fromJson(templateJson.toUtf8(), &parseError); + if (templateDoc.isObject()) + requestJson = templateDoc.object(); + else + qWarning() << "Failed to parse request template:" << parseError.errorString(); } - // Step 2: Run ace-synth to generate audio - QStringList ditVaeArgs; - ditVaeArgs << "--request"<(QRandomGenerator::global()->generate()); + return req; +} \ No newline at end of file diff --git a/src/AceStepWorker.h b/src/AceStepWorker.h index 17962f8..56d95b1 100644 --- a/src/AceStepWorker.h +++ b/src/AceStepWorker.h @@ -8,40 +8,41 @@ #include #include -#include +#include #include +#include #include "SongItem.h" +#include "pipeline-synth.h" +#include "request.h" -#ifdef Q_OS_WIN -inline const QString EXE_EXT = ".exe"; -#else -inline const QString EXE_EXT = ""; -#endif +struct AceLm; +struct AceSynth; -class AceStep : public QObject +class AceStepWorker : public QObject { Q_OBJECT - QProcess qwenProcess; - QProcess ditVaeProcess; - bool busy = false; +public: + explicit AceStepWorker(QObject* parent = nullptr); + ~AceStepWorker(); - struct Request - { - SongItem song; - uint64_t uid; - QString aceStepPath; - QString textEncoderModelPath; - QString ditModelPath; - QString vaeModelPath; - QString requestFilePath; - QString requestLlmFilePath; - }; + bool isGenerating(SongItem* song = nullptr); + void cancelGeneration(); - Request request; + // Model paths - set these before first generation + void setModelPaths(QString lmPath, QString textEncoderPath, QString ditPath, QString vaePath); - const QString tempDir = QStandardPaths::writableLocation(QStandardPaths::TempLocation); + // Low VRAM mode: unload models between phases to save VRAM + void setLowVramMode(bool enabled); + bool isLowVramMode() const { return m_lowVramMode; } + + // Flash attention mode + void setFlashAttention(bool enabled); + bool isFlashAttention() const { return m_flashAttention; } + + // Request a new song generation + bool requestGeneration(SongItem song, QString requestTemplate); signals: void songGenerated(SongItem song); @@ -49,19 +50,49 @@ signals: void generationError(QString error); void progressUpdate(int progress); -public slots: - bool requestGeneration(SongItem song, QString requestTemplate, QString aceStepPath, - QString qwen3ModelPath, QString textEncoderModelPath, QString ditModelPath, - QString vaeModelPath); - -public: - AceStep(QObject* parent = nullptr); - bool isGenerating(SongItem* song = nullptr); - void cancelGeneration(); - private slots: - void qwenProcFinished(int code, QProcess::ExitStatus status); - void ditProcFinished(int code, QProcess::ExitStatus status); + void runGeneration(); + +private: + static bool checkCancel(void* data); + + bool loadModels(); + void unloadModels(); + bool loadLm(); + void unloadLm(); + bool loadSynth(); + void unloadSynth(); + + AceRequest songToRequest(const SongItem& song, const QString& templateJson); + + static std::shared_ptr convertToWav(const AceAudio& audio); + + // Generation state + std::atomic m_busy{false}; + std::atomic m_cancelRequested{false}; + std::atomic m_modelsLoaded{false}; + bool m_lowVramMode = false; + bool m_flashAttention = true; + + // Current request data + SongItem m_currentSong; + QString m_requestTemplate; + uint64_t m_uid; + + // Model paths + QString m_lmModelPath; + QString m_textEncoderPath; + QString m_ditPath; + QString m_vaePath; + + // Loaded models (accessed from worker thread only) + AceLm* m_lmContext = nullptr; + AceSynth* m_synthContext = nullptr; + + QByteArray m_lmModelPathBytes; + QByteArray m_textEncoderPathBytes; + QByteArray m_ditPathBytes; + QByteArray m_vaePathBytes; }; -#endif // ACESTEPWORKER_H +#endif // ACESTEPWORKER_H \ No newline at end of file diff --git a/src/AdvancedSettingsDialog.cpp b/src/AdvancedSettingsDialog.cpp index a34dc4e..7e2e412 100644 --- a/src/AdvancedSettingsDialog.cpp +++ b/src/AdvancedSettingsDialog.cpp @@ -13,6 +13,12 @@ AdvancedSettingsDialog::AdvancedSettingsDialog(QWidget *parent) ui(new Ui::AdvancedSettingsDialog) { ui->setupUi(this); + + // Connect signals and slots explicitly + connect(ui->qwen3BrowseButton, &QPushButton::clicked, this, &AdvancedSettingsDialog::onQwen3BrowseButtonClicked); + connect(ui->textEncoderBrowseButton, &QPushButton::clicked, this, &AdvancedSettingsDialog::onTextEncoderBrowseButtonClicked); + connect(ui->ditBrowseButton, &QPushButton::clicked, this, &AdvancedSettingsDialog::onDiTBrowseButtonClicked); + connect(ui->vaeBrowseButton, &QPushButton::clicked, this, &AdvancedSettingsDialog::onVAEBrowseButtonClicked); } AdvancedSettingsDialog::~AdvancedSettingsDialog() @@ -25,11 +31,6 @@ QString AdvancedSettingsDialog::getJsonTemplate() const return ui->jsonTemplateEdit->toPlainText(); } -QString AdvancedSettingsDialog::getAceStepPath() const -{ - return ui->aceStepPathEdit->text(); -} - QString AdvancedSettingsDialog::getQwen3ModelPath() const { return ui->qwen3ModelEdit->text(); @@ -50,16 +51,21 @@ QString AdvancedSettingsDialog::getVAEModelPath() const return ui->vaeModelEdit->text(); } +bool AdvancedSettingsDialog::getLowVramMode() const +{ + return ui->lowVramCheckBox->isChecked(); +} + +bool AdvancedSettingsDialog::getFlashAttention() const +{ + return ui->flashAttentionCheckBox->isChecked(); +} + void AdvancedSettingsDialog::setJsonTemplate(const QString &templateStr) { ui->jsonTemplateEdit->setPlainText(templateStr); } -void AdvancedSettingsDialog::setAceStepPath(const QString &path) -{ - ui->aceStepPathEdit->setText(path); -} - void AdvancedSettingsDialog::setQwen3ModelPath(const QString &path) { ui->qwen3ModelEdit->setText(path); @@ -80,16 +86,17 @@ void AdvancedSettingsDialog::setVAEModelPath(const QString &path) ui->vaeModelEdit->setText(path); } -void AdvancedSettingsDialog::on_aceStepBrowseButton_clicked() +void AdvancedSettingsDialog::setLowVramMode(bool enabled) { - QString dir = QFileDialog::getExistingDirectory(this, "Select AceStep Build Directory", ui->aceStepPathEdit->text()); - if (!dir.isEmpty()) - { - ui->aceStepPathEdit->setText(dir); - } + ui->lowVramCheckBox->setChecked(enabled); } -void AdvancedSettingsDialog::on_qwen3BrowseButton_clicked() +void AdvancedSettingsDialog::setFlashAttention(bool enabled) +{ + ui->flashAttentionCheckBox->setChecked(enabled); +} + +void AdvancedSettingsDialog::onQwen3BrowseButtonClicked() { QString file = QFileDialog::getOpenFileName(this, "Select Qwen3 Model", ui->qwen3ModelEdit->text(), "GGUF Files (*.gguf)"); @@ -99,7 +106,7 @@ void AdvancedSettingsDialog::on_qwen3BrowseButton_clicked() } } -void AdvancedSettingsDialog::on_textEncoderBrowseButton_clicked() +void AdvancedSettingsDialog::onTextEncoderBrowseButtonClicked() { QString file = QFileDialog::getOpenFileName(this, "Select Text Encoder Model", ui->textEncoderEdit->text(), "GGUF Files (*.gguf)"); @@ -109,7 +116,7 @@ void AdvancedSettingsDialog::on_textEncoderBrowseButton_clicked() } } -void AdvancedSettingsDialog::on_ditBrowseButton_clicked() +void AdvancedSettingsDialog::onDiTBrowseButtonClicked() { QString file = QFileDialog::getOpenFileName(this, "Select DiT Model", ui->ditModelEdit->text(), "GGUF Files (*.gguf)"); if (!file.isEmpty()) @@ -118,7 +125,7 @@ void AdvancedSettingsDialog::on_ditBrowseButton_clicked() } } -void AdvancedSettingsDialog::on_vaeBrowseButton_clicked() +void AdvancedSettingsDialog::onVAEBrowseButtonClicked() { QString file = QFileDialog::getOpenFileName(this, "Select VAE Model", ui->vaeModelEdit->text(), "GGUF Files (*.gguf)"); if (!file.isEmpty()) diff --git a/src/AdvancedSettingsDialog.h b/src/AdvancedSettingsDialog.h index 8db247c..8a3047a 100644 --- a/src/AdvancedSettingsDialog.h +++ b/src/AdvancedSettingsDialog.h @@ -24,26 +24,27 @@ public: // Getters for settings QString getJsonTemplate() const; - QString getAceStepPath() const; QString getQwen3ModelPath() const; QString getTextEncoderModelPath() const; QString getDiTModelPath() const; QString getVAEModelPath() const; + bool getLowVramMode() const; + bool getFlashAttention() const; // Setters for settings void setJsonTemplate(const QString &templateStr); - void setAceStepPath(const QString &path); void setQwen3ModelPath(const QString &path); void setTextEncoderModelPath(const QString &path); void setDiTModelPath(const QString &path); void setVAEModelPath(const QString &path); + void setLowVramMode(bool enabled); + void setFlashAttention(bool enabled); private slots: - void on_aceStepBrowseButton_clicked(); - void on_qwen3BrowseButton_clicked(); - void on_textEncoderBrowseButton_clicked(); - void on_ditBrowseButton_clicked(); - void on_vaeBrowseButton_clicked(); + void onQwen3BrowseButtonClicked(); + void onTextEncoderBrowseButtonClicked(); + void onDiTBrowseButtonClicked(); + void onVAEBrowseButtonClicked(); private: Ui::AdvancedSettingsDialog *ui; diff --git a/src/AdvancedSettingsDialog.ui b/src/AdvancedSettingsDialog.ui index e9fadf1..6e95e90 100644 --- a/src/AdvancedSettingsDialog.ui +++ b/src/AdvancedSettingsDialog.ui @@ -1,203 +1,239 @@ - - AdvancedSettingsDialog - - - - 0 - 0 - 600 - 450 - - - - Advanced Settings - - - - - - 0 - - - - JSON Template - - - - - - JSON Template for AceStep generation: - - - true - - - - - - - - - - - Model Paths - - - - QFormLayout::FieldGrowthPolicy::AllNonFixedFieldsGrow - - - - - AceStep Path: - - - - - - - - - - - - Browse... - - - - - - - - - Qwen3 Model: - - - - - - - - - - - - Browse... - - - - - - - - - Text Encoder Model: - - - - - - - - - - - - Browse... - - - - - - - - - DiT Model: - - - - - - - - - - - - Browse... - - - - - - - - - VAE Model: - - - - - - - - - - - - Browse... - - - - - - - - - - - - - QDialogButtonBox::StandardButton::Cancel|QDialogButtonBox::StandardButton::Save - - - - - - - - - buttonBox - accepted() - AdvancedSettingsDialog - accept() - - - 248 - 254 - - - 157 - 254 - - - - - buttonBox - rejected() - AdvancedSettingsDialog - reject() - - - 316 - 260 - - - 286 - 260 - - - - - + + AdvancedSettingsDialog + + + + 0 + 0 + 600 + 450 + + + + Advanced Settings + + + + + + 0 + + + + Performance + + + + + + Low VRAM Mode + + + + + + + Unload models between generation phases to save VRAM. Slower but uses less memory. + + + true + + + + + + + Flash Attention + + + true + + + + + + + Use flash attention for faster generation. Disable if experiencing poor output quality on vulkan. + + + true + + + + + + + Qt::Orientation::Vertical + + + + 20 + 40 + + + + + + + + + JSON Template + + + + + + JSON Template for AceStep generation: + + + true + + + + + + + + + + + Model Paths + + + + QFormLayout::FieldGrowthPolicy::AllNonFixedFieldsGrow + + + + + Qwen3 Model: + + + + + + + + + + + + Browse... + + + + + + + + + Text Encoder Model: + + + + + + + + + + + + Browse... + + + + + + + + + DiT Model: + + + + + + + + + + + + Browse... + + + + + + + + + VAE Model: + + + + + + + + + + + + Browse... + + + + + + + + + + + + + QDialogButtonBox::StandardButton::Cancel|QDialogButtonBox::StandardButton::Save + + + + + + + + + buttonBox + accepted() + AdvancedSettingsDialog + accept() + + + 248 + 254 + + + 157 + 254 + + + + + buttonBox + rejected() + AdvancedSettingsDialog + reject() + + + 316 + 260 + + + 286 + 260 + + + + + diff --git a/src/AudioPlayer.cpp b/src/AudioPlayer.cpp index 2f039bb..bb53670 100644 --- a/src/AudioPlayer.cpp +++ b/src/AudioPlayer.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: GPL-3.0-or-later #include "AudioPlayer.h" +#include #include AudioPlayer::AudioPlayer(QObject *parent) @@ -48,6 +49,33 @@ void AudioPlayer::play(const QString &filePath) positionTimer->start(); } +void AudioPlayer::play(std::shared_ptr audioData) +{ + if (isPlaying()) + { + stop(); + } + + if (!audioData || audioData->isEmpty()) + { + emit playbackError("No audio data available"); + return; + } + + // Create a buffer with the audio data + QBuffer *buffer = new QBuffer(); + buffer->setData(*audioData); + buffer->open(QIODevice::ReadOnly); + buffer->setParent(this); + + // Use QMediaPlayer::setSourceDevice for in-memory playback + mediaPlayer->setSourceDevice(buffer, QUrl("memory://audio.wav")); + mediaPlayer->play(); + + // Start position timer + positionTimer->start(); +} + void AudioPlayer::play() { if (!isPlaying()) diff --git a/src/AudioPlayer.h b/src/AudioPlayer.h index 6eb4b25..e1d0b47 100644 --- a/src/AudioPlayer.h +++ b/src/AudioPlayer.h @@ -14,6 +14,7 @@ #include #include #include +#include class AudioPlayer : public QObject { @@ -23,6 +24,7 @@ public: ~AudioPlayer(); void play(const QString &filePath); + void play(std::shared_ptr audioData); void play(); void stop(); void pause(); diff --git a/src/MainWindow.cpp b/src/MainWindow.cpp index cff9f9f..5ed9b5a 100644 --- a/src/MainWindow.cpp +++ b/src/MainWindow.cpp @@ -21,7 +21,7 @@ MainWindow::MainWindow(QWidget *parent) ui(new Ui::MainWindow), songModel(new SongListModel(this)), audioPlayer(new AudioPlayer(this)), - aceStep(new AceStep(this)), + aceStep(new AceStepWorker), playbackTimer(new QTimer(this)), isPlaying(false), isPaused(false), @@ -29,6 +29,7 @@ MainWindow::MainWindow(QWidget *parent) isGeneratingNext(false) { aceStep->moveToThread(&aceThread); + aceThread.setObjectName("AceStep Woker Thread"); ui->setupUi(this); @@ -41,15 +42,25 @@ MainWindow::MainWindow(QWidget *parent) // Load settings loadSettings(); + // Set model paths for acestep.cpp + aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath); + // Auto-load playlist from config location on startup autoLoadPlaylist(); // Connect signals and slots - connect(ui->actionAdvancedSettings, &QAction::triggered, this, &MainWindow::on_advancedSettingsButton_clicked); - connect(ui->actionSavePlaylist, &QAction::triggered, this, &MainWindow::on_actionSavePlaylist); - connect(ui->actionLoadPlaylist, &QAction::triggered, this, &MainWindow::on_actionLoadPlaylist); - connect(ui->actionAppendPlaylist, &QAction::triggered, this, &MainWindow::on_actionAppendPlaylist); - connect(ui->actionSaveSong, &QAction::triggered, this, &MainWindow::on_actionSaveSong); + connect(ui->playButton, &QPushButton::clicked, this, &MainWindow::onPlayButtonClicked); + connect(ui->pauseButton, &QPushButton::clicked, this, &MainWindow::onPauseButtonClicked); + connect(ui->skipButton, &QPushButton::clicked, this, &MainWindow::onSkipButtonClicked); + connect(ui->stopButton, &QPushButton::clicked, this, &MainWindow::onStopButtonClicked); + connect(ui->shuffleButton, &QPushButton::clicked, this, &MainWindow::onShuffleButtonClicked); + connect(ui->addSongButton, &QPushButton::clicked, this, &MainWindow::onAddSongButtonClicked); + connect(ui->removeSongButton, &QPushButton::clicked, this, &MainWindow::onRemoveSongButtonClicked); + connect(ui->actionAdvancedSettings, &QAction::triggered, this, &MainWindow::onAdvancedSettingsButtonClicked); + connect(ui->actionSavePlaylist, &QAction::triggered, this, &MainWindow::onActionSavePlaylist); + connect(ui->actionLoadPlaylist, &QAction::triggered, this, &MainWindow::onActionLoadPlaylist); + connect(ui->actionAppendPlaylist, &QAction::triggered, this, &MainWindow::onActionAppendPlaylist); + connect(ui->actionSaveSong, &QAction::triggered, this, &MainWindow::onActionSaveSong); connect(ui->actionQuit, &QAction::triggered, this, [this]() { close(); @@ -62,16 +73,16 @@ MainWindow::MainWindow(QWidget *parent) connect(audioPlayer, &AudioPlayer::playbackStarted, this, &MainWindow::playbackStarted); connect(audioPlayer, &AudioPlayer::positionChanged, this, &MainWindow::updatePosition); connect(audioPlayer, &AudioPlayer::durationChanged, this, &MainWindow::updateDuration); - connect(aceStep, &AceStep::songGenerated, this, &MainWindow::songGenerated); - connect(aceStep, &AceStep::generationCanceled, this, &MainWindow::generationCanceld); - connect(aceStep, &AceStep::generationError, this, &MainWindow::generationError); - connect(aceStep, &AceStep::progressUpdate, ui->progressBar, &QProgressBar::setValue); + connect(aceStep, &AceStepWorker::songGenerated, this, &MainWindow::songGenerated); + connect(aceStep, &AceStepWorker::generationCanceled, this, &MainWindow::generationCanceld); + connect(aceStep, &AceStepWorker::generationError, this, &MainWindow::generationError); + connect(aceStep, &AceStepWorker::progressUpdate, ui->progressBar, &QProgressBar::setValue); // Connect double-click on song list for editing (works with QTableView too) - connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::on_songListView_doubleClicked); + connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::onSongListViewDoubleClicked); // Connect audio player error signal - connect(audioPlayer, &AudioPlayer::playbackError, [this](const QString &error) + connect(audioPlayer, &AudioPlayer::playbackError, this, [this](const QString &error) { QMessageBox::warning(this, "Playback Error", "Failed to play audio: " + error); }); @@ -98,11 +109,17 @@ MainWindow::MainWindow(QWidget *parent) ui->nowPlayingLabel->setText("Now Playing:"); currentSong = songModel->getSong(0); + + // Start the worker thread and enter its event loop + QObject::connect(&aceThread, &QThread::started, [this]() {qDebug() << "Worker thread started";}); + aceThread.start(); } MainWindow::~MainWindow() { aceStep->cancelGeneration(); + aceThread.quit(); + aceThread.wait(); autoSavePlaylist(); saveSettings(); @@ -145,13 +162,20 @@ void MainWindow::loadSettings() // Load path settings with defaults based on application directory QString appDir = QCoreApplication::applicationDirPath(); - aceStepPath = settings.value("aceStepPath", appDir + "/acestep.cpp").toString(); qwen3ModelPath = settings.value("qwen3ModelPath", appDir + "/acestep.cpp/models/acestep-5Hz-lm-4B-Q8_0.gguf").toString(); textEncoderModelPath = settings.value("textEncoderModelPath", appDir + "/acestep.cpp/models/Qwen3-Embedding-0.6B-BF16.gguf").toString(); ditModelPath = settings.value("ditModelPath", appDir + "/acestep.cpp/models/acestep-v15-turbo-Q8_0.gguf").toString(); vaeModelPath = settings.value("vaeModelPath", appDir + "/acestep.cpp/models/vae-BF16.gguf").toString(); + + // Load low VRAM mode + bool lowVram = settings.value("lowVramMode", false).toBool(); + aceStep->setLowVramMode(lowVram); + + // Load flash attention setting + bool flashAttention = settings.value("flashAttention", false).toBool(); + aceStep->setFlashAttention(flashAttention); } void MainWindow::saveSettings() @@ -165,12 +189,17 @@ void MainWindow::saveSettings() settings.setValue("shuffleMode", shuffleMode); // Save path settings - settings.setValue("aceStepPath", aceStepPath); settings.setValue("qwen3ModelPath", qwen3ModelPath); settings.setValue("textEncoderModelPath", textEncoderModelPath); settings.setValue("ditModelPath", ditModelPath); settings.setValue("vaeModelPath", vaeModelPath); + // Save low VRAM mode + settings.setValue("lowVramMode", aceStep->isLowVramMode()); + + // Save flash attention setting + settings.setValue("flashAttention", aceStep->isFlashAttention()); + settings.setValue("firstRun", false); } @@ -219,7 +248,7 @@ void MainWindow::updateControls() ui->removeSongButton->setEnabled(hasSongs && ui->songListView->currentIndex().isValid()); } -void MainWindow::on_playButton_clicked() +void MainWindow::onPlayButtonClicked() { if (isPaused) { @@ -240,7 +269,7 @@ void MainWindow::on_playButton_clicked() updateControls(); } -void MainWindow::on_pauseButton_clicked() +void MainWindow::onPauseButtonClicked() { if (isPlaying && !isPaused && audioPlayer->isPlaying()) { @@ -251,7 +280,7 @@ void MainWindow::on_pauseButton_clicked() } } -void MainWindow::on_skipButton_clicked() +void MainWindow::onSkipButtonClicked() { if (isPlaying) { @@ -261,7 +290,7 @@ void MainWindow::on_skipButton_clicked() } } -void MainWindow::on_stopButton_clicked() +void MainWindow::onStopButtonClicked() { if (isPlaying) { @@ -275,7 +304,7 @@ void MainWindow::on_stopButton_clicked() } } -void MainWindow::on_shuffleButton_clicked() +void MainWindow::onShuffleButtonClicked() { shuffleMode = ui->shuffleButton->isChecked(); updateControls(); @@ -285,7 +314,7 @@ void MainWindow::on_shuffleButton_clicked() ensureSongsInQueue(); } -void MainWindow::on_addSongButton_clicked() +void MainWindow::onAddSongButtonClicked() { SongDialog dialog(this); @@ -300,12 +329,12 @@ void MainWindow::on_addSongButton_clicked() } } -void MainWindow::on_songListView_doubleClicked(const QModelIndex &index) +void MainWindow::onSongListViewDoubleClicked(const QModelIndex &index) { if (!index.isValid()) return; - disconnect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::on_songListView_doubleClicked); + disconnect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::onSongListViewDoubleClicked); int row = index.row(); @@ -337,10 +366,10 @@ void MainWindow::on_songListView_doubleClicked(const QModelIndex &index) songModel->updateSong(songModel->index(row, 1), dialog.getSong()); } - connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::on_songListView_doubleClicked); + connect(ui->songListView, &QTableView::doubleClicked, this, &MainWindow::onSongListViewDoubleClicked); } -void MainWindow::on_removeSongButton_clicked() +void MainWindow::onRemoveSongButtonClicked() { QModelIndex currentIndex = ui->songListView->currentIndex(); if (!currentIndex.isValid()) @@ -360,17 +389,18 @@ void MainWindow::on_removeSongButton_clicked() } } -void MainWindow::on_advancedSettingsButton_clicked() +void MainWindow::onAdvancedSettingsButtonClicked() { AdvancedSettingsDialog dialog(this); // Set current values dialog.setJsonTemplate(jsonTemplate); - dialog.setAceStepPath(aceStepPath); dialog.setQwen3ModelPath(qwen3ModelPath); dialog.setTextEncoderModelPath(textEncoderModelPath); dialog.setDiTModelPath(ditModelPath); dialog.setVAEModelPath(vaeModelPath); + dialog.setLowVramMode(aceStep->isLowVramMode()); + dialog.setFlashAttention(aceStep->isFlashAttention()); if (dialog.exec() == QDialog::Accepted) { @@ -385,12 +415,20 @@ void MainWindow::on_advancedSettingsButton_clicked() // Update settings jsonTemplate = dialog.getJsonTemplate(); - aceStepPath = dialog.getAceStepPath(); qwen3ModelPath = dialog.getQwen3ModelPath(); textEncoderModelPath = dialog.getTextEncoderModelPath(); ditModelPath = dialog.getDiTModelPath(); vaeModelPath = dialog.getVAEModelPath(); + // Update model paths for acestep.cpp + aceStep->setModelPaths(qwen3ModelPath, textEncoderModelPath, ditModelPath, vaeModelPath); + + // Update low VRAM mode + aceStep->setLowVramMode(dialog.getLowVramMode()); + + // Update flash attention setting + aceStep->setFlashAttention(dialog.getFlashAttention()); + saveSettings(); } } @@ -403,7 +441,14 @@ void MainWindow::playbackStarted() void MainWindow::playSong(const SongItem& song) { currentSong = song; - audioPlayer->play(song.file); + if (song.audioData) + { + audioPlayer->play(song.audioData); + } + else if (!song.file.isEmpty()) + { + audioPlayer->play(song.file); + } songModel->setPlayingIndex(songModel->findSongIndexById(song.uniqueId)); ui->nowPlayingLabel->setText("Now Playing: " + song.caption); ui->lyricsTextEdit->setPlainText(song.lyrics); @@ -494,7 +539,7 @@ void MainWindow::updatePlaybackStatus(bool playing) updateControls(); } -void MainWindow::on_positionSlider_sliderMoved(int position) +void MainWindow::onPositionSliderSliderMoved(int position) { if (isPlaying && audioPlayer->isPlaying()) { @@ -533,10 +578,7 @@ void MainWindow::ensureSongsInQueue(bool enqeueCurrent) isGeneratingNext = true; ui->statusbar->showMessage("Generateing: "+nextSong.caption); - aceStep->requestGeneration(nextSong, jsonTemplate, - aceStepPath, qwen3ModelPath, - textEncoderModelPath, ditModelPath, - vaeModelPath); + QMetaObject::invokeMethod(aceStep, &AceStepWorker::requestGeneration, Qt::QueuedConnection, nextSong, jsonTemplate); } void MainWindow::flushGenerationQueue() @@ -547,7 +589,7 @@ void MainWindow::flushGenerationQueue() } // Playlist save/load methods -void MainWindow::on_actionSavePlaylist() +void MainWindow::onActionSavePlaylist() { QString filePath = QFileDialog::getSaveFileName(this, "Save Playlist", QStandardPaths::writableLocation(QStandardPaths::DocumentsLocation) + "/playlist.json", @@ -559,7 +601,7 @@ void MainWindow::on_actionSavePlaylist() } } -void MainWindow::on_actionLoadPlaylist() +void MainWindow::onActionLoadPlaylist() { QString filePath = QFileDialog::getOpenFileName(this, "Load Playlist", QStandardPaths::writableLocation(QStandardPaths::DocumentsLocation), @@ -572,7 +614,7 @@ void MainWindow::on_actionLoadPlaylist() } } -void MainWindow::on_actionAppendPlaylist() +void MainWindow::onActionAppendPlaylist() { QString filePath = QFileDialog::getOpenFileName(this, "Load Playlist", QStandardPaths::writableLocation(QStandardPaths::DocumentsLocation), @@ -583,7 +625,7 @@ void MainWindow::on_actionAppendPlaylist() } } -void MainWindow::on_actionSaveSong() +void MainWindow::onActionSaveSong() { QString filePath = QFileDialog::getSaveFileName(this, "Save Playlist", QStandardPaths::writableLocation(QStandardPaths::DocumentsLocation) + "/song.json", @@ -607,7 +649,22 @@ void MainWindow::on_actionSaveSong() QFile file(filePath); if (!file.open(QIODevice::WriteOnly | QIODevice::Text)) return; - QFile::copy(currentSong.file, filePath + ".wav"); + + // Save audio from memory if available, otherwise fall back to file + if (currentSong.audioData) + { + QFile wavFile(filePath + ".wav"); + if (wavFile.open(QIODevice::WriteOnly)) + { + wavFile.write(*currentSong.audioData); + wavFile.close(); + } + } + else if (!currentSong.file.isEmpty()) + { + QFile::copy(currentSong.file, filePath + ".wav"); + } + file.write(jsonData); file.close(); } diff --git a/src/MainWindow.h b/src/MainWindow.h index 9a7406a..6b3663f 100644 --- a/src/MainWindow.h +++ b/src/MainWindow.h @@ -36,7 +36,7 @@ class MainWindow : public QMainWindow SongListModel *songModel; AudioPlayer *audioPlayer; QThread aceThread; - AceStep *aceStep; + AceStepWorker *aceStep; QTimer *playbackTimer; QString formatTime(int milliseconds); @@ -50,7 +50,6 @@ class MainWindow : public QMainWindow QString jsonTemplate; // Path settings - QString aceStepPath; QString qwen3ModelPath; QString textEncoderModelPath; QString ditModelPath; @@ -68,19 +67,19 @@ public slots: void show(); private slots: - void on_playButton_clicked(); - void on_pauseButton_clicked(); - void on_skipButton_clicked(); - void on_stopButton_clicked(); - void on_shuffleButton_clicked(); - void on_positionSlider_sliderMoved(int position); + void onPlayButtonClicked(); + void onPauseButtonClicked(); + void onSkipButtonClicked(); + void onStopButtonClicked(); + void onShuffleButtonClicked(); + void onPositionSliderSliderMoved(int position); void updatePosition(int position); void updateDuration(int duration); - void on_addSongButton_clicked(); - void on_removeSongButton_clicked(); - void on_advancedSettingsButton_clicked(); + void onAddSongButtonClicked(); + void onRemoveSongButtonClicked(); + void onAdvancedSettingsButtonClicked(); - void on_songListView_doubleClicked(const QModelIndex &index); + void onSongListViewDoubleClicked(const QModelIndex &index); void songGenerated(const SongItem& song); void generationCanceld(const SongItem& song); @@ -89,10 +88,10 @@ private slots: void updatePlaybackStatus(bool playing); void generationError(const QString &error); - void on_actionSavePlaylist(); - void on_actionLoadPlaylist(); - void on_actionAppendPlaylist(); - void on_actionSaveSong(); + void onActionSavePlaylist(); + void onActionLoadPlaylist(); + void onActionAppendPlaylist(); + void onActionSaveSong(); private: void loadSettings(); diff --git a/src/SongDialog.cpp b/src/SongDialog.cpp index 58f4ec6..e0ebde8 100644 --- a/src/SongDialog.cpp +++ b/src/SongDialog.cpp @@ -12,6 +12,10 @@ SongDialog::SongDialog(QWidget *parent, const SongItem &song) { ui->setupUi(this); + // Connect signals and slots explicitly + connect(ui->okButton, &QPushButton::clicked, this, &SongDialog::onOkButtonClicked); + connect(ui->cancelButton, &QPushButton::clicked, this, &SongDialog::onCancelButtonClicked); + ui->captionEdit->setPlainText(song.caption); ui->lyricsEdit->setPlainText(song.lyrics); ui->checkBoxEnhanceCaption->setChecked(song.cotCaption); @@ -140,7 +144,7 @@ SongDialog::~SongDialog() delete ui; } -void SongDialog::on_okButton_clicked() +void SongDialog::onOkButtonClicked() { // Validate that caption is not empty QString caption = ui->captionEdit->toPlainText(); @@ -153,7 +157,7 @@ void SongDialog::on_okButton_clicked() accept(); } -void SongDialog::on_cancelButton_clicked() +void SongDialog::onCancelButtonClicked() { reject(); } diff --git a/src/SongDialog.h b/src/SongDialog.h index 1ab08de..5e4ce48 100644 --- a/src/SongDialog.h +++ b/src/SongDialog.h @@ -28,8 +28,8 @@ public: const SongItem& getSong(); private slots: - void on_okButton_clicked(); - void on_cancelButton_clicked(); + void onOkButtonClicked(); + void onCancelButtonClicked(); private: Ui::SongDialog *ui; diff --git a/src/SongItem.h b/src/SongItem.h index a8e993c..0a80926 100644 --- a/src/SongItem.h +++ b/src/SongItem.h @@ -8,6 +8,7 @@ #include #include #include +#include class SongItem { @@ -22,6 +23,7 @@ public: uint64_t uniqueId; QString file; QString json; + std::shared_ptr audioData; SongItem(const QString &caption = "", const QString &lyrics = ""); SongItem(const QJsonObject& json); diff --git a/tests/test_acestep_worker.cpp b/tests/test_acestep_worker.cpp new file mode 100644 index 0000000..60a5ffa --- /dev/null +++ b/tests/test_acestep_worker.cpp @@ -0,0 +1,525 @@ +// Test for AceStepWorker +// Compile with: cmake .. && make test_acestep_worker && ./test_acestep_worker + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../src/AceStepWorker.h" + +// Test result tracking +static int testsPassed = 0; +static int testsFailed = 0; +static int testsSkipped = 0; + +#define TEST(name) void test_##name() +#define RUN_TEST(name) do { \ + std::cout << "Running " << #name << "... "; \ + test_##name(); \ + if (test_skipped) { \ + std::cout << "SKIPPED" << std::endl; \ + testsSkipped++; \ + test_skipped = false; \ + } else if (test_failed) { \ + std::cout << "FAILED" << std::endl; \ + testsFailed++; \ + test_failed = false; \ + } else { \ + std::cout << "PASSED" << std::endl; \ + testsPassed++; \ + } \ +} while(0) + +static bool test_failed = false; +static bool test_skipped = false; + +#define ASSERT_TRUE(cond) do { \ + if (!(cond)) { \ + std::cout << "FAILED: " << #cond << " at line " << __LINE__ << std::endl; \ + test_failed = true; \ + return; \ + } \ +} while(0) + +#define ASSERT_FALSE(cond) ASSERT_TRUE(!(cond)) + +#define SKIP_IF(cond) do { \ + if (cond) { \ + std::cout << "(skipping: " << #cond << ") "; \ + test_skipped = true; \ + return; \ + } \ +} while(0) + +// Helper to get model paths from settings like main app +struct ModelPaths { + QString lmPath; + QString textEncoderPath; + QString ditPath; + QString vaePath; +}; + +static ModelPaths getModelPathsFromSettings() +{ + ModelPaths paths; + QSettings settings("MusicGenerator", "AceStepGUI"); + + QString appDir = QCoreApplication::applicationDirPath(); + paths.lmPath = settings.value("qwen3ModelPath", + appDir + "/acestep.cpp/models/acestep-5Hz-lm-4B-Q8_0.gguf").toString(); + paths.textEncoderPath = settings.value("textEncoderModelPath", + appDir + "/acestep.cpp/models/Qwen3-Embedding-0.6B-Q8_0.gguf").toString(); + paths.ditPath = settings.value("ditModelPath", + appDir + "/acestep.cpp/models/acestep-v15-turbo-Q8_0.gguf").toString(); + paths.vaePath = settings.value("vaeModelPath", + appDir + "/acestep.cpp/models/vae-BF16.gguf").toString(); + + return paths; +} + +static bool modelsExist(const ModelPaths& paths) +{ + return QFileInfo::exists(paths.lmPath) && + QFileInfo::exists(paths.textEncoderPath) && + QFileInfo::exists(paths.ditPath) && + QFileInfo::exists(paths.vaePath); +} + +// Test 1: Check that isGenerating returns false initially +TEST(initialState) +{ + AceStepWorker worker; + ASSERT_TRUE(!worker.isGenerating()); +} + +// Test 2: Check that requestGeneration returns false when no model paths set +TEST(noModelPaths) +{ + AceStepWorker worker; + SongItem song("test caption", ""); + + bool result = worker.requestGeneration(song, "{}"); + ASSERT_FALSE(result); + ASSERT_TRUE(!worker.isGenerating()); +} + +// Test 3: Check that setModelPaths stores paths correctly +TEST(setModelPaths) +{ + AceStepWorker worker; + worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf"); + ASSERT_TRUE(true); +} + +// Test 4: Check async behavior - requestGeneration returns immediately +TEST(asyncReturnsImmediately) +{ + AceStepWorker worker; + worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf"); + + SongItem song("test caption", ""); + + // If this blocks, the test will hang + bool result = worker.requestGeneration(song, "{}"); + + // Should return false due to invalid paths, but immediately + ASSERT_FALSE(result); +} + +// Test 5: Check that cancelGeneration sets the cancel flag +TEST(cancellationFlag) +{ + AceStepWorker worker; + worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf"); + worker.cancelGeneration(); + ASSERT_TRUE(true); +} + +// Test 6: Check that signals are defined correctly +TEST(signalsExist) +{ + AceStepWorker worker; + + // Verify signals exist by connecting to them (compile-time check) + QObject::connect(&worker, &AceStepWorker::songGenerated, [](const SongItem&) {}); + QObject::connect(&worker, &AceStepWorker::generationCanceled, [](const SongItem&) {}); + QObject::connect(&worker, &AceStepWorker::generationError, [](const QString&) {}); + QObject::connect(&worker, &AceStepWorker::progressUpdate, [](int) {}); + + ASSERT_TRUE(true); +} + +// Test 7: Check SongItem to AceRequest conversion (internal) +TEST(requestConversion) +{ + AceStepWorker worker; + + SongItem song("Upbeat pop rock", "[Verse 1]"); + song.cotCaption = true; + + QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})"; + + worker.setModelPaths("/path/lm.gguf", "/path/encoder.gguf", "/path/dit.gguf", "/path/vae.gguf"); + bool result = worker.requestGeneration(song, templateJson); + + // Should fail due to invalid paths, but shouldn't crash + ASSERT_FALSE(result); +} + +// Test 8: Read model paths from settings +TEST(readSettings) +{ + ModelPaths paths = getModelPathsFromSettings(); + + std::cout << "\n Model paths from settings:" << std::endl; + std::cout << " LM: " << paths.lmPath.toStdString() << std::endl; + std::cout << " Text Encoder: " << paths.textEncoderPath.toStdString() << std::endl; + std::cout << " DiT: " << paths.ditPath.toStdString() << std::endl; + std::cout << " VAE: " << paths.vaePath.toStdString() << std::endl; + + ASSERT_TRUE(!paths.lmPath.isEmpty()); + ASSERT_TRUE(!paths.textEncoderPath.isEmpty()); + ASSERT_TRUE(!paths.ditPath.isEmpty()); + ASSERT_TRUE(!paths.vaePath.isEmpty()); +} + +// Test 9: Check if model files exist +TEST(checkModelFiles) +{ + ModelPaths paths = getModelPathsFromSettings(); + + bool lmExists = QFileInfo::exists(paths.lmPath); + bool encoderExists = QFileInfo::exists(paths.textEncoderPath); + bool ditExists = QFileInfo::exists(paths.ditPath); + bool vaeExists = QFileInfo::exists(paths.vaePath); + + std::cout << "\n Model file status:" << std::endl; + std::cout << " LM: " << (lmExists ? "EXISTS" : "MISSING") << std::endl; + std::cout << " Text Encoder: " << (encoderExists ? "EXISTS" : "MISSING") << std::endl; + std::cout << " DiT: " << (ditExists ? "EXISTS" : "MISSING") << std::endl; + std::cout << " VAE: " << (vaeExists ? "EXISTS" : "MISSING") << std::endl; + + ASSERT_TRUE(lmExists); + ASSERT_TRUE(encoderExists); + ASSERT_TRUE(ditExists); + ASSERT_TRUE(vaeExists); +} + +// Test 10: Actually generate a song (requires valid model paths) +TEST(generateSong) +{ + ModelPaths paths = getModelPathsFromSettings(); + + // Skip if models don't exist + SKIP_IF(!modelsExist(paths)); + + AceStepWorker worker; + worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath); + + SongItem song("Upbeat pop rock with driving guitars", ""); + + QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})"; + + // Track if we get progress updates + bool gotProgress = false; + QObject::connect(&worker, &AceStepWorker::progressUpdate, [&gotProgress](int p) { + std::cout << "\n Progress: " << p << "%" << std::endl; + gotProgress = true; + }); + + // Track generation result + bool generationCompleted = false; + SongItem resultSong; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&generationCompleted, &resultSong](const SongItem& song) { + std::cout << "\n Song generated successfully!" << std::endl; + std::cout << " Caption: " << song.caption.toStdString() << std::endl; + std::cout << " Lyrics: " << song.lyrics.left(100).toStdString() << "..." << std::endl; + std::cout << " File: " << song.file.toStdString() << std::endl; + resultSong = song; + generationCompleted = true; + }); + + QString errorMsg; + QObject::connect(&worker, &AceStepWorker::generationError, + [&errorMsg](const QString& err) { + std::cout << "\n Error: " << err.toStdString() << std::endl; + errorMsg = err; + }); + + std::cout << "\n Starting generation..." << std::endl; + + // Request generation + bool result = worker.requestGeneration(song, templateJson); + ASSERT_TRUE(result); + + // Use QEventLoop with timer for proper event processing + QEventLoop loop; + QTimer timeoutTimer; + + timeoutTimer.setSingleShot(true); + timeoutTimer.start(300000); // 5 minute timeout + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit); + QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit); + + loop.exec(); + + ASSERT_TRUE(generationCompleted); + ASSERT_TRUE(resultSong.audioData != nullptr); + ASSERT_TRUE(!resultSong.audioData->isEmpty()); + + // Check audio data is not empty + std::cout << " Audio data size: " << resultSong.audioData->size() << " bytes" << std::endl; + ASSERT_TRUE(resultSong.audioData->size() > 1000); // Should be at least 1KB for valid audio +} + +// Test 11: Test cancellation +TEST(cancellation) +{ + ModelPaths paths = getModelPathsFromSettings(); + + // Skip if models don't exist + SKIP_IF(!modelsExist(paths)); + + AceStepWorker worker; + worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath); + + SongItem song("A very long ambient piece", ""); + + QString templateJson = R"({"inference_steps": 50, "shift": 3.0, "vocal_language": "en"})"; + + bool cancelReceived = false; + QObject::connect(&worker, &AceStepWorker::generationCanceled, + [&cancelReceived](const SongItem&) { + std::cout << "\n Generation was canceled!" << std::endl; + cancelReceived = true; + }); + + std::cout << "\n Starting generation and will cancel after 2 seconds..." << std::endl; + + // Start generation + bool result = worker.requestGeneration(song, templateJson); + ASSERT_TRUE(result); + + // Wait 2 seconds then cancel + QThread::sleep(2); + worker.cancelGeneration(); + + // Wait a bit for cancellation to be processed + QThread::sleep(1); + QCoreApplication::processEvents(); + + // Note: cancellation may or may not complete depending on where in the process + // the cancel was requested. The important thing is it doesn't crash. + std::cout << " Cancel requested, no crash detected" << std::endl; + ASSERT_TRUE(true); +} + +// Test 12: Test low VRAM mode generation +TEST(generateSongLowVram) +{ + ModelPaths paths = getModelPathsFromSettings(); + + // Skip if models don't exist + SKIP_IF(!modelsExist(paths)); + + AceStepWorker worker; + worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath); + worker.setLowVramMode(true); + + ASSERT_TRUE(worker.isLowVramMode()); + + SongItem song("Chill electronic music", ""); + + QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})"; + + // Track generation result + bool generationCompleted = false; + SongItem resultSong; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&generationCompleted, &resultSong](const SongItem& song) { + std::cout << "\n Low VRAM mode: Song generated successfully!" << std::endl; + std::cout << " Caption: " << song.caption.toStdString() << std::endl; + if (song.audioData) { + std::cout << " Audio data size: " << song.audioData->size() << " bytes" << std::endl; + } else { + std::cout << " Audio data size: null" << std::endl; + } + resultSong = song; + generationCompleted = true; + }); + + QString errorMsg; + QObject::connect(&worker, &AceStepWorker::generationError, + [&errorMsg](const QString& err) { + std::cout << "\n Error: " << err.toStdString() << std::endl; + errorMsg = err; + }); + + std::cout << "\n Starting low VRAM mode generation..." << std::endl; + + // Request generation + bool result = worker.requestGeneration(song, templateJson); + ASSERT_TRUE(result); + + // Use QEventLoop with timer for proper event processing + QEventLoop loop; + QTimer timeoutTimer; + + timeoutTimer.setSingleShot(true); + timeoutTimer.start(300000); // 5 minute timeout + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit); + QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit); + + loop.exec(); + + ASSERT_TRUE(generationCompleted); + ASSERT_TRUE(resultSong.audioData != nullptr); + ASSERT_TRUE(!resultSong.audioData->isEmpty()); + + std::cout << " Audio data size: " << resultSong.audioData->size() << " bytes" << std::endl; + ASSERT_TRUE(resultSong.audioData->size() > 1000); +} + +// Test 13: Test normal mode keeps models loaded between generations +TEST(normalModeKeepsModelsLoaded) +{ + ModelPaths paths = getModelPathsFromSettings(); + + // Skip if models don't exist + SKIP_IF(!modelsExist(paths)); + + AceStepWorker worker; + worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath); + // Normal mode is default (lowVramMode = false) + + ASSERT_FALSE(worker.isLowVramMode()); + + QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})"; + + // Generate first song + bool firstGenerationCompleted = false; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&firstGenerationCompleted](const SongItem&) { + firstGenerationCompleted = true; + }); + + QObject::connect(&worker, &AceStepWorker::generationError, + [](const QString& err) { + std::cout << "\n Error: " << err.toStdString() << std::endl; + }); + + std::cout << "\n Generating first song (normal mode)..." << std::endl; + + SongItem song1("First song", ""); + bool result = worker.requestGeneration(song1, templateJson); + ASSERT_TRUE(result); + + QEventLoop loop; + QTimer timeoutTimer; + timeoutTimer.setSingleShot(true); + timeoutTimer.start(300000); + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit); + QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit); + + loop.exec(); + + ASSERT_TRUE(firstGenerationCompleted); + std::cout << " First generation completed, models should still be loaded" << std::endl; + + // Generate second song - in normal mode this should be faster since models are already loaded + bool secondGenerationCompleted = false; + SongItem secondResult; + QObject::connect(&worker, &AceStepWorker::songGenerated, + [&secondGenerationCompleted, &secondResult](const SongItem& song) { + secondGenerationCompleted = true; + secondResult = song; + }); + + std::cout << " Generating second song (should use cached models)..." << std::endl; + + SongItem song2("Second song", ""); + result = worker.requestGeneration(song2, templateJson); + ASSERT_TRUE(result); + + QEventLoop loop2; + QTimer timeoutTimer2; + timeoutTimer2.setSingleShot(true); + timeoutTimer2.start(300000); + + QObject::connect(&worker, &AceStepWorker::songGenerated, &loop2, &QEventLoop::quit); + QObject::connect(&worker, &AceStepWorker::generationError, &loop2, &QEventLoop::quit); + QObject::connect(&timeoutTimer2, &QTimer::timeout, &loop2, &QEventLoop::quit); + + loop2.exec(); + + ASSERT_TRUE(secondGenerationCompleted); + ASSERT_TRUE(secondResult.audioData != nullptr); + ASSERT_TRUE(!secondResult.audioData->isEmpty()); + + std::cout << " Second generation completed successfully" << std::endl; + std::cout << " Audio data size: " << secondResult.audioData->size() << " bytes" << std::endl; +} + +// Test 14: Test setLowVramMode toggle +TEST(lowVramModeToggle) +{ + AceStepWorker worker; + + // Default should be false (normal mode) + ASSERT_FALSE(worker.isLowVramMode()); + + // Enable low VRAM mode + worker.setLowVramMode(true); + ASSERT_TRUE(worker.isLowVramMode()); + + // Disable low VRAM mode + worker.setLowVramMode(false); + ASSERT_FALSE(worker.isLowVramMode()); + + // Toggle again + worker.setLowVramMode(true); + ASSERT_TRUE(worker.isLowVramMode()); +} + +int main(int argc, char *argv[]) +{ + QCoreApplication app(argc, argv); + + std::cout << "=== AceStepWorker Tests ===" << std::endl; + + RUN_TEST(initialState); + RUN_TEST(noModelPaths); + RUN_TEST(setModelPaths); + RUN_TEST(asyncReturnsImmediately); + RUN_TEST(cancellationFlag); + RUN_TEST(signalsExist); + RUN_TEST(requestConversion); + RUN_TEST(readSettings); + RUN_TEST(checkModelFiles); + RUN_TEST(generateSong); + RUN_TEST(cancellation); + RUN_TEST(generateSongLowVram); + RUN_TEST(normalModeKeepsModelsLoaded); + RUN_TEST(lowVramModeToggle); + + std::cout << "\n=== Results ===" << std::endl; + std::cout << "Passed: " << testsPassed << std::endl; + std::cout << "Skipped: " << testsSkipped << std::endl; + std::cout << "Failed: " << testsFailed << std::endl; + + return testsFailed > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/third_party/acestep.cpp b/third_party/acestep.cpp new file mode 160000 index 0000000..d28398d --- /dev/null +++ b/third_party/acestep.cpp @@ -0,0 +1 @@ +Subproject commit d28398db0ffdb77e8ae071ff31bde8c559e7085a