Skip to content

Commit

Permalink
Basic support for implementing and using a parameterized interface. (#…
Browse files Browse the repository at this point in the history
…4203)

The main change here is to form a specific when checking an interface
function against an impl function, instead of just substituting the
`Self` type.
  • Loading branch information
zygoloid authored Aug 9, 2024
1 parent b2a13af commit 4a21b6a
Show file tree
Hide file tree
Showing 54 changed files with 998 additions and 266 deletions.
15 changes: 5 additions & 10 deletions toolchain/check/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
#include "toolchain/check/function.h"

#include "toolchain/check/merge.h"
#include "toolchain/check/subst.h"
#include "toolchain/sem_ir/ids.h"

namespace Carbon::Check {

auto CheckFunctionTypeMatches(Context& context,
const SemIR::Function& new_function,
const SemIR::Function& prev_function,
Substitutions substitutions, bool check_syntax)
-> bool {
SemIR::SpecificId prev_specific_id,
bool check_syntax) -> bool {
if (!CheckRedeclParamsMatch(context, DeclParams(new_function),
DeclParams(prev_function), substitutions,
DeclParams(prev_function), prev_specific_id,
check_syntax)) {
return false;
}
Expand All @@ -25,16 +24,12 @@ auto CheckFunctionTypeMatches(Context& context,
// use it here.
auto new_return_type_id =
new_function.GetDeclaredReturnType(context.sem_ir());
auto prev_return_type_id = prev_function.GetDeclaredReturnType(
context.sem_ir(), SemIR::SpecificId::Invalid);
auto prev_return_type_id =
prev_function.GetDeclaredReturnType(context.sem_ir(), prev_specific_id);
if (new_return_type_id == SemIR::TypeId::Error ||
prev_return_type_id == SemIR::TypeId::Error) {
return false;
}
if (prev_return_type_id.is_valid()) {
prev_return_type_id =
SubstType(context, prev_return_type_id, substitutions);
}
if (!context.types().AreEqualAcrossDeclarations(new_return_type_id,
prev_return_type_id)) {
CARBON_DIAGNOSTIC(
Expand Down
14 changes: 7 additions & 7 deletions toolchain/check/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ struct SuspendedFunction {
};

// Checks that `new_function` has the same parameter types and return type as
// `prev_function`, applying the specified set of substitutions to the
// previous function. Prints a suitable diagnostic and returns false if not.
auto CheckFunctionTypeMatches(Context& context,
const SemIR::Function& new_function,
const SemIR::Function& prev_function,
Substitutions substitutions, bool check_syntax)
-> bool;
// `prev_function`, or if `prev_function_id` is specified, a specific version of
// `prev_function`. Prints a suitable diagnostic and returns false if not.
auto CheckFunctionTypeMatches(
Context& context, const SemIR::Function& new_function,
const SemIR::Function& prev_function,
SemIR::SpecificId prev_specific_id = SemIR::SpecificId::Invalid,
bool check_syntax = true) -> bool;

// Checks that the return type of the specified function is complete, issuing an
// error if not. This computes the return slot usage for the function if
Expand Down
3 changes: 1 addition & 2 deletions toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ static auto MergeFunctionRedecl(Context& context, SemIRLoc new_loc,
SemIR::ImportIRId prev_import_ir_id) -> bool {
auto& prev_function = context.functions().Get(prev_function_id);

if (!CheckFunctionTypeMatches(context, new_function, prev_function, {},
/*check_syntax=*/true)) {
if (!CheckFunctionTypeMatches(context, new_function, prev_function)) {
return false;
}

Expand Down
113 changes: 84 additions & 29 deletions toolchain/check/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
#include "toolchain/check/function.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/import_ref.h"
#include "toolchain/check/subst.h"
#include "toolchain/diagnostics/diagnostic_emitter.h"
#include "toolchain/sem_ir/generic.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/impl.h"
#include "toolchain/sem_ir/inst.h"
Expand All @@ -27,13 +28,59 @@ static auto NoteAssociatedFunction(Context& context,
builder.Note(function.decl_id, ImplAssociatedFunctionHere, function.name_id);
}

// Gets the self specific of a generic declaration that is an interface member,
// given a specific for an enclosing generic, plus a type to use as `Self`.
static auto GetSelfSpecificForInterfaceMemberWithSelfType(
Context& context, SemIR::SpecificId enclosing_specific_id,
SemIR::GenericId generic_id, SemIR::TypeId self_type_id)
-> SemIR::SpecificId {
const auto& generic = context.generics().Get(generic_id);
auto bindings = context.inst_blocks().Get(generic.bindings_id);

llvm::SmallVector<SemIR::InstId> arg_ids;
arg_ids.reserve(bindings.size());

// Start with the enclosing arguments.
if (enclosing_specific_id.is_valid()) {
auto enclosing_specific_args_id =
context.specifics().Get(enclosing_specific_id).args_id;
auto enclosing_specific_args =
context.inst_blocks().Get(enclosing_specific_args_id);
arg_ids.assign(enclosing_specific_args.begin(),
enclosing_specific_args.end());
}

// Add the `Self` argument.
CARBON_CHECK(
context.entity_names()
.Get(context.insts()
.GetAs<SemIR::BindSymbolicName>(bindings[arg_ids.size()])
.entity_name_id)
.name_id == SemIR::NameId::SelfType)
<< "Expected a Self binding, found "
<< context.insts().Get(bindings[arg_ids.size()]);
arg_ids.push_back(context.types().GetInstId(self_type_id));

// Take any trailing argument values from the self specific.
// TODO: If these refer to outer arguments, for example in their types, we may
// need to perform extra substitutions here.
auto self_specific_args = context.inst_blocks().Get(
context.specifics().Get(generic.self_specific_id).args_id);
for (auto arg_id : self_specific_args.drop_front(arg_ids.size())) {
arg_ids.push_back(context.constant_values().GetConstantInstId(arg_id));
}

auto args_id = context.inst_blocks().AddCanonical(arg_ids);
return MakeSpecific(context, generic_id, args_id);
}

// Checks that `impl_function_id` is a valid implementation of the function
// described in the interface as `interface_function_id`. Returns the value to
// put into the corresponding slot in the witness table, which can be
// `BuiltinError` if the function is not usable.
static auto CheckAssociatedFunctionImplementation(
Context& context, SemIR::FunctionId interface_function_id,
SemIR::InstId impl_decl_id, Substitutions substitutions) -> SemIR::InstId {
Context& context, SemIR::FunctionType interface_function_type,
SemIR::InstId impl_decl_id, SemIR::TypeId self_type_id) -> SemIR::InstId {
auto impl_function_decl =
context.insts().TryGetAs<SemIR::FunctionDecl>(impl_decl_id);
if (!impl_function_decl) {
Expand All @@ -42,19 +89,32 @@ static auto CheckAssociatedFunctionImplementation(
SemIR::NameId);
auto builder = context.emitter().Build(
impl_decl_id, ImplFunctionWithNonFunction,
context.functions().Get(interface_function_id).name_id);
NoteAssociatedFunction(context, builder, interface_function_id);
context.functions().Get(interface_function_type.function_id).name_id);
NoteAssociatedFunction(context, builder,
interface_function_type.function_id);
builder.Emit();

return SemIR::InstId::BuiltinError;
}

// Map from the specific for the function type to the specific for the
// function signature. The function signature may have additional generic
// parameters.
auto interface_function_specific_id =
GetSelfSpecificForInterfaceMemberWithSelfType(
context, interface_function_type.specific_id,
context.functions()
.Get(interface_function_type.function_id)
.generic_id,
self_type_id);

// TODO: This should be a semantic check rather than a syntactic one. The
// functions should be allowed to have different signatures as long as we can
// synthesize a suitable thunk.
if (!CheckFunctionTypeMatches(
context, context.functions().Get(impl_function_decl->function_id),
context.functions().Get(interface_function_id), substitutions,
context.functions().Get(interface_function_type.function_id),
interface_function_specific_id,
/*check_syntax=*/false)) {
return SemIR::InstId::BuiltinError;
}
Expand All @@ -63,18 +123,17 @@ static auto CheckAssociatedFunctionImplementation(

// Builds a witness that the specified impl implements the given interface.
static auto BuildInterfaceWitness(
Context& context, const SemIR::Impl& impl,
Context& context, const SemIR::Impl& impl, SemIR::TypeId interface_type_id,
SemIR::InterfaceType interface_type,
llvm::SmallVectorImpl<SemIR::InstId>& used_decl_ids) -> SemIR::InstId {
const auto& interface = context.interfaces().Get(interface_type.interface_id);
if (!interface.is_defined()) {
CARBON_DIAGNOSTIC(ImplOfUndefinedInterface, Error,
"Implementation of undefined interface {0}.",
SemIR::NameId);
auto builder = context.emitter().Build(
impl.definition_id, ImplOfUndefinedInterface, interface.name_id);
context.NoteUndefinedInterface(interface_type.interface_id, builder);
builder.Emit();
if (!context.TryToDefineType(interface_type_id, [&] {
CARBON_DIAGNOSTIC(ImplOfUndefinedInterface, Error,
"Implementation of undefined interface {0}.",
SemIR::NameId);
return context.emitter().Build(
impl.definition_id, ImplOfUndefinedInterface, interface.name_id);
})) {
return SemIR::InstId::BuiltinError;
}

Expand All @@ -85,18 +144,11 @@ static auto BuildInterfaceWitness(
context.inst_blocks().Get(interface.associated_entities_id);
table.reserve(assoc_entities.size());

// Substitute `Self` with the impl's self type when associated functions.
// TODO: Also substitute the arguments from interface_type.specific_id.
auto self_bind =
context.insts().GetAs<SemIR::BindSymbolicName>(interface.self_param_id);
Substitution substitutions[1] = {
{.bind_id =
context.entity_names().Get(self_bind.entity_name_id).bind_index,
.replacement_id = context.types().GetConstantId(impl.self_id)}};

for (auto decl_id : assoc_entities) {
LoadImportRef(context, decl_id);
decl_id = context.constant_values().GetConstantInstId(decl_id);
decl_id =
context.constant_values().GetInstId(SemIR::GetConstantValueInSpecific(
context.sem_ir(), interface_type.specific_id, decl_id));
CARBON_CHECK(decl_id.is_valid()) << "Non-constant associated entity";
auto decl = context.insts().Get(decl_id);
CARBON_KIND_SWITCH(decl) {
Expand All @@ -115,7 +167,7 @@ static auto BuildInterfaceWitness(
if (impl_decl_id.is_valid()) {
used_decl_ids.push_back(impl_decl_id);
table.push_back(CheckAssociatedFunctionImplementation(
context, fn_type->function_id, impl_decl_id, substitutions));
context, *fn_type, impl_decl_id, impl.self_id));
} else {
CARBON_DIAGNOSTIC(
ImplMissingFunction, Error,
Expand All @@ -137,7 +189,10 @@ static auto BuildInterfaceWitness(
"impl of interface with associated constant");
return SemIR::InstId::BuiltinError;
default:
CARBON_FATAL() << "Unexpected kind of associated entity " << decl;
CARBON_CHECK(decl_id == SemIR::InstId::BuiltinError)
<< "Unexpected kind of associated entity " << decl;
table.push_back(SemIR::InstId::BuiltinError);
break;
}
}

Expand All @@ -162,8 +217,8 @@ auto BuildImplWitness(Context& context, SemIR::ImplId impl_id)

llvm::SmallVector<SemIR::InstId> used_decl_ids;

auto witness_id =
BuildInterfaceWitness(context, impl, *interface_type, used_decl_ids);
auto witness_id = BuildInterfaceWitness(context, impl, impl.constraint_id,
*interface_type, used_decl_ids);

// TODO: Diagnose if any declarations in the impl are not in used_decl_ids.

Expand Down
15 changes: 11 additions & 4 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,8 @@ class ImportRefResolver {

// Make a declaration of a function. This is done as a separate step from
// importing the function declaration in order to resolve cycles.
auto MakeFunctionDecl(const SemIR::Function& import_function)
auto MakeFunctionDecl(const SemIR::Function& import_function,
SemIR::SpecificId specific_id)
-> std::pair<SemIR::FunctionId, SemIR::ConstantId> {
SemIR::FunctionDecl function_decl = {
.type_id = SemIR::TypeId::Invalid,
Expand All @@ -1375,8 +1376,6 @@ class ImportRefResolver {
.is_extern = import_function.is_extern,
.builtin_function_kind = import_function.builtin_function_kind}});

// TODO: Import this or recompute it.
auto specific_id = SemIR::SpecificId::Invalid;
function_decl.type_id =
context_.GetFunctionType(function_decl.function_id, specific_id);

Expand All @@ -1393,14 +1392,22 @@ class ImportRefResolver {

SemIR::FunctionId function_id = SemIR::FunctionId::Invalid;
if (!function_const_id.is_valid()) {
auto import_specific_id = import_ir_.types()
.GetAs<SemIR::FunctionType>(inst.type_id)
.specific_id;
auto specific_data = GetLocalSpecificData(import_specific_id);
if (HasNewWork()) {
// This is the end of the first phase. Don't make a new function yet if
// we already have new work.
return Retry();
}

auto specific_id =
GetOrAddLocalSpecific(import_specific_id, specific_data);

// On the second phase, create a forward declaration of the interface.
std::tie(function_id, function_const_id) =
MakeFunctionDecl(import_function);
MakeFunctionDecl(import_function, specific_id);
} else {
// On the third phase, compute the function ID from the constant value of
// the declaration.
Expand Down
18 changes: 6 additions & 12 deletions toolchain/check/member_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
#include "toolchain/check/convert.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/import_ref.h"
#include "toolchain/check/subst.h"
#include "toolchain/diagnostics/diagnostic_emitter.h"
#include "toolchain/sem_ir/generic.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/typed_insts.h"
Expand Down Expand Up @@ -191,16 +190,11 @@ static auto PerformImplLookup(Context& context, Parse::NodeId node_id,
return SemIR::InstId::BuiltinError;
}

// Substitute into the type declared in the interface.
// TODO: Also substitute the arguments from interface_type.specific_id.
auto self_param =
context.insts().GetAs<SemIR::BindSymbolicName>(interface.self_param_id);
Substitution substitutions[1] = {
{.bind_id =
context.entity_names().Get(self_param.entity_name_id).bind_index,
.replacement_id = type_const_id}};
auto subst_type_id =
SubstType(context, assoc_type.entity_type_id, substitutions);
// TODO: This produces the type of the associated entity with no value for
// `Self`. The type `Self` might appear in the type of an associated constant,
// and if so, we'll need to substitute it here somehow.
auto subst_type_id = SemIR::GetTypeInSpecific(
context.sem_ir(), interface_type.specific_id, assoc_type.entity_type_id);

return context.AddInst(
SemIR::LocIdAndInst::NoLoc<SemIR::InterfaceWitnessAccess>(
Expand Down
Loading

0 comments on commit 4a21b6a

Please sign in to comment.