PersonDatasetAssembler: fix raw extensions being case-sensetive

This commit is contained in:
2024-06-17 08:58:46 +02:00
parent cd1e2756bc
commit e083ce3da7

View File

@ -25,7 +25,7 @@ from typing import Iterator
import cv2 import cv2
import numpy import numpy
from tqdm import tqdm from tqdm import tqdm
from wand.exceptions import BlobError from wand.exceptions import BlobError, CoderError
from wand.image import Image from wand.image import Image
image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"] image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"]
@ -41,7 +41,7 @@ def find_image_files(path: str) -> list[str]:
for root, dirs, files in os.walk(path): for root, dirs, files in os.walk(path):
for filename in files: for filename in files:
name, extension = os.path.splitext(filename) name, extension = os.path.splitext(filename)
if extension.lower() in image_ext_ocv or extension in image_ext_wand: if extension.lower() in image_ext_ocv or extension.lower() in image_ext_wand:
paths.append(os.path.join(root, filename)) paths.append(os.path.join(root, filename))
return paths return paths
@ -58,10 +58,20 @@ def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
yield image yield image
elif extension in image_ext_wand: elif extension in image_ext_wand:
try: try:
image = Image(filename=path) wandImage = Image(filename=path)
wandImage.auto_orient()
image = numpy.array(wandImage)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
yield image
except BlobError as e: except BlobError as e:
print(f"Warning: could not load {path}, {e}") print(f"Warning: could not load {path}, {e}")
continue continue
except CoderError as e:
print(f"Warning: failure in wand while loading {path}, {e}")
else:
print(f"Warning: could not load {path}, {e}")
continue
def extract_video_images(video: cv2.VideoCapture, interval: int = 0): def extract_video_images(video: cv2.VideoCapture, interval: int = 0):
@ -132,7 +142,7 @@ if __name__ == "__main__":
recognizer = cv2.FaceRecognizerSF.create(model=args.match_model, config="", backend_id=cv2.dnn.DNN_BACKEND_DEFAULT , target_id=cv2.dnn.DNN_TARGET_CPU) recognizer = cv2.FaceRecognizerSF.create(model=args.match_model, config="", backend_id=cv2.dnn.DNN_BACKEND_DEFAULT , target_id=cv2.dnn.DNN_TARGET_CPU)
detector = cv2.FaceDetectorYN.create(model=args.detect_model, config="", input_size=[320, 320], detector = cv2.FaceDetectorYN.create(model=args.detect_model, config="", input_size=[320, 320],
score_threshold=0.6, nms_threshold=0.3, top_k=5000, backend_id=cv2.dnn.DNN_BACKEND_DEFAULT, target_id=cv2.dnn.DNN_TARGET_CPU) score_threshold=0.4, nms_threshold=0.2, top_k=5000, backend_id=cv2.dnn.DNN_BACKEND_DEFAULT, target_id=cv2.dnn.DNN_TARGET_CPU)
referance_features = process_referance(detector, recognizer, args.referance) referance_features = process_referance(detector, recognizer, args.referance)
if len(referance_features) < 1: if len(referance_features) < 1:
@ -166,7 +176,7 @@ if __name__ == "__main__":
resized = image resized = image
score, match = contains_face_match(detector, recognizer, resized, referance_features, args.threshold) score, match = contains_face_match(detector, recognizer, resized, referance_features, args.threshold)
if match and not args.invert or not match and args.invert: if match and not args.invert or not match and args.invert:
filename = f"{counter:04}.png" filename = f"{counter:04}.jpg"
cv2.imwrite(os.path.join(args.out, filename), image) cv2.imwrite(os.path.join(args.out, filename), image)
counter += 1 counter += 1
progress.set_description(f"{score:1.2f}") progress.set_description(f"{score:1.2f}")