Skip to content

Differentiable Discrete Sampling using AgentTorch

Introduction

Traditional neural networks excel at continuous computations but struggle with discrete sampling operations because the sampling process is inherently non-differentiable. This tutorial demonstrates how to use AgentTorch to implement differentiable discrete sampling, enabling gradient-based optimization through stochastic discrete operations.

Key Concepts:

  • Discrete Sampling: Operations that produce discrete outcomes (e.g., Bernoulli trials, categorical choices)
  • Differentiable Relaxations: Techniques to estimate gradients through non-differentiable operations
  • Straight-Through Estimator: A simple gradient approximation that copies gradients from output to input
  • Stochastic Triples Method: More sophisticated gradient estimation using probability-aware weightings

Example: Markovian Random Walk with Categorical sampling

Let's implement a 1D markovian random walk X0, X1, ...., Xn using the categorical sampling. The agent can move left or right with probabilites:

  • Xn+1 = Xn + 1 with probability e^(-Xn/p)
  • Xn+1 = Xn - 1 with probability 1 - e^(-Xn/p)

First, lets import the important modules:

import torch
import math
from agent_torch.core.distributions import Categorical

We are interested in studying the asymptotic behavior of the variance of our automatically derived gradient estimator, and so set p = n so that the transition function varies appreciably over the range of the walk for all n.

Let's define the main function:

def simulate_markovian_random_walk_categorical(n, p, device):
    x = 0.0  # initial state
    path = [0.0]
    for _ in range(n):
        # Compute the probability of moving up.
        q = math.exp(-x / p)
        prob = torch.tensor([q, 1.0 - q], dtype=torch.float32, device=device).unsqueeze(0)  
        # Sample an action using the custom Categorical function.
        sample = Categorical.apply(prob)  
        move = 1 if sample.item() == 0 else -1
        # if at x==0, a downward move is overridden, since probability for going up is 1.
        if x == 0 and move == -1:
            move = 1
        x += move
        path.append(x)
    return path

This random walk can be generated by:

n = 20  # A 20 step simulation
p = n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random_path = random_walk_categorical(n,p,device)

# This random walk looks like [0,1,2,1,...]

Conclusion

This tutorial demonstrated how to implement and use differentiable discrete sampling operations using AgentTorch.