资源简介

使用LSTM预测股价案例,超级精简,便于理解,是LSTM入门的好案例。

资源截图

代码片段和文件信息


# coding: utf-8

# In[1]:


import requests
import numpy as np
import matplotlib.pyplot as plt
import math
from numpy import *
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error

#通过深交所获取股票历史数据
url=“http://www.szse.cn/api/market/ssjjhq/getHistoryData?random=0.08048793408036281&cycleType=32&marketId=1&code=000058“
    
r=requests.get(url)
data=r.json()

#从数据中提取收盘价备用
data_arr=np.array(data[“data“][“picupdata“])
close_prices=data_arr[:2].astype(np.float32)
#print(close_prices)


# In[2]:


np.random.seed(8)


# In[3]:


#数据归一化
scaler=MinMaxScaler(feature_range=(01))
dataset=scaler.fit_transform(close_prices.reshape(len(close_prices)1))


# In[4]:


#从原始数据中分割出训练数据和测试数据
train_size=int(len(dataset)*0.618)
test_size=len(dataset)-train_size
traintest=dataset[:train_size:]dataset[train_size::]


# In[5]:


#格式化数据便于输入
def create_dataset(dataset look_back=1):
    dataX dataY = [] []
    for i in range(look_backlen(dataset)):
        a = dataset[i-look_back:i0]                          
        dataX.append([a])
    Y=dataset[look_back:0]
    return np.array(dataX) Y


# In[6]:


short_term=5 #设定短期周期为15天
train_Xtrain_Y=create_dataset(trainlook_back=short_term)
test_Xtest_Y=create_dataset(testlook_back=short_term)
origin_Xorigin_Y=create_dataset(datasetlook_back=short_term)


# In[7]:


#使用序列模型训练
model=Sequential()
model.add(LSTM(100input_shape=(train_X.shape[1]train_X.shape[2])))
model.add(Dense(1))
model.compile(loss=‘mean_squared_error‘ optimizer=‘adam‘) 

history = model.fit(train_X train_Y 
                    epochs=200 
                    batch_size=100 
                    validation_data=(test_X test_Y))

plt.plot(history.history[“loss“]label=“train“)
plt.plot(history.history[“val_loss“]label=“test“)
plt.legend()
plt.show()


# In[8]:


y_predicted=model.predict(origin_X)

plt.rcParams[‘font.sans-serif‘] = [‘SimHei‘]  # 用来正常显示中文标签
plt.rcParams[‘axes.unicode_minus‘] = False  # 用来正常显示负号


plt.figure(figsize=(12 4))

x_index=range(len(origin_Y))

#反归一化,还原为原价
origin_Y=scaler.inverse_transform(np.array(origin_Y).reshape((len(origin_Y)1)))
y_predicted=scaler.inverse_transform(y_predicted)

plt.plot(x_indexorigin_Ylabel=“真实收盘价“color=“#008080“)
plt.plot(x_indexy_predictedlabel=“预测收盘价“color=“#ff6666“)

plt.legend()

plt.show()


 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2019-08-29 18:15  LSTM股价预测(python)\
     文件       56650  2019-08-29 18:13  LSTM股价预测(python)\LSTM_stock_price_prediction.pdf
     文件        2655  2019-08-29 18:15  LSTM股价预测(python)\LSTM_stock_price_prediction.py

评论

共有 条评论