资源简介
基于KERAS实现的LSTM网络,有run.py, model.py , 数据处理模块和参数文件。用KERAS搭建的网络。很好理解。
代码片段和文件信息
__author__ = “Jakob Aungiers“
__copyright__ = “Jakob Aungiers 2018“
__version__ = “2.0.0“
__license__ = “MIT“
import os
import json
import time
import math
import matplotlib.pyplot as plt
from core.data_processor import DataLoader
from core.model import Model
def plot_results(predicted_data true_data):
fig = plt.figure(facecolor=‘white‘)
ax = fig.add_subplot(111)
ax.plot(true_data label=‘True Data‘)
plt.plot(predicted_data label=‘Prediction‘)
plt.legend()
plt.show()
def plot_results_multiple(predicted_data true_data prediction_len):
fig = plt.figure(facecolor=‘white‘)
ax = fig.add_subplot(111)
ax.plot(true_data label=‘True Data‘)
# Pad the list of predictions to shift it in the graph to it‘s correct start
for i data in enumerate(predicted_data):
padding = [None for p in range(i * prediction_len)]
plt.plot(padding + data label=‘Prediction‘)
plt.legend()
plt.show()
def main():
configs = json.load(open(‘config.json‘ ‘r‘))
if not os.path.exists(configs[‘model‘][‘save_dir‘]): os.makedirs(configs[‘model‘][‘save_dir‘])
data = DataLoader(
os.path.join(‘data‘ configs[‘data‘][‘filename‘])
configs[‘data‘][‘train_test_split‘]
configs[‘data‘][‘columns‘]
)
model = Model()
model.build_model(configs)
x y = data.get_train_data(
seq_len=configs[‘data‘][‘sequence_length‘]
normalise=configs[‘data‘][‘normalise‘]
)
‘‘‘
# in-memory training
model.train(
x
y
epochs = configs[‘training‘][‘epochs‘]
batch_size = configs[‘training‘][‘batch_size‘]
save_dir = configs[‘model‘][‘save_dir‘]
)
‘‘‘
# out-of memory generative training
steps_per_epoch = math.ceil((data.len_train - configs[‘data‘][‘sequence_length‘]) / configs[‘training‘][‘batch_size‘])
model.train_generator(
data_gen=data.generate_train_batch(
seq_len=configs[‘data‘][‘sequence_length‘]
batch_size=configs[‘training‘][‘batch_size‘]
normalise=configs[‘data‘][‘normalise‘]
)
epochs=configs[‘training‘][‘epochs‘]
batch_size=configs[‘training‘][‘batch_size‘]
steps_per_epoch=steps_per_epoch
save_dir=configs[‘model‘][‘save_dir‘]
)
x_test y_test = data.get_test_data(
seq_len=configs[‘data‘][‘sequence_length‘]
normalise=configs[‘data‘][‘normalise‘]
)
predictions = model.predict_sequences_multiple(x_test configs[‘data‘][‘sequence_length‘] configs[‘data‘][‘sequence_length‘])
# predictions = model.predict_sequence_full(x_test configs[‘data‘][‘sequence_length‘])
# predictions = model.predict_point_by_point(x_test)
plot_results_multiple(predictions y_test configs[‘data‘][‘sequence_length‘])
# plot_results(predictions y_test)
if __name__ == ‘__main__‘:
main()
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\
文件 12 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\.gitignore
文件 1084 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\README.md
文件 765 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\config.json
目录 0 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\core\
文件 362 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\core\__init__.py
文件 3562 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\core\data_processor.py
文件 4274 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\core\model.py
文件 237 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\core\utils.py
目录 0 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\data\
文件 61721 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\data\sinewave.csv
文件 310533 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\data\sp500.csv
文件 82 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\requirements.txt
文件 2870 2018-10-15 20:50 LSTM-Neural-Network-for-Time-Series-Prediction-master\run.py
相关资源
- 基于simbad平台的避障算法---栅格法改
- 学图像的必看-图像特征提取算法总汇
- 相关峰细化的精确时延估计快速算法
- 快速线结构光中心提取算法
- 密码学算法
- 遗传算法优化rbf网络108003
- 用于图象处理的量子遗传算法
- 页面置换算法
- 遗传算法优化BP网络(用于电力负荷预
- GAOT工具箱
- 工业调度粒子群算法
- 几种经典的Hash算法的实现(源代码)
- TensorFlow平台上基于LSTM神经网络的人体
- 粒子群算法(详细的算法介绍讲解及
- Patchwork水印算法 可以嵌入图像哦
- 多层建筑物应急疏散的模型和算法啊
- 文学研究助手与模式匹配算法
- 基于遗传算法的多目标优化.rar
- 遗传算法工具箱gatool
- 改进的蚁群算法及其在TSP中的应用研
- 基于msp430f149的FFT算法,结果在1602液晶
- 多种K-means聚类算法或改进算法包,
- 随机快速扩展树RRT路径规划算法代码
- 直线、平面、球体的拟合算法
- ( 高速数据链的挖掘算法——VFDT算法
- 光线追踪算法
- acm学习课件《ACM算法与程序设计》
- Chord算法实现
- 一维大地电磁测深遗传算法反演
- 直接插入排序/快速排序/选择排序/冒
评论
共有 条评论