简单线性回归(梯度下降法) python实现

时间:2019-08-29
本文章向大家介绍简单线性回归(梯度下降法) python实现,主要包括简单线性回归(梯度下降法) python实现使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

grad_desc

 

简单线性回归(梯度下降法)

 

0.引入依赖

In [1]:
import numpy as np
import matplotlib.pyplot as plt
 

1.导入数据

In [34]:
points = np.genfromtxt("data.csv",delimiter=",")
#points
#提取points中的两列数据,分别作为x,y
x=points[:,0];
y=points[:,1];

#用plt画出散点图
plt.scatter(x,y)
plt.show()
 
 

2.定义损失函数

In [35]:
# 损失函数是系数的函数,另外还要传入数据的x,y
def compute_cost(w,b,points):
    total_cost=0
    M =len(points)
    for i in range(M):
        x=points[i,0]
        y=points[i,1]
        total_cost += (y-w*x-b)**2
    return total_cost/M #一除都是浮点 两个除号是地板除,整型。 如 3 // 4
 

3.定义模型的超参数

In [52]:
alpha = 0.0000001
initial_w = 0
initial_b = 0
num_iter =20
 

4.定义核心梯度下降算法函数

In [37]:
def grad_desc(points,initial_w,initial_b,alpha,num_iter):
    w = initial_w
    b = initial_b
    # 定义一个list保存所有的损失函数值,用来显示下降过程。
    cost_list=[]
    for i in range(num_iter):
        cost_list.append(compute_cost(w,b,points))
        w,b= step_grad_desc(w,b,alpha,points)
    return [w,b,cost_list]

def step_grad_desc(current_w,current_b,alpha,points):
    sum_grad_w=0
    sum_grad_b=0
    M=len(points)
    #对每个点代入公式求和
    for i in range(M):
        x= points[i,0]
        y= points[i,1]
        sum_grad_w += (current_w * x +current_b -y) *x
        sum_grad_b +=  current_w * x +current_b -y
    #用公式求当前梯度
    grad_w=2/M * sum_grad_w
    grad_b=2/M * sum_grad_b
    
    #梯度下降,更新当前的w和b
    updated_w = current_w- alpha * grad_w
    updated_b = current_b -alpha * grad_b
    return updated_w,updated_b
 

5.测试,运行梯度下降算法

In [54]:
w,b,cost_list= grad_desc(points,initial_w,initial_b,alpha,num_iter)
print ("w is :",w)
print ("b is :",b)

cost = compute_cost(w,b,points)

print("cost_list:",cost_list)
print("cost is:",cost)
plt.plot(cost_list)
 
w is : 1.9845988031472985
b is : 0.0004970348345541671
cost_list: [30684366.833333332, 9539857.724899232, 2973884.507279095, 934962.5312039739, 301819.09812286275, 105210.00196432497, 44157.269403835446, 25198.654436632325, 19311.463701577555, 17483.323048238828, 16915.633228203948, 16739.349345812665, 16684.608171772015, 16667.609475636003, 16662.3308954798, 16660.691745647422, 16660.182742743986, 16660.02468269902, 16659.975600422167, 16659.960358850854]
cost is: 16659.95562578394
Out[54]:
[<matplotlib.lines.Line2D at 0x1218cd978>]
 
In [55]:
plt.scatter(x,y)

pred_y= w*x+b

plt.plot(x,pred_y,c='r')
Out[55]:
[<matplotlib.lines.Line2D at 0x121984940>]
 
In [ ]:
 

原文地址:https://www.cnblogs.com/arli/p/11428236.html