Java使用最小二乘法实现线性回归预测
时间:2022-07-24
本文章向大家介绍Java使用最小二乘法实现线性回归预测,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
最小二乘法
在研究两个变量(x, y)之间的相互关系时
通常可以得到一系列成对的数据(x1, y1),(x2, y2)… (xm , ym)
将这些数据描绘在x-y直角坐标系中
若发现这些点在一条直线附近
可以令这条直线方程y= e + wx
其中:we是任意实数
为建立这直线方程就要确定e和w
应用《最小二乘法原理》
将实测值Yi与利用计算y= e + wx值的离差(yi-y)的平方和
即〔∑(yi - y)²〕最小
简单来说就是以下公式
y = a x + b
b = sum( y ) / n - a * sum( x ) / n
a = ( n * sum( xy ) - sum( x* ) * sum( y ) ) / ( n * sum( x^2 ) - sum(x) ^ 2 )
一个预测问题在回归模型下的解决步骤为:
1.构造训练集;
2.学习,得到输入输出间的关系;
3.预测,通过学习得到的关系预测输出
代码实现
你看,代码风格依旧良好
中间用到了Double类型的数据运算
而Double类型的数据直接加减乘除是有可能有问题的
所以附上了Double数据运算的常用方法
/**
* 使用最小二乘法实现线性回归预测
*
* @author daijiyong
*/
public class LinearRegression {
/**
* 训练集数据
*/
private Map<Double, Double> initData = new HashMap<>();
/**
* 截距
*/
private double intercept = 0.0;
//斜率
private double slope = 0.0;
/**
* x、y平均值
*/
private double averageX, averageY;
/**
* 求斜率的上下两个分式的值
*/
private double slopeUp, slopeDown;
public LinearRegression(Map<Double, Double> initData) {
this.initData = initData;
initData();
}
public LinearRegression() {
}
/**
* 根据训练集数据进行训练预测
* 并计算斜率和截距
*/
public void initData() {
if (initData.size() > 0) {
//数据个数
int number = 0;
//x值、y值总和
double sumX = 0;
double sumY = 0;
averageX = 0;
averageY = 0;
slopeUp = 0;
slopeDown = 0;
for (Double x : initData.keySet()) {
if (x == null || initData.get(x) == null) {
continue;
}
number++;
sumX = add(sumX, x);
sumY = add(sumY, initData.get(x));
}
//求x,y平均值
averageX = DoubleUtils.div(sumX, (double) number);
averageY = DoubleUtils.div(sumY, (double) number);
for (Double x : initData.keySet()) {
if (x == null || initData.get(x) == null) {
continue;
}
slopeUp = add(slopeUp, mul(sub(x, averageX), sub(initData.get(x), averageY)));
slopeDown = add(slopeDown, mul(sub(x, averageX), sub(x, averageX)));
}
initSlopeIntercept();
}
}
/**
* 计算斜率和截距
*/
private void initSlopeIntercept() {
if (slopeUp != 0 && slopeDown != 0) {
slope = slopeUp / slopeDown;
}
intercept = averageY - averageX * slope;
}
/**
* 根据x值预测y值
*
* @param x x值
* @return y值
*/
public Double getY(Double x) {
return add(intercept, mul(slope, x));
}
/**
* 根据y值预测x值
*
* @param y y值
* @return x值
*/
public Double getX(Double y) {
return div(sub(y, intercept), slope);
}
public Map<Double, Double> getInitData() {
return initData;
}
public void setInitData(Map<Double, Double> initData) {
this.initData = initData;
}
public static void main(String[] args) {
LinearRegression linearRegression = new LinearRegression();
//训练集数据
linearRegression.getInitData().put(1D, 8D);
linearRegression.getInitData().put(1.5D, 9.5D);
linearRegression.getInitData().put(2D, 11D);
linearRegression.getInitData().put(2.5D, 10D);
linearRegression.getInitData().put(3D, 14D);
//根据训练集数据进行线性函数预测
linearRegression.initData();
/*
* 给定x值,预测y值
*/
System.out.println(linearRegression.getY(8D));
/*
* 给定y值,预测x值
*/
System.out.println(linearRegression.getX(9.5D));
}
}
/**
* Created by daijiyong on 2017/4/6.
*/
public class DoubleUtils {
private static final int DEF_DIV_SCALE = 10;
/**
* * 两个Double数相加 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double add(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.add(b2).doubleValue();
}
/**
* * 两个Double数相减 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double sub(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.subtract(b2).doubleValue();
}
/**
* * 两个Double数相乘 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double mul(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.multiply(b2).doubleValue();
}
/**
* * 两个Double数相除 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double div(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.divide(b2, DEF_DIV_SCALE, BigDecimal.ROUND_HALF_UP).doubleValue();
}
/**
* * 两个Double数相除,并保留scale位小数 *
*
* @param v1 *
* @param v2 *
* @param scale *
* @return Double
*/
public static Double div(Double v1, Double v2, int scale) {
if (scale < 0) {
throw new IllegalArgumentException(
"The scale must be a positive integer or zero");
}
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.divide(b2, scale, BigDecimal.ROUND_HALF_UP).doubleValue();
}
public static int max(int a, int b) {
return Math.max(a, b);
}
public static int min(int a, int b) {
return Math.min(a, b);
}
运行测试
给个例子,测试一下
文/戴先生@2020年6月8日---end---
更多精彩推荐
- delete相关的pl/sql调优(r4笔记第87天)
- Java文件上传与下载【面试+工作】
- QBC查询
- 一条delete语句的调优(r4笔记第86天)
- Java支付宝接口开发【面试+工作】
- 03.SVN检出/解决冲突/提交
- Spring思维导图,让Spring不再难懂(mvc篇)
- SQL优化一(SQL使用技巧)
- Spring思维导图,让Spring不再难懂(aop篇)
- MongoDB初探第二篇 (r4笔记第82天)
- Spring思维导图,让Spring不再难懂(cache篇)
- 曲折的10g,11g中EM的安装配置过程(r4笔记第98天)
- Linux 学习记录 一(安装、基本文件操作).
- 实用的位运算应用(r4笔记第97天)
- java教程
- Java快速入门
- Java 开发环境配置
- Java基本语法
- Java 对象和类
- Java 基本数据类型
- Java 变量类型
- Java 修饰符
- Java 运算符
- Java 循环结构
- Java 分支结构
- Java Number类
- Java Character类
- Java String类
- Java StringBuffer和StringBuilder类
- Java 数组
- Java 日期时间
- Java 正则表达式
- Java 方法
- Java 流(Stream)、文件(File)和IO
- Java 异常处理
- Java 继承
- Java 重写(Override)与重载(Overload)
- Java 多态
- Java 抽象类
- Java 封装
- Java 接口
- Java 包(package)
- Java 数据结构
- Java 集合框架
- Java 泛型
- Java 序列化
- Java 网络编程
- Java 发送邮件
- Java 多线程编程
- Java Applet基础
- Java 文档注释
- c语言数组越界的避免方法
- 单片机的存储区范例
- 大点干!早点散----------Nginx+Tomcat动静分离
- 大点干!早点散----------深入剖析缓存加速--squid传统代理和透明代理
- stm32 HardFault_Handler调试及问题查找方法——飞思卡尔
- 堆栈的分布
- memset()函数的使用
- 质量保障的方法和实践
- Selenium4 IDE,它终于来了
- strtol函数的用法——字符串转长整形
- JsonPath工具类封装
- Ubuntu16.04 实时内核 RT Preempt 安装
- c语言实现整数转换为字符串——不考虑负数
- JsonPath工具类单元测试
- Selenium4 IDE特性:无代码趋势和SIDE Runner