SimplexRL · Architecture Note

The Solver-Feature Attention Model

A pivot-selection policy that scores entering-variable candidates by combining a graph encoding of the LP with explicit, hand-derived simplex signals — reduced costs, basis status, pricing scores, and rankings — fused through attention.

Audience  DL fundamentals + general SimplexRL familiarity Code  src/simplexrl/models/solver_feature_attention/ Selected by  [model].architecture = "solver_feature_attention"

1Background and context

The primal simplex algorithm solves a linear program by walking from vertex to vertex of the feasible polytope. At each iteration it must choose an entering variable: a non-basic variable to bring into the basis, which determines the edge the algorithm traverses next. Classical rules make this choice with a fixed heuristic — Dantzig's rule picks the variable with the most negative reduced cost; Devex and steepest-edge rules normalize that by an approximate column norm. The choice is consequential: it does not change the optimum, but it can change the number of iterations needed to reach it by a large factor.

SimplexRL frames this choice as a reinforcement learning problem. An agent observes the solver's internal state at each pivot and selects an index into the list of eligible entering candidates; the environment wraps the HiGHS LP solver and returns a reward that penalizes each iteration taken (so fewer pivots is better). The goal is a learned pivot rule that outperforms the classical heuristics on a target distribution of problems.

The project maintains four peer model architectures for this task, all sharing the same environment, action space, and training cores. They differ in how they turn the observation into candidate scores:

The defining idea of solver-feature attention is to lean explicitly on quantities the simplex solver already computes. Rather than asking a network to rediscover, say, the notion of a reduced cost or a pricing ratio from raw problem structure, the model feeds those quantities in directly as per-candidate features, alongside a learned graph embedding of the LP. It then uses attention to let candidates compare themselves against one another and against the constraint rows before scoring. The bet is that injecting solver-native signal makes the policy both easier to learn and more transferable; the cost is a heavier dependence on the semantics and quality of those signals.

Scope

This note describes the model as wired in configs/solver-feature-attention/ppo.toml. Concrete dimension numbers below come from that config; the dataclass defaults differ in places (e.g. embedding width defaults to 32 but the config uses 64) and are called out where useful. Optional features are all disabled in the default config and are described separately in §3.

2High-level overview

The model factors cleanly into two encoding stages followed by lightweight scoring heads. This split mirrors the structure of the problem: the LP itself is fixed for an entire episode, while the solver state changes at every pivot.

Stage 1 · once per problem

Problem encoder

A graph neural network reads the static LP — objective, bounds, and the constraint matrix — as a bipartite variable–constraint graph and produces an embedding for every variable and constraint row, plus one pooled problem-level vector. Because the LP does not change within an episode, this stage is computed once and cached.

Stage 2 · every pivot

State encoder

For the current iteration, the encoder builds candidate tokens (one per entering candidate) and row tokens (one per constraint). Candidate tokens carry a 21-dimensional vector of explicit solver features. Attention then lets candidates attend to rows, to each other, and a learned state token pools the whole scene into a single iteration summary.

Stage 3 · every pivot

Heads

The policy head scores each candidate from its token, the iteration summary, and the problem vector, producing masked logits over the action space. The value head pools the scene into a scalar state value for the RL baseline. An optional Q-head produces per-candidate action values when enabled (required for IQL).

STAGE 1 · PER PROBLEM (CACHED) Static LP objective · bounds · matrix A Bipartite GNN ProblemEncoderGNN · 4× GATv2 var · row · problem embeddings STAGE 2 · PER PIVOT — STATE ENCODER Row tokens rows + slack + basic var + basic solution + globals Candidate tokens gathered var embedding + 21-dim solver feature vector Row self-attention Candidates cross-attend to rows (×2) Candidate self-attention Memory = [ candidate tokens ‖ row tokens ] with combined padding mask Learned state token cross-attends to memory (×2) → RMSNorm iteration_embedding candidate embeddings (refined) iteration + row embeddings STAGE 3 · HEADS Policy head masked logits over candidates Value head scalar V(s) via attention pools Q-head (optional) per-candidate Q(s,a), dueling
The two-stage encoder feeds three scoring heads. Blue = static problem path (cached); green = per-pivot state path; purple = heads. Dashed boxes/edges are optional or context links.

3Input features

