Skip to article frontmatterSkip to article content

The Discrete Bayes Tree

An example of building a Bayes net, then eliminating it into a Bayes tree. Mirrors the code in testDiscreteBayesTree.cpp .

Open In Colab

from gtsam import DiscreteBayesTree, DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering
from gtsam.symbol_shorthand import S
def P(*args):
    """ Create a DiscreteKeys instances from a variable number of DiscreteKey pairs."""
    #TODO: We can make life easier by providing variable argument functions in C++ itself.
    dks = DiscreteKeys()
    for key in args:
        dks.push_back(key)
    return dks
import graphviz
class show(graphviz.Source):
    """ Display an object with a dot method as a graph."""

    def __init__(self, obj):
        """Construct from object with 'dot' method."""
        # This small class takes an object, calls its dot function, and uses the
        # resulting string to initialize a graphviz.Source instance. This in turn
        # has a _repr_mimebundle_ method, which then renders it in the notebook.
        super().__init__(obj.dot())
# Define DiscreteKey pairs.
keys = [(j, 2) for j in range(15)]

# Create thin-tree Bayesnet.
bayesNet = DiscreteBayesNet()


bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1")
bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4")
bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1")
bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1")

bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1")
bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4")
bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1")
bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1")

bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1")
bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4")
bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1")
bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1")

bayesNet.add(keys[12], P(keys[14]), "3/1 3/1")
bayesNet.add(keys[13], P(keys[14]), "1/3 3/1")

bayesNet.add(keys[14], P(), "1/3")

show(bayesNet)
Loading...
# Sample Bayes net (needs conditionals added in elimination order!)
for i in range(5):
    print(bayesNet.sample())
DiscreteValues{0: 1, 1: 1, 2: 0, 3: 1, 4: 1, 5: 1, 6: 0, 7: 1, 8: 0, 9: 0, 10: 0, 11: 0, 12: 1, 13: 1, 14: 0}
DiscreteValues{0: 0, 1: 1, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}
DiscreteValues{0: 1, 1: 0, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1, 9: 0, 10: 1, 11: 1, 12: 0, 13: 1, 14: 0}
DiscreteValues{0: 1, 1: 1, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1, 8: 0, 9: 1, 10: 0, 11: 0, 12: 1, 13: 0, 14: 1}
DiscreteValues{0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 1, 6: 1, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}
# Create a factor graph out of the Bayes net.
factorGraph = DiscreteFactorGraph(bayesNet)
show(factorGraph)
Loading...
# Create a BayesTree out of the factor graph.
ordering = Ordering()
for j in range(15): ordering.push_back(j)
bayesTree = factorGraph.eliminateMultifrontal(ordering)
show(bayesTree)
Loading...