diff --git a/CMakeLists.txt b/CMakeLists.txt index a438385..33cdab5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ find_package(OpenCV REQUIRED) set(CMAKE_CXX_STANDARD 17) -set(SRC_FILES main.cpp yolo.cpp tokenize.cpp log.cpp seamcarving.cpp utils.cpp intelligentroi.cpp) +set(SRC_FILES main.cpp yolo.cpp tokenize.cpp log.cpp seamcarving.cpp utils.cpp intelligentroi.cpp facerecognizer.cpp) add_executable(${PROJECT_NAME} ${SRC_FILES}) target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} -ltbb) diff --git a/facerecognizer.cpp b/facerecognizer.cpp new file mode 100644 index 0000000..6534aaf --- /dev/null +++ b/facerecognizer.cpp @@ -0,0 +1,136 @@ +#include "facerecognizer.h" +#include + +#define INCBIN_PREFIX r +#include "incbin.h" + +INCBIN(defaultRecognizer, "../face_recognition_sface_2021dec.onnx"); +INCBIN(defaultDetector, "../face_detection_yunet_2023mar.onnx"); + +#include +#include +#include +#include + +#include "log.h" + +static const std::vector onnx((unsigned char*)rdefaultDetectorData, ((unsigned char*)rdefaultDetectorData)+rdefaultDetectorSize); + +FaceRecognizer::FaceRecognizer(std::filesystem::path recognizerPath, const std::filesystem::path& detectorPath, const std::vector& referances) +{ + if(detectorPath.empty()) + { + Log(Log::INFO)<<"Using builtin face detection model"; + + detector = cv::FaceDetectorYN::create("onnx", onnx, std::vector(), {320, 320}, 0.6, 0.3, 5000, cv::dnn::Backend::DNN_BACKEND_OPENCV, cv::dnn::Target::DNN_TARGET_CPU); + if(!detector) + throw LoadException("Unable to load detector network from built in file"); + } + else + { + detector = cv::FaceDetectorYN::create(detectorPath, "", {320, 320}, 0.6, 0.3, 5000, cv::dnn::Backend::DNN_BACKEND_OPENCV, cv::dnn::Target::DNN_TARGET_CPU); + if(!detector) + throw LoadException("Unable to load detector network from "+detectorPath.string()); + } + + bool defaultNetwork = recognizerPath.empty(); + + if(defaultNetwork) + { + Log(Log::INFO)<<"Using builtin face recognition model"; + recognizerPath = cv::tempfile("onnx"); + std::ofstream file(recognizerPath); + if(!file.is_open()) + throw LoadException("Unable open temporary file at "+recognizerPath.string()); + Log(Log::DEBUG)<<"Using "<(rdefaultRecognizerData), rdefaultRecognizerSize); + file.close(); + } + + recognizer = cv::FaceRecognizerSF::create(recognizerPath.string(), "", cv::dnn::Backend::DNN_BACKEND_OPENCV, cv::dnn::Target::DNN_TARGET_CPU); + + if(defaultNetwork) + std::filesystem::remove(recognizerPath); + + if(!recognizer) + throw LoadException("Unable to load recognizer network from "+recognizerPath.string()); + + addReferances(referances); +} + +cv::Mat FaceRecognizer::detectFaces(const cv::Mat& input) +{ + detector->setInputSize(input.size()); + cv::Mat faces; + detector->detect(input, faces); + return faces; +} + +bool FaceRecognizer::addReferances(const std::vector& referances) +{ + bool ret = false; + for(const cv::Mat& image : referances) + { + cv::Mat faces = detectFaces(image); + assert(faces.cols == 15); + if(faces.empty()) + { + Log(Log::WARN)<<"A referance image provided dose not contian any face"; + continue; + } + if(faces.rows > 1) + Log(Log::WARN)<<"A referance image provided contains more than one face, only the first detected face will be considered"; + cv::Mat cropedImage; + recognizer->alignCrop(image, faces.row(0), cropedImage); + cv::Mat features; + recognizer->feature(cropedImage, features); + referanceFeatures.push_back(features.clone()); + ret = true; + } + + return ret; +} + +void FaceRecognizer::setThreshold(double threasholdIn) +{ + threshold = threasholdIn; +} + +double FaceRecognizer::getThreshold() +{ + return threshold; +} + +void FaceRecognizer::clearReferances() +{ + referanceFeatures.clear(); +} + +std::pair FaceRecognizer::isMatch(const cv::Mat& input, bool alone) +{ + cv::Mat faces = detectFaces(input); + + if(alone && faces.rows > 1) + return {-2, 0}; + + std::pair bestMatch = {-1, 0}; + + for(int i = 0; i < faces.rows; ++i) + { + cv::Mat face; + recognizer->alignCrop(input, faces.row(0), face); + cv::Mat features; + recognizer->feature(face, features); + features = features.clone(); + for(size_t referanceIndex = 0; referanceIndex < referanceFeatures.size(); ++referanceIndex) + { + double score = recognizer->match(referanceFeatures[referanceIndex], features, cv::FaceRecognizerSF::FR_COSINE); + if(score > threshold && score > bestMatch.second) + { + bestMatch = {referanceIndex, score}; + } + } + } + + return bestMatch; +} diff --git a/facerecognizer.h b/facerecognizer.h new file mode 100644 index 0000000..20a2d9d --- /dev/null +++ b/facerecognizer.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +class FaceRecognizer +{ +public: + + class LoadException : public std::exception + { + private: + std::string message; + public: + LoadException(const std::string& msg): std::exception(), message(msg) {} + virtual const char* what() const throw() override + { + return message.c_str(); + } + }; + +private: + std::vector referanceFeatures; + std::shared_ptr recognizer; + std::shared_ptr detector; + + double threshold = 0.363; + +public: + FaceRecognizer(std::filesystem::path recognizerPath = "", const std::filesystem::path& detectorPath = "", const std::vector& referances = std::vector()); + cv::Mat detectFaces(const cv::Mat& input); + std::pair isMatch(const cv::Mat& input, bool alone = false); + bool addReferances(const std::vector& referances); + void setThreshold(double threashold); + double getThreshold(); + void clearReferances(); +}; diff --git a/main.cpp b/main.cpp index 640cf2e..834bfff 100644 --- a/main.cpp +++ b/main.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -15,6 +16,7 @@ #include "utils.h" #include "intelligentroi.h" #include "seamcarving.h" +#include "facerecognizer.h" const Yolo::Detection* pointInDetectionHoriz(int x, const std::vector& detections, const Yolo::Detection* ignore = nullptr) { @@ -223,7 +225,7 @@ void drawDebugInfo(cv::Mat &image, const cv::Rect& rect, const std::vector detections = yolo.runInference(image); + yoloMutex.unlock(); Log(Log::DEBUG)<<"Got "< match = recognizer->isMatch(person); + reconizerMutex.unlock(); + if(match.first >= 0) + { + detection.priority += 10; + hasmatch = true; + } + } + Log(Log::DEBUG)<addReferances({personImage}); + recognizer->setThreshold(config.threshold); + } + + std::mutex yoloMutex; + + auto pipelineLambda = [&yolo, &debugOutputPath, &config, &yoloMutex, &recognizer, &recognizerMutex](const std::filesystem::path& path) + { + pipeline(path, config, yolo, yoloMutex, recognizer, recognizerMutex, debugOutputPath); + }; + std::for_each(std::execution::par_unseq, imagePaths.begin(), imagePaths.end(), pipelineLambda); return 0; } diff --git a/options.h b/options.h index 4bb1abb..2d2aad7 100644 --- a/options.h +++ b/options.h @@ -20,7 +20,10 @@ static struct argp_option options[] = {"classes", 'c', "[FILENAME]", 0, "classes text file to use" }, {"out", 'o', "[DIRECTORY]", 0, "directory whre images are to be saved" }, {"debug", 'd', 0, 0, "output debug images" }, - {"seam-carving", 's', 0, 0, "model to train"}, + {"seam-carving", 's', 0, 0, "use seam carving to change image aspect ratio instead of croping"}, + {"size", 'z', "[PIXELS]", 0, "target output size, default: 512"}, + {"focus-person", 'f', "[FILENAME]", 0, "a file name to an image of a person that the crop should focus on"}, + {"person-threshold", 't', "[NUMBER]", 0, "the threshold at witch to consider a person matched, defaults to 0.363"}, {0} }; @@ -30,42 +33,64 @@ struct Config std::filesystem::path modelPath; std::filesystem::path classesPath; std::filesystem::path outputDir; + std::filesystem::path focusPersonImage; bool seamCarving = false; bool debug = false; + double threshold = 0.363; cv::Size targetSize = cv::Size(512, 512); }; static error_t parse_opt (int key, char *arg, struct argp_state *state) { Config *config = reinterpret_cast(state->input); - switch (key) + try { - case 'q': - Log::level = Log::ERROR; - break; - case 'v': - Log::level = Log::DEBUG; - break; - case 'm': - config->modelPath = arg; - break; - case 'c': - config->classesPath = arg; - break; - case 'd': - config->debug = true; - break; - case 'o': - config->outputDir.assign(arg); - break; - case 's': - config->seamCarving = true; - break; - case ARGP_KEY_ARG: - config->imagePaths.push_back(arg); - break; - default: - return ARGP_ERR_UNKNOWN; + switch (key) + { + case 'q': + Log::level = Log::ERROR; + break; + case 'v': + Log::level = Log::DEBUG; + break; + case 'm': + config->modelPath = arg; + break; + case 'c': + config->classesPath = arg; + break; + case 'd': + config->debug = true; + break; + case 'o': + config->outputDir.assign(arg); + break; + case 's': + config->seamCarving = true; + break; + case 'f': + config->focusPersonImage = arg; + break; + case 't': + config->threshold = std::atof(arg); + break; + case 'z': + { + int x = std::stoi(arg); + config->targetSize = cv::Size(x, x); + break; + } + case ARGP_KEY_ARG: + config->imagePaths.push_back(arg); + break; + default: + return ARGP_ERR_UNKNOWN; + } + } + catch(const std::invalid_argument& ex) + { + std::cout<(key)<<" is not a valid number.\n"; + return ARGP_KEY_ERROR; } return 0; }