The model consumes a dictionary of padded observation tensors from the environment. They fall into four groups by what they describe and how they are indexed. Throughout, B is the batch size, and the padded sizes come from the [padding] block — in the reference PPO config, max_vars = 160, max_constraints = 224, and max_candidates = 384. Many per-variable tensors span the full simplex variable space, structural variables followed by row/slack variables, so their length is max_vars + max_constraints.

3.1 Static problem data (Stage 1 input)

These describe the LP itself and are read by the GNN as graph node and edge features. They are fetched by problem index from the ProblemGraphStore rather than re-sent every step.

FeatureIndexingLP / simplex meaning
objective_coeffsper variableThe cost vector c: the linear objective being minimized.
var_lower_bounds
var_upper_bounds
per variableBox bounds on each structural variable; finiteness flags are derived as graph features.
constraint_matrixvariable × constraintThe matrix A, supplied as the sparse edge set of the bipartite graph (a non-zero A[i,j] becomes an edge between variable j and row i, carrying the coefficient as an edge feature).
num_vars, num_constraintsscalarTrue problem dimensions; used to unpad and to remap HiGHS indices into the unified layout.

3.2 Dynamic solver state (Stage 2 input)

These change at every pivot and are the heart of what makes this architecture "solver-feature" oriented. The reduced-cost-style vectors use HiGHS internal ordering: the first num_vars entries are structural variables, the next num_constraints are row (slack) variables.

FeatureIndexingMeaning
reduced_costsper variableMarginal objective change per unit increase of a non-basic variable — the primary pivot signal. Negative reduced cost (for minimization) marks an improving direction.
dual_infeasibilityper variableHow much each variable violates the dual-feasibility (optimality) condition; defines the candidate set.
edge_weightper variableHiGHS pricing weight (the Devex / steepest-edge denominator). The model forms a pricing score = reduced_cost² / edge_weight from it.
nonbasic_moveper variableDirection a non-basic variable would move: -1 toward its lower bound, +1 toward its upper, 0 free/degenerate.
basis_statusper variableFive-class status (basic / nonbasic-at-lower / -at-upper / -free / superbasic); fed as a one-hot.
basic_solutionper constraintThe current value of the basic variable in each row — the present vertex. One value per row, since each row has exactly one basic variable.
basic_indexper constraintWhich variable is basic in each row; used to attach each basic variable's learned embedding back onto its row token.
entering_candidatesper candidateVariable indices of the non-basic variables eligible to enter — the action space. The policy selects one index.
action_maskper candidateBoolean validity mask separating real candidates from padding; gates attention and masks logits.
incumbent_indexscalarThe candidate Dantzig's rule would pick (max reduced-cost magnitude) — a baseline marker, exposed as a per-candidate flag.
iteration, solve_phase
num_candidates
scalarIteration counter, primal phase (0 init / 1 feasibility / 2 optimality), and the count of valid candidates.

3.3 The 21-dimensional candidate feature vector

The model's signature step is to compress the dynamic state into one explicit feature vector per candidate. These 21 numbers are built inside the state encoder (features.py / encoders.py), projected by a linear layer into the embedding width, passed through an MLP, and added to the candidate's gathered graph embedding. The vector is:

#FeatureDefinition
1–2reduced cost & |reduced cost|The candidate's reduced cost, signed and magnitude.
3dual infeasibilityThe candidate's optimality-condition violation.
4log edge weightlog1p of the pricing weight.
5log pricing scorelog1p(reduced_cost² / edge_weight) — the steepest-edge-style merit.
6nonbasic moveMovement direction, clamped to [-1, 1].
7–8is-structural / is-slackWhether the candidate is a structural or a slack variable.
9–10dual rank & pricing rankEach candidate's descending rank (normalized to [0,1]) by dual infeasibility and by pricing score, among the valid candidates. Enabled by use_rank_features.
11incumbent flag1.0 if this candidate is the Dantzig choice.
12–16basis-status one-hotFive channels for the candidate's basis status.
17–21problem-context scalarslog1p(iteration), phase/2, candidate-count fraction, and log-normalized variable / constraint counts — broadcast to every candidate.
Why rankings matter

Several features (dual rank, pricing rank) are relative — a candidate's standing among its peers this iteration, not an absolute value. This gives the policy a scale-free view of the candidate set and is exactly the kind of comparison classical pricing rules make implicitly when they take an argmax.

3.4 Optional derived features

