From d183885b8a70aaf0ce1b3c32f66779390465b24a Mon Sep 17 00:00:00 2001 From: fubark Date: Thu, 1 Aug 2024 19:04:12 -0400 Subject: [PATCH] Dependent template params. --- src/cte.zig | 69 +++++++++--- src/sema.zig | 119 ++++++++++++++++---- src/sema_func.zig | 1 + test/behavior_test.zig | 2 + test/types/template_dep_param_type.cy | 12 ++ test/types/template_dep_param_type_error.cy | 12 ++ 6 files changed, 178 insertions(+), 37 deletions(-) create mode 100644 test/types/template_dep_param_type.cy create mode 100644 test/types/template_dep_param_type_error.cy diff --git a/src/cte.zig b/src/cte.zig index 8ae714e14..258bd0d43 100644 --- a/src/cte.zig +++ b/src/cte.zig @@ -16,6 +16,47 @@ pub fn expandTemplateOnCallExpr(c: *cy.Chunk, node: *ast.CallExpr) !*cy.Sym { return cte.expandTemplateOnCallArgs(c, callee.cast(.template), node.args, @ptrCast(node)); } +pub fn pushNodeValuesCstr(c: *cy.Chunk, args: []const *ast.Node, template: *cy.sym.Template, ct_arg_start: usize, node: *ast.Node) !void { + const params = c.sema.getFuncSig(template.sigId).params(); + if (args.len != params.len) { + const params_s = try c.sema.allocFuncParamsStr(params, c); + defer c.alloc.free(params_s); + return c.reportErrorFmt( + \\Expected template signature `{}[{}]`. + , &.{v(template.head.name()), v(params_s)}, node); + } + for (args, 0..) |arg, i| { + const param = params[i]; + const exp_type = try resolveTemplateParamType(c, param.type, ct_arg_start); + const res = try resolveCtValue(c, arg); + try c.valueStack.append(c.alloc, res.value); + + if (res.type != exp_type) { + const cstrName = try c.sema.allocTypeName(exp_type); + defer c.alloc.free(cstrName); + const typeName = try c.sema.allocTypeName(res.type); + defer c.alloc.free(typeName); + return c.reportErrorFmt("Expected type `{}`, got `{}`.", &.{v(cstrName), v(typeName)}, arg); + } + } +} + +/// This is similar to `sema_func.resolveTemplateParamType` except it only cares about ct_ref. +fn resolveTemplateParamType(c: *cy.Chunk, type_id: cy.TypeId, ct_arg_start: usize) !cy.TypeId { + const type_e = c.sema.types.items[type_id]; + if (type_e.kind == .ct_ref) { + const ct_arg = c.valueStack.items[ct_arg_start + type_e.data.ct_infer.ct_param_idx]; + if (ct_arg.getTypeId() != bt.Type) { + return error.TODO; + } + return ct_arg.asHeapObject().type.type; + } else if (type_e.info.ct_ref) { + return error.TODO; + } else { + return type_id; + } +} + pub fn pushNodeValues(c: *cy.Chunk, args: []const *ast.Node) !void { for (args) |arg| { const res = try resolveCtValue(c, arg); @@ -26,11 +67,8 @@ pub fn pushNodeValues(c: *cy.Chunk, args: []const *ast.Node) !void { pub fn expandTemplateOnCallArgs(c: *cy.Chunk, template: *cy.sym.Template, args: []const *ast.Node, node: *ast.Node) !*cy.Sym { // Accumulate compile-time args. - const typeStart = c.typeStack.items.len; const valueStart = c.valueStack.items.len; defer { - c.typeStack.items.len = typeStart; - // Values need to be released. const values = c.valueStack.items[valueStart..]; for (values) |val| { @@ -39,21 +77,8 @@ pub fn expandTemplateOnCallArgs(c: *cy.Chunk, template: *cy.sym.Template, args: c.valueStack.items.len = valueStart; } - try pushNodeValues(c, args); - - const argTypes = c.typeStack.items[typeStart..]; + try pushNodeValuesCstr(c, args, template, valueStart, node); const arg_vals = c.valueStack.items[valueStart..]; - - // Check against template signature. - if (!cy.types.isTypeFuncSigCompat(c.compiler, @ptrCast(argTypes), .not_void, template.sigId)) { - const sig = c.sema.getFuncSig(template.sigId); - const params_s = try c.sema.allocFuncParamsStr(sig.params(), c); - defer c.alloc.free(params_s); - return c.reportErrorFmt( - \\Expected template signature `{}[{}]`. - , &.{v(template.head.name()), v(params_s)}, node); - } - return expandTemplate(c, template, arg_vals); } @@ -263,6 +288,14 @@ pub const CtValue = struct { // TODO: Evaluate const expressions. pub fn resolveCtValue(c: *cy.Chunk, expr: *ast.Node) !CtValue { switch (expr.type()) { + .floatLit => { + const literal = c.ast.nodeString(expr); + const val = try std.fmt.parseFloat(f64, literal); + return .{ + .type = bt.Float, + .value = cy.Value.initF64(val), + }; + }, .decLit => { const literal = c.ast.nodeString(expr); const val = try std.fmt.parseInt(i64, literal, 10); @@ -352,7 +385,7 @@ pub fn resolveCtValue(c: *cy.Chunk, expr: *ast.Node) !CtValue { return c.reportErrorFmt("Unexpected compile-time expression.", &.{}, expr); }, else => { - return c.reportErrorFmt("Unsupported expr: `{}`", &.{v(expr.type())}, expr); + return c.reportErrorFmt("Unsupported compile-time expression: `{}`", &.{v(expr.type())}, expr); } } } diff --git a/src/sema.zig b/src/sema.zig index 2c306dc77..983af490b 100644 --- a/src/sema.zig +++ b/src/sema.zig @@ -3332,6 +3332,19 @@ fn semaLocal(c: *cy.Chunk, id: LocalVarId, node: *ast.Node) !ExprResult { } } +pub fn semaCtValue(c: *cy.Chunk, ct_value: cte.CtValue, node: *ast.Node) !ExprResult { + switch (ct_value.type) { + bt.Integer => { + return c.semaInt(ct_value.value.asBoxInt(), node); + }, + else => { + const type_n = try c.sema.allocTypeName(ct_value.type); + defer c.alloc.free(type_n); + return c.reportErrorFmt("Unsupported compile-time value: `{}`.", &.{v(type_n)}, node); + } + } +} + fn semaIdent(c: *cy.Chunk, node: *ast.Node, symAsValue: bool, prefer_addressable: bool) !ExprResult { const name = c.ast.nodeString(node); const res = try getOrLookupVar(c, name, node); @@ -3355,6 +3368,16 @@ fn semaIdent(c: *cy.Chunk, node: *ast.Node, symAsValue: bool, prefer_addressable .static => |sym| { return sema.symbol(c, sym, node, symAsValue); }, + .ct_value => |ct_value| { + defer c.vm.release(ct_value.value); + if (ct_value.type == bt.Type) { + const type_id = ct_value.value.castHeapObject(*cy.heap.Type).type; + const sym = c.sema.getTypeSym(type_id); + return sema.symbol(c, sym, node, symAsValue); + } else { + return semaCtValue(c, ct_value, node); + } + }, } } @@ -3373,24 +3396,45 @@ pub fn getLocalDistinctSym(c: *cy.Chunk, name: []const u8, node: *ast.Node) !?*S return null; } -pub fn getResolvedLocalSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, distinct: bool) !?*Sym { +const NameResultType = enum { + sym, + ct_value, +}; + +const NameResult = struct { + type: NameResultType, + data: union { + sym: *Sym, + ct_value: cte.CtValue, + }, + + fn initSym(sym: *Sym) NameResult { + return .{ .type = .sym, .data = .{ .sym = sym, }}; + } + + fn initCtValue(ct_value: cte.CtValue) NameResult { + return .{ .type = .ct_value, .data = .{ .ct_value = ct_value, }}; + } +}; + +pub fn getResolvedSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, distinct: bool) !?NameResult { if (c.sym_cache.get(name)) |sym| { if (distinct and !sym.isDistinct()) { return c.reportErrorFmt("`{}` is not a unique symbol.", &.{v(name)}, node); } - return sym; + return NameResult.initSym(sym); } // Look in the current chunk module. if (distinct) { if (try c.getResolvedDistinctSym(@ptrCast(c.sym), name, node, false)) |res| { try c.sym_cache.putNoClobber(c.alloc, name, res); - return res; + return NameResult.initSym(res); } } else { if (try c.getOptResolvedSym(@ptrCast(c.sym), name)) |sym| { try c.sym_cache.putNoClobber(c.alloc, name, sym); - return sym; + return NameResult.initSym(sym); } } @@ -3400,14 +3444,8 @@ pub fn getResolvedLocalSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, dist const ctx = c.resolve_stack.items[resolve_ctx_idx]; if (ctx.ct_params.size > 0) { if (ctx.ct_params.get(name)) |param| { - if (param.getTypeId() != bt.Type) { - const param_type_name = c.sema.getTypeBaseName(param.getTypeId()); - return c.reportErrorFmt("Can not use a `{}` template param here.", &.{v(param_type_name)}, node); - } - const sym = c.sema.getTypeSym(param.asHeapObject().type.type); - if (!distinct or sym.isDistinct()) { - return sym; - } + c.vm.retain(param); + return NameResult.initCtValue(.{ .type = param.getTypeId(), .value = param }); } } if (!ctx.has_parent_ctx) { @@ -3424,12 +3462,12 @@ pub fn getResolvedLocalSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, dist if (distinct) { if (try c.getResolvedDistinctSym(@ptrCast(mod_sym), name, node, false)) |res| { try c.sym_cache.putNoClobber(c.alloc, name, res); - return res; + return NameResult.initSym(res); } } else { if (try c.getOptResolvedSym(@ptrCast(mod_sym), name)) |sym| { try c.sym_cache.putNoClobber(c.alloc, name, sym); - return sym; + return NameResult.initSym(sym); } } } @@ -3537,9 +3575,23 @@ pub fn resolveSym(c: *cy.Chunk, expr: *ast.Node) anyerror!*cy.Sym { } } - return (try getResolvedLocalSym(c, name, expr, true)) orelse { + const res = (try getResolvedSym(c, name, expr, true)) orelse { return c.reportErrorFmt("Could not find the symbol `{}`.", &.{v(name)}, expr); }; + switch (res.type) { + .sym => return res.data.sym, + .ct_value => { + defer c.vm.release(res.data.ct_value.value); + if (res.data.ct_value.type == bt.Type) { + const type_id = res.data.ct_value.value.castHeapObject(*cy.heap.Type).type; + return c.sema.getTypeSym(type_id); + } else { + const type_n = try c.sema.allocTypeName(res.data.ct_value.type); + defer c.alloc.free(type_n); + return c.reportErrorFmt("Expected symbol, found compile-time value `{}`.", &.{v(type_n)}, expr); + } + } + } }, .accessExpr => { const access_expr = expr.cast(.accessExpr); @@ -3647,10 +3699,16 @@ fn resolveTemplateSig(c: *cy.Chunk, params: []*ast.FuncParam, outSigId: *FuncSig } const typeId = try resolveTypeSpecNode(c, param.typeSpec); try c.typeStack.append(c.alloc, typeId); + const param_name = c.ast.funcParamName(param); tparams[i] = .{ - .name = c.ast.funcParamName(param), + .name = param_name, .type = typeId, }; + + const ct_param_idx = getResolveContext(c).ct_params.size; + const ref_t = try c.sema.ensureCtRefType(ct_param_idx); + const param_v = try c.vm.allocType(ref_t); + try setResolveCtParam(c, param_name, param_v); } const retType = bt.Type; @@ -4182,6 +4240,7 @@ fn referenceSym(c: *cy.Chunk, sym: *Sym, node: *ast.Node) !void { const VarLookupResult = union(enum) { global: *Sym, static: *Sym, + ct_value: cte.CtValue, /// Local, parent local alias, or parent object member alias. local: LocalVarId, @@ -4285,16 +4344,28 @@ pub fn getOrLookupVar(self: *cy.Chunk, name: []const u8, node: *ast.Node) !VarLo } return self.reportErrorFmt("Undeclared variable `{}`.", &.{v(name)}, node); }; - _ = try pushStaticVarAlias(self, name, res.static); + switch (res) { + .static => |sym| { + _ = try pushStaticVarAlias(self, name, sym); + }, + else => {} + } return res; } } fn lookupStaticVar(c: *cy.Chunk, name: []const u8, node: *ast.Node) !?VarLookupResult { - const res = (try getResolvedLocalSym(c, name, node, false)) orelse { + const res = (try getResolvedSym(c, name, node, false)) orelse { return null; }; - return VarLookupResult{ .static = res }; + switch (res.type) { + .sym => { + return VarLookupResult{ .static = res.data.sym }; + }, + .ct_value => { + return VarLookupResult{ .ct_value = res.data.ct_value }; + }, + } } const LookupParentLocalResult = struct { @@ -5992,6 +6063,16 @@ pub const ChunkExt = struct { .static => |sym| { return callSym(c, sym, node.callee, node.args, expr.getRetCstr(), expr.node); }, + .ct_value => |ct_value| { + defer c.vm.release(ct_value.value); + if (ct_value.type == bt.Type) { + const type_id = ct_value.value.castHeapObject(*cy.heap.Type).type; + const sym = c.sema.getTypeSym(type_id); + return callSym(c, sym, node.callee, node.args, expr.getRetCstr(), expr.node); + } else { + return error.TODO; + } + }, } } else { // preCall. diff --git a/src/sema_func.zig b/src/sema_func.zig index 785682e2e..c790b6380 100644 --- a/src/sema_func.zig +++ b/src/sema_func.zig @@ -378,6 +378,7 @@ fn inferCtArgValueEq(c: *cy.Chunk, arg: cy.Value, infer: cy.Value) bool { } /// Returns the target param type at index. Returns null if type should be deduced from arg. +/// resolveSymType isn't used because that performs resolving on a node rather than a type. fn resolveTargetParam(c: *cy.Chunk, param_t: cy.TypeId, ct_arg_start: usize) !?cy.TypeId { const type_e = c.sema.types.items[param_t]; if (type_e.info.ct_ref or type_e.info.ct_infer) { diff --git a/test/behavior_test.zig b/test/behavior_test.zig index b42fc81cf..04e341f68 100644 --- a/test/behavior_test.zig +++ b/test/behavior_test.zig @@ -203,6 +203,8 @@ if (!aot) { run.case("types/struct_circular_dep_error.cy"); run.case("types/structs.cy"); run.case("types/template_choices.cy"); + run.case("types/template_dep_param_type.cy"); + run.case("types/template_dep_param_type_error.cy"); run.case("types/template_object_init_noexpand_error.cy"); run.case("types/template_object_spec_noexpand_error.cy"); run.case("types/template_object_expand_error.cy"); diff --git a/test/types/template_dep_param_type.cy b/test/types/template_dep_param_type.cy new file mode 100644 index 000000000..7f0a25ded --- /dev/null +++ b/test/types/template_dep_param_type.cy @@ -0,0 +1,12 @@ +use test + +type Foo[T type, Value T] struct: + a T + + func get(self) T: + return self.a + Value + +var f = Foo[int, 10]{a=123} +test.eq(f.get(), 133) + +--cytest: pass \ No newline at end of file diff --git a/test/types/template_dep_param_type_error.cy b/test/types/template_dep_param_type_error.cy new file mode 100644 index 000000000..9fb847875 --- /dev/null +++ b/test/types/template_dep_param_type_error.cy @@ -0,0 +1,12 @@ +type Foo[T type, Value T] struct: + a T + +var f = Foo[int, 10.0]{a=123} + +--cytest: error +--CompileError: Expected type `int`, got `float`. +-- +--main:4:18: +--var f = Foo[int, 10.0]{a=123} +-- ^ +-- \ No newline at end of file