简介
Matting和分割是图像处理中两个重要的任务,它们在抠图和图像分析中起着不同的作用。
分割方法将图像分成不同的区域,并为每个像素分配一个分类标签,因此其输出是一个像素级别的分类标签图,通常是整型数据。这种方法适用于将图像中的不同对象或区域进行明确的区分。
而Matting方法则更侧重于提供像素级别的前景和背景的概率信息,通常表示为概率值P。Matting模型会为图像中的每个像素生成一个代表其属于前景的概率,从而在前景和背景交界处产生渐变效果,使得抠图更加自然。Matting模型训练完成后,会生成一个称为Alpha的值,用于表示每个像素的前景透明度。所有Alpha值的集合称为Alpha Matte,它可以被用来对原始图像进行精细的背景替换,使得合成的效果更加逼真。
PP-Matting
PP-Matting是一种无三分图的Matting架构,旨在实现高精度的自然图像Matting。其主要贡献包括:
- 提出了双分支架构,包括上下文分支(SCB)和高分辨率细节分支(HRDB),共享一个编码器。这种结构有助于并行高效地提取细节和语义特征,并通过引导流机制实现适当的交互,从而提高了Matting的准确性和自然度。
- 应用金字塔池模块(PPM)来加强语义上下文信息,并通过引导流策略帮助HRDB进行细节预测。这些方法使得PP-Matting能够在没有任何辅助信息的情况下进行端到端的训练,从而轻松地实现高精度的Matting。
- 在多个数据集上评估了PP-Matting的性能,结果表明其在消光任务上优于其他方法,并在人体抠图实验中展现出突出的性能。
PP-Matting的模型可根据用户对图像分辨率的需求,提供最匹配的模型,并在Trimap Free方向上实现了SOTA级别的精度。除了考虑模型性能外,PaddleSeg还特别优化了模型的部署环境,包括边缘端和服务端,针对模型体积等指标进行了优化。
针对人像场景,PaddleSeg还进行了特殊优化处理,提供了不同场景下的预训练模型和部署模型。这些模型既可直接部署使用,也可以根据具体任务进行微调,为用户提供了更加灵活的选择。
在技术实现方面,基于深度学习的Matting方法通常分为两大类:一种是基于辅助信息输入,另一种是不依赖任何辅助信息直接进行Alpha预测。
PP-Matting的设计初衷是为了方便用户快速实现抠图,因此用户在使用时无需依赖辅助信息的输入,即可直接获得预测结果。为了实现更高的效果,PP-Matting采用了Semantic context branch (SCB)和high-resolution detail branch (HRDB)两个分支,分别进行语义和细节预测,并通过引导流机制实现了语义引导下的高分辨率细节预测,从而实现了Trimap-free高精度图像抠图。
模型推理
源码下载地址:https://download.csdn.net/download/matt45m/89005564?spm=1001.2014.3001.5501 ,源码里面只有一个尺寸的模型,如果想更多尺寸模型,私信博主。
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <fstream>
#include <string>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
//#include <cuda_provider_factory.h> ///cuda加速,要配合onnxruntime gpu版本使用
#include <onnxruntime_cxx_api.h>class Matting
{
public:Matting();Matting(std::string model_path);void inference(cv::Mat &cv_src,std::vector<cv::Mat> &cv_dsts);
private:void preprocess(cv::Mat &cv_src);int inpWidth;int inpHeight;std::vector<float> input_image_;const float conf_threshold = 0.65;Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "Matting");Ort::Session* ort_session = nullptr;Ort::SessionOptions sessionOptions = Ort::SessionOptions();std::vector<char*> input_names;std::vector<char*> output_names;std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputsstd::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};Matting::Matting()
{}Matting::Matting(std::string model_path)
{std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); //windows//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); ///使用cuda加速sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions); //windows写法//ort_session = new Session(env, model_path.c_str(), sessionOptions); //linux写法size_t numInputNodes = ort_session->GetInputCount();size_t numOutputNodes = ort_session->GetOutputCount();Ort::AllocatorWithDefaultOptions allocator;for (int i = 0; i < numInputNodes; i++){input_names.push_back(ort_session->GetInputName(i, allocator));Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();auto input_dims = input_tensor_info.GetShape();input_node_dims.push_back(input_dims);}for (int i = 0; i < numOutputNodes; i++){output_names.push_back(ort_session->GetOutputName(i, allocator));Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();auto output_dims = output_tensor_info.GetShape();output_node_dims.push_back(output_dims);}this->inpHeight = input_node_dims[0][2];this->inpWidth = input_node_dims[0][3];
}void Matting::preprocess(cv::Mat &cv_src)
{cv::Mat cv_dst;cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight), cv::INTER_LINEAR);int row = cv_dst.rows;int col = cv_dst.cols;this->input_image_.resize(row * col * cv_dst.channels());for (int c = 0; c < 3; c++){for (int i = 0; i < row; i++){for (int j = 0; j < col; j++){float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];this->input_image_[c * row * col + i * col + j] = pix / 255.0;}}}
}void Matting::inference(cv::Mat &cv_src,std::vector<cv::Mat> &cv_dsts)
{this->preprocess(cv_src);std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, input_image_.data(),input_image_.size(), input_shape_.data(), input_shape_.size());std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, input_names.data(), &input_tensor_, 1, output_names.data(), output_names.size()); // 开始推理Ort::Value& mask_pred = ort_outputs.at(0);const int out_h = this->output_node_dims[0][2];const int out_w = this->output_node_dims[0][3];float* mask_ptr = mask_pred.GetTensorMutableData<float>();cv::Mat cv_map;cv::Mat cv_mask_out(out_h, out_w, CV_32FC1, mask_ptr);cv::resize(cv_mask_out, cv_map, cv::Size(cv_src.cols, cv_src.rows));cv::Mat cv_three_channel = cv::Mat::zeros(cv_src.rows, cv_src.cols, CV_32FC3);std::vector<cv::Mat> channels(3);for (int i = 0; i < 3; i++){channels[i] = cv_map;}merge(channels, cv_three_channel);cv::Mat cv_rgbimg = cv_src.clone();cv_rgbimg.setTo(cv::Scalar(0, 255, 0), cv_three_channel > this->conf_threshold);cv::Mat dstimg;cv::addWeighted(cv_src, 0.5, cv_rgbimg, 0.5, 0, dstimg);cv_dsts.push_back(cv_map);cv_dsts.push_back(dstimg);
}cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{int width = cv_src.cols;int height = cv_src.rows;cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);float* alpha_data = (float*)alpha.data;for (int i = 0; i < height; i++){for (int j = 0; j < width; j++){float alpha_ = alpha_data[i * width + j];cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];}}return cv_matting;
}int main()
{cv::Mat cv_src = cv::imread("images/6.jpg");Matting net("models/ppmatting_736x1280.onnx");std::vector<cv::Mat> cv_dsts;net.inference(cv_src, cv_dsts);std::vector<int> color{ 255, 255, 255 };cv::Mat cv_dst = replaceBG(cv_src, cv_dsts[0], color);cv::namedWindow("src", 0);cv::namedWindow("alpha", 0);cv::namedWindow("BG", 0);cv::imshow("src", cv_src);cv::imshow("alpha", cv_dsts[0]);cv::imshow("BG", cv_dst);cv::waitKey();
}