Skip to content

Commit

Permalink
Mutable variables in dynamic branches prevent full constant folding i…
Browse files Browse the repository at this point in the history
…n partial evaluation (#2089)

This fixes the bug by having partial evaluation more explicitly track
different variable mappings to literals across branches and recombining
those mappings that match (ie: are constant) when all branches are done.
This also includes partial eval and RIR SSA pass fixes to correctly
support immutable and mutable copies of dynamic variables. New test
cases for several combinations of constant folding at partial eval are
included, as well as a new test case confirming RIR SSA fix.

Fixes #2087
  • Loading branch information
swernli authored Jan 9, 2025
1 parent 189523e commit d7f8962
Show file tree
Hide file tree
Showing 14 changed files with 1,206 additions and 393 deletions.
4 changes: 2 additions & 2 deletions compiler/qsc/src/codegen/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1230,8 +1230,8 @@ mod adaptive_ri_profile {
block_2:
br label %block_3
block_3:
%var_3 = phi i64 [0, %block_1], [1, %block_2]
call void @__quantum__rt__int_record_output(i64 %var_3, i8* null)
%var_4 = phi i64 [0, %block_1], [1, %block_2]
call void @__quantum__rt__int_record_output(i64 %var_4, i8* null)
ret void
}
Expand Down
52 changes: 30 additions & 22 deletions compiler/qsc_partial_eval/src/evaluation_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ pub struct Scope {
pub args_value_kind: Vec<ValueKind>,
/// The classical environment of the callable, which holds values corresponding to local variables.
pub env: Env,
// Consider optimizing `hybrid_vars` and `mutable_vars` by removing them and enlightening the evaluator on how to
// properly handle `Value::Var`, which could be either `static` or `dynamic`.
/// Map that holds the values of local variables.
hybrid_vars: FxHashMap<LocalVarId, Value>,
/// Maps variable IDs to mutable variables, which contain their current kind.
mutable_vars: FxHashMap<VariableId, MutableKind>,
/// Maps variable IDs to static literal values, if any.
static_vars: FxHashMap<VariableId, Literal>,
/// Number of currently active blocks (starting from where this scope was created).
active_block_count: usize,
}
Expand Down Expand Up @@ -161,18 +159,37 @@ impl Scope {
env,
active_block_count: 1,
hybrid_vars,
mutable_vars: FxHashMap::default(),
static_vars: FxHashMap::default(),
}
}

