神经网络有三层,输入层A,隐藏层B,输出层C,满足:
A(10x1)*W1(1x10)+b1(1x10)=B(10x10)
B(10x10)*W2(10x1)+b2(10x1)=C(10x1)
我们需要做的,便是通过多次训练(尝试不同 w、b 的值),找到合适的 w1w2、b1b2,使预测结果更接近真实结果。
代码:
import tensorflow.compat.v1 as tf
import numpy as np
import matplotlib.pyplot as plt
# 兼容 tensorflow 1.0
tf.disable_eager_execution()
# 日期,0-9
date = np.linspace(0, 9, 10)
# 收盘价格
endPrice = np.array([2511.90, 2538.26, 2510.68, 2591.66, 2732.98, 2701.69, 2701.29, 2678.67, 2726.50, 2681.50])
# 开盘价格
beginPrice = np.array([2438.71, 2739.17, 2715.07, 2823.58, 2864.90, 2919.08, 2500.88, 2534.95, 2512.52, 2594.04])
# 绘制涨幅趋势
# 价格上涨,红色
# 价格下跌,绿色
plt.figure()
for i in range(10):
temData = np.zeros([2])
temData[0] = i
temData[1] = i
temPrice = np.zeros([2])
temPrice[0] = beginPrice[i]
temPrice[1] = endPrice[i]
if beginPrice[i] > endPrice[i]:
plt.plot(temData, temPrice, "green", lw="2")
else:
plt.plot(temData, temPrice, "red", lw="2")
# 神经网络有三层,输入层A,隐藏层B,输出层C
# A(10x1)*W1(1x10)+b1(1x10)=B(10x10)
# B(10x10)*W2(10x1)+b2(10x1)=C(10x1)
# 输入层
# 归一化处理,减小计算量
dateNormal = np.zeros([10, 1])
priceNormal = np.zeros(([10, 1]))
for i in range(10):
dateNormal[i] = i / 9.0
priceNormal[i] = endPrice[i]/3000.0
# 声明输入输出参数,xy
x = tf.placeholder(tf.float32, [None, 1])
y = tf.placeholder(tf.float32, [None, 1])
# 隐藏层
w1 = tf.Variable(tf.random_uniform([1, 10], 0, 1))
b1 = tf.Variable(tf.zeros([1, 10]))
wb1 = tf.matmul(x, w1) + b1
# 激励函数
layer1 = tf.nn.relu(wb1)
# 输出层
w2 = tf.Variable(tf.random_uniform([10, 1], 0, 1))
b2 = tf.Variable(tf.zeros([10, 1]))
wb2 = tf.matmul(layer1, w2) + b2
# 激励函数
layer2 = tf.nn.relu(wb2)
# 预测值与实际值的标准差
loss = tf.reduce_mean(tf.square(y - layer2))
# 使用梯度下降法,减小差异
tran_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# 变量初始化
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 训练10000次
for i in range(10000):
sess.run(tran_step, feed_dict={x: dateNormal, y: priceNormal})
# 查看拟合结果
pre = sess.run(layer2, feed_dict={x: dateNormal})
prePrice = np.zeros([10, 1])
# 反归一化
for i in range(10):
prePrice[i, 0] = (pre * 3000)[i, 0]
# 绘制折线
plt.plot(date, prePrice, "blue", lw=2)
plt.show()
运行效果:
原文:https://www.cnblogs.com/bjxqmy/p/13488642.html