Skip to content

最小二乘法求解回归直线


public-time:2022-10-06 17:05

from numpy import mean


class LeastSquare:
    """
    西瓜书 P54

    最小二乘法 求解回归直线
    """

    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.__m = len(self.x)
        self.__mean_x = mean(self.x)

        self.w = self.calculate_w()
        self.b = self.calculate_b()

    def calculate_w(self):
        fz = sum([self.y[i] * (self.x[i] - self.__mean_x) for i in range(self.__m)])
        fm = sum([self.x[i] ** 2 for i in range(self.__m)]) - (
                sum([self.x[i] for i in range(self.__m)]) ** 2) / self.__m
        return fz / fm

    def calculate_b(self):
        return sum([self.y[i] - self.w * self.x[i] for i in range(self.__m)]) / self.__m

测试代码

faker_xy_data()生成数据集 见: https://www.cnblogs.com/boran/p/16757677.html

import numpy as np
from matplotlib import pyplot as plt

from practice1.create_xy_data import faker_xy_data
from practice1.least_square_method import LeastSquare
import matplotlib

matplotlib.use('Qt5Agg')
# 需要安装 pyqt5 pip install PyQt5


x1 = faker_xy_data()

print(f"a={x1.a} b={x1.b}")

E = LeastSquare(x=x1.x, y=x1.y)

print(f"w={E.w} b={E.b}")

x = np.linspace(0, 100, 100)  # 从-3 到 3 在这个范围内生成50数
y = x * E.w + E.b

plt.plot(x, y, color='red', linewidth=1.0)

plt.scatter(x1.x, x1.y)

plt.show()

一次测试结果

a=1.2948318280599067 b=8.944712778469608
w=1.1574286803853961 b=15.848312488159783