From 831b137f8d445f2c247801264edd8109cb468977 Mon Sep 17 00:00:00 2001 From: uvos Date: Mon, 10 Jun 2024 19:41:51 +0200 Subject: [PATCH] update danboorutagger --- DanbooruTagger/DanbooruTagger.py | 74 ++++++++------------------------ 1 file changed, 19 insertions(+), 55 deletions(-) diff --git a/DanbooruTagger/DanbooruTagger.py b/DanbooruTagger/DanbooruTagger.py index 4c138d4..0ebeee1 100644 --- a/DanbooruTagger/DanbooruTagger.py +++ b/DanbooruTagger/DanbooruTagger.py @@ -1,12 +1,10 @@ import warnings from deepdanbooru_onnx import DeepDanbooru +from PIL import Image import argparse import cv2 -import torch import os -import numpy -from typing import Iterator -from torch.multiprocessing import Process, Queue +from multiprocessing import Process, Queue import json from tqdm import tqdm @@ -24,7 +22,7 @@ def find_image_files(path: str) -> list[str]: return paths -def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]: +def image_loader(paths: list[str]): for path in paths: name, extension = os.path.splitext(path) extension = extension.lower() @@ -33,46 +31,20 @@ def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]: if image is None: print(f"Warning: could not load {path}") else: - yield image, path + image_pil = Image.fromarray(image) + yield image_pil, path -def pipeline(queue: Queue, image_paths: list[str], prompt: str, device: torch.device, model_name_or_path: str, batch_size: int): - model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=None, - quantization_config=BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=False, - bnb_4bit_quant_type='nf4', - ), device_map=device, attn_implementation="flash_attention_2") - processor = AutoProcessor.from_pretrained(model_name_or_path) - image_generator = image_loader(image_paths) +def pipeline(queue: Queue, image_paths: list[str], device: int): + danbooru = DeepDanbooru() - stop = False - finished_count = 0 - while not stop: - prompts = list() - images = list() - filenames = list() - for i in range(0, batch_size): - image, filename = next(image_generator, (None, None)) - if image is None: - stop = True - break + for path in image_paths: + imageprompt = "" + tags = danbooru(path) + for tag in tags: + imageprompt = imageprompt + ", " + tag - filenames.append(filename) - images.append(image) - prompts.append(prompt) - - if len(images) == 0: - break - - inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device) - generate_ids = model.generate(**inputs, max_new_tokens=100, min_new_tokens=3, length_penalty=1.0, do_sample=False, temperature=1.0, top_k=50, top_p=1.0) - decodes = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - finished_count += len(images) - for i, decoded in enumerate(decodes): - trim = len(prompt) - len("") - queue.put({"file_name": filenames[i], "text": decoded[trim:].strip()}) + queue.put({"file_name": path, "text": imageprompt}) def split_list(input_list, count): @@ -90,31 +62,23 @@ def save_meta(meta_file, meta, reldir, common_description): if __name__ == "__main__": - parser = argparse.ArgumentParser("A script to tag images via llava") - parser.add_argument('--model', '-m', default="llava-hf/llava-1.5-13b-hf", help="model to use") - parser.add_argument('--quantize', '-q', action='store_true', help="load quantized") - parser.add_argument('--prompt', '-p', default="Please describe this image in 10 to 20 words.", help="Prompt to use on eatch image") + parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru") parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference") parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one") parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag") args = parser.parse_args() - prompt = "USER: \n" + args.prompt + "\nASSISTANT: " - os.environ["BITSANDBYTES_NOWELCOME"] = "1" + nparalell = 2 image_paths = find_image_files(args.image_dir) - image_path_chunks = list(split_list(image_paths, torch.cuda.device_count())) + image_path_chunks = list(split_list(image_paths, nparalell)) - print(f"Will use {torch.cuda.device_count()} processies to create tags") - - logging.set_verbosity_error() - warnings.filterwarnings("ignore") - torch.multiprocessing.set_start_method('spawn') + print(f"Will use {nparalell} processies to create tags") queue = Queue() processies = list() - for i in range(0, torch.cuda.device_count()): - processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], prompt, torch.device(i), args.model, args.batch))) + for i in range(0, nparalell): + processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i))) processies[-1].start() progress = tqdm(desc="Generateing tags", total=len(image_paths))