Add low vram mode (unloads models)
This commit is contained in:
parent
14dec9f335
commit
216e59c105
7 changed files with 597 additions and 122 deletions
|
|
@ -324,6 +324,177 @@ TEST(cancellation)
|
|||
ASSERT_TRUE(true);
|
||||
}
|
||||
|
||||
// Test 12: Test low VRAM mode generation
|
||||
TEST(generateSongLowVram)
|
||||
{
|
||||
ModelPaths paths = getModelPathsFromSettings();
|
||||
|
||||
// Skip if models don't exist
|
||||
SKIP_IF(!modelsExist(paths));
|
||||
|
||||
AceStepWorker worker;
|
||||
worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath);
|
||||
worker.setLowVramMode(true);
|
||||
|
||||
ASSERT_TRUE(worker.isLowVramMode());
|
||||
|
||||
SongItem song("Chill electronic music", "");
|
||||
|
||||
QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})";
|
||||
|
||||
// Track generation result
|
||||
bool generationCompleted = false;
|
||||
SongItem resultSong;
|
||||
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||
[&generationCompleted, &resultSong](const SongItem& song) {
|
||||
std::cout << "\n Low VRAM mode: Song generated successfully!" << std::endl;
|
||||
std::cout << " Caption: " << song.caption.toStdString() << std::endl;
|
||||
if (song.audioData) {
|
||||
std::cout << " Audio data size: " << song.audioData->size() << " bytes" << std::endl;
|
||||
} else {
|
||||
std::cout << " Audio data size: null" << std::endl;
|
||||
}
|
||||
resultSong = song;
|
||||
generationCompleted = true;
|
||||
});
|
||||
|
||||
QString errorMsg;
|
||||
QObject::connect(&worker, &AceStepWorker::generationError,
|
||||
[&errorMsg](const QString& err) {
|
||||
std::cout << "\n Error: " << err.toStdString() << std::endl;
|
||||
errorMsg = err;
|
||||
});
|
||||
|
||||
std::cout << "\n Starting low VRAM mode generation..." << std::endl;
|
||||
|
||||
// Request generation
|
||||
bool result = worker.requestGeneration(song, templateJson);
|
||||
ASSERT_TRUE(result);
|
||||
|
||||
// Use QEventLoop with timer for proper event processing
|
||||
QEventLoop loop;
|
||||
QTimer timeoutTimer;
|
||||
|
||||
timeoutTimer.setSingleShot(true);
|
||||
timeoutTimer.start(300000); // 5 minute timeout
|
||||
|
||||
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit);
|
||||
QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit);
|
||||
QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||||
|
||||
loop.exec();
|
||||
|
||||
ASSERT_TRUE(generationCompleted);
|
||||
ASSERT_TRUE(resultSong.audioData != nullptr);
|
||||
ASSERT_TRUE(!resultSong.audioData->isEmpty());
|
||||
|
||||
std::cout << " Audio data size: " << resultSong.audioData->size() << " bytes" << std::endl;
|
||||
ASSERT_TRUE(resultSong.audioData->size() > 1000);
|
||||
}
|
||||
|
||||
// Test 13: Test normal mode keeps models loaded between generations
|
||||
TEST(normalModeKeepsModelsLoaded)
|
||||
{
|
||||
ModelPaths paths = getModelPathsFromSettings();
|
||||
|
||||
// Skip if models don't exist
|
||||
SKIP_IF(!modelsExist(paths));
|
||||
|
||||
AceStepWorker worker;
|
||||
worker.setModelPaths(paths.lmPath, paths.textEncoderPath, paths.ditPath, paths.vaePath);
|
||||
// Normal mode is default (lowVramMode = false)
|
||||
|
||||
ASSERT_FALSE(worker.isLowVramMode());
|
||||
|
||||
QString templateJson = R"({"inference_steps": 8, "shift": 3.0, "vocal_language": "en"})";
|
||||
|
||||
// Generate first song
|
||||
bool firstGenerationCompleted = false;
|
||||
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||
[&firstGenerationCompleted](const SongItem&) {
|
||||
firstGenerationCompleted = true;
|
||||
});
|
||||
|
||||
QObject::connect(&worker, &AceStepWorker::generationError,
|
||||
[](const QString& err) {
|
||||
std::cout << "\n Error: " << err.toStdString() << std::endl;
|
||||
});
|
||||
|
||||
std::cout << "\n Generating first song (normal mode)..." << std::endl;
|
||||
|
||||
SongItem song1("First song", "");
|
||||
bool result = worker.requestGeneration(song1, templateJson);
|
||||
ASSERT_TRUE(result);
|
||||
|
||||
QEventLoop loop;
|
||||
QTimer timeoutTimer;
|
||||
timeoutTimer.setSingleShot(true);
|
||||
timeoutTimer.start(300000);
|
||||
|
||||
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop, &QEventLoop::quit);
|
||||
QObject::connect(&worker, &AceStepWorker::generationError, &loop, &QEventLoop::quit);
|
||||
QObject::connect(&timeoutTimer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||||
|
||||
loop.exec();
|
||||
|
||||
ASSERT_TRUE(firstGenerationCompleted);
|
||||
std::cout << " First generation completed, models should still be loaded" << std::endl;
|
||||
|
||||
// Generate second song - in normal mode this should be faster since models are already loaded
|
||||
bool secondGenerationCompleted = false;
|
||||
SongItem secondResult;
|
||||
QObject::connect(&worker, &AceStepWorker::songGenerated,
|
||||
[&secondGenerationCompleted, &secondResult](const SongItem& song) {
|
||||
secondGenerationCompleted = true;
|
||||
secondResult = song;
|
||||
});
|
||||
|
||||
std::cout << " Generating second song (should use cached models)..." << std::endl;
|
||||
|
||||
SongItem song2("Second song", "");
|
||||
result = worker.requestGeneration(song2, templateJson);
|
||||
ASSERT_TRUE(result);
|
||||
|
||||
QEventLoop loop2;
|
||||
QTimer timeoutTimer2;
|
||||
timeoutTimer2.setSingleShot(true);
|
||||
timeoutTimer2.start(300000);
|
||||
|
||||
QObject::connect(&worker, &AceStepWorker::songGenerated, &loop2, &QEventLoop::quit);
|
||||
QObject::connect(&worker, &AceStepWorker::generationError, &loop2, &QEventLoop::quit);
|
||||
QObject::connect(&timeoutTimer2, &QTimer::timeout, &loop2, &QEventLoop::quit);
|
||||
|
||||
loop2.exec();
|
||||
|
||||
ASSERT_TRUE(secondGenerationCompleted);
|
||||
ASSERT_TRUE(secondResult.audioData != nullptr);
|
||||
ASSERT_TRUE(!secondResult.audioData->isEmpty());
|
||||
|
||||
std::cout << " Second generation completed successfully" << std::endl;
|
||||
std::cout << " Audio data size: " << secondResult.audioData->size() << " bytes" << std::endl;
|
||||
}
|
||||
|
||||
// Test 14: Test setLowVramMode toggle
|
||||
TEST(lowVramModeToggle)
|
||||
{
|
||||
AceStepWorker worker;
|
||||
|
||||
// Default should be false (normal mode)
|
||||
ASSERT_FALSE(worker.isLowVramMode());
|
||||
|
||||
// Enable low VRAM mode
|
||||
worker.setLowVramMode(true);
|
||||
ASSERT_TRUE(worker.isLowVramMode());
|
||||
|
||||
// Disable low VRAM mode
|
||||
worker.setLowVramMode(false);
|
||||
ASSERT_FALSE(worker.isLowVramMode());
|
||||
|
||||
// Toggle again
|
||||
worker.setLowVramMode(true);
|
||||
ASSERT_TRUE(worker.isLowVramMode());
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
QCoreApplication app(argc, argv);
|
||||
|
|
@ -341,6 +512,9 @@ int main(int argc, char *argv[])
|
|||
RUN_TEST(checkModelFiles);
|
||||
RUN_TEST(generateSong);
|
||||
RUN_TEST(cancellation);
|
||||
RUN_TEST(generateSongLowVram);
|
||||
RUN_TEST(normalModeKeepsModelsLoaded);
|
||||
RUN_TEST(lowVramModeToggle);
|
||||
|
||||
std::cout << "\n=== Results ===" << std::endl;
|
||||
std::cout << "Passed: " << testsPassed << std::endl;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue