30行代码徒手实现logistic回归
时间:2022-07-22
本文章向大家介绍30行代码徒手实现logistic回归,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
准备数据集
采用的数据集是sklearn中的breast cancer数据集,30维特征,569个样本。训练前进行MinMax标准化缩放至[0,1]区间。按照75/25比例划分成训练集和验证集。
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
# 准备数据集
breast = datasets.load_breast_cancer()
scaler = preprocessing.MinMaxScaler()
data = scaler.fit_transform(breast['data'])
target = breast['target']
X_train,X_test,y_train,y_test = train_test_split(data,target)
二
模型结构图
三
正反传播公式
四
LR实现代码
import numpy as np
import pandas as pd
class LogisticRegression(object):
def __init__(self,alpha = 0.1,ITERNUM = 200000):
self.alpha,self.ITERNUM = alpha,ITERNUM
self.dfJ = pd.DataFrame(data = np.zeros((ITERNUM,1)),columns = ['J'])
self.w,self.b = np.nan,np.nan
def fit(self,X_train,y_train):
X,Y = X_train.T,y_train.reshape(1,-1)
n,m = X.shape
w,b = np.zeros((n,1)),0
for i in range(self.ITERNUM):
# 正向传播求函数值 X-->Z-->A-->J
Z = np.dot(w.T,X) + b
A = 1/(1 + np.exp(-Z))
J = (1/m) * np.sum(- Y*np.log(A) -(1-Y)*np.log(1-A))
self.dfJ.loc[i]['J']= J
# 反向传播求导数: J-->dA-->dZ(dw,db)
dA = 1/m*(-Y/A + (1-Y)/(1-A))
dZ = 1/m*(A-Y)
dw = np.dot(X,dZ.T)
db = np.sum(dZ)
# 梯度下降
w = w - self.alpha*dw
b = b - self.alpha*db
self.w,self.b = w,b
def predict_prob(self,X_test):
Z_test = np.dot(self.w.T,X_test.T) + self.b
Y_prob = 1/(1 + np.exp(-Z_test))
Y_prob = Y_prob.reshape(-1)
return(Y_prob)
def predict(self,X_test):
Y_prob = self.predict_prob(X_test)
Y_test = Y_prob.copy()
Y_test[Y_prob>=0.5] = 1
Y_test[Y_prob< 0.5] = 0
return(Y_test)
五
数据集测试
# 用数据喂养模型
clf = LogisticRegression(alpha = 0.1,ITERNUM = 200000)
clf.fit(X_train= X_train,y_train= y_train)
# 绘制目标函数的迭代曲线
%matplotlib inline
clf.dfJ.plot(y = 'J' ,kind = 'line',figsize = (10,7))
# 测试在验证集的auc得分
from sklearn.metrics import roc_auc_score
Y_prob = clf.predict_prob(X_test)
roc_auc_score(list(y_test),list(Y_prob))
# 和sklearn中的模型对比
from sklearn.linear_model import LogisticRegressionCV as LRCV
lr = LRCV()
lr.fit(X_train,y_train)
Y_proba = lr.predict_proba(X_test)
roc_auc_score(list(y_test),list(Y_proba[:,1]))
- 51 Nod 1008 N的阶乘 mod P【Java大数乱搞】
- 【AlphaGo Zero 核心技术-深度强化学习教程代码实战06】给Agent添加记忆功能
- Gym 100952A&&2015 HIAST Collegiate Programming Contest A. Who is the winner?【字符串,暴力】
- [开源,学习,分享]UWP第三方简书客户端分享
- HDU 1024 Max Sum Plus Plus【动态规划求最大M子段和详解 】
- 51 Nod 1057 N的阶乘【Java大数乱搞】
- 2017 Multi-University Training Contest - Team 1 1011&&HDU 6043 KazaQ's Socks【规律题,数学,水】
- 2017 Multi-University Training Contest - Team 1 1001&&HDU 6033 Add More Zero【签到题,数学,水】
- 51 Nod 1005 大数加法【Java大数乱搞,python大数乱搞】
- 51 Nod 1029 大数除法【Java大数乱搞】
- 51 Nod 1027 大数乘法【Java大数乱搞】
- SQL常用的基础语法
- 51 Nod 1028 大数乘法 V2【Java大数乱搞】
- Gym 100952J&&2015 HIAST Collegiate Programming Contest J. Polygons Intersection【计算几何求解两个凸多边形的相交面积板子题
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 【LeetCode每日一题】23. Merge k Sorted Lists
- MyBatis Generator( 逆向工程以及源码分析 )
- 虚拟机(Linux)常用命令
- Spring框架
- Servlet技术2
- 你缺的不是天赋,而是亲和度
- TypeScript: 为什么必须学
- 一、环境搭建、以及聊聊更重要的...
- 四、作用域与作用域链
- 【从0到1学算法】递归
- 手把手教你创建 Spring MVC 实例
- 举一反三:三种问题,两个指针,一种方法
- torch.backends.cudnn.benchmark ?!
- jQuery ui中sortable draggable droppable的使用
- 阿里面试:看你springBoot用的比较溜来,说说springboot自动装配是怎么回事?