initial commit

This commit is contained in:
2024-06-14 08:54:09 +02:00
commit cd1e2756bc
39 changed files with 4163 additions and 0 deletions

142
LLavaTagger/LLavaTagger.py Normal file
View File

@ -0,0 +1,142 @@
import warnings
warnings.simplefilter(action='ignore')
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig, logging
import argparse
import cv2
import torch
import os
import numpy
from typing import Iterator
from torch.multiprocessing import Process, Queue
import json
from tqdm import tqdm
image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"]
def find_image_files(path: str) -> list[str]:
paths = list()
for root, dirs, files in os.walk(path):
for filename in files:
name, extension = os.path.splitext(filename)
if extension.lower() in image_ext_ocv:
paths.append(os.path.join(root, filename))
return paths
def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
for path in paths:
name, extension = os.path.splitext(path)
extension = extension.lower()
imagebgr = cv2.imread(path)
image = cv2.cvtColor(imagebgr, cv2.COLOR_BGR2RGB)
if image is None:
print(f"Warning: could not load {path}")
else:
yield image, path
def pipeline(queue: Queue, image_paths: list[str], prompt: str, device: torch.device, model_name_or_path: str, batch_size: int):
model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
), device_map=device, attn_implementation="flash_attention_2")
processor = AutoProcessor.from_pretrained(model_name_or_path)
image_generator = image_loader(image_paths)
stop = False
finished_count = 0
while not stop:
prompts = list()
images = list()
filenames = list()
for i in range(0, batch_size):
image, filename = next(image_generator, (None, None))
if image is None:
stop = True
break
filenames.append(filename)
images.append(image)
prompts.append(prompt)
if len(images) == 0:
break
inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_new_tokens=100, min_new_tokens=3, length_penalty=1.0, do_sample=False, temperature=1.0, top_k=50, top_p=1.0)
decodes = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finished_count += len(images)
for i, decoded in enumerate(decodes):
trim = len(prompt) - len("<image>")
queue.put({"file_name": filenames[i], "text": decoded[trim:].strip()})
def split_list(input_list, count):
target_length = int(len(input_list) / count)
for i in range(0, count - 1):
yield input_list[i * target_length: (i + 1) * target_length]
yield input_list[(count - 1) * target_length: len(input_list)]
def save_meta(meta_file, meta, reldir, common_description):
meta["file_name"] = os.path.relpath(meta["file_name"], reldir)
if common_description is not None:
meta["text"] = common_description + meta["text"]
meta_file.write(json.dumps(meta) + '\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser("A script to tag images via llava")
parser.add_argument('--model', '-m', default="llava-hf/llava-1.5-13b-hf", help="model to use")
parser.add_argument('--quantize', '-q', action='store_true', help="load quantized")
parser.add_argument('--prompt', '-p', default="Please describe this image in 10 to 20 words.", help="Prompt to use on eatch image")
parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
parser.add_argument('--image_dir', '-i', required=True, help="A directory containg the images to tag")
args = parser.parse_args()
prompt = "USER: <image>\n" + args.prompt + "\nASSISTANT: "
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
image_paths = find_image_files(args.image_dir)
image_path_chunks = list(split_list(image_paths, torch.cuda.device_count()))
print(f"Will use {torch.cuda.device_count()} processies to create tags")
logging.set_verbosity_error()
warnings.filterwarnings("ignore")
torch.multiprocessing.set_start_method('spawn')
queue = Queue()
processies = list()
for i in range(0, torch.cuda.device_count()):
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], prompt, torch.device(i), args.model, args.batch)))
processies[-1].start()
progress = tqdm(desc="Generateing tags", total=len(image_paths))
exit = False
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
while not exit:
if not queue.empty():
meta = queue.get()
save_meta(output_file, meta, args.image_dir, args.common_description)
progress.update()
exit = True
for process in processies:
if process.is_alive():
exit = False
break
while not queue.empty():
meta = queue.get()
save_meta(output_file, meta, args.image_dir, args.common_description)
progress.update()
for process in processies:
process.join()

21
LLavaTagger/README.md Normal file
View File

@ -0,0 +1,21 @@
# LLavaTagger
LLavaTagger is a python script that tags images based on a given prompt using the [LLaVA](https://llava-vl.github.io/) multi modal llm. LLavaTagger supports using any number of gpus in ddp parralel for this task.
## How to use
first create a python venv and install the required packages into it:
$ python -m venv venv
$ source venv/bin/activate
$ pip install -r requirements.txt
Then run LLavaTagger for instance like so:
$ python LLavaTagger.py --common_description "a image of a cat, " --prompt "describe the cat in 10 to 20 words" --batch 8 --quantize --image_dir ~/cat_images
By default LLavaTagger will run in parallel on all available gpus, if this is undesriable please use the ROCR_VISIBLE_DEVICES= or CUDA_VISIBLE_DEVICES= environment variable to hide unwanted gpus
LLavaTagger will then create a meta.jsonl in the image directory sutable to be used by the scripts of [diffusers](https://github.com/huggingface/diffusers) to train stable diffusion (xl) if other formats are desired ../utils contains scripts to transform the metadata into other formats for instace for the use with [kohya](https://github.com/bmaltais/kohya_ss)
If editing the created tags is desired, [QImageTagger](https://uvos.xyz/git/uvos/QImageTagger) can be used for this purpose

View File

@ -0,0 +1,11 @@
accelerate==0.29.0
bitsandbytes
huggingface-hub==0.22.2
ninja==1.11.1.1
safetensors==0.4.2
tokenizers==0.15.2
transformers
torch
opencv-python
numpy
tqdm