资源简介
用K近邻(KNN)做手写体识别(MNIST),准确率可以达到94%。关于具体原理可以看我的博客
代码片段和文件信息
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
public class HandwritingRecogniton {
static final int K = 20;
public static double calDistance(int[] a int[] b) {
double temp = 0;
for (int x = 0; x < a.length; x++) {
temp += (a[x] - b[x]) * (a[x] - b[x]);
}
return temp = Math.sqrt(temp);
}
public static double cosDis(int[] a int[] b) {
double numerator = 0 aLength = 0 bLength = 0;
for (int x = 0; x < a.length; x++) {
numerator += a[x] * b[x];
aLength += a[x];
bLength += b[x];
}
return numerator / (Math.sqrt(aLength) * Math.sqrt(bLength));
}
public static double hanmingDis(int[] a int[] b) {
double result = 0;
for (int x = 0; x < a.length; x++) {
result += Math.abs(a[x] - b[x]);
}
return result;
}
public static int[] str2int(String[] a) {
int[] b = new int[a.length];
for (int x = 0; x < a.length; x++) {
b[x] = Integer.parseInt(a[x]);
}
return b;
}
public static int classify(String filename int[] a) throws IOException {
FileReader fr = new FileReader(filename);
BufferedReader bufr = new BufferedReader(fr);
double[] d = new double[K];//存放K近邻的距离
for (int x = 0; x < K; x++) {//先将所有K近邻的距离初始化为最大距离28
d[x] = 28;
}
double temp = 0;
int lable = 0;
int[] num = new int[K];//记录对应距离的类标
String str = null;
int t = 0;
while ((str = bufr.readLine()) != null && t++ < 10000) {
int[] b = str2int(str.substring(0 str.length() - 1).split(““));
temp = calDistance(a b);
lable = Integer.parseInt(str.substring(str.length() - 1));
for (int x = 0; x < K; x++) {//找到K近邻的样本
if (temp < d[x]) {
d[x] = temp;
num[x] = lable;
break;
}
}
}
bufr.close();
int[] count = new int[10];
for (int x = 0; x < K; x++) {//统计各数字出现次数
count[num[x]]++;
}
int result = 0;
for (int x = 1; x < 10; x++) {//找出出现次数最多的
if (count[x] > count[result])
result = x;
}
return result;
}
public static void main(String[] arg) throws IOException {
System.out.println(System.currentTimeMillis());
FileReader fr = new FileReader(“validation.txt“);
BufferedReader bufr = new BufferedReader(fr);
int right = 0 sum = 0;
String str = null;
while ((str = bufr.readLine()) != null) {
int[] a = str2int(str.substring(0 str.length() - 1).split(““));
int result = classify(“train.txt“ a);
int lable = Integer.parseInt(str.substring(str.length() - 1));
sum++;
if (result == lable) {
right++;
// System.out.println(“result of classicication is:“ + result +
// “ original lable is:“ + lable);
} else {
int cc[][] = new int[28][28];
int count = 0;
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
cc[x][y] = a[count++];
}
}
System.out.println(“result of classicication is:“ + result + “ original lable is:“ + lable);
for (int x = 0; x <
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2017-01-21 19:17 ML_project\
文件 301 2016-12-12 17:47 ML_project\.classpath
文件 386 2016-12-12 17:47 ML_project\.project
目录 0 2017-01-21 19:17 ML_project\.settings\
文件 598 2016-12-12 17:47 ML_project\.settings\org.eclipse.jdt.core.prefs
目录 0 2017-01-21 20:26 ML_project\bin\
文件 3941 2017-01-21 20:26 ML_project\bin\HandwritingRecogniton.class
目录 0 2017-01-21 20:26 ML_project\src\
文件 3331 2017-01-21 20:26 ML_project\src\HandwritingRecogniton.java
文件 15710000 2016-12-12 18:22 ML_project\train.txt
文件 785500 2016-12-12 18:22 ML_project\validation.txt
评论
共有 条评论