AITagger/main.cpp
2024-02-11 15:17:07 +01:00

197 lines
5.0 KiB
C++

#include <climits>
#include <iostream>
#include <numeric>
#include <vector>
#include <string>
#include <filesystem>
#include <functional>
#include <llama.h>
#include <algorithm>
#include <thread>
#include <cctype>
#include <fstream>
#include "options.h"
#include "utils.h"
#include "bar.h"
#include "llmai.h"
#include "tag.h"
static bool match_music_files(const std::filesystem::path& path)
{
static const std::vector<std::string> musicExtensions = {".mp3", ".ogg", ".flac", ".opus"};
auto search = std::find(musicExtensions.begin(), musicExtensions.end(), path.extension());
return search != musicExtensions.end();
}
static std::vector<llama_token> tokenize_path(const std::filesystem::path& p, llama_model *model, indicators::ProgressBar* bar = nullptr)
{
constexpr int32_t max_len = 255;
std::vector<llama_token> out(max_len);
std::string filename = p.filename().string();
int32_t ret = llama_tokenize(model, filename.c_str(), filename.length(), out.data(), out.size(), false, false);
if(ret < 0)
out.clear();
out.resize(ret);
if(bar)
bar->tick();
return out;
}
std::vector<std::string> get_responses_for_paths(llama_model* model, const std::vector<std::filesystem::path>& prompts)
{
auto* bar = bar_create_new("Tokenizeing prompts", prompts.size());
std::vector<std::vector<llama_token>> tokenizedPrompts(prompts.size());
std::transform(prompts.begin(), prompts.end(), tokenizedPrompts.begin(), [model, bar](const std::filesystem::path& p){return tokenize_path(p, model, bar);});
delete bar;
std::vector<std::string> responses(prompts.size());
bar = bar_create_new("Generating responses", prompts.size());
for(size_t i = 0; i < prompts.size(); ++i)
{
std::vector<llama_token> tokens = generate_text(tokenizedPrompts[i], model);
responses[i] = llama_untokenize(tokens, model);
bar->tick();
}
delete bar;
return responses;
}
void drop_log(enum ggml_log_level level, const char* text, void* user_data)
{
(void)level;
(void)text;
(void)user_data;
}
std::vector<AiTags> parse_tags_from_responses(const std::vector<std::string>& responses, const std::vector<std::filesystem::path>& paths)
{
assert(responses.size() == paths.size());
std::vector<AiTags> out(responses.size());
auto bar = bar_create_new("Processing responses", responses.size());
#pragma omp parallel for
for(size_t i = 0; i < responses.size(); ++i)
{
out[i].parseFromResponse(responses[i], paths[i]);
bar->tick();
}
delete bar;
return out;
}
int main(int argc, char** argv)
{
Config config = get_arguments(argc, argv);
if(!config.debug)
llama_log_set(drop_log, nullptr);
llama_backend_init(false);
llama_model_params modelParams = llama_model_default_params();
modelParams.n_gpu_layers = config.gpu ? 1000 : 0;
llama_model* model = llama_load_model_from_file(config.model.c_str(), modelParams);
if(!model)
{
std::cerr<<"Unable to load model from "<<config.model<<'\n';
return 2;
}
std::thread* barThread = bar_create_new_indeterminate("Searching for music files");
std::vector<std::filesystem::path> muiscPaths = recursive_get_matching_files(config.in, match_music_files);
bar_stop_indeterminate(barThread);
std::vector<std::string> responses = get_responses_for_paths(model, muiscPaths);
llama_free_model(model);
llama_backend_free();
if(config.debug)
{
std::ofstream file;
file.open("./debug.log");
if(file.is_open())
{
for(size_t i = 0; i < responses.size(); ++i)
file<<muiscPaths[i]<<":\n"<<responses[i]<<"\n\n";
file.close();
}
else
{
std::cerr<<"could not open ./debug.log for writeing";
}
}
std::vector<AiTags> aiTags = parse_tags_from_responses(responses, muiscPaths);
if(!std::filesystem::is_directory(config.out) && !std::filesystem::create_directory(config.out))
{
std::cerr<<config.out<<" is not a directory and a directory could not be created at this location\n";
return 1;
}
if(!config.rejectDir.empty() && !std::filesystem::is_directory(config.rejectDir) && !std::filesystem::create_directory(config.rejectDir))
{
std::cerr<<config.rejectDir<<" is not a directory and a directory could not be created at this location\n";
return 1;
}
auto bar = bar_create_new("Writeing tags to disk", aiTags.size());
#pragma omp parallel for
for(AiTags& tag : aiTags)
{
bar->tick();
if(!tag.isFilled())
{
std::error_code ec;
std::filesystem::copy(tag.path, config.rejectDir, ec);
if(ec)
std::cerr<<ec.message()<<'\n';
continue;
}
std::filesystem::path artistDir;
if(config.flatOutput)
{
artistDir = config.out;
}
else
{
artistDir = config.out/tag.artist;
if(!std::filesystem::is_directory(artistDir) && !std::filesystem::create_directory(artistDir))
{
std::cerr<<artistDir<<" is not a directory and a directory could not be created at this location\n";
return 1;
}
}
std::error_code ec;
std::filesystem::copy(tag.path, artistDir, ec);
if(ec)
{
std::cerr<<ec.message()<<'\n';
continue;
}
tag.path = artistDir/tag.path.filename();
tag.writeToFile();
}
delete bar;
return 0;
}