AlphaStar Architecture
Posted on
TL;DR
AlphaStar’s recipe is pragmatic:
(1) structure the state along what it really is (scalars, entities, spatial).
(2) narrow the combinatorial action space with an autoregressive head stack.
(3) keep rich side-channels (pointer keys, map skips) so heads don’t have to squeeze everything through a single vector.
Relational Inductive Bias
Fully-connected nets are expressive, but when locality or relations dominate, you win by choosing structure that matches the world. Images reward convolution; multi-entity worlds reward set reasoning and spatial fusion. The key claim here: the architecture itself acts like a constraint, nudging the model to extract the right relations instead of memorizing spurious ones.
Encoders: three views of the world
1) Scalar Encoder (Dense, with context gating)
- Inputs:
Agent stats, race, upgrade flags, time (positional), valid-action indicators, unit-count transforms (e.g., sqrt/log). -
Outputs:
- Embedded Scalar: 1D summary to the core.
- Scalar Context: a small vector that gates the Action-Type head (GLU-style) so macro numeric state can nudge top-level choices.
- Practice:
Prefer discrete buckets/one-hots (even for “scalars”), and rescale (log/sqrt). Treat “scalar” as information to be enumerated, not a raw float to be guessed.
2) Entity Encoder (Multi-head self-attention)
- Inputs:
Per-unit features: type one-hot, boolean attrs, health (sqrt one-hot + ratio), grid coordinates (binary/one-hot). -
Outputs:
- Embedded Entity: global set summary to the core.
- Entity Embeddings: per-entity vectors used as keys by pointer-style heads (unit/target selection).
- Design note:
Coordinates as discrete codes make the relation to the grid explicit, easing downstream fusion.
3) Spatial Encoder (CNN + ResBlocks, with Scatter Connections)
- Idea:
Write (scatter) each Entity Embedding onto the feature map at its (x, y) position; then process with a CNN pyramid. -
Outputs:
- Embedded Spatial: 1D spatial summary to the core.
- Map Skips: multi-scale tensors reserved for the Location head.
- Design note:
Scatter is simple and effective(notably boosts supervised win rate in ablations).
Core (LSTM): temporal glue
Concatenate Embedded Scalar / Entity / Spatial with the previous hidden state to produce the Embedded State—the base autoregressive embedding for the head stack.
- When memory helps:
In POMDPs (fog-of-war, economy/HP trends, motion), the core tracks dynamics that aren’t visible in a single frame. - When memory can be optional:
In fully observed settings, recurrence is less critical.
Heads: an autoregressive control surface
A monolithic action prediction explodes combinatorially. Instead, predict a sequence and update context after each choice:
- Action Type (ResNet/MLP + GLU gate)
Uses Scalar Context to gate the top-level decision (build/move/attack/cast…). - Delay / Queue (MLP)
Often discretized even though it’s scalar—this avoids committing to a specific continuous distribution early and tends to train cleaner. Queue toggles “now vs later”. - Selected Units (Pointer Network)
Query = current AR embedding (optionally extended with a small LSTM when longer selection sequences help).
Keys = Entity Embeddings.
Samples which controllable units will execute the action. - Target Unit (Pointer Network)
Same mechanism, different role—choose the target unit for attack/heal, etc. Skipped for action types without a unit target. - Location (Deconv/ResNet with FiLM-like modulation)
Consumes the Map Skips directly; masks invalid coordinates based on the chosen action type. Feature-wise modulation helps combine multi-scale spatial evidence with the current AR context.
Critical loop. After each head samples, fold the choice back into the AR embedding before moving on. That way, later heads “know” what earlier heads picked even without immediate environment feedback.
Scatter Connections — deeper dive
- Entities contain their positions, but handing those positions to the map by literally writing the entity vectors into the grid makes the geometry explicit and learnable by the CNN pyramid.
Practical patterns that transfer
- Don’t over-bottleneck. Heads should tap the right representation (entity set for pointers, multi-scale maps for coordinates) instead of relying only on a single 1D latent.
- Discretize aggressively and rescale. One-hot/bucket, log/sqrt—this makes signals legible to the net.
- Autoregression everywhere a choice is sequential. “Pick units → maybe a target → maybe a location” benefits from AR context updates between sub-decisions.
- Be explicit. If a piece of information is crucial, route it where it’s needed (e.g., Scalar Context → Action Type) rather than hoping the core latent carries it intact.
Closing
AlphaStar’s architecture is less about exotic tricks and more about matching representation to reality, then letting an AR head stack express complex actions step by step. If your environment is multi-entity and partially observed, this template travels with minimal ceremony.