Java使用最小二乘法实现线性回归预测

时间:2022-07-24
本文章向大家介绍Java使用最小二乘法实现线性回归预测,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

最小二乘法

在研究两个变量(x, y)之间的相互关系时

通常可以得到一系列成对的数据(x1, y1),(x2, y2)… (xm , ym)

将这些数据描绘在x-y直角坐标系中

若发现这些点在一条直线附近

可以令这条直线方程y= e + wx

其中:we是任意实数

为建立这直线方程就要确定e和w

应用《最小二乘法原理》

将实测值Yi与利用计算y= e + wx值的离差(yi-y)的平方和

即〔∑(yi - y)²〕最小

简单来说就是以下公式

y = a x + b

b = sum( y ) / n - a * sum( x ) / n

a = ( n * sum( xy ) - sum( x* ) * sum( y ) ) / ( n * sum( x^2 ) - sum(x) ^ 2 )

一个预测问题在回归模型下的解决步骤为:

1.构造训练集;

2.学习,得到输入输出间的关系;

3.预测,通过学习得到的关系预测输出

代码实现

你看,代码风格依旧良好

中间用到了Double类型的数据运算

而Double类型的数据直接加减乘除是有可能有问题的

所以附上了Double数据运算的常用方法


/**
 * 使用最小二乘法实现线性回归预测
 *
 * @author daijiyong
 */
public class LinearRegression {

    /**
     * 训练集数据
     */
    private Map<Double, Double> initData = new HashMap<>();

    /**
     * 截距
     */
    private double intercept = 0.0;
    //斜率
    private double slope = 0.0;

    /**
     * x、y平均值
     */
    private double averageX, averageY;
    /**
     * 求斜率的上下两个分式的值
     */
    private double slopeUp, slopeDown;


    public LinearRegression(Map<Double, Double> initData) {
        this.initData = initData;
        initData();
    }

    public LinearRegression() {
    }

    /**
     * 根据训练集数据进行训练预测
     * 并计算斜率和截距
     */
    public void initData() {
        if (initData.size() > 0) {
            //数据个数
            int number = 0;
            //x值、y值总和
            double sumX = 0;
            double sumY = 0;
            averageX = 0;
            averageY = 0;
            slopeUp = 0;
            slopeDown = 0;
            for (Double x : initData.keySet()) {
                if (x == null || initData.get(x) == null) {
                    continue;
                }
                number++;
                sumX = add(sumX, x);
                sumY = add(sumY, initData.get(x));
            }
            //求x,y平均值
            averageX = DoubleUtils.div(sumX, (double) number);
            averageY = DoubleUtils.div(sumY, (double) number);

            for (Double x : initData.keySet()) {
                if (x == null || initData.get(x) == null) {
                    continue;
                }

                slopeUp = add(slopeUp, mul(sub(x, averageX), sub(initData.get(x), averageY)));
                slopeDown = add(slopeDown, mul(sub(x, averageX), sub(x, averageX)));

            }
            initSlopeIntercept();
        }
    }

    /**
     * 计算斜率和截距
     */
    private void initSlopeIntercept() {
        if (slopeUp != 0 && slopeDown != 0) {
            slope = slopeUp / slopeDown;
        }
        intercept = averageY - averageX * slope;
    }

    /**
     * 根据x值预测y值
     *
     * @param x x值
     * @return y值
     */
    public Double getY(Double x) {
        return add(intercept, mul(slope, x));
    }

    /**
     * 根据y值预测x值
     *
     * @param y y值
     * @return x值
     */
    public Double getX(Double y) {
        return div(sub(y, intercept), slope);
    }


    public Map<Double, Double> getInitData() {
        return initData;
    }

    public void setInitData(Map<Double, Double> initData) {
        this.initData = initData;
    }

    public static void main(String[] args) {
        LinearRegression linearRegression = new LinearRegression();
        //训练集数据
        linearRegression.getInitData().put(1D, 8D);
        linearRegression.getInitData().put(1.5D, 9.5D);
        linearRegression.getInitData().put(2D, 11D);
        linearRegression.getInitData().put(2.5D, 10D);
        linearRegression.getInitData().put(3D, 14D);
        //根据训练集数据进行线性函数预测
        linearRegression.initData();

        /*
         * 给定x值,预测y值
         */
        System.out.println(linearRegression.getY(8D));
        /*
         * 给定y值,预测x值
         */
        System.out.println(linearRegression.getX(9.5D));
    }
}

/**
 * Created by daijiyong on 2017/4/6.
 */
public class DoubleUtils {
    private static final int DEF_DIV_SCALE = 10;

    /**
     * * 两个Double数相加 *
     *
     * @param v1 *
     * @param v2 *
     * @return Double
     */
    public static Double add(Double v1, Double v2) {
        BigDecimal b1 = new BigDecimal(v1.toString());
        BigDecimal b2 = new BigDecimal(v2.toString());
        return b1.add(b2).doubleValue();
    }

    /**
     * * 两个Double数相减 *
     *
     * @param v1 *
     * @param v2 *
     * @return Double
     */
    public static Double sub(Double v1, Double v2) {
        BigDecimal b1 = new BigDecimal(v1.toString());
        BigDecimal b2 = new BigDecimal(v2.toString());
        return b1.subtract(b2).doubleValue();
    }

    /**
     * * 两个Double数相乘 *
     *
     * @param v1 *
     * @param v2 *
     * @return Double
     */
    public static Double mul(Double v1, Double v2) {
        BigDecimal b1 = new BigDecimal(v1.toString());
        BigDecimal b2 = new BigDecimal(v2.toString());
        return b1.multiply(b2).doubleValue();
    }

    /**
     * * 两个Double数相除 *
     *
     * @param v1 *
     * @param v2 *
     * @return Double
     */
    public static Double div(Double v1, Double v2) {
        BigDecimal b1 = new BigDecimal(v1.toString());
        BigDecimal b2 = new BigDecimal(v2.toString());
        return b1.divide(b2, DEF_DIV_SCALE, BigDecimal.ROUND_HALF_UP).doubleValue();
    }

    /**
     * * 两个Double数相除,并保留scale位小数 *
     *
     * @param v1    *
     * @param v2    *
     * @param scale *
     * @return Double
     */
    public static Double div(Double v1, Double v2, int scale) {
        if (scale < 0) {
            throw new IllegalArgumentException(
                    "The scale must be a positive integer or zero");
        }
        BigDecimal b1 = new BigDecimal(v1.toString());
        BigDecimal b2 = new BigDecimal(v2.toString());
        return b1.divide(b2, scale, BigDecimal.ROUND_HALF_UP).doubleValue();
    }


    public static int max(int a, int b) {
        return Math.max(a, b);
    }


    public static int min(int a, int b) {
        return Math.min(a, b);
    }

运行测试

给个例子,测试一下

文/戴先生@2020年6月8日---end---

更多精彩推荐