Skip to content

Commit

Permalink
Model type expressions as regions (#4698)
Browse files Browse the repository at this point in the history
This is a precondition for enabling the new pattern-matching subsystem
to support binding patterns that have `if` expressions in the type
position.

---------

Co-authored-by: Richard Smith <[email protected]>
  • Loading branch information
geoffromer and zygoloid authored Dec 19, 2024
1 parent 95c9634 commit a112cbd
Show file tree
Hide file tree
Showing 504 changed files with 5,699 additions and 8,719 deletions.
96 changes: 79 additions & 17 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ auto Context::AddConvergenceBlockAndPush(Parse::NodeId node_id, int num_blocks)
inst_block_stack().Pop();
}
inst_block_stack().Push(new_block_id);
AddToRegion(new_block_id, node_id);
}

auto Context::AddConvergenceBlockWithArgAndPush(
Expand All @@ -787,6 +788,7 @@ auto Context::AddConvergenceBlockWithArgAndPush(
inst_block_stack().Pop();
}
inst_block_stack().Push(new_block_id);
AddToRegion(new_block_id, node_id);

// Acquire the result value.
SemIR::TypeId result_type_id = insts().Get(*block_args.begin()).type_id();
Expand Down Expand Up @@ -823,30 +825,90 @@ auto Context::SetBlockArgResultBeforeConstantUse(SemIR::InstId select_id,
}
}

