#include #include #include #include #include #include #include #include #include #include #include #include #include "options.h" #include "utils.h" #include "bar.h" #include "llmai.h" #include "tag.h" static bool match_music_files(const std::filesystem::path& path) { static const std::vector musicExtensions = {".mp3", ".ogg", ".flac", ".opus"}; auto search = std::find(musicExtensions.begin(), musicExtensions.end(), path.extension()); return search != musicExtensions.end(); } static std::vector tokenize_path(const std::filesystem::path& p, llama_model *model, indicators::ProgressBar* bar = nullptr) { constexpr int32_t max_len = 255; std::vector out(max_len); std::string filename = p.filename().string(); int32_t ret = llama_tokenize(model, filename.c_str(), filename.length(), out.data(), out.size(), false, false); if(ret < 0) out.clear(); out.resize(ret); if(bar) bar->tick(); return out; } std::vector get_responses_for_paths(llama_model* model, const std::vector& prompts) { auto* bar = bar_create_new("Tokenizeing prompts", prompts.size()); std::vector> tokenizedPrompts(prompts.size()); std::transform(prompts.begin(), prompts.end(), tokenizedPrompts.begin(), [model, bar](const std::filesystem::path& p){return tokenize_path(p, model, bar);}); delete bar; std::vector responses(prompts.size()); bar = bar_create_new("Generating responses", prompts.size()); for(size_t i = 0; i < prompts.size(); ++i) { std::vector tokens = generate_text(tokenizedPrompts[i], model); responses[i] = llama_untokenize(tokens, model); bar->tick(); } delete bar; return responses; } void drop_log(enum ggml_log_level level, const char* text, void* user_data) { (void)level; (void)text; (void)user_data; } std::vector parse_tags_from_responses(const std::vector& responses, const std::vector& paths) { assert(responses.size() == paths.size()); std::vector out(responses.size()); auto bar = bar_create_new("Processing responses", responses.size()); #pragma omp parallel for for(size_t i = 0; i < responses.size(); ++i) { out[i].parseFromResponse(responses[i], paths[i]); bar->tick(); } delete bar; return out; } int main(int argc, char** argv) { Config config = get_arguments(argc, argv); if(!config.debug) llama_log_set(drop_log, nullptr); llama_backend_init(false); llama_model_params modelParams = llama_model_default_params(); modelParams.n_gpu_layers = config.gpu ? 1000 : 0; llama_model* model = llama_load_model_from_file(config.model.c_str(), modelParams); if(!model) { std::cerr<<"Unable to load model from "< muiscPaths = recursive_get_matching_files(config.in, match_music_files); bar_stop_indeterminate(barThread); std::vector responses = get_responses_for_paths(model, muiscPaths); llama_free_model(model); llama_backend_free(); if(config.debug) { std::ofstream file; file.open("./debug.log"); if(file.is_open()) { for(size_t i = 0; i < responses.size(); ++i) file< aiTags = parse_tags_from_responses(responses, muiscPaths); if(!std::filesystem::is_directory(config.out) && !std::filesystem::create_directory(config.out)) { std::cerr<tick(); if(!tag.isFilled()) { std::error_code ec; std::filesystem::copy(tag.path, config.rejectDir, ec); if(ec) std::cerr<