资源简介
构建CNN
LayerBuilder builder = new LayerBuilder();
builder.addLayer(Layer.buildInputLayer(new Size(28, 28)));
builder.addLayer(Layer.buildConvLayer(6, new Size(5, 5)));
builder.addLayer(Layer.buildSampLayer(new Size(2, 2)));
builder.addLayer(Layer.buildConvLayer(12, new Size(5, 5)));
builder.addLayer(Layer.buildSampLayer(new Size(2, 2)));
builder.addLayer(Layer.buildOutputLayer(10));
CNN cnn = new CNN(builder, 50);
运行MNIST数据集
String fileName = "data/train.format";
Dataset dataset = Dataset.load(fileName, ",", 784);
cnn.train(dataset, 100);
Dataset testset = Dataset.load("data/test.format", ",", -1);
cnn.predict(testset, "data/test.predict");
计算精度可以达到97.8%。
代码片段和文件信息
package info.hb.ccn.main;
import info.hb.cnn.core.CNN;
import info.hb.cnn.core.CNN.layerBuilder;
import info.hb.cnn.core.layer;
import info.hb.cnn.core.layer.Size;
import info.hb.cnn.data.DataSet;
import info.hb.cnn.utils.ConcurentRunner;
public class CNNMnist {
private static final String MODEL_NAME = “mnist/model/model.cnn“;
private static final String TRAIN_DATA = “mnist/train.format“;
private static final String TEST_DATA = “mnist/test.format“;
private static final String TEST_PREDICT = “mnist/test.predict“;
public static void main(String[] args) {
System.err.println(“训练阶段:“);
runTrain();
System.err.println(“测试阶段:“);
runTest();
ConcurentRunner.stop();
}
public static void runTrain() {
// 构建网络层次结构
layerBuilder builder = new layerBuilder();
builder.addlayer(layer.buildInputlayer(new Size(28 28))); // 输入层输出map大小为28×28
builder.addlayer(layer.buildConvlayer(6 new Size(5 5))); // 卷积层输出map大小为24×2424=28+1-5
builder.addlayer(layer.buildSamplayer(new Size(2 2))); // 采样层输出map大小为12×1212=24/2
builder.addlayer(layer.buildConvlayer(12 new Size(5 5))); // 卷积层输出map大小为8×88=12+1-5
builder.addlayer(layer.buildSamplayer(new Size(2 2))); // 采样层输出map大小为4×44=8/2
builder.addlayer(layer.buildOutputlayer(10));
CNN cnn = new CNN(builder 10);
// 加载训练数据
DataSet dataset = DataSet.load(TRAIN_DATA ““ 784);
// 开始训练模型
cnn.train(dataset 5);
// 保存训练好的模型
cnn.saveModel(MODEL_NAME);
dataset.clear();
}
public static void runTest() {
// 加载训练好的模型
CNN cnn = CNN.loadModel(MODEL_NAME);
// 加载测试数据
DataSet testSet = DataSet.load(TEST_DATA ““ -1);
// 预测结果
cnn.predict(testSet TEST_PREDICT);
testSet.clear();
}
}
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2016-08-21 08:14 cnn-master\
文件 45 2016-08-21 08:14 cnn-master\.gitignore
文件 754 2016-08-21 08:14 cnn-master\README.md
目录 0 2016-08-21 08:14 cnn-master\mnist\
文件 208 2016-08-21 08:14 cnn-master\mnist\readme.md
文件 43904000 2016-08-21 08:14 cnn-master\mnist\test.format
文件 18841570 2016-08-21 08:14 cnn-master\mnist\train.format
文件 3024 2016-08-21 08:14 cnn-master\pom.xm
目录 0 2016-08-21 08:14 cnn-master\speech\
目录 0 2016-08-21 08:14 cnn-master\speech\model\
文件 16718 2016-08-21 08:14 cnn-master\speech\model\model.cnn
文件 31110 2016-08-21 08:14 cnn-master\speech\test.format
文件 86 2016-08-21 08:14 cnn-master\speech\test.label
文件 86 2016-08-21 08:14 cnn-master\speech\test.predict
文件 144463 2016-08-21 08:14 cnn-master\speech\train.format
目录 0 2016-08-21 08:14 cnn-master\src\
目录 0 2016-08-21 08:14 cnn-master\src\main\
目录 0 2016-08-21 08:14 cnn-master\src\main\java\
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\ccn\
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\ccn\main\
文件 1887 2016-08-21 08:14 cnn-master\src\main\java\info\hb\ccn\main\CNNMnist.java
文件 1681 2016-08-21 08:14 cnn-master\src\main\java\info\hb\ccn\main\CNNSpeech.java
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\core\
文件 14926 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\core\CNN.java
文件 5893 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\core\la
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\data\
文件 3714 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\data\DataSet.java
目录 0 2016-08-21 08:14 cnn-master\src\main\java\info\hb\cnn\utils\
............此处省略17个文件信息
- 上一篇:当当网javaweb-SSM框架项目
- 下一篇:安卓拼音输入法代码
评论
共有 条评论