• 大小: 4KB
    文件类型: .m
    金币: 1
    下载: 0 次
    发布日期: 2021-06-13
  • 语言: Matlab
  • 标签: matlab  Kd-tree  kNN  

资源简介

使用matlab对输入数据建立Kd-tree并通过Kd-tree进行k-NN查询。k-NN查询的主要算法思路来自知乎【量化课堂】kd 树算法之详细篇

资源截图

代码片段和文件信息

clear;
close all;
clc;
% 生成数据
data = [2 3;
    5 4;
    9 6;
    4 7;
    8 1;
    7 2];
% 给数据标号
for i = 1: size(data1)
    data(i3) = i;
end
% 建立Kd树
Kd_tree = Kd_tree_create(data);
% 利用Kd树进行kNN查询
closest = Kd_tree_search_knn(Kd_tree [63.1] 2);

%% 使用data建立Kd树
function [tree] = Kd_tree_create(data)
% 生成Kd树,每次分割以方差最大的维度进行分割
[num dimension] = size(data);
dimension = dimension - 1;
for i = 1: dimension
    data_var(i) = var(data(:i));
end
[~ choose_dim] = max(data_var);
data = sortrows(data choose_dim);
tree.id = data(round(num/2)end);
tree.node = data(round(num/2)1:end-1);
tree.dim = choose_dim;
tree.parent = [];
tree.left = [];
tree.right = [];

% 递归生成左右子树
lefttree = [];
righttree = [];
if round(num/2) > 1
    leftdata = data(1:(round(num/2)-1) :);
    lefttree = Kd_tree_create(leftdata);
    for i = 1: size(lefttree 1)
        if isempty(lefttree(i).parent)
            lefttree(i).parent = tree.id;
            tree.left = lefttree(i).id;
        end
    end
end
if round(num/2) < num
    rightdata = data((round(num/2)+1):end :);
    righttree = Kd_tree_create(rightdata);
    for i = 1: size(righttree 1)
        if isempty(righttree(i).parent)
            righttree(i).parent = tree.id;
            tree.right = righttree(i).id;
        end
    end
end
tree = [tree; lefttree];
tree = [tree; righttree];
end


%% 利用Kd树进行kNN查询
function [closest_point] = Kd_tree_search_knn(Kd_tree data n)
% 从根节点开始一直查询到叶节点,找到和data在一个区域的叶节点
closest = Kd_tree(1);
while(1)
    if closest.node(closest.dim) >= data(closest.dim) && ~isempty(closest.left)
        closest = Kd_tree(find([Kd_tree.id]==closest.left));
    elseif closest.node(closest.dim) <= data(closest.dim) && ~isempty(closest.right)
        closest = Kd_tree(find([Kd_tree.id]==closest.right));
    else
        break
    end
end

Kd_tree(find([Kd_tree.id]==closest.id)).done = 1;
closest_point = closest.node;
[max_dis max_idx] = max(sum((closest_point - data).^2 2));
max_dis = max_dis(1);
max_idx = max_idx(1);

% 从当前节点向上回溯
node_now = closest;
while(1)
    % 回溯

评论

共有 条评论