Beyond the base set, the environment can compute additional derived features, registered in env/optional_features.py and toggled per run under [env.optional_features]. They span per-variable, per-row, per-candidate, and scalar granularities, and range from cheap trajectory-memory signals (e.g. iterations since a variable last changed basis status, two-cycle detectors) to richer outputs of a HiGHS ratio-test analyzer (predicted step lengths, degeneracy and tie counts, predicted leaving variable, tableau-column norms). When enabled, the state encoder fuses each family into the appropriate token stream via dedicated fusion modules. All optional features are disabled in the default config; the 21-dimensional vector above is the standard candidate input.

4Stage 1 — Problem encoder

The problem encoder (SolverFeatureProblemEncoder) turns the static LP into reusable embeddings. It is built around ProblemEncoderGNN, the bipartite graph backbone shared with the transformer-gnn architecture.

4.1 The bipartite graph

The LP becomes a graph with three node types — structural variables, slack variables, and constraint rows — connected by edges wherever the constraint matrix has a non-zero. Each node type gets its own input projection from an 8-dimensional raw feature vector (objective and bound information for variables; row statistics for constraints), and edges carry a 6-dimensional feature (the coefficient, its sign and magnitude, an identity indicator, and local sparsity statistics).

4.2 Graph attention layers

The backbone applies a stack of GATv2Conv graph-attention layers (4 in the reference config), each followed by a ReLU and dropout, with residual connections after the first layer and a final RMSNorm. Message passing along the bipartite edges lets each variable embedding accumulate information about the constraints it participates in, and vice versa — so a variable's representation reflects its structural role in the LP, not just its own coefficients.

# ProblemEncoderGNN.forward — schematic
x  = relu(type_specific_init(node_features))      # per node-type linear
ea = relu(edge_init(edge_features))
for i, layer in enumerate(gnn_layers):           # 4× GATv2Conv
    h = dropout(relu(layer(x, edge_index, ea)))
    x = x + h if i > 0 else h            # residual after layer 0
x = rms_norm(x)                                   # (total_nodes, d)

4.3 Outputs

The encoder repackages the node embeddings into padded, masked tensors and pools a single problem vector:

Because these depend only on the LP, the training core caches them and reuses them across every pivot of the episode (and across the trajectories of a GRPO group, which all share one problem).

5Stage 2 — State encoder

The state encoder (SolverFeatureStateEncoder) is where the per-iteration computation happens. It assembles two token sequences — rows and candidates — enriches them, lets them attend to one another, and distills the result into a single iteration summary. All attention respects padding masks so that padded slots never influence real ones.

5.1 Building row tokens

Each constraint row starts from its GNN embedding and is augmented to reflect the current basis. To the row embedding the encoder adds: the slack variable's embedding, the embedding of whichever variable is currently basic in that row (gathered through basic_index), an MLP encoding of the row's basic_solution value, and a broadcast of a 5-dimensional global feature vector (iteration, phase, candidate count, and the two problem-size scalars). The resulting row tokens then pass through one round of self-attention, so rows can share context about the current vertex.

5.2 Building candidate tokens

Each entering candidate's token is the sum of two pieces: its gathered variable embedding from Stage 1 (looked up via the candidate's variable index) and an MLP encoding of the 21-dimensional solver feature vector from §3.3. This is the fusion point where learned LP structure meets explicit solver signal.

5.3 Attention: candidates against rows, then against each other

The candidate tokens are refined in two attention phases:

  1. Cross-attention to rows (2 layers): each candidate queries the row tokens, gathering information about how it relates to the current basis and constraints — a learned analogue of looking down a column of the tableau.
  2. Self-attention among candidates (1–2 layers): candidates compare directly against one another, which is what lets the policy reason about relative merit rather than scoring each candidate in isolation.

5.4 The state token

Finally, a single learned state token — seeded by the global feature encoding — cross-attends (2 layers) over a combined memory of all candidate tokens and all row tokens. After an RMSNorm, this becomes the iteration_embedding: a fixed-size (B, d) summary of everything the solver knows at this pivot. The encoder returns the refined candidate embeddings, the row embeddings, this iteration embedding, and the candidate mask.

# SolverFeatureStateEncoder.forward — schematic
row_tok  = row_emb + slack_emb + basic_var_emb + enc(basic_solution) + globals
row_tok  = row_self_attention(row_tok)                       # (B, max_cons, d)

