Advanced_artificial_intelli.../SVM最优分类面.ipynb

901 lines
29 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.653016Z",
"start_time": "2019-12-30T05:03:52.292979Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.660514Z",
"start_time": "2019-12-30T05:03:52.655012Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0, 0],\n",
" [ 0, -1],\n",
" [ 1, 1],\n",
" [-1, 0],\n",
" [ 0, 1]]), array([ 1, 1, 1, -1, -1]))"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.asarray([[0,0],[0,-1],[1,1],[-1,0],[0,1]])\n",
"y = np.asarray([1,1,1,-1,-1])\n",
"x,y"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [],
"source": [
"xx = np.arange(-1,1,0.1)\n",
"yy = xx+0.5\n"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.862963Z",
"start_time": "2019-12-30T05:03:52.661502Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(-2, 2)"
]
},
"execution_count": 206,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(xx,yy)\n",
"plt.scatter(x[:,0],x[:,1],c=y)\n",
"plt.xlim(-2,2)\n",
"plt.ylim(-2,2)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.866980Z",
"start_time": "2019-12-30T05:03:52.863960Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1],\n",
" [ 1],\n",
" [-1],\n",
" [-1]])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = y.reshape(-1,1)\n",
"y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 第一种用numpy的做法 $A(\\alpha)=\\sum_{i=1}^{n}\\alpha_i - \\frac{1}{2}\\sum_{i=1}^{n}\\sum_{j=1}^{n}\\alpha_i\\alpha_j y_i y_j x_i^T x_j$"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.874931Z",
"start_time": "2019-12-30T05:03:52.867950Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, 0],\n",
" [ 0, 2, -1, 1],\n",
" [ 0, -1, 1, 0],\n",
" [ 0, 1, 0, 1]])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xij = x.dot(x.T)\n",
"xij"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.880947Z",
"start_time": "2019-12-30T05:03:52.875928Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1, 1, -1, -1],\n",
" [ 1, 1, -1, -1],\n",
" [-1, -1, 1, 1],\n",
" [-1, -1, 1, 1]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yij = y.dot(y.T)\n",
"yij"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.886899Z",
"start_time": "2019-12-30T05:03:52.882937Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, 0],\n",
" [ 0, 2, 1, -1],\n",
" [ 0, 1, 1, 0],\n",
" [ 0, -1, 0, 1]])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A = yij*xij\n",
"A"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.894878Z",
"start_time": "2019-12-30T05:03:52.888894Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 a2^2\n",
"2 a2a3\n",
"-2 a2a4\n",
"1 a3^2\n",
"1 a4^2\n"
]
}
],
"source": [
"for i in range(A.shape[0]):\n",
" for j in range(i,A.shape[1]):\n",
" if i == j:\n",
" coef = A[i][i]\n",
" if coef !=0:\n",
" print(f\"{coef} a{i+1}^2\")\n",
" else:\n",
" coef = A[i][j] + A[j][i] \n",
" if coef !=0:\n",
" print(f\"{coef} a{i+1}a{j+1}\")\n",
"# 手算求导下去"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 第二种sympy符号计算 $A(\\alpha)=\\sum_{i=1}^{n}\\alpha_i - \\frac{1}{2}\\sum_{i=1}^{n}\\sum_{j=1}^{n}\\alpha_i\\alpha_j y_i y_j x_i^T x_j$"
]
},
{
"cell_type": "code",
"execution_count": 255,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:52.900890Z",
"start_time": "2019-12-30T05:03:52.895875Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0, 0],\n",
" [ 0, -1],\n",
" [ 1, 1],\n",
" [-1, 0],\n",
" [ 0, 1]]), array([ 1, 1, 1, -1, -1]))"
]
},
"execution_count": 255,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.asarray([[0,0],[0,-1],[1,1],[-1,0],[0,1]])\n",
"y = np.asarray([1,1,1,-1,-1])\n",
"x,y"
]
},
{
"cell_type": "code",
"execution_count": 256,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.426485Z",
"start_time": "2019-12-30T05:03:52.901859Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a_{1}$"
],
"text/plain": [
"a_1"
]
},
"execution_count": 256,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sympy\n",
"a = sympy.symbols(\"a_1 a_2 a_3 a_4 a_5\") \n",
"a[0]"
]
},
{
"cell_type": "code",
"execution_count": 257,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.468474Z",
"start_time": "2019-12-30T05:03:53.427455Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a_{1} + a_{2} + a_{3} + a_{4} + a_{5}$"
],
"text/plain": [
"a_1 + a_2 + a_3 + a_4 + a_5"
]
},
"execution_count": 257,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eq1 = sympy.Add(*a)\n",
"eq1"
]
},
{
"cell_type": "code",
"execution_count": 258,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.683928Z",
"start_time": "2019-12-30T05:03:53.469470Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a_{2}^{2} - 2 a_{2} a_{3} + 2 a_{2} a_{5} + 2 a_{3}^{2} + 2 a_{3} a_{4} - 2 a_{3} a_{5} + a_{4}^{2} + a_{5}^{2}$"
],
"text/plain": [
"a_2**2 - 2*a_2*a_3 + 2*a_2*a_5 + 2*a_3**2 + 2*a_3*a_4 - 2*a_3*a_5 + a_4**2 + a_5**2"
]
},
"execution_count": 258,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eq = []\n",
"for i in range(x.shape[0]):\n",
" for j in range(x.shape[0]):\n",
" eq.append(a[i]*a[j]*y[i]*y[j]*(x[i]*x[j]).sum())\n",
"eq2 = sympy.Add(*eq)\n",
"eq2"
]
},
{
"cell_type": "code",
"execution_count": 259,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.694868Z",
"start_time": "2019-12-30T05:03:53.684935Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a_{1} - 0.5 a_{2}^{2} + 1.0 a_{2} a_{3} - 1.0 a_{2} a_{5} + a_{2} - 1.0 a_{3}^{2} - 1.0 a_{3} a_{4} + 1.0 a_{3} a_{5} + a_{3} - 0.5 a_{4}^{2} + a_{4} - 0.5 a_{5}^{2} + a_{5}$"
],
"text/plain": [
"a_1 - 0.5*a_2**2 + 1.0*a_2*a_3 - 1.0*a_2*a_5 + a_2 - 1.0*a_3**2 - 1.0*a_3*a_4 + 1.0*a_3*a_5 + a_3 - 0.5*a_4**2 + a_4 - 0.5*a_5**2 + a_5"
]
},
"execution_count": 259,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eq = eq1-1/2*eq2\n",
"eq"
]
},
{
"cell_type": "code",
"execution_count": 260,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.701888Z",
"start_time": "2019-12-30T05:03:53.695865Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1$"
],
"text/plain": [
"1"
]
},
"execution_count": 260,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.diff(eq,\"a_1\")"
]
},
{
"cell_type": "code",
"execution_count": 261,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.721829Z",
"start_time": "2019-12-30T05:03:53.702847Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - 1.0 a_{2} + 1.0 a_{3} - 1.0 a_{5} + 1$"
],
"text/plain": [
"-1.0*a_2 + 1.0*a_3 - 1.0*a_5 + 1"
]
},
"execution_count": 261,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.diff(eq,\"a_2\")"
]
},
{
"cell_type": "code",
"execution_count": 262,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.736783Z",
"start_time": "2019-12-30T05:03:53.722793Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 a_{2} - 2.0 a_{3} - 1.0 a_{4} + 1.0 a_{5} + 1$"
],
"text/plain": [
"1.0*a_2 - 2.0*a_3 - 1.0*a_4 + 1.0*a_5 + 1"
]
},
"execution_count": 262,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.diff(eq,\"a_3\")"
]
},
{
"cell_type": "code",
"execution_count": 263,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.747754Z",
"start_time": "2019-12-30T05:03:53.737753Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - 1.0 a_{3} - 1.0 a_{4} + 1$"
],
"text/plain": [
"-1.0*a_3 - 1.0*a_4 + 1"
]
},
"execution_count": 263,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.diff(eq,\"a_4\")"
]
},
{
"cell_type": "code",
"execution_count": 264,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - 1.0 a_{2} + 1.0 a_{3} - 1.0 a_{5} + 1$"
],
"text/plain": [
"-1.0*a_2 + 1.0*a_3 - 1.0*a_5 + 1"
]
},
"execution_count": 264,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.diff(eq,\"a_5\")"
]
},
{
"cell_type": "code",
"execution_count": 265,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.754711Z",
"start_time": "2019-12-30T05:03:53.748723Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a_{1} + a_{2} + a_{3} - a_{4} - a_{5}$"
],
"text/plain": [
"a_1 + a_2 + a_3 - a_4 - a_5"
]
},
"execution_count": 265,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.Add(*[y[i]*a[i] for i in range(len(a))])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 求 $\\alpha_i$\n",
"约束包括:\n",
"- $\\sum_{i=1}^{n} y_i \\alpha_i=0$\n",
"- $\\frac{\\partial A(\\alpha)}{\\partial\\alpha_i}=0$\n",
"\n",
"中间过程不等式约束需要手算"
]
},
{
"cell_type": "code",
"execution_count": 266,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.762686Z",
"start_time": "2019-12-30T05:03:53.755705Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[1,\n",
" -1.0*a_2 + 1.0*a_3 - 1.0*a_5 + 1,\n",
" 1.0*a_2 - 2.0*a_3 - 1.0*a_4 + 1.0*a_5 + 1,\n",
" -1.0*a_3 - 1.0*a_4 + 1,\n",
" -1.0*a_2 + 1.0*a_3 - 1.0*a_5 + 1,\n",
" a_1 + a_2 + a_3 - a_4 - a_5]"
]
},
"execution_count": 266,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lin = []\n",
"for i in range(len(a)):\n",
" lin.append(sympy.diff(eq,a[i]))\n",
"lin.append(sympy.Add(*[y[i]*a[i] for i in range(len(a))]))\n",
"lin"
]
},
{
"cell_type": "code",
"execution_count": 267,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.770665Z",
"start_time": "2019-12-30T05:03:53.763683Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - 1.0 a_{2} + 1.0 a_{3} - 1.0 a_{5} + 1$"
],
"text/plain": [
"-1.0*a_2 + 1.0*a_3 - 1.0*a_5 + 1"
]
},
"execution_count": 267,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lin[-2]"
]
},
{
"cell_type": "code",
"execution_count": 268,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.777647Z",
"start_time": "2019-12-30T05:03:53.771663Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a_{4}$"
],
"text/plain": [
"a_4"
]
},
"execution_count": 268,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a[3]"
]
},
{
"cell_type": "code",
"execution_count": 269,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:03:53.821559Z",
"start_time": "2019-12-30T05:03:53.779642Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left\\{\\left( 3.0 a_{4} + 2.0 a_{5} - 5.0, \\ - 1.0 a_{4} - 1.0 a_{5} + 3.0, \\ 2.0 - 1.0 a_{4}, \\ a_{4}, \\ a_{5}\\right)\\right\\}$"
],
"text/plain": [
"FiniteSet((3.0*a_4 + 2.0*a_5 - 5.0, -1.0*a_4 - 1.0*a_5 + 3.0, 2.0 - 1.0*a_4, a_4, a_5))"
]
},
"execution_count": 269,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sympy.linsolve([lin[2],lin[4],lin[5]],a)"
]
},
{
"cell_type": "code",
"execution_count": 243,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.20000000000000018, 0.0, 0.8, 0.2, 0.8)"
]
},
"execution_count": 243,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a4=0.2\n",
"a5=0.8\n",
"3*a4+2*a5-2,-a4-a5+1,1-a4,a4,a5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**不等式约束取一个合法值手算得出这个要人工看一下数据a1点不是支撑向量设置为0**"
]
},
{
"cell_type": "code",
"execution_count": 252,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:05:39.202328Z",
"start_time": "2019-12-30T05:05:39.197312Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[2, 0, 0, 1, 1]"
]
},
"execution_count": 252,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# a = [0.5,0,0.25,0.25] # a2一定是支撑向量这个解不合法!!!\n",
"a = [2,0,0,1,1]\n",
"a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 计算参数 $W^{*}=\\sum_{i=1}^{n}\\alpha_i y_i x_i$"
]
},
{
"cell_type": "code",
"execution_count": 253,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:05:40.901388Z",
"start_time": "2019-12-30T05:05:40.896401Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1, -1])"
]
},
"execution_count": 253,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w = 0\n",
"for i in range(x.shape[0]):\n",
" w += a[i]*y[i]*x[i]\n",
"w"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 计算参数b\n",
"$a_i$的值大于0对应于样本$i$是支持向量\n",
"\n",
"从支持向量的约束条件: $yi * ( W * Xi + b ) 1 = 0 $\n",
"\n",
"当 $y_i = +1$ 类时,有 $b^* = 1- W * X_i$ \n",
"\n",
"当 $y_j = - 1$ 类时, 有 $b^* = - 1- W * X_j$ \n",
"\n",
"合并两个公式\n",
"$b^*=-\\frac{1}{2}W^{*T}(x^+ + x^-)$\n",
"- $x^+$ 和 $x^-$ 是正负例各一个支持向量"
]
},
{
"cell_type": "code",
"execution_count": 254,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-30T05:09:07.821819Z",
"start_time": "2019-12-30T05:09:07.817801Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"0.5"
]
},
"execution_count": 254,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b = -1/2*(w*(x[0]+x[4])).sum()\n",
"b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 最优分类面:$W^{*T}x+b^*=0$\n",
" - 0.5x1 - 0.5x2 + 0.75 = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tensorflow",
"language": "python",
"name": "tensorflow"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}