机器学习之线性回归预测销量


背景:给出广告在TV,Radio,Newspaper的销售额,利用线性回归预测其以后的销量趋势

数据:

  TV Radio Newspaper Sales
1 230.1 37.8 69.2 22.1
2 44.5 39.3 45.1 10.4
3 17.2 45.9 69.3 9.3
4 151.5 41.3 58.5 18.5
5 180.8 10.8 58.4 12.9
6 8.7 48.9 75 7.2
7 57.5 32.8 23.5 11.8
8 120.2 19.6 11.6 13.2
9 8.6 2.1 1 4.8
10 199.8 2.6 21.2 10.6
11 66.1 5.8 24.2 8.6
12 214.7 24 4 17.4
13 23.8 35.1 65.9 9.2
14 97.5 7.6 7.2 9.7
15 204.1 32.9 46 19
16 195.4 47.7 52.9 22.4
17 67.8 36.6 114 12.5
18 281.4 39.6 55.8 24.4
19 69.2 20.5 18.3 11.3
20 147.3 23.9 19.1 14.6
21 218.4 27.7 53.4 18
22 237.4 5.1 23.5 12.5
23 13.2 15.9 49.6 5.6
24 228.3 16.9 26.2 15.5
25 62.3 12.6 18.3 9.7
26 262.9 3.5 19.5 12
27 142.9 29.3 12.6 15
28 240.1 16.7 22.9 15.9
29 248.8 27.1 22.9 18.9
30 70.6 16 40.8 10.5
31 292.9 28.3 43.2 21.4
32 112.9 17.4 38.6 11.9
33 97.2 1.5 30 9.6
34 265.6 20 0.3 17.4
35 95.7 1.4 7.4 9.5
36 290.7 4.1 8.5 12.8
37 266.9 43.8 5 25.4
38 74.7 49.4 45.7 14.7
39 43.1 26.7 35.1 10.1
40 228 37.7 32 21.5
41 202.5 22.3 31.6 16.6
42 177 33.4 38.7 17.1
43 293.6 27.7 1.8 20.7
44 206.9 8.4 26.4 12.9
45 25.1 25.7 43.3 8.5
46 175.1 22.5 31.5 14.9
47 89.7 9.9 35.7 10.6
48 239.9 41.5 18.5 23.2
49 227.2 15.8 49.9 14.8
50 66.9 11.7 36.8 9.7
51 199.8 3.1 34.6 11.4
52 100.4 9.6 3.6 10.7
53 216.4 41.7 39.6 22.6
54 182.6 46.2 58.7 21.2
55 262.7 28.8 15.9 20.2
56 198.9 49.4 60 23.7
57 7.3 28.1 41.4 5.5
58 136.2 19.2 16.6 13.2
59 210.8 49.6 37.7 23.8
60 210.7 29.5 9.3 18.4
61 53.5 2 21.4 8.1
62 261.3 42.7 54.7 24.2
63 239.3 15.5 27.3 15.7
64 102.7 29.6 8.4 14
65 131.1 42.8 28.9 18
66 69 9.3 0.9 9.3
67 31.5 24.6 2.2 9.5
68 139.3 14.5 10.2 13.4
69 237.4 27.5 11 18.9
70 216.8 43.9 27.2 22.3
71 199.1 30.6 38.7 18.3
72 109.8 14.3 31.7 12.4
73 26.8 33 19.3 8.8
74 129.4 5.7 31.3 11
75 213.4 24.6 13.1 17
76 16.9 43.7 89.4 8.7
77 27.5 1.6 20.7 6.9
78 120.5 28.5 14.2 14.2
79 5.4 29.9 9.4 5.3
80 116 7.7 23.1 11
81 76.4 26.7 22.3 11.8
82 239.8 4.1 36.9 12.3
83 75.3 20.3 32.5 11.3
84 68.4 44.5 35.6 13.6
85 213.5 43 33.8 21.7
86 193.2 18.4 65.7 15.2
87 76.3 27.5 16 12
88 110.7 40.6 63.2 16
89 88.3 25.5 73.4 12.9
90 109.8 47.8 51.4 16.7
91 134.3 4.9 9.3 11.2
92 28.6 1.5 33 7.3
93 217.7 33.5 59 19.4
94 250.9 36.5 72.3 22.2
95 107.4 14 10.9 11.5
96 163.3 31.6 52.9 16.9
97 197.6 3.5 5.9 11.7
98 184.9 21 22 15.5
99 289.7 42.3 51.2 25.4
100 135.2 41.7 45.9 17.2
101 222.4 4.3 49.8 11.7
102 296.4 36.3 100.9 23.8
103 280.2 10.1 21.4 14.8
104 187.9 17.2 17.9 14.7
105 238.2 34.3 5.3 20.7
106 137.9 46.4 59 19.2
107 25 11 29.7 7.2
108 90.4 0.3 23.2 8.7
109 13.1 0.4 25.6 5.3
110 255.4 26.9 5.5 19.8
111 225.8 8.2 56.5 13.4
112 241.7 38 23.2 21.8
113 175.7 15.4 2.4 14.1
114 209.6 20.6 10.7 15.9
115 78.2 46.8 34.5 14.6
116 75.1 35 52.7 12.6
117 139.2 14.3 25.6 12.2
118 76.4 0.8 14.8 9.4
119 125.7 36.9 79.2 15.9
120 19.4 16 22.3 6.6
121 141.3 26.8 46.2 15.5
122 18.8 21.7 50.4 7
123 224 2.4 15.6 11.6
124 123.1 34.6 12.4 15.2
125 229.5 32.3 74.2 19.7
126 87.2 11.8 25.9 10.6
127 7.8 38.9 50.6 6.6
128 80.2 0 9.2 8.8
129 220.3 49 3.2 24.7
130 59.6 12 43.1 9.7
131 0.7 39.6 8.7 1.6
132 265.2 2.9 43 12.7
133 8.4 27.2 2.1 5.7
134 219.8 33.5 45.1 19.6
135 36.9 38.6 65.6 10.8
136 48.3 47 8.5 11.6
137 25.6 39 9.3 9.5
138 273.7 28.9 59.7 20.8
139 43 25.9 20.5 9.6
140 184.9 43.9 1.7 20.7
141 73.4 17 12.9 10.9
142 193.7 35.4 75.6 19.2
143 220.5 33.2 37.9 20.1
144 104.6 5.7 34.4 10.4
145 96.2 14.8 38.9 11.4
146 140.3 1.9 9 10.3
147 240.1 7.3 8.7 13.2
148 243.2 49 44.3 25.4
149 38 40.3 11.9 10.9
150 44.7 25.8 20.6 10.1
151 280.7 13.9 37 16.1
152 121 8.4 48.7 11.6
153 197.6 23.3 14.2 16.6
154 171.3 39.7 37.7 19
155 187.8 21.1 9.5 15.6
156 4.1 11.6 5.7 3.2
157 93.9 43.5 50.5 15.3
158 149.8 1.3 24.3 10.1
159 11.7 36.9 45.2 7.3
160 131.7 18.4 34.6 12.9
161 172.5 18.1 30.7 14.4
162 85.7 35.8 49.3 13.3
163 188.4 18.1 25.6 14.9
164 163.5 36.8 7.4 18
165 117.2 14.7 5.4 11.9
166 234.5 3.4 84.8 11.9
167 17.9 37.6 21.6 8
168 206.8 5.2 19.4 12.2
169 215.4 23.6 57.6 17.1
170 284.3 10.6 6.4 15
171 50 11.6 18.4 8.4
172 164.5 20.9 47.4 14.5
173 19.6 20.1 17 7.6
174 168.4 7.1 12.8 11.7
175 222.4 3.4 13.1 11.5
176 276.9 48.9 41.8 27
177 248.4 30.2 20.3 20.2
178 170.2 7.8 35.2 11.7
179 276.7 2.3 23.7 11.8
180 165.6 10 17.6 12.6
181 156.6 2.6 8.3 10.5
182 218.5 5.4 27.4 12.2
183 56.2 5.7 29.7 8.7
184 287.6 43 71.8 26.2
185 253.8 21.3 30 17.6
186 205 45.1 19.6 22.6
187 139.5 2.1 26.6 10.3
188 191.1 28.7 18.2 17.3
189 286 13.9 3.7 15.9
190 18.7 12.1 23.4 6.7
191 39.5 41.1 5.8 10.8
192 75.5 10.8 6 9.9
193 17.2 4.1 31.6 5.9
194 166.8 42 3.6 19.6
195 149.7 35.6 6 17.3
196 38.2 3.7 13.8 7.6
197 94.2 4.9 8.1 9.7
198 177 9.3 6.4 12.8
199 283.6 42 66.2 25.5
200 232.1 8.6 8.7 13.4


#!/usr/bin/python
# -*- coding:utf-8 -*-
import csv
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression,Ridge
from pprint import pprint

if __name__ == "__main__":
path = 'Advertising.csv'
# # 手写读取数据
# f = file(path)
# x = []
# y = []
# for i, d in enumerate(f):
# if i == 0:
# continue
# d = d.strip()
# if not d:
# continue
# d = map(float, d.split(','))
# x.append(d[1:-1])
# y.append(d[-1])
# pprint(x)
# pprint(y)
# x = np.array(x)
# y = np.array(y)

# Python自带库
# f = file(path, 'r')
# print f
# d = csv.reader(f)
# for line in d:
# print line
# f.close()

# # numpy读入
# p = np.loadtxt(path, delimiter=',', skiprows=1)
# print p
# print '\n\n===============\n\n'

# pandas读入
data = pd.read_csv(path) # TVRadioNewspaperSales读取数据
x = data[['TV', 'Radio', 'Newspaper']]
# x = data[['TV', 'Radio']]
y = data['Sales']
print x
print y

mpl.rcParams['font.sans-serif'] = [u'simHei']#指定默认字体
mpl.rcParams['axes.unicode_minus'] = False#解决保存图像是负号'-'显示为方块的问题

# 绘制1,主要目的:看着三组数据的分布情况
plt.figure(facecolor='w')#画布颜色为白色
plt.plot(data['TV'], y, 'ro', label='TV')
plt.plot(data['Radio'], y, 'g^', label='Radio')
plt.plot(data['Newspaper'], y, 'mv', label='Newspaer')
plt.legend(loc='lower right')#图例在右下方
plt.xlabel(u'广告花费', fontsize=16)
plt.ylabel(u'销售额', fontsize=16)
plt.title(u'广告花费与销售额对比数据', fontsize=20)
plt.grid()#设置背景网格线的颜色,样式,尺寸和透明度
plt.show()

# 绘制2
plt.figure(facecolor='w', figsize=(9, 10))
plt.subplot(311)#三行一列的第一个
plt.plot(data['TV'], y, 'ro')
plt.title('TV')
plt.grid()
plt.subplot(312)
plt.plot(data['Radio'], y, 'g^')
plt.title('Radio')
plt.grid()
plt.subplot(313)
plt.plot(data['Newspaper'], y, 'b*')
plt.title('Newspaper')
plt.grid()
plt.tight_layout()# 紧凑显示图片
plt.show()

x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.8, random_state=1)
#train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train datatestdata
#x为所要划分的样本特征集,y所要划分的样本结果,train_size样本占比,如果是整数的话就是样本的数量,
# random_state随机种子数,设置了种子,这样每次的随机数都是一样的
print type(x_test)
print x_train.shape, y_train.shape#.shape读取矩阵的长度,若设置参数,shape[0]读取第一维的长度
linreg = LinearRegression()#线性回归
model = linreg.fit(x_train, y_train)#训练
print model
print linreg.coef_, linreg.intercept_#打印参数,

order = y_test.argsort(axis=0)
y_test = y_test.values[order]
x_test = x_test.values[order, :]
y_hat = linreg.predict(x_test)
mse = np.average((y_hat - np.array(y_test)) ** 2) # Mean Squared Error均方差
rmse = np.sqrt(mse) # Root Mean Squared Error
print 'MSE = ', mse,
print 'RMSE = ', rmse
print 'R2 = ', linreg.score(x_train, y_train)
print 'R2 = ', linreg.score(x_test, y_test)

plt.figure(facecolor='w')
t = np.arange(len(x_test))
plt.plot(t, y_test, 'r-', linewidth=2, label=u'真实数据')
plt.plot(t, y_hat, 'g-', linewidth=2, label=u'预测数据')
plt.legend(loc='upper right')
plt.title(u'线性回归预测销量', fontsize=18)
plt.grid(b=True)
plt.show()
第一张图:

第二张图:


第三张图:

在这里要提一下,在训练数据时,将数据分为训练数据验证数据和测试数据。其中训练数据和验证数据是进行交叉验证的,例如:将除测试数据之外的数据分为3份,A,B,C,
当AB为训练数据时C为验证数据,AC为训练数据时B为验证数据,BC为训练数据时A为验证数据。
 


注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
粤ICP备14056181号  © 2014-2020 ITdaan.com