/// Gets a mutable variable.
pub fn find_mutable_kind(&self, var_id: VariableId) -> Option<&MutableKind> {
self.mutable_vars.get(&var_id)
/// Gets the static literal value for the given variable.
pub fn get_static_value(&self, var_id: VariableId) -> Option<&Literal> {
self.static_vars.get(&var_id)
}

/// Gets a mutable mutable variable.
pub fn find_mutable_var_mut(&mut self, var_id: VariableId) -> Option<&mut MutableKind> {
self.mutable_vars.get_mut(&var_id)
/// Removes the static literal value for the given variable.
pub fn remove_static_value(&mut self, var_id: VariableId) {
self.static_vars.remove(&var_id);
}

/// Clones the static literal value mappings, which allows callers to cache the mappings across branches.
pub fn clone_static_var_mappings(&self) -> FxHashMap<VariableId, Literal> {
self.static_vars.clone()
}

/// Sets the static literal value mappings to the given mapping, overwriting any existing mappings.
pub fn set_static_var_mappings(&mut self, static_vars: FxHashMap<VariableId, Literal>) {
self.static_vars = static_vars;
}

/// Keeps only the static literal value mappings that are also present in the provided other mapping.
pub fn keep_matching_static_var_mappings(
&mut self,
other_mappings: &FxHashMap<VariableId, Literal>,
) {
self.static_vars
.retain(|var_id, lit| Some(&*lit) == other_mappings.get(var_id));
}

/// Gets the value of a hybrid local variable.
Expand All @@ -197,11 +214,8 @@ impl Scope {
}

// Insert a variable into the mutable variables map.
pub fn insert_mutable_var(&mut self, var_id: VariableId, mutable_kind: MutableKind) {
let Entry::Vacant(vacant) = self.mutable_vars.entry(var_id) else {
panic!("mutable variable should not already exist");
};
vacant.insert(mutable_kind);
pub fn insert_static_var_mapping(&mut self, var_id: VariableId, literal: Literal) {
self.static_vars.insert(var_id, literal);
}

/// Determines whether we are currently evaluating a branch within the scope.
Expand Down Expand Up @@ -317,9 +331,3 @@ fn map_eval_value_to_value_kind(value: &Value) -> ValueKind {
| Value::String(_) => ValueKind::Element(RuntimeKind::Static),
}
}

#[derive(Clone, Copy, Debug)]
pub enum MutableKind {
Static(Literal),
Dynamic,
}
96 changes: 72 additions & 24 deletions compiler/qsc_partial_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod management;

use core::panic;
use evaluation_context::{
Arg, BlockNode, BranchControlFlow, EvalControlFlow, EvaluationContext, MutableKind, Scope,
Arg, BlockNode, BranchControlFlow, EvalControlFlow, EvaluationContext, Scope,
};
use management::{QuantumIntrinsicsChecker, ResourceManager};
use miette::Diagnostic;
Expand Down Expand Up @@ -51,7 +51,7 @@ use qsc_rir::{
builder,
rir::{
self, Callable, CallableId, CallableType, ConditionCode, Instruction, Literal, Operand,
Program,
Program, VariableId,
},
};
use rustc_hash::FxHashMap;
Expand Down Expand Up @@ -204,8 +204,19 @@ impl<'a> PartialEvaluator<'a> {
fn bind_value_to_ident(&mut self, mutability: Mutability, ident: &Ident, value: Value) {
// We do slightly different things depending on the mutability of the identifier.
match mutability {
Mutability::Immutable => self.bind_value_to_immutable_ident(ident, value),
Mutability::Mutable => self.bind_value_to_mutable_ident(ident, value),
Mutability::Immutable => {
let current_scope = self.eval_context.get_current_scope();
if matches!(value, Value::Var(var) if current_scope.get_static_value(var.id.into()).is_none())
{
// An immutable identifier is being bound to a dynamic value, so treat the identifier as mutable.
// This allows it to represent a point-in-time copy of the mutable value during evaluation.
self.bind_value_to_mutable_ident(ident, value);
} else {
// The value is static, so bind it to the classical map.
self.bind_value_to_immutable_ident(ident, value);
}
}
};
}

Expand All @@ -226,11 +237,13 @@ impl<'a> PartialEvaluator<'a> {
}

// Always bind the value to the hybrid map but do it differently depending of the value type.
if let Some((var_id, mutable_kind)) = self.try_create_mutable_variable(ident.id, &value) {
// Keep track of whether the mutable variable is static or dynamic.
self.eval_context
.get_current_scope_mut()
.insert_mutable_var(var_id, mutable_kind);
if let Some((var_id, literal)) = self.try_create_mutable_variable(ident.id, &value) {
// If the variable maps to a know static literal, track that mapping.
if let Some(literal) = literal {
self.eval_context
.get_current_scope_mut()
.insert_static_var_mapping(var_id, literal);
}
} else {
self.bind_value_in_hybrid_map(ident, value);
}
Expand Down Expand Up @@ -1510,6 +1523,8 @@ impl<'a> PartialEvaluator<'a> {
};

// Evaluate the body expression.
// First, we cache the current static variable mappings so that we can restore them later.
let cached_mappings = self.clone_current_static_var_map();
let if_true_branch_control_flow =
self.eval_expr_if_branch(body_expr_id, continuation_block_node_id, maybe_if_expr_var)?;
let if_true_block_id = match if_true_branch_control_flow {
Expand All @@ -1519,16 +1534,28 @@ impl<'a> PartialEvaluator<'a> {

// Evaluate the otherwise expression (if any), and determine the block to branch to if the condition is false.
let if_false_block_id = if let Some(otherwise_expr_id) = otherwise_expr_id {
// Cache the mappings after the true block so we can compare afterwards.
let post_if_true_mappings = self.clone_current_static_var_map();
// Restore the cached mappings from before evaluating the true block.
self.overwrite_current_static_var_map(cached_mappings);
let if_false_branch_control_flow = self.eval_expr_if_branch(
otherwise_expr_id,
continuation_block_node_id,
maybe_if_expr_var,
)?;
// Only keep the static mappings that are the same in both blocks; when they are different,
// the variable is no longer static across the if expression.
self.keep_matching_static_var_mappings(&post_if_true_mappings);
match if_false_branch_control_flow {
BranchControlFlow::Block(block_id) => block_id,
BranchControlFlow::Return(value) => return Ok(EvalControlFlow::Return(value)),
}
} else {
// Only keep the static mappings that are the same after the true block as before; when they are different,
// the variable is no longer static across the if expression.
self.keep_matching_static_var_mappings(&cached_mappings);

// Since there is no otherwise block, we branch to the continuation block.
continuation_block_node_id
};

Expand Down Expand Up @@ -1814,9 +1841,7 @@ impl<'a> PartialEvaluator<'a> {
// the variable if it is static at this moment.
if let Value::Var(var) = bound_value {
let current_scope = self.eval_context.get_current_scope();
if let Some(MutableKind::Static(literal)) =
current_scope.find_mutable_kind(var.id.into())
{
if let Some(literal) = current_scope.get_static_value(var.id.into()) {
map_rir_literal_to_eval_value(*literal)
} else {
bound_value.clone()
Expand Down Expand Up @@ -2229,7 +2254,7 @@ impl<'a> PartialEvaluator<'a> {
&mut self,
local_var_id: LocalVarId,
value: &Value,
) -> Option<(rir::VariableId, MutableKind)> {
) -> Option<(rir::VariableId, Option<Literal>)> {
// Check if we can create a mutable variable for this value.
let var_ty = try_get_eval_var_type(value)?;

Expand All @@ -2249,13 +2274,13 @@ impl<'a> PartialEvaluator<'a> {
let store_ins = Instruction::Store(value_operand, rir_var);
self.get_current_rir_block_mut().0.push(store_ins);

// Create a mutable variable.
let mutable_kind = match value_operand {
Operand::Literal(literal) => MutableKind::Static(literal),
Operand::Variable(_) => MutableKind::Dynamic,
// Create a mutable variable, mapping it to the static value if any.
let static_value = match value_operand {
Operand::Literal(literal) => Some(literal),
Operand::Variable(_) => None,
};

Some((var_id, mutable_kind))
Some((var_id, static_value))
}

fn get_or_insert_callable(&mut self, callable: Callable) -> CallableId {
Expand Down Expand Up @@ -2625,14 +2650,16 @@ impl<'a> PartialEvaluator<'a> {

// If this is a mutable variable, make sure to update whether it is static or dynamic.
let current_scope = self.eval_context.get_current_scope_mut();
if matches!(rhs_operand, Operand::Variable(_))
|| current_scope.is_currently_evaluating_branch()
{
if let Some(mutable_kind) = current_scope.find_mutable_var_mut(rir_var.variable_id)
{
*mutable_kind = MutableKind::Dynamic;
match rhs_operand {
Operand::Literal(literal) => {
// The variable maps to a static literal here, so track that literal value.
current_scope.insert_static_var_mapping(rir_var.variable_id, literal);
}
}
Operand::Variable(_) => {
// The variable is not known to be some literal value, so remove the static mapping.
current_scope.remove_static_value(rir_var.variable_id);
}
};
} else {
// Verify that we are not updating a value that does not have a backing variable from a dynamic branch
// because it is unsupported.
Expand Down Expand Up @@ -2918,6 +2945,27 @@ impl<'a> PartialEvaluator<'a> {
_ => panic!("{value} cannot be mapped to a RIR operand"),
}
}

fn clone_current_static_var_map(&self) -> FxHashMap<VariableId, Literal> {
self.eval_context
.get_current_scope()
.clone_static_var_mappings()
}

fn overwrite_current_static_var_map(&mut self, static_vars: FxHashMap<VariableId, Literal>) {
self.eval_context
.get_current_scope_mut()
.set_static_var_mappings(static_vars);
}

fn keep_matching_static_var_mappings(
&mut self,
other_mappings: &FxHashMap<VariableId, Literal>,
) {
self.eval_context
.get_current_scope_mut()
.keep_matching_static_var_mappings(other_mappings);
}
}

fn eval_un_op_with_literals(un_op: UnOp, value: Value) -> Value {
Expand Down
Loading

0 comments on commit d7f8962

Please sign in to comment.