diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 8e78ff0843..c4f845b977 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2496,8 +2496,11 @@ namespace Slang // of the parameters of the generic. // CheckConstraintSubType(decl->sub); - decl->sub = TranslateTypeNodeForced(decl->sub); - decl->sup = TranslateTypeNodeForced(decl->sup); + + if (!decl->sub.type) + decl->sub = TranslateTypeNodeForced(decl->sub); + if (!decl->sup.type) + decl->sup = TranslateTypeNodeForced(decl->sup); if (!decl->isEqualityConstraint && !isValidGenericConstraintType(decl->sup) && !as(decl->sub.type)) { getSink()->diagnose(decl->sup.exp, Diagnostics::invalidTypeForConstraint, decl->sup); @@ -3762,6 +3765,10 @@ namespace Slang ThisExpr*& synThis) { auto synGenericDecl = m_astBuilder->create(); + synGenericDecl->parentDecl = context->parentDecl; + synGenericDecl->ownedScope = m_astBuilder->create(); + synGenericDecl->ownedScope->containerDecl = synGenericDecl; + synGenericDecl->ownedScope->parent = getScope(context->parentDecl); // For now our synthesized method will use the name and source // location of the requirement we are trying to satisfy. @@ -3779,7 +3786,7 @@ namespace Slang } // Dictionary to map from the original type parameters to the synthesized ones. - Dictionary mapOrigToSynTypeParams; + Dictionary mapOrigToSynTypeParams; // Our synthesized method will have parameters matching the names // and types of those on the requirement, and it will use expressions @@ -3788,27 +3795,77 @@ namespace Slang // for (auto member : requiredMemberDeclRef.getDecl()->members) { - if (auto typeParamDecl = as(member)) + if (auto typeParamDeclBase = as(member)) { - auto synTypeParamDecl = m_astBuilder->create(); - synTypeParamDecl->nameAndLoc = typeParamDecl->getNameAndLoc(); - synTypeParamDecl->initType = typeParamDecl->initType; - synTypeParamDecl->parentDecl = synGenericDecl; - synGenericDecl->members.add(synTypeParamDecl); + auto synTypeParamDeclBase = (GenericTypeParamDeclBase*)m_astBuilder->createByNodeType(typeParamDeclBase->astNodeType); + synTypeParamDeclBase->nameAndLoc = typeParamDeclBase->getNameAndLoc(); + synTypeParamDeclBase->parameterIndex = typeParamDeclBase->parameterIndex; + synTypeParamDeclBase->parentDecl = synGenericDecl; + + // Note: we intentionally do not copy GenericTypeParamDecl::initType here, + // because initType maybe dependent on the original type parameters, + // and if we copy we must also substitute all the original type parameters with the synthesized ones. + // It shouldn't be required for the implementing declaration to define initType anyways, so we'll just + // save ourselves from the trouble. + // + synGenericDecl->members.add(synTypeParamDeclBase); - mapOrigToSynTypeParams.add(typeParamDecl, synTypeParamDecl); + mapOrigToSynTypeParams.add(typeParamDeclBase, synTypeParamDeclBase); // Construct a DeclRefExpr from the type parameter. - auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl); + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDeclBase); auto synTypeParamDeclRefExpr = m_astBuilder->create(); synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); synGenericArgs.add(synTypeParamDeclRefExpr); - } + } + else if (auto valParamDecl = as(member)) + { + auto synValParamDecl = m_astBuilder->create(); + synValParamDecl->nameAndLoc = valParamDecl->nameAndLoc; + synValParamDecl->parentDecl = synGenericDecl; + synValParamDecl->parameterIndex = valParamDecl->parameterIndex; + synValParamDecl->type = valParamDecl->type; + + // Note: we intentionally do not copy GenericValueParamDecl::initExpr here, + // because initType maybe dependent on the original type/value parameters, + // and if we copy we must also substitute all the original type parameters with the synthesized ones. + // It shouldn't be required for the implementing declaration to define initType anyways, so we'll just + // save ourselves from the trouble. + // + synGenericDecl->members.add(synValParamDecl); + + mapOrigToSynTypeParams.add(valParamDecl, synGenericDecl); + + // Construct a DeclRefExpr from the value parameter. + auto synValParamDeclRef = makeDeclRef(synValParamDecl); + + auto synValParamDeclRefExpr = m_astBuilder->create(); + synValParamDeclRefExpr->declRef = synValParamDeclRef; + synValParamDeclRefExpr->type = synValParamDecl->type.type; + + synGenericArgs.add(synValParamDeclRefExpr); + } } + // With all generic parameters in place, we can now form a partial substitution argument list + // without taking into account all the generic constraints. + + // Given `requiredMemberDeclRef` that is `Lookup(ConcreteType:IFoo, IFoo::bar)`, we can now + // form a partial specialized declref to `IFoo::bar` with substitution args comming + // from the synthesized generic decl, i.e. we want to form: + // `Lookup(ConcreteType:IFoo, IFoo::bar)` where `UImpl` is a synthesized generic parameter. + // + auto partialDefaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); + DeclRef partiallySpecializedRequiredGenericDeclRef = m_astBuilder->getGenericAppDeclRef( + requiredMemberDeclRef, partialDefaultArgs.getArrayView()).as(); + + // With `partiallySpecializedRequiredGenericDeclRef`, we can obtain the right specialized types + // from the original requirement decl. For example, we can simply apply declref substituion on + // the original type constraint `U:IDerived` to get `UImpl : IDerived`. + // for (auto member : requiredMemberDeclRef.getDecl()->members) { if (auto constraintDecl = as(member)) @@ -3817,28 +3874,16 @@ namespace Slang synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); synConstraintDecl->parentDecl = synGenericDecl; - // For constraints of type T : Interface, where T is a simple type parameter, - // find the declaration of T - // - if (auto typeParamDecl = as(constraintDecl->sub.type)->getDeclRef().as().getDecl()) - { - auto synTypeParamDecl = mapOrigToSynTypeParams.getValue(typeParamDecl); - - // Construct a DeclRefExpr from the type parameter. - auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl); - - auto synTypeParamDeclRefExpr = m_astBuilder->create(); - synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; - synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); - - synConstraintDecl->sub = TypeExp(synTypeParamDeclRefExpr); - synConstraintDecl->sup = constraintDecl->sup; - synGenericDecl->members.add(synConstraintDecl); - } - else - { - SLANG_UNEXPECTED("Cannot perform synthesis for requirements with complex type constraints."); - } + // For generic constraint Sub : Sup, we need to substitute them with + // synthesized generic parameters. + // + synConstraintDecl->sub = TypeExp( + (Type*)constraintDecl->sub.type->substitute(m_astBuilder, + SubstitutionSet(partiallySpecializedRequiredGenericDeclRef))); + synConstraintDecl->sup = TypeExp( + (Type*)constraintDecl->sup.type->substitute(m_astBuilder, + SubstitutionSet(partiallySpecializedRequiredGenericDeclRef))); + synGenericDecl->members.add(synConstraintDecl); } } @@ -3847,18 +3892,18 @@ namespace Slang // original parameters. // auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); - DeclRef requiredFuncDeclRef = m_astBuilder->getGenericAppDeclRef( - requiredMemberDeclRef, defaultArgs.getArrayView()).as(); + DeclRef requiredFuncDeclRef = m_astBuilder->getGenericAppDeclRef( + requiredMemberDeclRef, defaultArgs.getArrayView()).as(); SLANG_ASSERT(requiredFuncDeclRef); + ConformanceCheckingContext subContext = *context; + subContext.parentDecl = synGenericDecl; - synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness( - context, + synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitnessInner( + &subContext, requiredFuncDeclRef, synArgs, synThis); - synGenericDecl->inner->parentDecl = synGenericDecl; - return synGenericDecl; } @@ -4023,12 +4068,38 @@ namespace Slang DeclRef requiredMemberDeclRef, List& synArgs, ThisExpr*& synThis) + { + if (auto genericDeclRef = as(requiredMemberDeclRef.getParent())) + { + List synGenericArgs; + auto genericDecl = synthesizeGenericSignatureForRequirementWitness( + context, + genericDeclRef, + synArgs, + synGenericArgs, + synThis); + return (CallableDecl*)genericDecl->inner; + } + return synthesizeMethodSignatureForRequirementWitnessInner( + context, + requiredMemberDeclRef, + synArgs, + synThis); + } + + CallableDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitnessInner( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + ThisExpr*& synThis) { CallableDecl* synFuncDecl = as(m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType)); SLANG_ASSERT(synFuncDecl); + synFuncDecl->ownedScope = m_astBuilder->create(); synFuncDecl->ownedScope->containerDecl = synFuncDecl; synFuncDecl->ownedScope->parent = getScope(context->parentDecl); + synFuncDecl->parentDecl = context->parentDecl; // For now our synthesized method will use the name and source // location of the requirement we are trying to satisfy. @@ -4179,16 +4250,16 @@ namespace Slang // and we really just need to wrap that result up as an overloaded // expression. // - auto synBase = m_astBuilder->create(); - synBase->name = requiredMemberDeclRef.getDecl()->getName(); + auto baseOverloadedExpr = m_astBuilder->create(); + baseOverloadedExpr->name = requiredMemberDeclRef.getDecl()->getName(); if (isInWrapperType) { auto aggTypeDecl = as(context->parentDecl); - synBase->lookupResult2 = lookUpMember( + baseOverloadedExpr->lookupResult2 = lookUpMember( m_astBuilder, this, - synBase->name, + baseOverloadedExpr->name, aggTypeDecl->wrappedType.type, aggTypeDecl->ownedScope, LookupMask::Default, @@ -4199,7 +4270,7 @@ namespace Slang } else { - synBase->lookupResult2 = lookupResult; + baseOverloadedExpr->lookupResult2 = lookupResult; } // If `synThis` is non-null, then we will use it as the base of @@ -4217,23 +4288,16 @@ namespace Slang auto innerExpr = m_astBuilder->create(); innerExpr->scope = synThis->scope; innerExpr->name = getName("inner"); - synBase->base = CheckExpr(innerExpr); + baseOverloadedExpr->base = CheckExpr(innerExpr); SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); - bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, baseOverloadedExpr->base->type); } else { - synBase->base = synThis; + baseOverloadedExpr->base = synThis; } } - // We now have the reference to the overload group we plan to call, - // and we already built up the argument list, so we can construct - // an `InvokeExpr` that represents the call we want to make. - // - auto synCall = m_astBuilder->create(); - synCall->functionExpr = synBase; - synCall->arguments = synArgs; // In order to know if our call is well-formed, we need to run // the semantic checking logic for overload resolution. If it @@ -4248,6 +4312,48 @@ namespace Slang ExprLocalScope localScope; SemanticsVisitor subVisitor(withSink(&tempSink).withParentFunc(synFuncDecl).withExprLocalScope(&localScope)); + Expr* synBase = baseOverloadedExpr; + + // If the requirement is a generic decl, fill in all generic arguments explicitly. + if (auto genericDeclRef = as(synFuncDecl->parentDecl)) + { + auto genericAppExpr = m_astBuilder->create(); + genericAppExpr->functionExpr = synBase; + for (auto member : genericDeclRef->members) + { + if (auto typeParamDecl = as(member)) + { + auto synTypeParamDeclRef = makeDeclRef(typeParamDecl); + auto synTypeParamDeclRefExpr = m_astBuilder->create(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + genericAppExpr->arguments.add(synTypeParamDeclRefExpr); + } + else if (auto valParamDecl = as(member)) + { + auto synValParamDeclRef = makeDeclRef(valParamDecl); + auto synValParamDeclRefExpr = m_astBuilder->create(); + synValParamDeclRefExpr->declRef = synValParamDeclRef; + synValParamDeclRefExpr->type = getType(m_astBuilder, synValParamDeclRef); + genericAppExpr->arguments.add(synValParamDeclRefExpr); + } + } + synBase = subVisitor.checkGenericAppWithCheckedArgs(genericAppExpr); + + // If checking the generic app failed, we can't synthesize the witness. + // + if (tempSink.getErrorCount() != 0) + return false; + } + + // We now have the reference to the overload group we plan to call, + // and we already built up the argument list, so we can construct + // an `InvokeExpr` that represents the call we want to make. + // + auto synCall = m_astBuilder->create(); + synCall->functionExpr = synBase; + synCall->arguments = synArgs; + // With our temporary diagnostic sink soaking up any messages // from overload resolution, we can now try to resolve // the call to see what happens. @@ -4294,17 +4400,16 @@ namespace Slang synFuncDecl->body = synReturn; - // Once we are sure that we want to use the declaration - // we've synthesized, aew can go ahead and wire it up - // to the AST so that subsequent stages can generate - // IR code from it. - // // Note: we set the parent of the synthesized declaration // to the parent of the inheritance declaration being // validated (which is either a type declaration or // an `extension`), but we do *not* add the syntehsized // declaration to the list of child declarations at // this point. + // + // The synthesized decl already has its parent set to + // the current parent decl, so we don't need more actions + // to wire it up to the AST hierarchy. // // By leaving the synthesized declaration off of the list // of members, we ensure that it doesn't get found @@ -4317,7 +4422,6 @@ namespace Slang // declaration into the hierarchy, but then attach a modifier // to it to indicate that it should be ignored by things like lookup. // - synFuncDecl->parentDecl = context->parentDecl; // If the synthesized func is differentiable, make sure to populate its // differential type dictionary. @@ -4332,7 +4436,10 @@ namespace Slang // difference between our synthetic method and a hand-written // one with the same behavior. // - _addMethodWitness(witnessTable, requiredMemberDeclRef, makeDeclRef(synFuncDecl)); + auto containerDecl = getParentDecl(synFuncDecl); + auto containerDeclRef = getDefaultDeclRef(containerDecl); + auto synDeclRef = m_astBuilder->getMemberDeclRef(containerDeclRef, synFuncDecl); + _addMethodWitness(witnessTable, requiredMemberDeclRef, synDeclRef); return true; } @@ -4354,31 +4461,28 @@ namespace Slang bool isDefaultInitializableType = requiredMemberDeclRef.getParent() == getASTBuilder()->getDefaultInitializableTypeInterfaceDecl(); bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); - if (!isInWrapperType && !isDefaultInitializableType) + if (!isInWrapperType && !isDefaultInitializableType && !satisfyingMemberLookupResult.isValid()) { return false; } - auto ctorDecl = m_astBuilder->create(); - ctorDecl->ownedScope = m_astBuilder->create(); - ctorDecl->ownedScope->containerDecl = ctorDecl; - ctorDecl->ownedScope->parent = getScope(context->parentDecl); + List synArgs; + ThisExpr* synThis = nullptr; + + auto ctorDecl = (ConstructorDecl*)synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis); ctorDecl->loc = context->parentDecl->loc; ctorDecl->closingSourceLoc = ctorDecl->loc; - ctorDecl->parentDecl = context->parentDecl; auto ctorName = getName("$init"); ctorDecl->nameAndLoc.name = ctorName; - ctorDecl->nameAndLoc.loc = ctorDecl->loc; + ctorDecl->nameAndLoc.loc = context->parentDecl->loc; auto seqStmt = m_astBuilder->create(); ctorDecl->body = seqStmt; - ctorDecl->returnType.type = DeclRefType::create(m_astBuilder, makeDeclRef(context->parentDecl)); - - List synArgs; - addRequiredParamsToSynthesizedDecl(requiredMemberDeclRef, ctorDecl, synArgs); - ThisExpr* synThis = nullptr; - addModifiersToSynthesizedDecl(context, requiredMemberDeclRef, ctorDecl, synThis); if (isInWrapperType) { @@ -4424,11 +4528,69 @@ namespace Slang } } } + else if (synArgs.getCount()) + { + // The body of our synthesized method is going to try to + // make a ctor call with the specified arguments (e.g., + // the name `increment` in our example at the top of this function). + // + auto synBase = m_astBuilder->create(); + synBase->name = requiredMemberDeclRef.getDecl()->getName(); + + synBase->lookupResult2 = satisfyingMemberLookupResult; + + // We now have the reference to the overload group we plan to call, + // and we already built up the argument list, so we can construct + // an `InvokeExpr` that represents the call we want to make. + // + auto synCall = m_astBuilder->create(); + synCall->functionExpr = synBase; + synCall->arguments = synArgs; + + // In order to know if our call is well-formed, we need to run + // the semantic checking logic for overload resolution. If it + // runs into an error, we don't want that being reported back + // to the user as some kind of overload-resolution failure. + // + // In order to protect the user from whatever errors might + // occur, we will perform the checking in the context of + // a temporary diagnostic sink. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + ExprLocalScope localScope; + SemanticsVisitor subVisitor(withSink(&tempSink).withParentFunc(ctorDecl).withExprLocalScope(&localScope)); + + // With our temporary diagnostic sink soaking up any messages + // from overload resolution, we can now try to resolve + // the call to see what happens. + // + auto checkedCall = subVisitor.ResolveInvoke(synCall); + + // If any error occurs during overload resolution, we can't synthesize the witness. + if (tempSink.getErrorCount() != 0) + return false; + + // If we were able to type-check the call, then we should + // be able to finish construction of a suitable ctor witness, + // by emitting `this = resolvedCtorCall()`. + // + AssignExpr* assignExpr = m_astBuilder->create(); + assignExpr->left = synThis; + assignExpr->right = checkedCall; + assignExpr->type = m_astBuilder->getVoidType(); + ExpressionStmt* exprStmt = m_astBuilder->create(); + exprStmt->expression = assignExpr; + seqStmt->stmts.add(exprStmt); + } + if (isDefaultInitializableType) context->parentDecl->addMember(ctorDecl); - - _addMethodWitness(witnessTable, requiredMemberDeclRef, makeDeclRef(ctorDecl)); - + + auto containerDecl = getParentDecl(ctorDecl); + auto containerDeclRef = getDefaultDeclRef(containerDecl); + auto synDeclRef = m_astBuilder->getMemberDeclRef(containerDeclRef, ctorDecl); + _addMethodWitness(witnessTable, requiredMemberDeclRef, synDeclRef); + return true; } @@ -4497,17 +4659,8 @@ namespace Slang synPropertyDecl->parentDecl = context->parentDecl; - // The type of our synthesized property will be the expected type - // of the interface requirement. - // - // TODO: This logic can/will run into problems if the type is, - // or uses, an associated type or `This`. - // - // Ideally we should be looking up the type using a `DeclRef` that - // refers to the interface requirement using a `LookupDeclRef` - // that refers to the satisfying type declaration, and requirement - // checking for non-associated-type requirements should be done *after* - // requirement checking for associated-type requirements. + // The type of our synthesized property can be derived from the + // specialized declref to the requirement decl. // auto propertyType = getType(m_astBuilder, requiredMemberDeclRef); synPropertyDecl->type.type = propertyType; @@ -5059,12 +5212,18 @@ namespace Slang // difference between our synthetic property and a hand-written // one with the same behavior. // + auto containerDecl = getParentDecl(synAccesorContainer); + auto containerDeclRef = getDefaultDeclRef(containerDecl); for (auto& [key, value] : mapRequiredAccessorToSynAccessor) { - witnessTable->add(key.getDecl(), RequirementWitness(getDefaultDeclRef(value))); + witnessTable->add( + key.getDecl(), + RequirementWitness( + m_astBuilder->getMemberDeclRef(containerDeclRef, value))); } + witnessTable->add(requiredMemberDeclRef.getDecl(), - RequirementWitness(getDefaultDeclRef(synAccesorContainer))); + RequirementWitness(m_astBuilder->getMemberDeclRef(containerDeclRef, synAccesorContainer))); return true; } @@ -5092,15 +5251,17 @@ namespace Slang // List synArgs; ThisExpr* synThis; - auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness(context, requiredMemberDeclRef, - synArgs, synThis); + auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis); auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); synThis->checked = true; // Form a `this[args...]` expression that we will use to coerce from // in the synthesized subscript accessors. // - synSubscriptDecl->parentDecl = context->parentDecl; DiagnosticSink tempSink(getSourceManager(), nullptr); SemanticsVisitor subVisitor(withSink(&tempSink)); auto base = m_astBuilder->create(); @@ -5127,8 +5288,6 @@ namespace Slang if (!canSynAccessors) return false; - synSubscriptDecl->parentDecl = context->parentDecl; - // The visibility of synthesized decl should be the min of the parent decl and the requirement. if (requiredMemberDeclRef.getDecl()->findModifier()) { @@ -5143,6 +5302,7 @@ namespace Slang bool SemanticsVisitor::trySynthesizeSubscriptRequirementWitness( ConformanceCheckingContext* context, + const LookupResult& lookupResult, DeclRef requiredMemberDeclRef, RefPtr witnessTable) { @@ -5196,11 +5356,13 @@ namespace Slang List synArgs; ThisExpr* synThis; - auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness(context, requiredMemberDeclRef, - synArgs, synThis); + auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis); synThis->type.isLeftValue = true; synThis->checked = true; - synSubscriptDecl->parentDecl = context->parentDecl; auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); @@ -5209,11 +5371,24 @@ namespace Slang // DiagnosticSink tempSink(getSourceManager(), nullptr); SemanticsVisitor subVisitor(withSink(&tempSink)); - IndexExpr* indexExpr = m_astBuilder->create(); - indexExpr->baseExpression = synThis; - indexExpr->indexExprs = _Move(synArgs); - auto synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); - + Expr* synBaseStorageExpr = nullptr; + if (lookupResult.isValid()) + { + auto calleeExpr = m_astBuilder->create(); + calleeExpr->base = synThis; + calleeExpr->lookupResult2 = lookupResult; + auto invokeExpr = m_astBuilder->create(); + invokeExpr->functionExpr = calleeExpr; + invokeExpr->arguments = _Move(synArgs); + synBaseStorageExpr = subVisitor.ResolveInvoke(invokeExpr); + } + else + { + IndexExpr* indexExpr = m_astBuilder->create(); + indexExpr->baseExpression = synThis; + indexExpr->indexExprs = _Move(synArgs); + synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); + } if (tempSink.getErrorCount() != 0) return false; @@ -5296,7 +5471,10 @@ namespace Slang // For generic decl, check if we match DMulFunc, and synthesize the method. if (auto requiredGenericDeclRef = requiredMemberDeclRef.as()) { - if (auto builtinAttr = getInner(requiredGenericDeclRef)->findModifier()) + auto inner = getInner(requiredGenericDeclRef); + + // TODO: we should be able to remove DMul synthesis logic. + if (auto builtinAttr = inner->findModifier()) { switch (builtinAttr->kind) { @@ -5308,6 +5486,15 @@ namespace Slang SynthesisPattern::FixedFirstArg); } } + + if (as(inner)) + { + return trySynthesizeRequirementWitness( + context, + lookupResult, + m_astBuilder->getMemberDeclRef(requiredGenericDeclRef, inner), + witnessTable); + } return false; } @@ -5324,6 +5511,7 @@ namespace Slang { return trySynthesizeSubscriptRequirementWitness( context, + lookupResult, requiredSubscriptDeclRef, witnessTable); } @@ -5496,7 +5684,6 @@ namespace Slang default: SLANG_ASSERT("unknown builtin requirement kind."); } - synFunc->parentDecl = context->parentDecl; synFunc->loc = context->parentDecl->closingSourceLoc; synFunc->nameAndLoc.loc = synFunc->loc; context->parentDecl->members.add(synFunc); @@ -5576,10 +5763,6 @@ namespace Slang addModifier(synFunc, m_astBuilder->create()); - if (synGeneric) - synGeneric->parentDecl = context->parentDecl; - else - synFunc->parentDecl = context->parentDecl; synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create(); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 29a57ae356..8edcbb16a7 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1743,6 +1743,12 @@ namespace Slang CallableDecl* synthesized, List& synArgs); + CallableDecl* synthesizeMethodSignatureForRequirementWitnessInner( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + ThisExpr*& synThis); + CallableDecl* synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, DeclRef requiredMemberDeclRef, @@ -1803,6 +1809,7 @@ namespace Slang bool trySynthesizeSubscriptRequirementWitness( ConformanceCheckingContext* context, + const LookupResult& lookupResult, DeclRef requiredMemberDeclRef, RefPtr witnessTable); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index b2173cd7b9..c738f5b49e 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2674,8 +2674,7 @@ namespace Slang else if (auto overloadedExpr = as(baseExpr)) { // We are referring to a bunch of declarations, each of which might be generic - LookupResult result; - for (auto item : overloadedExpr->lookupResult2.items) + for (auto item : overloadedExpr->lookupResult2) { AddGenericOverloadCandidate(item, context); } diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index bee8f088a2..0c7f10ad7b 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -23,7 +23,7 @@ static const char* kDeclKeywords[] = { "protected", "typedef", "typealias", "uniform", "export", "groupshared", "extension", "associatedtype", "namespace", "This", "using", "__generic", "__exported", "import", "enum", "cbuffer", "tbuffer", "func", - "functype"}; + "functype", "typename", "each", "expand" }; static const char* kStmtKeywords[] = { "if", "else", "switch", "case", "default", "return", "try", "throw", "throws", "catch", "while", "for", @@ -35,7 +35,8 @@ static const char* kStmtKeywords[] = { "__generic", "__exported", "import", "enum", "break", "continue", "discard", "defer", "cbuffer", "tbuffer", "func", "is", "as", "nullptr", "none", "true", "false", "functype", - "sizeof", "alignof", "__target_switch", "__intrinsic_asm"}; + "sizeof", "alignof", "__target_switch", "__intrinsic_asm", + "each", "expand" }; static const char* hlslSemanticNames[] = { "register", diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 5729adb29d..926b3c1d92 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1765,6 +1765,27 @@ namespace Slang expr->baseExpression->accept(this, nullptr); expr->scope = scope; } + void visitAppExprBase(AppExprBase* expr) + { + expr->functionExpr->accept(this, nullptr); + for (auto arg : expr->arguments) + arg->accept(this, nullptr); + } + void visitIsTypeExpr(IsTypeExpr* expr) + { + if (expr->typeExpr.exp) + expr->typeExpr.exp->accept(this, nullptr); + } + void visitAsTypeExpr(AsTypeExpr* expr) + { + if (expr->typeExpr) + expr->typeExpr->accept(this, nullptr); + } + void visiSizeOfLikeExpr(SizeOfLikeExpr* expr) + { + if (expr->value) + expr->value->accept(this, nullptr); + } void visitExpr(Expr* /*expr*/) {} }; diff --git a/tests/language-feature/interfaces/generic-requirement-synth-2.slang b/tests/language-feature/interfaces/generic-requirement-synth-2.slang new file mode 100644 index 0000000000..9066fc5b77 --- /dev/null +++ b/tests/language-feature/interfaces/generic-requirement-synth-2.slang @@ -0,0 +1,35 @@ +// Test that we allow type conformances whose base interface is generic. + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-dx11 -compute -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -output-using-type + +interface IStack +{ + IStack popN(); + + int get(); +} +struct StackImpl : IStack +{ + // member 'popN' does not match interface requirement. + StackImpl popN() { return StackImpl(); } + + int get() { return D; } +} + +int helper>(T stack) +{ + return stack.popN<2>().get(); +} + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4); +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain() +{ + StackImpl<5> obj = StackImpl<5>(); + + // CHECK: 3 + outputBuffer[0] = helper(obj); +} \ No newline at end of file diff --git a/tests/language-feature/interfaces/generic-requirement-synth.slang b/tests/language-feature/interfaces/generic-requirement-synth.slang new file mode 100644 index 0000000000..1ac1180142 --- /dev/null +++ b/tests/language-feature/interfaces/generic-requirement-synth.slang @@ -0,0 +1,90 @@ +// Test that we can synthesize requirements for generic methods. + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-dx11 -compute -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -output-using-type + +interface IBase +{ + static float get(); +} +interface IBar : IBase +{ + float derivedMethod(); +} + +struct Bar : IBar +{ + static float get() { return 1.0f; } + float derivedMethod() { return 2.0f; } +} + +interface ITestInterface +{ + Real sample(T t); + + __init(T t); + + __generic + __subscript(T t)->Real { get; } +} + +struct TestInterfaceImpl : ITestInterface +{ + // The signature of this sample method is different from the one in the + // interface. However, we should be able to form a call into this method + // from the synthesized implementation matching the interface definition, + // so the conformance should hold. + Real sample(T t) + { + return x + Real(T.get()); + } + + // Test the same thing for constructors. + __init(T t) + { + x = Real(T.get()); + } + + // Test the same thing for subscript operators. + __generic + __subscript(T t)->Real { get { return x + Real(T.get()); } } + Real x; +} + +float test(ITestInterface obj) +{ + Bar b = {}; + return obj.sample(b); +} + +float test1(ITestInterface obj) +{ + Bar b = {}; + return obj[b]; +} + +float test2>() +{ + Bar b = {}; + T obj = T(b); + return obj[b]; +} + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4); +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain() +{ + TestInterfaceImpl obj; + obj.x = 1.0f; + + // CHECK: 2 + outputBuffer[0] = int(test(obj)); + + // CHECK: 2 + outputBuffer[1] = int(test1(obj)); + + // CHECK: 2 + outputBuffer[3] = int(test2>()); +} \ No newline at end of file