PyTorch创建简单的逻辑回归模型(LogisticRegression)
时间:2021-09-06
本文章向大家介绍PyTorch创建简单的逻辑回归模型(LogisticRegression),主要包括PyTorch创建简单的逻辑回归模型(LogisticRegression)使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
import torch
import torch.nn.functional as F # 从torch引入激活函数
x_data = torch.tensor([[1.0], [2.0], [3.0]]).cuda() # 将数据放在GPU上
y_data = torch.tensor([[0.0], [0.0], [1.0]]).cuda()
class LogisticRegressionModel(torch.nn.Module): # 继承torch.nn.Module
def __init__(self): # 初始化类
super(LogisticRegressionModel, self).__init__() # 继承父类
self.linear = torch.nn.Linear(1, 1) # 创建线性层
def forward(self, x): # 定义正向传播
y_pred = F.sigmoid(self.linear(x)) # 将线性层输出的结果经过sigmoid激活函数
return y_pred
model = LogisticRegressionModel().cuda() # 实例化对象为model然后将model的计算图放在GPU上
criterion = torch.nn.BCELoss(size_average=False).cuda() # 创建损失函数,BCELoss为二分类的损失函数,并设置size_average=False不平均损失
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) # 创建优化器SGD来优化model的parameters,并设置learning_rate=1e-2
for epoch in range(100): # 进行100轮训练
y_pred = model(x_data) # 通过model的正向传播得到y_pred
loss = criterion(y_pred, y_data) # 将预测值与真实值进行计算损失函数
print(epoch + 1, loss.item()) # 输出轮数和损失函数
optimizer.zero_grad() # 进行反向传播前将梯度置0
loss.backward() # 损失函数进行反向传播来计算梯度
optimizer.step() # 根据计算的梯度来更新参数权重
print(model(torch.tensor([[4.0]]).cuda())) # 使用模型来预测值
运行结果
原文地址:https://www.cnblogs.com/Reion/p/15235875.html
- Robot Framework | 03 基于Public API创建你RFS测试
- Robot Framework | 02 从抛弃RIDE开始创建你的RFS测试
- ASP.NET5 Beta8可用性
- Docker Swarm集群初探
- 数据库逻辑设计
- 06.移动先行之谁主沉浮----我的代码我来写(Xaml的优势)
- [快学Python3]迭代器和生成器
- [快学Python3]INI文件读写
- Vijos P1131 最小公倍数和最大公约数问题【暴力】
- Vjios P1736 铺地毯【暴力,思维】
- Vijos P1116 一元三次方程求解【多解,暴力,二分】
- Python Selenium设计模式-POM
- [快学Python3]HTTP处理 - urllib模块
- Vijos P1786 质因数分解【暴力】
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 从echarts-for-react源码中学习如何写单元测试
- 好用到飞起的12个jupyter lab插件
- Debug LinkedList
- Java对象公约
- 【Flutter 专题】96 图解 Draggable + DragTarget 基本拖拽效果
- Spring 基于注解(annotation)的配置之@Autowired注解
- 人心易变,这段有趣的C代码也一样!!!
- matplotlib绘制常见统计图形(一)
- python与安全(二)格式化字符串和Flask session
- ROS2机器人笔记20-07-24
- Postgresql 渗透利用总结
- Spring 基于注解(annotation)的配置之@Required注解
- 由一个系统激活工具引起的一次简单测试
- Golang channel 快速入门
- 潘石屹用Python解决100个问题 | 素数