update danboorutagger

This commit is contained in:
2024-06-10 19:41:51 +02:00
parent 28debaf267
commit 831b137f8d

View File

@ -1,12 +1,10 @@
import warnings import warnings
from deepdanbooru_onnx import DeepDanbooru from deepdanbooru_onnx import DeepDanbooru
from PIL import Image
import argparse import argparse
import cv2 import cv2
import torch
import os import os
import numpy from multiprocessing import Process, Queue
from typing import Iterator
from torch.multiprocessing import Process, Queue
import json import json
from tqdm import tqdm from tqdm import tqdm
@ -24,7 +22,7 @@ def find_image_files(path: str) -> list[str]:
return paths return paths
def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]: def image_loader(paths: list[str]):
for path in paths: for path in paths:
name, extension = os.path.splitext(path) name, extension = os.path.splitext(path)
extension = extension.lower() extension = extension.lower()
@ -33,46 +31,20 @@ def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
if image is None: if image is None:
print(f"Warning: could not load {path}") print(f"Warning: could not load {path}")
else: else:
yield image, path image_pil = Image.fromarray(image)
yield image_pil, path
def pipeline(queue: Queue, image_paths: list[str], prompt: str, device: torch.device, model_name_or_path: str, batch_size: int): def pipeline(queue: Queue, image_paths: list[str], device: int):
model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=None, danbooru = DeepDanbooru()
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
), device_map=device, attn_implementation="flash_attention_2")
processor = AutoProcessor.from_pretrained(model_name_or_path)
image_generator = image_loader(image_paths)
stop = False for path in image_paths:
finished_count = 0 imageprompt = ""
while not stop: tags = danbooru(path)
prompts = list() for tag in tags:
images = list() imageprompt = imageprompt + ", " + tag
filenames = list()
for i in range(0, batch_size):
image, filename = next(image_generator, (None, None))
if image is None:
stop = True
break
filenames.append(filename) queue.put({"file_name": path, "text": imageprompt})
images.append(image)
prompts.append(prompt)
if len(images) == 0:
break
inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_new_tokens=100, min_new_tokens=3, length_penalty=1.0, do_sample=False, temperature=1.0, top_k=50, top_p=1.0)
decodes = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finished_count += len(images)
for i, decoded in enumerate(decodes):
trim = len(prompt) - len("<image>")
queue.put({"file_name": filenames[i], "text": decoded[trim:].strip()})
def split_list(input_list, count): def split_list(input_list, count):
@ -90,31 +62,23 @@ def save_meta(meta_file, meta, reldir, common_description):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("A script to tag images via llava") parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru")
parser.add_argument('--model', '-m', default="llava-hf/llava-1.5-13b-hf", help="model to use")
parser.add_argument('--quantize', '-q', action='store_true', help="load quantized")
parser.add_argument('--prompt', '-p', default="Please describe this image in 10 to 20 words.", help="Prompt to use on eatch image")
parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference") parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one") parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag") parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag")
args = parser.parse_args() args = parser.parse_args()
prompt = "USER: <image>\n" + args.prompt + "\nASSISTANT: " nparalell = 2
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
image_paths = find_image_files(args.image_dir) image_paths = find_image_files(args.image_dir)
image_path_chunks = list(split_list(image_paths, torch.cuda.device_count())) image_path_chunks = list(split_list(image_paths, nparalell))
print(f"Will use {torch.cuda.device_count()} processies to create tags") print(f"Will use {nparalell} processies to create tags")
logging.set_verbosity_error()
warnings.filterwarnings("ignore")
torch.multiprocessing.set_start_method('spawn')
queue = Queue() queue = Queue()
processies = list() processies = list()
for i in range(0, torch.cuda.device_count()): for i in range(0, nparalell):
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], prompt, torch.device(i), args.model, args.batch))) processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i)))
processies[-1].start() processies[-1].start()
progress = tqdm(desc="Generateing tags", total=len(image_paths)) progress = tqdm(desc="Generateing tags", total=len(image_paths))