Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Performance and Memory Tuning

Two questions decide how a model runs on accelerators: does it fit in memory, and are the devices used well. pylcm keeps them separate as two independent knobs on every grid — batch_size (splay) and distributed (shard) — plus a forward-simulation chunk size and a handful of XLA environment flags. This page explains what each does, when it helps, and the trade-offs that are easy to get backwards.

The one-line model:

Keeping these straight is the whole game: splaying never speeds anything up, and sharding is the only knob that does.

The two grid knobs

Every grid — DiscreteGrid and every continuous grid (LinSpacedGrid, LogSpacedGrid, IrregSpacedGrid, the piecewise variants) — takes both:

from lcm.grids import DiscreteGrid, LinSpacedGrid

# A permanent (never-transitioning) discrete state, sharded one block per device (speed):
pref_type = DiscreteGrid(PrefType, distributed=True)

# A continuous state, scan-chunked into pieces to save memory (time-neutral):
assets = LinSpacedGrid(start=0.0, stop=1_000.0, n_points=200, batch_size=50)
knobwhat it doeswhat it buysapplies to
batch_size=k (splay)lax.scan the per-period work over chunks of k points along that axislower peak memoryany axis
distributed=True (shard)place that axis’s blocks on separate devicesparallel speedupdiscrete, non-transitioning axes only

batch_size=0 (the default) means “no splay” — one kernel per period over the full axis. distributed=False (the default) means “not sharded”.

batch_size: splay for memory, time-neutral

At each period, backward induction builds the value array over every (state, action) combination and maximises over actions. batch_size=k only changes how that work is tiled: instead of one big vmap, it runs a lax.scan over chunks of k points along the chosen axis. The total FLOPs are identical — every combination is still evaluated exactly once — so the wall-clock barely moves. What drops is peak memory, because only one chunk’s intermediate is live at a time.

Splay stays time-neutral as long as each chunk still has enough parallel work to saturate the device — and in a real model it does, because the other grid dimensions (assets × savings × shocks × …) provide ample parallelism inside every chunk.

It stops being free only at the extremes:

Which axis to splay. Prefer a large, uniform axis:

Rule: use the fewest chunks that fit. Halving memory needs only two chunks (batch_size = n_points / 2), not batch_size = 1.

distributed: shard for speed (discrete, non-transitioning axes)

distributed=True places the blocks of an axis on separate devices and solves them in parallel. It is the only knob that reduces wall-clock — but it is legal only for a narrow class of axes, and pylcm enforces the boundaries at construction time.

It runs communication-free only for axes the agent never transitions along. If an agent’s position on the axis is fixed for life (a permanent type, a fixed group), each block’s value function is independent of the others, so the blocks sit on different devices with zero cross-device traffic. An axis the agent moves along (health, wealth, a lagged choice) couples the blocks: every period would need an all-to-all exchange, and the communication swamps the compute.

Two guards make this concrete — both raise GridInitializationError at construction:

Forward simulation: subject_batch_size

Solving is one memory profile; simulating a large panel forward is another. Model.simulate(..., subject_batch_size=k) chunks the simulated subjects so only one chunk is resident at a time:

Like grid batch_size, this is a time-neutral memory knob — raise the chunk count if the simulated panel does not fit, and otherwise leave it at a single pass.

Worked example

Measured on 80 GB A100s, one six-regime lifecycle model:

The takeaway is the one-line model: the multiplicative speedup comes from sharding across devices, not from any choice of batch_size.

Environment flags

pylcm sets two JAX defaults at import and leaves the rest to the environment.

Set by pylcm (override before importing lcm):

Knobs you set yourself, with the trade-off each carries:

A stable multi-GPU configuration. One environment that holds up at production scale, trading compile-time kernel search and launch batching for memory headroom:

export XLA_PYTHON_CLIENT_PREALLOCATE=true
export XLA_PYTHON_CLIENT_ALLOCATOR=default          # pooled BFC
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.90
export XLA_FLAGS='--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer='

Command buffers are the one knob to revisit once a model fits comfortably: re-enabling them amortizes launch overhead, at the cost of the non-pool driver memory they consume. Autotuning, by contrast, has not been observed to speed these gather-bound solves, so leaving it off costs nothing and keeps the memory headroom.

Checklist