auto Context::AddCurrentCodeBlockToFunction(Parse::NodeId node_id) -> void {
CARBON_CHECK(!inst_block_stack().empty(), "no current code block");

if (return_scope_stack().empty()) {
CARBON_CHECK(node_id.is_valid(),
"No current function, but node_id not provided");
TODO(node_id,
auto Context::AddToRegion(SemIR::InstBlockId block_id, SemIR::LocId loc_id)
-> void {
if (region_stack_.empty()) {
TODO(loc_id,
"Control flow expressions are currently only supported inside "
"functions.");
return;
}

if (!inst_block_stack().is_current_block_reachable()) {
// Don't include unreachable blocks in the function.
if (block_id == SemIR::InstBlockId::Unreachable) {
return;
}

auto function_id =
insts()
.GetAs<SemIR::FunctionDecl>(return_scope_stack().back().decl_id)
.function_id;
functions()
.Get(function_id)
.body_block_ids.push_back(inst_block_stack().PeekOrAdd());
region_stack_.AppendToTop(block_id);
}

auto Context::BeginSubpattern() -> void {
inst_block_stack().Push();
PushRegion(inst_block_stack().PeekOrAdd());
}

auto Context::EndSubpatternAsExpr(SemIR::InstId result_id)
-> SemIR::ExprRegionId {
if (region_stack_.PeekArray().size() > 1) {
// End the exit block with a branch to a successor block, whose contents
// will be determined later.
AddInst(SemIR::LocIdAndInst::NoLoc<SemIR::Branch>(
{.target_id = inst_blocks().AddDefaultValue()}));
} else {
// This single-block region will be inserted as a SpliceBlock, so we don't
// need control flow out of it.
}
auto block_id = inst_block_stack().Pop();
CARBON_CHECK(block_id == region_stack_.PeekArray().back());

// TODO: Is it possible to validate that this region is genuinely
// single-entry, single-exit?
return sem_ir().expr_regions().Add(
{.block_ids = PopRegion(), .result_id = result_id});
}

auto Context::EndSubpatternAsEmpty() -> void {
auto block_id = inst_block_stack().Pop();
CARBON_CHECK(block_id == region_stack_.PeekArray().front());
CARBON_CHECK(inst_blocks().Get(block_id).empty());
region_stack_.PopArray();
}

auto Context::InsertHere(SemIR::ExprRegionId region_id) -> SemIR::InstId {
auto region = sem_ir_->expr_regions().Get(region_id);
auto loc_id = insts().GetLocId(region.result_id);
auto exit_block = inst_blocks().Get(region.block_ids.back());
if (region.block_ids.size() == 1) {
// TODO: Is it possible to avoid leaving an "orphan" block in the IR in the
// first two cases?
if (exit_block.size() == 0) {
return region.result_id;
}
if (exit_block.size() == 1) {
inst_block_stack_.AddInstId(exit_block.front());
return region.result_id;
}
return AddInst<SemIR::SpliceBlock>(
loc_id, {.type_id = insts().Get(region.result_id).type_id(),
.block_id = region.block_ids.front(),
.result_id = region.result_id});
}
if (region_stack_.empty()) {
TODO(loc_id,
"Control flow expressions are currently only supported inside "
"functions.");
return SemIR::ErrorInst::SingletonInstId;
}
AddInst(SemIR::LocIdAndInst::NoLoc<SemIR::Branch>(
{.target_id = region.block_ids.front()}));
inst_block_stack_.Pop();
// TODO: this will cumulatively cost O(MN) running time for M blocks
// at the Nth level of the stack. Figure out how to do better.
region_stack_.AppendToTop(region.block_ids);
auto resume_with_block_id =
insts().GetAs<SemIR::Branch>(exit_block.back()).target_id;
CARBON_CHECK(inst_blocks().GetOrEmpty(resume_with_block_id).empty());
inst_block_stack_.Push(resume_with_block_id);
AddToRegion(resume_with_block_id, loc_id);
return region.result_id;
}

auto Context::is_current_position_reachable() -> bool {
Expand Down
88 changes: 70 additions & 18 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class Context {

// Adds an instruction to the current pattern block, returning the produced
// ID.
// TODO: Is it possible to remove this and pattern_block_stack, now that
// we have BeginSubpattern etc. instead?
auto AddPatternInst(SemIR::LocIdAndInst loc_id_and_inst) -> SemIR::InstId {
auto inst_id = AddInstInNoBlock(loc_id_and_inst);
pattern_block_stack_.AddInstId(inst_id);
Expand Down Expand Up @@ -276,6 +278,27 @@ class Context {
return scope_stack().GetCurrentScopeAs<InstT>(sem_ir());
}

// Mark the start of a new single-entry region with the given entry block.
auto PushRegion(SemIR::InstBlockId entry_block_id) -> void {
region_stack_.PushArray();
region_stack_.AppendToTop(entry_block_id);
}

// Add `block_id` to the most recently pushed single-entry region. To preserve
// the single-entry property, `block_id` must not be directly reachable from
// any block outside the region. To ensure the region's blocks are in lexical
// order, this should be called when the first parse node associated with this
// block is handled, or as close as possible.
auto AddToRegion(SemIR::InstBlockId block_id, SemIR::LocId loc_id) -> void;

// Complete creation of the most recently pushed single-entry region, and
// return a list of its blocks.
auto PopRegion() -> llvm::SmallVector<SemIR::InstBlockId> {
llvm::SmallVector<SemIR::InstBlockId> result(region_stack_.PeekArray());
region_stack_.PopArray();
return result;
}

// Adds a `Branch` instruction branching to a new instruction block, and
// returns the ID of the new block. All paths to the branch target must go
// through the current block, though not necessarily through this branch.
Expand All @@ -297,16 +320,18 @@ class Context {

// Handles recovergence of control flow. Adds branches from the top
// `num_blocks` on the instruction block stack to a new block, pops the
// existing blocks, and pushes the new block onto the instruction block stack.
// existing blocks, pushes the new block onto the instruction block stack,
// and adds it to the most recently pushed region.
auto AddConvergenceBlockAndPush(Parse::NodeId node_id, int num_blocks)
-> void;

// Handles recovergence of control flow with a result value. Adds branches
// from the top few blocks on the instruction block stack to a new block, pops
// the existing blocks, and pushes the new block onto the instruction block
// stack. The number of blocks popped is the size of `block_args`, and the
// corresponding result values are the elements of `block_args`. Returns an
// instruction referring to the result value.
// the existing blocks, pushes the new block onto the instruction block
// stack, and adds it to the most recently pushed region. The number of blocks
// popped is the size of `block_args`, and the corresponding result values are
// the elements of `block_args`. Returns an instruction referring to the
// result value.
auto AddConvergenceBlockWithArgAndPush(
Parse::NodeId node_id, std::initializer_list<SemIR::InstId> block_args)
-> SemIR::InstId;
Expand All @@ -322,13 +347,6 @@ class Context {
SemIR::InstId if_true,
SemIR::InstId if_false) -> void;

// Add the current code block to the enclosing function.
// TODO: The node_id is taken for expressions, which can occur in
// non-function contexts. This should be refactored to support non-function
// contexts, and node_id removed.
auto AddCurrentCodeBlockToFunction(
Parse::NodeId node_id = Parse::NodeId::Invalid) -> void;

// Returns whether the current position in the current block is reachable.
auto is_current_position_reachable() -> bool;

Expand Down Expand Up @@ -619,12 +637,46 @@ class Context {

auto global_init() -> GlobalInit& { return global_init_; }

// Marks the start of a region of insts in a pattern context that might
// represent an expression or a pattern.
auto BeginSubpattern() -> void;

// Ends a region started by BeginSubpattern (in stack order), treating it as
// an expression with the given result, and returns the ID of the region. The
// region will not yet have any control-flow edges into or out of it.
auto EndSubpatternAsExpr(SemIR::InstId result_id) -> SemIR::ExprRegionId;

// Ends a region started by BeginSubpattern (in stack order), asserting that
// it was empty.
auto EndSubpatternAsEmpty() -> void;

// TODO: Add EndSubpatternAsPattern, when needed.

// Inserts the given region into the current code block. If the region
// consists of a single block, this will be implemented as a `splice_block`
// inst. Otherwise, this will end the current block with a branch to the entry
// block of the region, and add future insts to a new block which is the
// immediate successor of the region's exit block. As a result, this cannot be
// called more than once for the same region.
auto InsertHere(SemIR::ExprRegionId region_id) -> SemIR::InstId;

auto import_ref_ids() -> llvm::SmallVector<SemIR::InstId>& {
return import_ref_ids_;
}

auto bind_name_cache() -> Map<SemIR::EntityNameId, SemIR::InstId>& {
return bind_name_cache_;
// Map from an AnyBindingPattern inst to precomputed parts of the
// pattern-match SemIR for it.
//
// TODO: Consider putting this behind a narrower API to guard against emitting
// multiple times.
struct BindingPatternInfo {
// The corresponding AnyBindName inst.
SemIR::InstId bind_name_id;
// The region of insts that computes the type of the binding.
SemIR::ExprRegionId type_expr_id;
};
auto bind_name_map() -> Map<SemIR::InstId, BindingPatternInfo>& {
return bind_name_map_;
}

private:
Expand Down Expand Up @@ -738,10 +790,10 @@ class Context {
// FinalizeImportRefBlock() will produce an inst block for them.
llvm::SmallVector<SemIR::InstId> import_ref_ids_;

// Cache of allocated AnyBindName insts, keyed by the entity names they refer
// to. These are allocated while generating the pattern IR, but are emitted
// later as part of the pattern-match IR.
Map<SemIR::EntityNameId, SemIR::InstId> bind_name_cache_;
Map<SemIR::InstId, BindingPatternInfo> bind_name_map_;

// Stack of single-entry regions being built.
ArrayStack<SemIR::InstBlockId> region_stack_;
};

} // namespace Carbon::Check
Expand Down
13 changes: 9 additions & 4 deletions toolchain/check/handle_binding_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ static auto HandleAnyBindingPattern(Context& context, Parse::NodeId node_id,

// TODO: Handle `_` bindings.

SemIR::ExprRegionId type_expr_region_id =
context.EndSubpatternAsExpr(cast_type_inst_id);

// Every other kind of pattern binding has a name.
auto [name_node, name_id] = context.node_stack().PopNameWithNodeId();

Expand Down Expand Up @@ -212,10 +215,6 @@ static auto HandleAnyBindingPattern(Context& context, Parse::NodeId node_id,
context.AddNameToLookup(name_id, bind_id);
auto entity_name_id =
context.insts().GetAs<SemIR::AnyBindName>(bind_id).entity_name_id;
bool inserted = context.bind_name_cache()
.Insert(entity_name_id, bind_id)
.is_inserted();
CARBON_CHECK(inserted);
auto pattern_inst_id = SemIR::InstId::Invalid;
if (is_generic) {
pattern_inst_id =
Expand All @@ -227,6 +226,12 @@ static auto HandleAnyBindingPattern(Context& context, Parse::NodeId node_id,
name_node,
{.type_id = cast_type_id, .entity_name_id = entity_name_id});
}
bool inserted =
context.bind_name_map()
.Insert(pattern_inst_id, {.bind_name_id = bind_id,
.type_expr_id = type_expr_region_id})
.is_inserted();
CARBON_CHECK(inserted);
param_pattern_id = context.AddPatternInst<SemIR::ValueParamPattern>(
node_id,
{
Expand Down
6 changes: 4 additions & 2 deletions toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,9 @@ static auto HandleFunctionDefinitionAfterSignature(
// Create the function scope and the entry block.
context.return_scope_stack().push_back({.decl_id = decl_id});
context.inst_block_stack().Push();
context.PushRegion(context.inst_block_stack().PeekOrAdd());
context.scope_stack().Push(decl_id);
StartGenericDefinition(context);
context.AddCurrentCodeBlockToFunction();

CheckFunctionDefinitionSignature(context, function);

Expand Down Expand Up @@ -441,8 +441,10 @@ auto HandleParseNode(Context& context, Parse::FunctionDefinitionId node_id)
context.return_scope_stack().pop_back();
context.decl_name_stack().PopScope();

// If this is a generic function, collect information about the definition.
auto& function = context.functions().Get(function_id);
function.body_block_ids = context.PopRegion();

// If this is a generic function, collect information about the definition.
FinishGenericDefinition(context, function.generic_id);

return true;
Expand Down
10 changes: 7 additions & 3 deletions toolchain/check/handle_if_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ auto HandleParseNode(Context& context, Parse::IfExprIfId node_id) -> bool {
// Start emitting the `then` block.
context.inst_block_stack().Pop();
context.inst_block_stack().Push(then_block_id);
context.AddCurrentCodeBlockToFunction(node_id);
context.AddToRegion(then_block_id, node_id);

context.node_stack().Push(if_node, else_block_id);
return true;
Expand Down Expand Up @@ -56,13 +56,18 @@ auto HandleParseNode(Context& context, Parse::IfExprThenId node_id) -> bool {

// Start emitting the `else` block.
context.inst_block_stack().Push(else_block_id);
context.AddCurrentCodeBlockToFunction(node_id);
context.AddToRegion(else_block_id, node_id);

context.node_stack().Push(node_id, then_value_id);
return true;
}

auto HandleParseNode(Context& context, Parse::IfExprElseId node_id) -> bool {
if (context.return_scope_stack().empty()) {
context.TODO(node_id,
"Control flow expressions are currently only supported inside "
"functions.");
}
// Alias node_id for if/then/else consistency.
auto& else_node = node_id;

Expand All @@ -84,7 +89,6 @@ auto HandleParseNode(Context& context, Parse::IfExprElseId node_id) -> bool {
if_node, {else_value_id, then_value_id});
context.SetBlockArgResultBeforeConstantUse(chosen_value_id, cond_value_id,
then_value_id, else_value_id);
context.AddCurrentCodeBlockToFunction(node_id);

// Push the result value.
context.node_stack().Push(else_node, chosen_value_id);
Expand Down
6 changes: 3 additions & 3 deletions toolchain/check/handle_if_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ auto HandleParseNode(Context& context, Parse::IfConditionId node_id) -> bool {
// Start emitting the `then` block.
context.inst_block_stack().Pop();
context.inst_block_stack().Push(then_block_id);
context.AddCurrentCodeBlockToFunction();
context.AddToRegion(then_block_id, node_id);

context.node_stack().Push(node_id, else_block_id);
return true;
Expand All @@ -40,7 +40,7 @@ auto HandleParseNode(Context& context, Parse::IfStatementElseId node_id)

// Switch to emitting the `else` block.
context.inst_block_stack().Push(else_block_id);
context.AddCurrentCodeBlockToFunction();
context.AddToRegion(else_block_id, node_id);

context.node_stack().Push(node_id);
return true;
Expand All @@ -56,6 +56,7 @@ auto HandleParseNode(Context& context, Parse::IfStatementId node_id) -> bool {
context.AddInst<SemIR::Branch>(node_id, {.target_id = else_block_id});
context.inst_block_stack().Pop();
context.inst_block_stack().Push(else_block_id);
context.AddToRegion(else_block_id, node_id);
break;
}

Expand All @@ -72,7 +73,6 @@ auto HandleParseNode(Context& context, Parse::IfStatementId node_id) -> bool {
}
}

context.AddCurrentCodeBlockToFunction();
return true;
}

Expand Down
1 change: 1 addition & 0 deletions toolchain/check/handle_let_and_var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static auto HandleIntroducer(Context& context, Parse::NodeId node_id) -> bool {
// Push a bracketing node and pattern block to establish the pattern context.
context.node_stack().Push(node_id);
context.pattern_block_stack().Push();
context.BeginSubpattern();
return true;
}

Expand Down
Loading

0 comments on commit a112cbd

Please sign in to comment.