最小二乘法求解回归直线
public-time:2022-10-06 17:05
from numpy import mean
class LeastSquare:
"""
西瓜书 P54
最小二乘法 求解回归直线
"""
def __init__(self, x, y):
self.x = x
self.y = y
self.__m = len(self.x)
self.__mean_x = mean(self.x)
self.w = self.calculate_w()
self.b = self.calculate_b()
def calculate_w(self):
fz = sum([self.y[i] * (self.x[i] - self.__mean_x) for i in range(self.__m)])
fm = sum([self.x[i] ** 2 for i in range(self.__m)]) - (
sum([self.x[i] for i in range(self.__m)]) ** 2) / self.__m
return fz / fm
def calculate_b(self):
return sum([self.y[i] - self.w * self.x[i] for i in range(self.__m)]) / self.__m
测试代码
faker_xy_data()生成数据集 见: https://www.cnblogs.com/boran/p/16757677.html
import numpy as np
from matplotlib import pyplot as plt
from practice1.create_xy_data import faker_xy_data
from practice1.least_square_method import LeastSquare
import matplotlib
matplotlib.use('Qt5Agg')
# 需要安装 pyqt5 pip install PyQt5
x1 = faker_xy_data()
print(f"a={x1.a} b={x1.b}")
E = LeastSquare(x=x1.x, y=x1.y)
print(f"w={E.w} b={E.b}")
x = np.linspace(0, 100, 100) # 从-3 到 3 在这个范围内生成50数
y = x * E.w + E.b
plt.plot(x, y, color='red', linewidth=1.0)
plt.scatter(x1.x, x1.y)
plt.show()
一次测试结果