用二叉树实现自动求导(Python版)
时间:2022-07-28
本文章向大家介绍用二叉树实现自动求导(Python版),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
最近在研究怎么用C++从头写一个深度学习训练框架,在写自动求导的时候顺手写了个Python版,代码更简单一些,在这里分享给大家。
思路
这里采用二叉树的形式来表示计算图,原理就是绝大部分深度学习中的运算都可以分解为二元运算(两个输入得到一个输出)。
比如说
可以表示成如下的形式(懒得画图了,大家将就看):
x
* -- temp
/
w
+ -- z
/
/
b
这样就形成了一个由二叉树表示的计算图,其中z是根节点。
在反向传播时,先从根节点开始计算梯度:
再从temp计算x和w的梯度:
同理
。
这是很简单的反向传播过程,相信大家都了解,那么要如何用Python实现呢。
Python实现
首先要定义一个节点类,这里我们给它取名叫Tensor(Pytorch用习惯了),然后重载这个类的四则运算,使对象在进行四则运算的同时建立运算关系。
class Tensor:
def __init__(self,data,left=None,right=None,op = None):
self.data = data
self.grad = 0 # 如果当前节点有多个父节点,则梯度需要叠加,所以grad初始化为0更方便一点
self.left = left
self.right = right
self.op = op
def __add__(self, other):
data = self.data + other.data
t = Tensor(data,left = self,right=other,op = "add")
return t
def __sub__(self, other):
data = self.data - other.data
t = Tensor(data,left = self,right=other,op = "sub")
return t
def __mul__(self, other):
data = self.data * other.data
t = Tensor(data, left=self, right=other, op="mul")
return t
def __truediv__(self, other):
if other.data - 0 < 1e-9:
raise Exception("Can't divide zero")
data = self.data / other.data
t = Tensor(data, left=self, right=other, op="div")
return t
def backward(self,init_grad = 1):
# init_grad: 来自上一层的梯度
if self.left is not None:
if self.op == "add":
self.left.grad += 1 * init_grad
elif self.op == "sub":
self.left.grad += 1 * init_grad
elif self.op == "mul":
self.left.grad += self.right.data * init_grad
elif self.op == "div":
self.left.grad += 1 / self.right.data * init_grad
else:
raise Exception("Op unacceptable")
self.left.backward(self.left.grad)
if self.right is not None:
if self.op == "add":
self.right.grad += 1 * init_grad
elif self.op == "sub":
self.right.grad += -1 * init_grad
elif self.op == "mul":
self.right.grad += self.left.data * init_grad
elif self.op == "div":
self.right.grad += (-1 * self.left.data / (self.right.data*self.right.data)) * init_grad
else:
raise Exception("Op unacceptable")
self.right.backward(self.right.grad)
验证
写好了反向传播之后,可以通过两种方式验证争取性,第一就是用带反向传播的框架写一个同样的计算,对比一下梯度计算结果是否相同。
a = Tensor(1.0)
b = Tensor(2.0)
c = a * b + a / b - a * a * a
c.backward()
print("grad na:{} b:{}".format(a.grad,b.grad))
import torch
# 需要用小写的torch.tensor才能添加requires_grad参数
m = torch.tensor([[1.0]],requires_grad=True)
n = torch.tensor([[2.0]],requires_grad=True)
k = m * n + m / n - m * m * m
k.backward()
print("grad torchnm:{} n:{}".format(m.grad.item(),n.grad.item()))
第二种方法就是用这个类写一个基于梯度下降的线性回归模型,看看能不能收敛了。
class Linear_regression:
def __init__(self):
self.w = Tensor(1.0)
self.b = Tensor(1.0)
self.lr = Tensor(0.02)
def fit(self,x,y,num_epochs = 60,show=True):
if show:
fig = plt.figure()
plt.scatter(x,y,color = 'r')
ims = []
for epoch in range(num_epochs):
losses = 0.0
for m,n in zip(x,y):
yp = self.w * Tensor(m) + self.b
loss = (Tensor(n) - yp) * (Tensor(n) - yp)
loss.backward()
self.w -= self.lr * Tensor(self.w.grad)
self.b -= self.lr * Tensor(self.b.grad)
self.w.grad = 0
self.b.grad = 0
# 切断计算图
self.w.left = None
self.w.right = None
self.b.right = None
self.b.left = None
losses += loss.data
print(losses)
if show:
im = plt.plot(x,[self.w.data * item + self.b.data for item in x],color = 'g')
ims.append(im)
if show:
ani = animation.ArtistAnimation(fig, ims, interval=200,
repeat_delay=1000)
ani.save("test.gif", writer='pillow')
x = [1,2,3,4,5]
y = [6,5,4,3,2]
clf = Linear_regression()
clf.fit(x,y)
最终的结果如下:
这段代码参考了B站up主EvilGeniusMR的视频:https://www.bilibili.com/video/av48101995?from=search&seid=14494572176379913757
这个示例的完整代码在:
https://github.com/Arctanxy/ToyNet/blob/master/Autograd_sample.py
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 你要的rmarkdown文献图表复现全套代码来了(单细胞)
- 祖传的单个10x样本的seurat标准代码
- 浏览器输入某URL后,HTTP开启了一段奇妙之旅!
- 【Pytorch】笔记一:数据载体张量与线性回归
- 为什么我不再用Redux了
- 【Pytorch 】笔记二:动态图、自动求导及逻辑回归
- 听说国漫最近崛起了,那我们就来爬几部国漫看看(动态加载,反爬)
- 微信小程序开发实战(25):预览图像
- 【Pytorch】笔记三:数据读取机制与图像预处理模块
- 表白利器,马赛克拼贴照片制作
- 【014期】JavaSE面试题(十四):基本IO流
- 微信小程序开发实战(24):选择图像
- 反 996 有理:催程序员交代码,写不出好软件
- 一千个不用 Null 的理由!
- WebAssembly 是 Deno 的好搭档