我正试图编写一个程序,识别目录中哪些图像与查询图像相似,查询图像与目录中的图像相似,但通常略有不同。目录中有数千个图像。这个问题与比较图像相似性的简单快速方法有关。
我有几个目标:
- 使用查询图像,在图像目录中识别相似的图像
- 查询图像可能与目录中的图像略有不同。这些变化可能包括被裁剪的图像和不同的图像质量
- 这个程序应该很快(最多几秒钟就能识别出类似的图像(
我知道这是一个需要大量研究的问题。一章,";构建反向图像搜索引擎:理解嵌入";从";面向云、移动和边缘的实用深度学习;解释了这个问题的一些方法。
我开始编写一个程序,使用SIFT(尺度不变特征变换(+单词袋的方法来实现这一点。我在这方面没有太多经验。我编写的程序适用于相同的图像,也适用于稍相似的图像,但一旦图像变得更不相似,它就不再检测到正确的图像。
我有两个问题:
- 我使用的方法好吗?如果不好,还有更好的方法吗
- 我的程序中是否有任何内容可能导致对不同图像的搜索不准确
这就是程序的工作方式:
- 遍历每个图像,使用SIFT获取其描述符,并构建这些描述符的列表
- 使用k-means,找到描述符列表的质心。这就是";字典">
- 再次遍历每个图像,并获得每个图像的描述符和质心的k个最近邻居knnMatch,其中k=1。使用match.trainIdx,使用每个匹配为每个图像创建直方图
- 通过将每个"直方图"的计数除以"直方图"来归一化每个图像的直方图;单词";乘以";单词">
- 将k=1的knnMatch与查询图像的描述符和质心一起使用。浏览匹配项并创建一个标准化的直方图
- 在查询图像的直方图以及数据库中所有图像的直方图上使用k=1的knnMatch。这将创建一个匹配列表,按与查询图像的相似性排序
import numpy as np
import cv2
import os
from matplotlib import pyplot as plt
sift = cv2.xfeatures2d.SIFT_create()
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 100)
search_params = dict(checks = 100)
flann = cv2.FlannBasedMatcher(index_params, search_params)
bf = cv2.BFMatcher()
img1 = cv2.imread('path',0)
db = # load database
kp1, des1 = sift.detectAndCompute(img1,None)
load = False
clusters = 800
if load:
db.query('DELETE FROM centroids')
db.query('DELETE FROM histogram')
descriptors = []
for file in os.listdir('path'):
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file), 0)
kp, des = sift.detectAndCompute(img,None)
if des is None:
continue
descriptors.extend(des)
descriptors = np.float32(descriptors)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 5, .01)
centroids = cv2.kmeans(descriptors, clusters, None, criteria, 1, cv2.KMEANS_PP_CENTERS)[2]
db.insert('centroids', d = np.ndarray.dumps(centroids))
for file in os.listdir('path'):
counter = np.zeros((clusters,), dtype=np.uint32)
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file),0)
kp, d = sift.detectAndCompute(img,None)
if d is None:
continue
matches = bf.knnMatch(d, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
db.insert('histogram', frame_id = file, count=','.join(np.char.mod('%f', counter)))
histograms_db = list(db.query('SELECT * FROM histogram'))
histograms = []
for histogram in histograms_db:
histogram = histogram['count'].split(',')
histograms.append(histogram)
histograms = np.array(histograms)
counter = np.zeros((clusters,), dtype=np.uint32)
centroids = np.loads(db.query('SELECT * FROM centroids')[0]['d'])
matches = bf.knnMatch(des1, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
matches = bf.knnMatch(np.float32([counter]), np.float32(histograms), k=1)
for match in matches[0]:
print "{} {}".format(histograms_db[match.trainIdx]['frame_id'], match.distance)
name = histograms_db[match.trainIdx]['frame_id']
您可以使用任何近似的最近邻搜索库。例如,试试Faiss。