A DiscreteFactorGraph is a factor graph consisting of DiscreteFactors, such as DecisionTreeFactor. It represents a joint probability distribution (or a general function) over a set of discrete variables as a product of factors:
where each is a factor over a subset of the variables (its scope).
This is an undirected graphical model (a Markov Random Field), and it is used to perform inference tasks like finding the most likely state (optimize) or computing marginals (via elimination).
import gtsam
import numpy as np
import graphviz
from gtsam.symbol_shorthand import X, Y, ZCreating a DiscreteFactorGraph¶
A DiscreteFactorGraph is created by adding DiscreteFactors. We will create a simple graph with three variables and three factors.
# Define keys for three binary variables
KeyX = (X(0), 2)
KeyY = (Y(0), 2)
KeyZ = (Z(0), 2)
# Create an empty Factor Graph
dfg = gtsam.DiscreteFactorGraph()
# Add a unary factor on X (a prior)
dfg.add(KeyX, "0.3 0.7")
# Add a binary factor on X and Y
# Values for (X,Y) = (0,0), (0,1), (1,0), (1,1)
dfg.add([KeyX, KeyY], "0.5 0.2 0.1 0.9")
# Add a binary factor on Y and Z
# Values for (Y,Z) = (0,0), (0,1), (1,0), (1,1)
dfg.add([KeyY, KeyZ], "0.8 0.1 0.3 0.6")
print("Manually Constructed DiscreteFactorGraph:")
dfg.print()Manually Constructed DiscreteFactorGraph:
size: 3
factor 0: f[ (x0,2), ]
Choice(x0)
0 Leaf 0.3
1 Leaf 0.7
factor 1: f[ (x0,2), (y0,2), ]
Choice(y0)
0 Choice(x0)
0 0 Leaf 0.5
0 1 Leaf 0.1
1 Choice(x0)
1 0 Leaf 0.2
1 1 Leaf 0.9
factor 2: f[ (y0,2), (z0,2), ]
Choice(z0)
0 Choice(y0)
0 0 Leaf 0.8
0 1 Leaf 0.3
1 Choice(y0)
1 0 Leaf 0.1
1 1 Leaf 0.6
# Visualize the Factor Graph structure
# Circles are variables, squares are factors.
graphviz.Source(dfg.dot())Loading...
Rich Display in Jupyter¶
A factor graph can be displayed as a list of its component factor tables.
dfgLoading...
Inference on a DiscreteFactorGraph¶
# --- Evaluation ---
# Calculate the product of all factor values for a given assignment.
values = gtsam.DiscreteValues()
values[X(0)] = 1
values[Y(0)] = 1
values[Z(0)] = 0
# P(X=1,Y=1,Z=0) ∝ f1(1) * f2(1,1) * f3(1,0)
# = 0.7 * 0.9 * 0.3 = 0.189
prob = dfg(values)
print(f"Unnormalized probability at (X=1, Y=1, Z=0): {prob:.4f}")
# --- Optimization (MPE) ---
# Find the assignment with the highest product of probabilities.
# This is done using the max-product algorithm (belief propagation).
mpe_solution = dfg.optimize()
print("\nMost Probable Explanation (MPE) Solution:")
print(mpe_solution)Unnormalized probability at (X=1, Y=1, Z=0): 0.1890
Most Probable Explanation (MPE) Solution:
DiscreteValues{8646911284551352320: 1, 8718968878589280256: 1, 8791026472627208192: 1}
# --- Elimination (Sum-Product) ---
# Eliminate variables from the graph to obtain a Bayes Net (or just a marginal).
# This computes the marginal probability distribution.
ordering = gtsam.Ordering()
ordering.push_back(Z(0))
ordering.push_back(X(0))
ordering.push_back(Y(0))
# Eliminating the graph produces a Bayes Net P(X,Y,Z) = P(Z|Y)P(Y|X)P(X)
bayes_net = dfg.eliminateSequential(ordering)
print("--- Resulting Bayes Net after elimination ---")
bayes_netLoading...
# Visualize the Bayes net structure using graphviz
graphviz.Source(bayes_net.dot())Loading...
# --- Elimination into a Bayes tree (Sum-Product) ---
# We can also produce a Bayes tree, which is typically more efficient to do.
bayes_tree = dfg.eliminateMultifrontal(ordering)
print("--- Resulting Bayes Tree after elimination ---")
bayes_treeLoading...
# Visualize the Bayes tree structure using graphviz
graphviz.Source(bayes_tree.dot())Loading...