1, 生成数据集
2,定义一个数据迭代器,迭代器可以每次遍历一次所有的数据集
3,模型参数初始化
4,定义线性回归函数
5,定义损失函数
6,定义优化算法,随机梯度下降
7,进行模型训练,输出迭代之后的loss函数
遇到的几个问题:
for epoch in epochs 这是一种错误的写法,一定要记得加range函数,integer直接不可以迭代iter
为防止广播,两个形状相同的矩阵,y_hat-y.reshape(y_hat) 错误,记得最后括号里面是y_hat.shape
虽然很简单,但是自己写和看代码的感觉完全不一样
需要学习几个点,一是迭代函数在python中,还有yield的使用(https://www.ibm.com/developerworks/cn/opensource/os-cn-python-yield/)
还有 nd函数的使用(https://www.jianshu.com/p/7faf137775c8)
初始化时候nd.random函数的参数 (https://www.jianshu.com/p/3ea2cf092815)
自己还需要再打一次。
原文:https://www.cnblogs.com/huayecai/p/10784906.html