需求:将下图中的花朵提取出来。
代码:
import cv2
import torch
import numpy as np
import timedef get_similar_colors(image, color_list, threshold):# 将图像和颜色列表转换为torch张量device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)# 计算每个像素与颜色列表中每个颜色的距离distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)# 找到最小距离及其索引min_distances, _ = torch.min(distances, dim=-1)# 创建掩码,标记接近目标颜色的像素mask = min_distances < threshold# 根据掩码提取接近颜色的部分result = torch.where(mask.unsqueeze(-1), image_tensor, torch.zeros_like(image_tensor))# 将结果转换回numpy数组result_np = result.cpu().numpy().astype(np.uint8)return result_np
# 读取图像s
image = cv2.imread('flower2.jpg')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(15, 220, 255),(30, 50, 220)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
进一步,输出掩码部分的黑白图像
import cv2
import torch
import numpy as np
import timedef get_similar_colors(image, color_list, threshold):# 将图像和颜色列表转换为torch张量device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)# 计算每个像素与颜色列表中每个颜色的距离distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)# 找到最小距离及其索引min_distances, _ = torch.min(distances, dim=-1)# 创建掩码,标记接近目标颜色的像素mask = min_distances < threshold# 将符合条件的像素设置为黑色result = np.ones_like(image_tensor)result[mask] = [0, 0, 0] # 设置为黑色return result
# 读取图像s
image = cv2.imread('your/image/path')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(50, 15, 0), (45, 10, 0), (30, 10, 0)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()