Discrete Bayes Net example with famous Asia Bayes Network
import gtsam
from gtsam import (DiscreteBayesNet, DiscreteFactorGraph, DiscreteKeys,
Ordering)Helper Functions¶
def create_discrete_keys(*args):
"""Create a DiscreteKeys instance from a variable number of DiscreteKey pairs."""
dks = DiscreteKeys()
for key in args:
dks.push_back(key)
return dksAsia Bayes Network Example¶
This example demonstrates the famous Asia Bayes Network using discrete Bayes nets.
The Asia network is a classic example in probabilistic reasoning that models relationships between:
Visiting Asia (travel history)
Smoking habits
Diseases: Tuberculosis, Lung Cancer, Bronchitis
Symptoms: Dyspnea (shortness of breath)
Tests: X-Ray results
The network shows how these variables are conditionally dependent on each other.
asia = DiscreteBayesNet()
# Define discrete keys for each variable (key_id, num_states)
Asia = (0, 2) # Been to Asia: No=0, Yes=1
Smoking = (4, 2) # Smoking: No=0, Yes=1
Tuberculosis = (3, 2) # Tuberculosis: No=0, Yes=1
LungCancer = (6, 2) # Lung Cancer: No=0, Yes=1
Bronchitis = (7, 2) # Bronchitis: No=0, Yes=1
Either = (5, 2) # Either TB or LC: No=0, Yes=1
XRay = (2, 2) # X-Ray positive: No=0, Yes=1
Dyspnea = (1, 2) # Dyspnea: No=0, Yes=1# Add prior probabilities
asia.add(Asia, "99/1") # P(Asia) = [0.99, 0.01]
asia.add(Smoking, "50/50") # P(Smoking) = [0.5, 0.5]
# Add conditional probabilities
# P(Tuberculosis | Asia)
asia.add(Tuberculosis, create_discrete_keys(Asia), "99/1 95/5")
# P(LungCancer | Smoking)
asia.add(LungCancer, create_discrete_keys(Smoking), "99/1 90/10")
# P(Bronchitis | Smoking)
asia.add(Bronchitis, create_discrete_keys(Smoking), "70/30 40/60")
# P(Either | Tuberculosis, LungCancer) - OR gate: Either = TB OR LC
# "F T T T" means: P(Either=1|TB,LC) = [False, True, True, True]
# for combinations (TB=0,LC=0), (TB=0,LC=1), (TB=1,LC=0), (TB=1,LC=1)
asia.add(Either, create_discrete_keys(Tuberculosis, LungCancer), "F T T T")
# P(XRay | Either)
asia.add(XRay, create_discrete_keys(Either), "95/5 2/98")
# P(Dyspnea | Either, Bronchitis)
asia.add(Dyspnea, create_discrete_keys(Either, Bronchitis), "9/1 2/8 3/7 1/9")# Print the network with pretty variable names
pretty_names = ["Asia", "Dyspnea", "XRay", "Tuberculosis",
"Smoking", "Either", "LungCancer", "Bronchitis"]
def formatter(key):
return pretty_names[key]
asia.print("Asia", formatter)Asia
size: 8
conditional 0: P( Asia ):
Choice(Asia)
0 Leaf 0.99
1 Leaf 0.01
conditional 1: P( Smoking ):
Leaf 0.5
conditional 2: P( Tuberculosis | Asia ):
Choice(Tuberculosis)
0 Choice(Asia)
0 0 Leaf 0.99
0 1 Leaf 0.95
1 Choice(Asia)
1 0 Leaf 0.01
1 1 Leaf 0.05
conditional 3: P( LungCancer | Smoking ):
Choice(LungCancer)
0 Choice(Smoking)
0 0 Leaf 0.99
0 1 Leaf 0.9
1 Choice(Smoking)
1 0 Leaf 0.01
1 1 Leaf 0.1
conditional 4: P( Bronchitis | Smoking ):
Choice(Bronchitis)
0 Choice(Smoking)
0 0 Leaf 0.7
0 1 Leaf 0.4
1 Choice(Smoking)
1 0 Leaf 0.3
1 1 Leaf 0.6
conditional 5: P( Either | Tuberculosis LungCancer ):
Choice(LungCancer)
0 Choice(Either)
0 0 Choice(Tuberculosis)
0 0 0 Leaf 1
0 0 1 Leaf 0
0 1 Choice(Tuberculosis)
0 1 0 Leaf 0
0 1 1 Leaf 1
1 Choice(Either)
1 0 Leaf 0
1 1 Leaf 1
conditional 6: P( XRay | Either ):
Choice(Either)
0 Choice(XRay)
0 0 Leaf 0.95
0 1 Leaf 0.05
1 Choice(XRay)
1 0 Leaf 0.02
1 1 Leaf 0.98
conditional 7: P( Dyspnea | Either Bronchitis ):
Choice(Bronchitis)
0 Choice(Either)
0 0 Choice(Dyspnea)
0 0 0 Leaf 0.9
0 0 1 Leaf 0.1
0 1 Choice(Dyspnea)
0 1 0 Leaf 0.3
0 1 1 Leaf 0.7
1 Choice(Either)
1 0 Choice(Dyspnea)
1 0 0 Leaf 0.2
1 0 1 Leaf 0.8
1 1 Choice(Dyspnea)
1 1 0 Leaf 0.1
1 1 1 Leaf 0.9
Convert to Factor Graph and solve¶
# Convert to factor graph
fg = DiscreteFactorGraph(asia)
# Create elimination ordering
ordering = Ordering()
for i in [0, 1, 2, 3, 4, 5, 6, 7]:
ordering.push_back(i)
# Solve for most probable explanation (MPE)
mpe = fg.optimize()
print("mpe:", end="")
for i in range(8):
print(f" ({i}, {mpe[i]})", end="")
print()
# Build a Bayes tree (directed junction tree)
bayes_tree = fg.eliminateMultifrontal(ordering)
bayes_tree.print("bayesTree", formatter)mpe: (0, 0) (1, 0) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 0)
bayesTree: cliques: 6, variables: 8
bayesTree- P( Smoking LungCancer Bronchitis ):
Choice(Bronchitis)
0 Choice(LungCancer)
0 0 Choice(Smoking)
0 0 0 Leaf 0.3465
0 0 1 Leaf 0.18
0 1 Choice(Smoking)
0 1 0 Leaf 0.0035
0 1 1 Leaf 0.02
1 Choice(LungCancer)
1 0 Choice(Smoking)
1 0 0 Leaf 0.1485
1 0 1 Leaf 0.27
1 1 Choice(Smoking)
1 1 0 Leaf 0.0015
1 1 1 Leaf 0.03
bayesTree| - P( Either | LungCancer Bronchitis ):
Choice(LungCancer)
0 Choice(Either)
0 0 Leaf 0.9896
0 1 Leaf 0.0104
1 Choice(Either)
1 0 Leaf 0
1 1 Leaf 1
bayesTree| | - P( Tuberculosis | Either LungCancer ):
Choice(LungCancer)
0 Choice(Either)
0 0 Choice(Tuberculosis)
0 0 0 Leaf 1
0 0 1 Leaf 0
0 1 Choice(Tuberculosis)
0 1 0 Leaf 0
0 1 1 Leaf 1
1 Choice(Either)
1 0 Leaf 0
1 1 Choice(Tuberculosis)
1 1 0 Leaf 0.9896
1 1 1 Leaf 0.0104
bayesTree| | | - P( Asia | Tuberculosis ):
Choice(Tuberculosis)
0 Choice(Asia)
0 0 Leaf 0.99040016
0 1 Leaf 0.0095998383
1 Choice(Asia)
1 0 Leaf 0.95192308
1 1 Leaf 0.048076923
bayesTree| | - P( XRay | Either ):
Choice(Either)
0 Choice(XRay)
0 0 Leaf 0.95
0 1 Leaf 0.05
1 Choice(XRay)
1 0 Leaf 0.02
1 1 Leaf 0.98
bayesTree| | - P( Dyspnea | Either Bronchitis ):
Choice(Bronchitis)
0 Choice(Either)
0 0 Choice(Dyspnea)
0 0 0 Leaf 0.9
0 0 1 Leaf 0.1
0 1 Choice(Dyspnea)
0 1 0 Leaf 0.3
0 1 1 Leaf 0.7
1 Choice(Either)
1 0 Choice(Dyspnea)
1 0 0 Leaf 0.2
1 0 1 Leaf 0.8
1 1 Choice(Dyspnea)
1 1 0 Leaf 0.1
1 1 1 Leaf 0.9
Add evidence, solve again¶
# Add evidence: we were in Asia and we have dyspnea
fg.add(Asia, "0 1") # Evidence: Asia = 1 (Yes, been to Asia)
fg.add(Dyspnea, "0 1") # Evidence: Dyspnea = 1 (Yes, have dyspnea)
# Solve again with evidence
mpe2 = fg.optimize()
print("mpe2:", end="")
for i in range(8):
print(f" ({i}, {mpe2[i]})", end="")
print()mpe2: (0, 1) (1, 1) (2, 0) (3, 0) (4, 1) (5, 0) (6, 0) (7, 1)
Sample from Posterior Distribution¶
chordal = fg.eliminateSequential(ordering)
print("\n10 samples:")
for i in range(10):
sample = chordal.sample()
print("sample:", end="")
for j in range(8):
print(f" ({j}, {sample[j]})", end="")
print()
10 samples:
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 1) (5, 0) (6, 0) (7, 1)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 0)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 1) (5, 0) (6, 0) (7, 0)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 1)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 1)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 0)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 1)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 1)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 0) (5, 0) (6, 0) (7, 1)
sample: (0, 1) (1, 1) (2, 0) (3, 0) (4, 1) (5, 0) (6, 0) (7, 1)