Differentiable Discrete Sampling using AgentTorch
Introduction
Traditional neural networks excel at continuous computations but struggle with discrete sampling operations because the sampling process is inherently non-differentiable. This tutorial demonstrates how to use AgentTorch to implement differentiable discrete sampling, enabling gradient-based optimization through stochastic discrete operations.
Key Concepts:
- Discrete Sampling: Operations that produce discrete outcomes (e.g., Bernoulli trials, categorical choices)
- Differentiable Relaxations: Techniques to estimate gradients through non-differentiable operations
- Straight-Through Estimator: A simple gradient approximation that copies gradients from output to input
- Stochastic Triples Method: More sophisticated gradient estimation using probability-aware weightings
Example: Markovian Random Walk with Categorical sampling
Let's implement a 1D markovian random walk X0, X1, ...., Xn using the categorical sampling. The agent can move left or right with probabilites:
- Xn+1 = Xn + 1 with probability e^(-Xn/p)
- Xn+1 = Xn - 1 with probability 1 - e^(-Xn/p)
First, lets import the important modules:
import torch
import math
from agent_torch.core.distributions import Categorical
We are interested in studying the asymptotic behavior of the variance of our automatically derived gradient estimator, and so set p = n so that the transition function varies appreciably over the range of the walk for all n.
Let's define the main function:
def simulate_markovian_random_walk_categorical(n, p, device):
x = 0.0 # initial state
path = [0.0]
for _ in range(n):
# Compute the probability of moving up.
q = math.exp(-x / p)
prob = torch.tensor([q, 1.0 - q], dtype=torch.float32, device=device).unsqueeze(0)
# Sample an action using the custom Categorical function.
sample = Categorical.apply(prob)
move = 1 if sample.item() == 0 else -1
# if at x==0, a downward move is overridden, since probability for going up is 1.
if x == 0 and move == -1:
move = 1
x += move
path.append(x)
return path
This random walk can be generated by:
n = 20 # A 20 step simulation
p = n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random_path = random_walk_categorical(n,p,device)
# This random walk looks like [0,1,2,1,...]
Conclusion
This tutorial demonstrated how to implement and use differentiable discrete sampling operations using AgentTorch.