import math
import matplotlib.pyplot as pltclass Node:def __init__(self, data, left=None, right=None):self.data = dataself.left = leftself.right = right# 创建KDTree类
class KDTree:def __init__(self, k):self.k = kdef create_tree(self,dataset,depth):if not dataset:return Nonemid_index=len(dataset)//2 # 中位数axis = depth%self.k # 按照哪个坐标轴划分sorted_dataset = sorted(dataset,key=(lambda x : x[axis])) # 按照坐标轴划分mid_data = sorted_dataset[mid_index]#中位数数据值current_node = Node(mid_data) # 创建当前节点left_data = sorted_dataset[:mid_index] # 划分左节点数据right_data = sorted_dataset[mid_index+1:] # 划分右节点数据current_node.left = self.create_tree(left_data,depth+1) # 创建左子树current_node.right = self.create_tree(right_data,depth+1) # 创建右子树return current_nodedef search(self, tree, new_data):self.nearest_point = None # 当前最邻近点self.nearest_val = None # 当前最邻近点与目标节点间距离def dfs(node,depth): # 深度优先搜索# 递归找叶子节点if not node:return Noneaxis = depth % self.kif new_data[axis] < node.data[axis]:dfs(node.left, depth+1)else:dfs(node.right, depth+1)# 比较距离,判断是否更新最近邻点dist = self.distance(new_data,node.data)if not self.nearest_val or dist<self.nearest_val:self.nearest_val = distself.nearest_point = node.data# 判断是否遍历该节点另一边子树if abs(new_data[axis]-node.data[axis]) <= self.nearest_val: # 计算父节点在其分割特征上的data距离目标点在该特征上的data的距离。若该距离小于 nearest_val,则进入另一个孩子节点,否则不进入if new_data[axis] < node.data[axis]: # 之前若先遍历左子树,现在就要遍历右子树dfs(node.right, depth+1)else:dfs(node.left, depth+1)dfs(tree, 0)return self.nearest_pointdef distance(self,new_data, new_val):res = 0for i in range(self.k):res += (new_data[i]-new_val[i])**2return math.sqrt(res)if __name__ == '__main__':data_set = [[3,3],[5,4],[5,6],[2,7],[9,1],[2,5],[3,2],[2,0]new_data = [2,9]k = len(data_set[0])kd_tree = KDTree(k)our_tree = kd_tree.create_tree(data_set,0)predict = kd_tree.search(our_tree,new_data)print(f"Nearest Point of {new_data} is {predict}")plt.scatter([x[0] for x in data_set],[x[1] for x in data_set],c='purple',label='train_data')plt.scatter(new_data[0],new_data[1],c='red',label='target_data')plt.plot([predict[0], new_data[0]], [predict[1],new_data[1]], c='green',label='Nearest Point',linestyle='--')plt.legend()plt.show()
Node
类用于表示KD树的节点。data
保存当前节点的数据点。left
和right
分别指向左子树和右子树。KDTree
类用于创建和操作KD树。k
表示数据点的维度。
create_tree
方法用于递归地创建KD树。dataset
是要构建树的数据集。depth
表示当前节点的深度,用于确定划分的轴。- 根据深度计算轴并排序数据集,选择中位数作为当前节点的数据点。
- 递归地创建左子树和右子树。
search
方法用于在KD树中查找离new_data
最近的点。self.nearest_point
和self.nearest_val
用于保存当前找到的最近点及其距离。- 定义深度优先搜索
dfs
函数,递归地搜索树,更新最近点和距离。 - 检查是否需要遍历另一边的子树。
- 主程序创建数据集
data_set
和要查找的点new_data
。 - 初始化
KDTree
实例并创建KD树。 - 使用
search
方法查找最近点并打印结果。 - 使用
matplotlib
绘制数据点和最近邻点的连线。
参考文献Kd Tree算法详解_kd-tree-CSDN博客
Python手撸机器学习系列(十一):KNN之kd树实现_knn原理及python代码实现建立kd树-CSDN博客