update danboorutagger

This commit is contained in:
uvos 2024-06-10 19:41:51 +02:00
parent 28debaf267
commit 831b137f8d

View File

@ -1,12 +1,10 @@
import warnings
from deepdanbooru_onnx import DeepDanbooru
from PIL import Image
import argparse
import cv2
import torch
import os
import numpy
from typing import Iterator
from torch.multiprocessing import Process, Queue
from multiprocessing import Process, Queue
import json
from tqdm import tqdm
@ -24,7 +22,7 @@ def find_image_files(path: str) -> list[str]:
return paths
def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
def image_loader(paths: list[str]):
for path in paths:
name, extension = os.path.splitext(path)
extension = extension.lower()
@ -33,46 +31,20 @@ def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
if image is None:
print(f"Warning: could not load {path}")
else:
yield image, path
image_pil = Image.fromarray(image)
yield image_pil, path
def pipeline(queue: Queue, image_paths: list[str], prompt: str, device: torch.device, model_name_or_path: str, batch_size: int):
model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
), device_map=device, attn_implementation="flash_attention_2")
processor = AutoProcessor.from_pretrained(model_name_or_path)
image_generator = image_loader(image_paths)
def pipeline(queue: Queue, image_paths: list[str], device: int):
danbooru = DeepDanbooru()
stop = False
finished_count = 0
while not stop:
prompts = list()
images = list()
filenames = list()
for i in range(0, batch_size):
image, filename = next(image_generator, (None, None))
if image is None:
stop = True
break
for path in image_paths:
imageprompt = ""
tags = danbooru(path)
for tag in tags:
imageprompt = imageprompt + ", " + tag
filenames.append(filename)
images.append(image)
prompts.append(prompt)
if len(images) == 0:
break
inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_new_tokens=100, min_new_tokens=3, length_penalty=1.0, do_sample=False, temperature=1.0, top_k=50, top_p=1.0)
decodes = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finished_count += len(images)
for i, decoded in enumerate(decodes):
trim = len(prompt) - len("<image>")
queue.put({"file_name": filenames[i], "text": decoded[trim:].strip()})
queue.put({"file_name": path, "text": imageprompt})
def split_list(input_list, count):
@ -90,31 +62,23 @@ def save_meta(meta_file, meta, reldir, common_description):
if __name__ == "__main__":
parser = argparse.ArgumentParser("A script to tag images via llava")
parser.add_argument('--model', '-m', default="llava-hf/llava-1.5-13b-hf", help="model to use")
parser.add_argument('--quantize', '-q', action='store_true', help="load quantized")
parser.add_argument('--prompt', '-p', default="Please describe this image in 10 to 20 words.", help="Prompt to use on eatch image")
parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru")
parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag")
args = parser.parse_args()
prompt = "USER: <image>\n" + args.prompt + "\nASSISTANT: "
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
nparalell = 2
image_paths = find_image_files(args.image_dir)
image_path_chunks = list(split_list(image_paths, torch.cuda.device_count()))
image_path_chunks = list(split_list(image_paths, nparalell))
print(f"Will use {torch.cuda.device_count()} processies to create tags")
logging.set_verbosity_error()
warnings.filterwarnings("ignore")
torch.multiprocessing.set_start_method('spawn')
print(f"Will use {nparalell} processies to create tags")
queue = Queue()
processies = list()
for i in range(0, torch.cuda.device_count()):
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], prompt, torch.device(i), args.model, args.batch)))
for i in range(0, nparalell):
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i)))
processies[-1].start()
progress = tqdm(desc="Generateing tags", total=len(image_paths))