update danboorutagger
This commit is contained in:
		
							parent
							
								
									28debaf267
								
							
						
					
					
						commit
						831b137f8d
					
				
					 1 changed files with 19 additions and 55 deletions
				
			
		| 
						 | 
				
			
			@ -1,12 +1,10 @@
 | 
			
		|||
import warnings
 | 
			
		||||
from deepdanbooru_onnx import DeepDanbooru
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import argparse
 | 
			
		||||
import cv2
 | 
			
		||||
import torch
 | 
			
		||||
import os
 | 
			
		||||
import numpy
 | 
			
		||||
from typing import Iterator
 | 
			
		||||
from torch.multiprocessing import Process, Queue
 | 
			
		||||
from multiprocessing import Process, Queue
 | 
			
		||||
import json
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -24,7 +22,7 @@ def find_image_files(path: str) -> list[str]:
 | 
			
		|||
	return paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
 | 
			
		||||
def image_loader(paths: list[str]):
 | 
			
		||||
	for path in paths:
 | 
			
		||||
		name, extension = os.path.splitext(path)
 | 
			
		||||
		extension = extension.lower()
 | 
			
		||||
| 
						 | 
				
			
			@ -33,46 +31,20 @@ def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
 | 
			
		|||
		if image is None:
 | 
			
		||||
			print(f"Warning: could not load {path}")
 | 
			
		||||
		else:
 | 
			
		||||
			yield image, path
 | 
			
		||||
			image_pil = Image.fromarray(image)
 | 
			
		||||
			yield image_pil, path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pipeline(queue: Queue, image_paths: list[str], prompt: str, device: torch.device, model_name_or_path: str, batch_size: int):
 | 
			
		||||
	model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=None,
 | 
			
		||||
		quantization_config=BitsAndBytesConfig(
 | 
			
		||||
			load_in_4bit=True,
 | 
			
		||||
			bnb_4bit_compute_dtype=torch.float16,
 | 
			
		||||
			bnb_4bit_use_double_quant=False,
 | 
			
		||||
			bnb_4bit_quant_type='nf4',
 | 
			
		||||
			), device_map=device, attn_implementation="flash_attention_2")
 | 
			
		||||
	processor = AutoProcessor.from_pretrained(model_name_or_path)
 | 
			
		||||
	image_generator = image_loader(image_paths)
 | 
			
		||||
def pipeline(queue: Queue, image_paths: list[str], device: int):
 | 
			
		||||
	danbooru = DeepDanbooru()
 | 
			
		||||
 | 
			
		||||
	stop = False
 | 
			
		||||
	finished_count = 0
 | 
			
		||||
	while not stop:
 | 
			
		||||
		prompts = list()
 | 
			
		||||
		images = list()
 | 
			
		||||
		filenames = list()
 | 
			
		||||
		for i in range(0, batch_size):
 | 
			
		||||
			image, filename = next(image_generator, (None, None))
 | 
			
		||||
			if image is None:
 | 
			
		||||
				stop = True
 | 
			
		||||
				break
 | 
			
		||||
	for path in image_paths:
 | 
			
		||||
		imageprompt = ""
 | 
			
		||||
		tags = danbooru(path)
 | 
			
		||||
		for tag in tags:
 | 
			
		||||
			imageprompt = imageprompt + ", " + tag
 | 
			
		||||
 | 
			
		||||
			filenames.append(filename)
 | 
			
		||||
			images.append(image)
 | 
			
		||||
			prompts.append(prompt)
 | 
			
		||||
 | 
			
		||||
		if len(images) == 0:
 | 
			
		||||
			break
 | 
			
		||||
 | 
			
		||||
		inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
 | 
			
		||||
		generate_ids = model.generate(**inputs, max_new_tokens=100, min_new_tokens=3, length_penalty=1.0, do_sample=False, temperature=1.0, top_k=50, top_p=1.0)
 | 
			
		||||
		decodes = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
 | 
			
		||||
		finished_count += len(images)
 | 
			
		||||
		for i, decoded in enumerate(decodes):
 | 
			
		||||
			trim = len(prompt) - len("<image>")
 | 
			
		||||
			queue.put({"file_name": filenames[i], "text": decoded[trim:].strip()})
 | 
			
		||||
		queue.put({"file_name": path, "text": imageprompt})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_list(input_list, count):
 | 
			
		||||
| 
						 | 
				
			
			@ -90,31 +62,23 @@ def save_meta(meta_file, meta, reldir, common_description):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
	parser = argparse.ArgumentParser("A script to tag images via llava")
 | 
			
		||||
	parser.add_argument('--model', '-m', default="llava-hf/llava-1.5-13b-hf", help="model to use")
 | 
			
		||||
	parser.add_argument('--quantize', '-q', action='store_true', help="load quantized")
 | 
			
		||||
	parser.add_argument('--prompt', '-p', default="Please describe this image in 10 to 20 words.", help="Prompt to use on eatch image")
 | 
			
		||||
	parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru")
 | 
			
		||||
	parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
 | 
			
		||||
	parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
 | 
			
		||||
	parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag")
 | 
			
		||||
	args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
	prompt = "USER: <image>\n" + args.prompt + "\nASSISTANT: "
 | 
			
		||||
	os.environ["BITSANDBYTES_NOWELCOME"] = "1"
 | 
			
		||||
	nparalell = 2
 | 
			
		||||
 | 
			
		||||
	image_paths = find_image_files(args.image_dir)
 | 
			
		||||
	image_path_chunks = list(split_list(image_paths, torch.cuda.device_count()))
 | 
			
		||||
	image_path_chunks = list(split_list(image_paths, nparalell))
 | 
			
		||||
 | 
			
		||||
	print(f"Will use {torch.cuda.device_count()} processies to create tags")
 | 
			
		||||
 | 
			
		||||
	logging.set_verbosity_error()
 | 
			
		||||
	warnings.filterwarnings("ignore")
 | 
			
		||||
	torch.multiprocessing.set_start_method('spawn')
 | 
			
		||||
	print(f"Will use {nparalell} processies to create tags")
 | 
			
		||||
 | 
			
		||||
	queue = Queue()
 | 
			
		||||
	processies = list()
 | 
			
		||||
	for i in range(0, torch.cuda.device_count()):
 | 
			
		||||
		processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], prompt, torch.device(i), args.model, args.batch)))
 | 
			
		||||
	for i in range(0, nparalell):
 | 
			
		||||
		processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i)))
 | 
			
		||||
		processies[-1].start()
 | 
			
		||||
 | 
			
		||||
	progress = tqdm(desc="Generateing tags", total=len(image_paths))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue