50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from sklearn.metrics import classification_report
|
||
from sklearn import svm
|
||
# 载入数据
|
||
data = np.genfromtxt("LR-testSet2.txt", delimiter=",")
|
||
x_data = data[:,:-1]
|
||
y_data = data[:,-1]
|
||
def plot():
|
||
x0 = []
|
||
x1 = []
|
||
y0 = []
|
||
y1 = []
|
||
# 切分不同类别的数据
|
||
for i in range(len(x_data)):
|
||
if y_data[i]==0:
|
||
x0.append(x_data[i,0])
|
||
y0.append(x_data[i,1])
|
||
else:
|
||
x1.append(x_data[i,0])
|
||
y1.append(x_data[i,1])
|
||
# 画图
|
||
scatter0 = plt.scatter(x0, y0, c='b', marker='o')
|
||
scatter1 = plt.scatter(x1, y1, c='r', marker='x')
|
||
#画图例
|
||
plt.legend(handles=[scatter0,scatter1],labels=['label0','label1'],loc='best')
|
||
plot()
|
||
plt.show()
|
||
# fit the model
|
||
# C和gamma
|
||
# 'linear', 'poly', 'rbf', 'sigmoid'
|
||
model = svm.SVC(kernel='rbf', C=2, gamma=1)
|
||
model.fit(x_data, y_data)
|
||
model.score(x_data,y_data)
|
||
# 获取数据值所在的范围
|
||
x_min, x_max = x_data[:, 0].min() - 1, x_data[:, 0].max() + 1
|
||
y_min, y_max = x_data[:, 1].min() - 1, x_data[:, 1].max() + 1
|
||
# 生成网格矩阵
|
||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
|
||
np.arange(y_min, y_max, 0.02))
|
||
z = model.predict(np.c_[xx.ravel(), yy.ravel()])# ravel与flatten类似,多维数据转一维。flatten不会改变原始数据,ravel会改变原始数据
|
||
z = z.reshape(xx.shape)
|
||
# 等高线图
|
||
cs = plt.contourf(xx, yy, z)
|
||
plot()
|
||
plt.show()
|
||
|
||
|
||
|