cand_tok = gather(var_emb, entering_candidates) \
         + cand_feature_mlp(candidate_features_21d)         # (B, max_cand, d)
for blk in candidate_row_cross_attn:   cand_tok = blk(cand_tok, row_tok)
cand_tok = candidate_self_attention(cand_tok)

memory   = concat([cand_tok, row_tok], dim=1)
state    = state_token_seed + globals
for blk in state_cross_attn:          state = blk(state, memory)
iteration_embedding = rms_norm(state.squeeze(1))            # (B, d)

6Output heads

6.1 Policy head

The policy head (SolverFeaturePolicyHead) scores each candidate independently but conditioned on global context. It projects the iteration embedding and problem embedding into a shared context vector, broadcasts it across candidates, and forms a feature-rich score input by concatenating the candidate embedding, the context, and their element-wise product (a simple bilinear interaction). An MLP maps this to one logit per candidate; invalid positions are set to -inf so the downstream masked categorical never selects padding.

context = state_proj(iteration_emb) + problem_proj(problem_emb)     # (B, d)
context = context.unsqueeze(1).expand_as(cand_emb)                  # (B, max_cand, d)
scores  = mlp(concat([cand_emb, context, cand_emb * context]))      # (B, max_cand)
logits  = scores.masked_fill(~action_mask, -inf)

6.2 Value head

The value head (SolverFeatureValueHead) estimates the scalar state value used as the RL baseline. It pools the candidate embeddings and the row embeddings into two fixed vectors using small learned single-query attention pools (each a masked, softmax-weighted average), concatenates them with the problem and iteration embeddings, and runs the result through an MLP to a single scalar. Using attention pooling rather than a plain mean lets the value estimate concentrate on the rows and candidates that matter most at this vertex.

6.3 Optional Q-head

When [model.q_head].enabled = true — required for offline IQL — the model adds SolverFeatureQHead, producing a Q-value per candidate. It builds a richer interaction (candidate, state, and problem projections plus their products) scored by an MLP, and by default uses a dueling decomposition: a state-value term plus a mean-centered per-candidate advantage, recombined under the action mask. This head is disabled in the default PPO config; the value head's V(s) is what PPO and A2C use.

7End-to-end data flow

The training core _SolverFeatureAttentionCore wires the pieces together behind the uniform interface every architecture implements (policy_and_value, policy, value, and — when enabled — a Q path). A single forward pass for one observation runs:

StepOperationProduces
1encode_problem(problem_index) — cachedvariable, row, problem embeddings (B, ·, d)
2encode_state(problem_enc, **obs)candidate & row embeddings, iteration_embedding
3policy_head(…)masked logits (B, max_cand)
4value_head(…)state value (B,)
5q_head(…) (optional)per-candidate Q (B, max_cand)

The logits feed a masked categorical distribution; sampling from it yields the index into entering_candidates that the environment translates back into a HiGHS variable index to perform the pivot. The (B, max_cand) shape means the network always emits a score for every padded slot too, but masking guarantees only valid candidates carry probability mass.

Embedding width d = 64 GNN layers = 4 Attention heads = 4 candidate cross / self = 2 / 2 state-token layers = 2 candidate features = 21-d

Reference values from configs/solver-feature-attention/ppo.toml. Dataclass defaults are smaller (width 32, single-layer row/candidate-self attention); the config overrides them.

8Design notes & positioning

It is worth being explicit about what distinguishes this architecture from its siblings, since all four share the environment and several share the GNN backbone.

The trade-offs follow directly from these choices. Supplying solver-native signal should shorten what the policy has to learn and may transfer better across problem families, because reduced costs and pricing scores mean the same thing everywhere. The flip side is a tighter coupling to the semantics and numerical quality of those inputs — the model is only as well-grounded as the reduced_costs, edge_weight, and (when enabled) ratio-test features it is handed. Whether that exchange is favorable on the project's target distribution is an empirical question, and is exactly what comparative training runs across the four architectures are meant to settle.


Source of record: src/simplexrl/models/solver_feature_attention/ (config, encoders, features, heads), src/simplexrl/models/common/problem_encoder_gnn.py (shared GNN backbone), src/simplexrl/train/cores/solver_feature_attention.py (training-core wiring), and configs/solver-feature-attention/ppo.toml (reference hyperparameters). This note describes the model as of the current searchpi branch; if the code and this document disagree, the code wins.