Advanced_artificial_intelli.../贝叶斯网络第一题.ipynb

658 lines
20 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 题目\n",
"1.现有两台服务器(S1, S2) ,都会单向向用户 U 传送数据。服务器 S1 和 S2 之间也会有数据通讯但无法确定它们之间的数据流向。数据包的传送只取两种可能值T=1 ( 成功 ) 或 F=2 ( 失败 )。假设贝叶斯网络由S1、S2和U这三个节点构成现采集了100条该网络的数据传送样本如文件 server_data.txt 所给出。该文件中,每行代表一个三节点网络的样本, 试利用贝叶斯算法学习得到该网络的结构和参数。( 30分 )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 参考资料\n",
"参考学习网址 https://blog.csdn.net/leida_wt/article/details/88743323\n",
"\n",
"自动设计网络结构的核心问题有两个一个是评价网络好坏的指标另一个是查找的方法。穷举是不可取的因为组合数太大只能是利用各种启发式方法或是限定搜索条件以减少搜索空间因此产生两大类方法Score-based Structure Learning与constraint-based structure learning 以及他们的结合hybrid structure learning。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.433710Z",
"start_time": "2019-12-29T02:45:08.854096Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from pgmpy.models import BayesianModel\n",
"from pgmpy.estimators import MaximumLikelihoodEstimator, BayesianEstimator\n",
"from pgmpy.estimators import BdeuScore, K2Score, BicScore\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.452516Z",
"start_time": "2019-12-29T02:45:11.433710Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>S1</th>\n",
" <th>S2</th>\n",
" <th>U</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" S1 S2 U\n",
"0 1 2 1\n",
"1 2 2 2\n",
"2 2 1 1\n",
"3 2 1 1\n",
"4 2 1 1\n",
".. .. .. ..\n",
"95 2 1 1\n",
"96 2 1 1\n",
"97 2 1 1\n",
"98 2 1 1\n",
"99 2 1 1\n",
"\n",
"[100 rows x 3 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_list = []\n",
"with open('server_data.txt') as f:\n",
" lines = f.readlines()\n",
" for line in lines:\n",
" data_list.append(line.strip().split())\n",
"data_list = np.array(data_list, dtype=np.int32)\n",
"data = pd.DataFrame(data_list, columns=['S1', 'S2', 'U'])\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.459497Z",
"start_time": "2019-12-29T02:45:11.454480Z"
}
},
"outputs": [],
"source": [
"def showBN(model, save=False):\n",
" '''传入BayesianModel对象调用graphviz绘制结构图jupyter中可直接显示'''\n",
" from graphviz import Digraph\n",
" node_attr = dict(\n",
" style='filled',\n",
" shape='box',\n",
" align='left',\n",
" fontsize='12',\n",
" ranksep='0.1',\n",
" height='0.2'\n",
" )\n",
" dot = Digraph(node_attr=node_attr, graph_attr=dict(size=\"12,12\"))\n",
" seen = set()\n",
" edges = model.edges()\n",
" for a, b in edges:\n",
" dot.edge(a, b)\n",
" if save:\n",
" dot.view(cleanup=True)\n",
" return dot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 根据题目要求分别定义出两种可能的网络"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.608070Z",
"start_time": "2019-12-29T02:45:11.461462Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: %3 Pages: 1 -->\r\n",
"<svg width=\"90pt\" height=\"143pt\"\r\n",
" viewBox=\"0.00 0.00 90.00 143.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 139)\">\r\n",
"<title>%3</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-139 86,-139 86,4 -4,4\"/>\r\n",
"<!-- S1 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>S1</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"54,-135 0,-135 0,-114 54,-114 54,-135\"/>\r\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-121.4\" font-family=\"Times New Roman,serif\" font-size=\"12.00\">S1</text>\r\n",
"</g>\r\n",
"<!-- U -->\r\n",
"<g id=\"node2\" class=\"node\"><title>U</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"54,-21 0,-21 0,-0 54,-0 54,-21\"/>\r\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-7.4\" font-family=\"Times New Roman,serif\" font-size=\"12.00\">U</text>\r\n",
"</g>\r\n",
"<!-- S1&#45;&gt;U -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>S1&#45;&gt;U</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M24.7894,-113.802C22.8203,-104.634 20.0874,-90.4646 19,-78 18.1888,-68.702 18.1888,-66.298 19,-57 19.7476,-48.4306 21.2729,-39.0555 22.7931,-31.0538\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"26.2344,-31.6942 24.7894,-21.1984 19.3737,-30.3045 26.2344,-31.6942\"/>\r\n",
"</g>\r\n",
"<!-- S2 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>S2</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"82,-78 28,-78 28,-57 82,-57 82,-78\"/>\r\n",
"<text text-anchor=\"middle\" x=\"55\" y=\"-64.4\" font-family=\"Times New Roman,serif\" font-size=\"12.00\">S2</text>\r\n",
"</g>\r\n",
"<!-- S1&#45;&gt;S2 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>S1&#45;&gt;S2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M31.8772,-113.92C35.6053,-106.597 40.8621,-96.2709 45.4216,-87.3147\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"48.5723,-88.8404 49.9901,-78.3408 42.3342,-85.6646 48.5723,-88.8404\"/>\r\n",
"</g>\r\n",
"<!-- S2&#45;&gt;U -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>S2&#45;&gt;U</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M50.1228,-56.9197C46.3947,-49.5967 41.1379,-39.2709 36.5784,-30.3147\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"39.6658,-28.6646 32.0099,-21.3408 33.4277,-31.8404 39.6658,-28.6646\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x17515213780>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_1 = BayesianModel([('S1', 'U'), ('S2', 'U'), ('S1', 'S2')])\n",
"model_1.fit(data)\n",
"showBN(model_1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.709798Z",
"start_time": "2019-12-29T02:45:11.609066Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: %3 Pages: 1 -->\r\n",
"<svg width=\"89pt\" height=\"143pt\"\r\n",
" viewBox=\"0.00 0.00 89.00 143.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 139)\">\r\n",
"<title>%3</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-139 85,-139 85,4 -4,4\"/>\r\n",
"<!-- S1 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>S1</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"54,-78 0,-78 0,-57 54,-57 54,-78\"/>\r\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-64.4\" font-family=\"Times New Roman,serif\" font-size=\"12.00\">S1</text>\r\n",
"</g>\r\n",
"<!-- U -->\r\n",
"<g id=\"node2\" class=\"node\"><title>U</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"81,-21 27,-21 27,-0 81,-0 81,-21\"/>\r\n",
"<text text-anchor=\"middle\" x=\"54\" y=\"-7.4\" font-family=\"Times New Roman,serif\" font-size=\"12.00\">U</text>\r\n",
"</g>\r\n",
"<!-- S1&#45;&gt;U -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>S1&#45;&gt;U</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M31.703,-56.9197C35.2597,-49.6746 40.2593,-39.4903 44.6231,-30.601\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47.9041,-31.8599 49.169,-21.3408 41.6204,-28.7751 47.9041,-31.8599\"/>\r\n",
"</g>\r\n",
"<!-- S2 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>S2</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"81,-135 27,-135 27,-114 81,-114 81,-135\"/>\r\n",
"<text text-anchor=\"middle\" x=\"54\" y=\"-121.4\" font-family=\"Times New Roman,serif\" font-size=\"12.00\">S2</text>\r\n",
"</g>\r\n",
"<!-- S2&#45;&gt;S1 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>S2&#45;&gt;S1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M49.297,-113.92C45.7403,-106.675 40.7407,-96.4903 36.3769,-87.601\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"39.3796,-85.7751 31.831,-78.3408 33.0959,-88.8599 39.3796,-85.7751\"/>\r\n",
"</g>\r\n",
"<!-- S2&#45;&gt;U -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>S2&#45;&gt;U</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M56.4864,-113.819C58.7013,-104.663 61.7754,-90.5018 63,-78 63.9099,-68.7111 63.9099,-66.2889 63,-57 62.1581,-48.405 60.442,-39.0257 58.7319,-31.0279\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"62.1223,-30.1528 56.4864,-21.1813 55.2975,-31.7092 62.1223,-30.1528\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x175152c2128>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_2 = BayesianModel([('S1', 'U'), ('S2', 'U'), ('S2', 'S1')])\n",
"model_2.fit(data)\n",
"showBN(model_2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 评分函数使用k2bdeubic进行评分"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.715781Z",
"start_time": "2019-12-29T02:45:11.710794Z"
}
},
"outputs": [],
"source": [
"bdeu = BdeuScore(data, equivalent_sample_size=5)\n",
"k2 = K2Score(data)\n",
"bic = BicScore(data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.775621Z",
"start_time": "2019-12-29T02:45:11.716779Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-127.81019191674014\n",
"-130.82411202574002\n",
"-129.03972756462477\n"
]
}
],
"source": [
"print(bdeu.score(model_1))\n",
"print(k2.score(model_1))\n",
"print(bic.score(model_1))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.811526Z",
"start_time": "2019-12-29T02:45:11.777617Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-127.81019191674014\n",
"-130.99837093511061\n",
"-129.0397275646248\n"
]
}
],
"source": [
"print(bdeu.score(model_2))\n",
"print(k2.score(model_2))\n",
"print(bic.score(model_2))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.818531Z",
"start_time": "2019-12-29T02:45:11.812523Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bdeu.score(model_1)>bdeu.score(model_2)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.826510Z",
"start_time": "2019-12-29T02:45:11.819504Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k2.score(model_1)>k2.score(model_2)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.837484Z",
"start_time": "2019-12-29T02:45:11.827483Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bic.score(model_1)>bic.score(model_2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 查看模型的概率转移表"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.845461Z",
"start_time": "2019-12-29T02:45:11.838453Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-------+------+\n",
"| S1(1) | 0.28 |\n",
"+-------+------+\n",
"| S1(2) | 0.72 |\n",
"+-------+------+\n",
"+-------+---------------------+-------+\n",
"| S1 | S1(1) | S1(2) |\n",
"+-------+---------------------+-------+\n",
"| S2(1) | 0.17857142857142858 | 0.75 |\n",
"+-------+---------------------+-------+\n",
"| S2(2) | 0.8214285714285714 | 0.25 |\n",
"+-------+---------------------+-------+\n",
"+------+-------+-------+-------+-------+\n",
"| S1 | S1(1) | S1(1) | S1(2) | S1(2) |\n",
"+------+-------+-------+-------+-------+\n",
"| S2 | S2(1) | S2(2) | S2(1) | S2(2) |\n",
"+------+-------+-------+-------+-------+\n",
"| U(1) | 0.0 | 1.0 | 1.0 | 0.0 |\n",
"+------+-------+-------+-------+-------+\n",
"| U(2) | 1.0 | 0.0 | 0.0 | 1.0 |\n",
"+------+-------+-------+-------+-------+\n"
]
}
],
"source": [
"print(model_1.get_cpds('S1'))\n",
"print(model_1.get_cpds('S2'))\n",
"print(model_1.get_cpds('U'))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2019-12-29T02:45:11.854410Z",
"start_time": "2019-12-29T02:45:11.846432Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-------+--------------------+---------------------+\n",
"| S2 | S2(1) | S2(2) |\n",
"+-------+--------------------+---------------------+\n",
"| S1(1) | 0.0847457627118644 | 0.5609756097560976 |\n",
"+-------+--------------------+---------------------+\n",
"| S1(2) | 0.9152542372881356 | 0.43902439024390244 |\n",
"+-------+--------------------+---------------------+\n",
"+-------+------+\n",
"| S2(1) | 0.59 |\n",
"+-------+------+\n",
"| S2(2) | 0.41 |\n",
"+-------+------+\n",
"+------+-------+-------+-------+-------+\n",
"| S1 | S1(1) | S1(1) | S1(2) | S1(2) |\n",
"+------+-------+-------+-------+-------+\n",
"| S2 | S2(1) | S2(2) | S2(1) | S2(2) |\n",
"+------+-------+-------+-------+-------+\n",
"| U(1) | 0.0 | 1.0 | 1.0 | 0.0 |\n",
"+------+-------+-------+-------+-------+\n",
"| U(2) | 1.0 | 0.0 | 0.0 | 1.0 |\n",
"+------+-------+-------+-------+-------+\n"
]
}
],
"source": [
"print(model_2.get_cpds('S1'))\n",
"print(model_2.get_cpds('S2'))\n",
"print(model_2.get_cpds('U'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 结论\n",
"分数差距不是很大,说明对这组数据来说,题目假定的两种网络的区分度不够高,说明这两种网络的结构可能性都很大。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tensorflow",
"language": "python",
"name": "tensorflow"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}