Skip to content

Commit

Permalink
Dependent template params.
Browse files Browse the repository at this point in the history
  • Loading branch information
fubark committed Aug 1, 2024
1 parent 919447b commit d183885
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 37 deletions.
69 changes: 51 additions & 18 deletions src/cte.zig
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,47 @@ pub fn expandTemplateOnCallExpr(c: *cy.Chunk, node: *ast.CallExpr) !*cy.Sym {
return cte.expandTemplateOnCallArgs(c, callee.cast(.template), node.args, @ptrCast(node));
}

pub fn pushNodeValuesCstr(c: *cy.Chunk, args: []const *ast.Node, template: *cy.sym.Template, ct_arg_start: usize, node: *ast.Node) !void {
const params = c.sema.getFuncSig(template.sigId).params();
if (args.len != params.len) {
const params_s = try c.sema.allocFuncParamsStr(params, c);
defer c.alloc.free(params_s);
return c.reportErrorFmt(
\\Expected template signature `{}[{}]`.
, &.{v(template.head.name()), v(params_s)}, node);
}
for (args, 0..) |arg, i| {
const param = params[i];
const exp_type = try resolveTemplateParamType(c, param.type, ct_arg_start);
const res = try resolveCtValue(c, arg);
try c.valueStack.append(c.alloc, res.value);

if (res.type != exp_type) {
const cstrName = try c.sema.allocTypeName(exp_type);
defer c.alloc.free(cstrName);
const typeName = try c.sema.allocTypeName(res.type);
defer c.alloc.free(typeName);
return c.reportErrorFmt("Expected type `{}`, got `{}`.", &.{v(cstrName), v(typeName)}, arg);
}
}
}

/// This is similar to `sema_func.resolveTemplateParamType` except it only cares about ct_ref.
fn resolveTemplateParamType(c: *cy.Chunk, type_id: cy.TypeId, ct_arg_start: usize) !cy.TypeId {
const type_e = c.sema.types.items[type_id];
if (type_e.kind == .ct_ref) {
const ct_arg = c.valueStack.items[ct_arg_start + type_e.data.ct_infer.ct_param_idx];
if (ct_arg.getTypeId() != bt.Type) {
return error.TODO;
}
return ct_arg.asHeapObject().type.type;
} else if (type_e.info.ct_ref) {
return error.TODO;
} else {
return type_id;
}
}

pub fn pushNodeValues(c: *cy.Chunk, args: []const *ast.Node) !void {
for (args) |arg| {
const res = try resolveCtValue(c, arg);
Expand All @@ -26,11 +67,8 @@ pub fn pushNodeValues(c: *cy.Chunk, args: []const *ast.Node) !void {

pub fn expandTemplateOnCallArgs(c: *cy.Chunk, template: *cy.sym.Template, args: []const *ast.Node, node: *ast.Node) !*cy.Sym {
// Accumulate compile-time args.
const typeStart = c.typeStack.items.len;
const valueStart = c.valueStack.items.len;
defer {
c.typeStack.items.len = typeStart;

// Values need to be released.
const values = c.valueStack.items[valueStart..];
for (values) |val| {
Expand All @@ -39,21 +77,8 @@ pub fn expandTemplateOnCallArgs(c: *cy.Chunk, template: *cy.sym.Template, args:
c.valueStack.items.len = valueStart;
}

try pushNodeValues(c, args);

const argTypes = c.typeStack.items[typeStart..];
try pushNodeValuesCstr(c, args, template, valueStart, node);
const arg_vals = c.valueStack.items[valueStart..];

// Check against template signature.
if (!cy.types.isTypeFuncSigCompat(c.compiler, @ptrCast(argTypes), .not_void, template.sigId)) {
const sig = c.sema.getFuncSig(template.sigId);
const params_s = try c.sema.allocFuncParamsStr(sig.params(), c);
defer c.alloc.free(params_s);
return c.reportErrorFmt(
\\Expected template signature `{}[{}]`.
, &.{v(template.head.name()), v(params_s)}, node);
}

return expandTemplate(c, template, arg_vals);
}

Expand Down Expand Up @@ -263,6 +288,14 @@ pub const CtValue = struct {
// TODO: Evaluate const expressions.
pub fn resolveCtValue(c: *cy.Chunk, expr: *ast.Node) !CtValue {
switch (expr.type()) {
.floatLit => {
const literal = c.ast.nodeString(expr);
const val = try std.fmt.parseFloat(f64, literal);
return .{
.type = bt.Float,
.value = cy.Value.initF64(val),
};
},
.decLit => {
const literal = c.ast.nodeString(expr);
const val = try std.fmt.parseInt(i64, literal, 10);
Expand Down Expand Up @@ -352,7 +385,7 @@ pub fn resolveCtValue(c: *cy.Chunk, expr: *ast.Node) !CtValue {
return c.reportErrorFmt("Unexpected compile-time expression.", &.{}, expr);
},
else => {
return c.reportErrorFmt("Unsupported expr: `{}`", &.{v(expr.type())}, expr);
return c.reportErrorFmt("Unsupported compile-time expression: `{}`", &.{v(expr.type())}, expr);
}
}
}
119 changes: 100 additions & 19 deletions src/sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3332,6 +3332,19 @@ fn semaLocal(c: *cy.Chunk, id: LocalVarId, node: *ast.Node) !ExprResult {
}
}

pub fn semaCtValue(c: *cy.Chunk, ct_value: cte.CtValue, node: *ast.Node) !ExprResult {
switch (ct_value.type) {
bt.Integer => {
return c.semaInt(ct_value.value.asBoxInt(), node);
},
else => {
const type_n = try c.sema.allocTypeName(ct_value.type);
defer c.alloc.free(type_n);
return c.reportErrorFmt("Unsupported compile-time value: `{}`.", &.{v(type_n)}, node);
}
}
}

fn semaIdent(c: *cy.Chunk, node: *ast.Node, symAsValue: bool, prefer_addressable: bool) !ExprResult {
const name = c.ast.nodeString(node);
const res = try getOrLookupVar(c, name, node);
Expand All @@ -3355,6 +3368,16 @@ fn semaIdent(c: *cy.Chunk, node: *ast.Node, symAsValue: bool, prefer_addressable
.static => |sym| {
return sema.symbol(c, sym, node, symAsValue);
},
.ct_value => |ct_value| {
defer c.vm.release(ct_value.value);
if (ct_value.type == bt.Type) {
const type_id = ct_value.value.castHeapObject(*cy.heap.Type).type;
const sym = c.sema.getTypeSym(type_id);
return sema.symbol(c, sym, node, symAsValue);
} else {
return semaCtValue(c, ct_value, node);
}
},
}
}

Expand All @@ -3373,24 +3396,45 @@ pub fn getLocalDistinctSym(c: *cy.Chunk, name: []const u8, node: *ast.Node) !?*S
return null;
}

pub fn getResolvedLocalSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, distinct: bool) !?*Sym {
const NameResultType = enum {
sym,
ct_value,
};

const NameResult = struct {
type: NameResultType,
data: union {
sym: *Sym,
ct_value: cte.CtValue,
},

fn initSym(sym: *Sym) NameResult {
return .{ .type = .sym, .data = .{ .sym = sym, }};
}

fn initCtValue(ct_value: cte.CtValue) NameResult {
return .{ .type = .ct_value, .data = .{ .ct_value = ct_value, }};
}
};

pub fn getResolvedSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, distinct: bool) !?NameResult {
if (c.sym_cache.get(name)) |sym| {
if (distinct and !sym.isDistinct()) {
return c.reportErrorFmt("`{}` is not a unique symbol.", &.{v(name)}, node);
}
return sym;
return NameResult.initSym(sym);
}

// Look in the current chunk module.
if (distinct) {
if (try c.getResolvedDistinctSym(@ptrCast(c.sym), name, node, false)) |res| {
try c.sym_cache.putNoClobber(c.alloc, name, res);
return res;
return NameResult.initSym(res);
}
} else {
if (try c.getOptResolvedSym(@ptrCast(c.sym), name)) |sym| {
try c.sym_cache.putNoClobber(c.alloc, name, sym);
return sym;
return NameResult.initSym(sym);
}
}

Expand All @@ -3400,14 +3444,8 @@ pub fn getResolvedLocalSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, dist
const ctx = c.resolve_stack.items[resolve_ctx_idx];
if (ctx.ct_params.size > 0) {
if (ctx.ct_params.get(name)) |param| {
if (param.getTypeId() != bt.Type) {
const param_type_name = c.sema.getTypeBaseName(param.getTypeId());
return c.reportErrorFmt("Can not use a `{}` template param here.", &.{v(param_type_name)}, node);
}
const sym = c.sema.getTypeSym(param.asHeapObject().type.type);
if (!distinct or sym.isDistinct()) {
return sym;
}
c.vm.retain(param);
return NameResult.initCtValue(.{ .type = param.getTypeId(), .value = param });
}
}
if (!ctx.has_parent_ctx) {
Expand All @@ -3424,12 +3462,12 @@ pub fn getResolvedLocalSym(c: *cy.Chunk, name: []const u8, node: *ast.Node, dist
if (distinct) {
if (try c.getResolvedDistinctSym(@ptrCast(mod_sym), name, node, false)) |res| {
try c.sym_cache.putNoClobber(c.alloc, name, res);
return res;
return NameResult.initSym(res);
}
} else {
if (try c.getOptResolvedSym(@ptrCast(mod_sym), name)) |sym| {
try c.sym_cache.putNoClobber(c.alloc, name, sym);
return sym;
return NameResult.initSym(sym);
}
}
}
Expand Down Expand Up @@ -3537,9 +3575,23 @@ pub fn resolveSym(c: *cy.Chunk, expr: *ast.Node) anyerror!*cy.Sym {
}
}

return (try getResolvedLocalSym(c, name, expr, true)) orelse {
const res = (try getResolvedSym(c, name, expr, true)) orelse {
return c.reportErrorFmt("Could not find the symbol `{}`.", &.{v(name)}, expr);
};
switch (res.type) {
.sym => return res.data.sym,
.ct_value => {
defer c.vm.release(res.data.ct_value.value);
if (res.data.ct_value.type == bt.Type) {
const type_id = res.data.ct_value.value.castHeapObject(*cy.heap.Type).type;
return c.sema.getTypeSym(type_id);
} else {
const type_n = try c.sema.allocTypeName(res.data.ct_value.type);
defer c.alloc.free(type_n);
return c.reportErrorFmt("Expected symbol, found compile-time value `{}`.", &.{v(type_n)}, expr);
}
}
}
},
.accessExpr => {
const access_expr = expr.cast(.accessExpr);
Expand Down Expand Up @@ -3647,10 +3699,16 @@ fn resolveTemplateSig(c: *cy.Chunk, params: []*ast.FuncParam, outSigId: *FuncSig
}
const typeId = try resolveTypeSpecNode(c, param.typeSpec);
try c.typeStack.append(c.alloc, typeId);
const param_name = c.ast.funcParamName(param);
tparams[i] = .{
.name = c.ast.funcParamName(param),
.name = param_name,
.type = typeId,
};

const ct_param_idx = getResolveContext(c).ct_params.size;
const ref_t = try c.sema.ensureCtRefType(ct_param_idx);
const param_v = try c.vm.allocType(ref_t);
try setResolveCtParam(c, param_name, param_v);
}

const retType = bt.Type;
Expand Down Expand Up @@ -4182,6 +4240,7 @@ fn referenceSym(c: *cy.Chunk, sym: *Sym, node: *ast.Node) !void {
const VarLookupResult = union(enum) {
global: *Sym,
static: *Sym,
ct_value: cte.CtValue,

/// Local, parent local alias, or parent object member alias.
local: LocalVarId,
Expand Down Expand Up @@ -4285,16 +4344,28 @@ pub fn getOrLookupVar(self: *cy.Chunk, name: []const u8, node: *ast.Node) !VarLo
}
return self.reportErrorFmt("Undeclared variable `{}`.", &.{v(name)}, node);
};
_ = try pushStaticVarAlias(self, name, res.static);
switch (res) {
.static => |sym| {
_ = try pushStaticVarAlias(self, name, sym);
},
else => {}
}
return res;
}
}

fn lookupStaticVar(c: *cy.Chunk, name: []const u8, node: *ast.Node) !?VarLookupResult {
const res = (try getResolvedLocalSym(c, name, node, false)) orelse {
const res = (try getResolvedSym(c, name, node, false)) orelse {
return null;
};
return VarLookupResult{ .static = res };
switch (res.type) {
.sym => {
return VarLookupResult{ .static = res.data.sym };
},
.ct_value => {
return VarLookupResult{ .ct_value = res.data.ct_value };
},
}
}

const LookupParentLocalResult = struct {
Expand Down Expand Up @@ -5992,6 +6063,16 @@ pub const ChunkExt = struct {
.static => |sym| {
return callSym(c, sym, node.callee, node.args, expr.getRetCstr(), expr.node);
},
.ct_value => |ct_value| {
defer c.vm.release(ct_value.value);
if (ct_value.type == bt.Type) {
const type_id = ct_value.value.castHeapObject(*cy.heap.Type).type;
const sym = c.sema.getTypeSym(type_id);
return callSym(c, sym, node.callee, node.args, expr.getRetCstr(), expr.node);
} else {
return error.TODO;
}
},
}
} else {
// preCall.
Expand Down
1 change: 1 addition & 0 deletions src/sema_func.zig
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ fn inferCtArgValueEq(c: *cy.Chunk, arg: cy.Value, infer: cy.Value) bool {
}

/// Returns the target param type at index. Returns null if type should be deduced from arg.
/// resolveSymType isn't used because that performs resolving on a node rather than a type.
fn resolveTargetParam(c: *cy.Chunk, param_t: cy.TypeId, ct_arg_start: usize) !?cy.TypeId {
const type_e = c.sema.types.items[param_t];
if (type_e.info.ct_ref or type_e.info.ct_infer) {
Expand Down
2 changes: 2 additions & 0 deletions test/behavior_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ if (!aot) {
run.case("types/struct_circular_dep_error.cy");
run.case("types/structs.cy");
run.case("types/template_choices.cy");
run.case("types/template_dep_param_type.cy");
run.case("types/template_dep_param_type_error.cy");
run.case("types/template_object_init_noexpand_error.cy");
run.case("types/template_object_spec_noexpand_error.cy");
run.case("types/template_object_expand_error.cy");
Expand Down
12 changes: 12 additions & 0 deletions test/types/template_dep_param_type.cy
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use test

type Foo[T type, Value T] struct:
a T

func get(self) T:
return self.a + Value

var f = Foo[int, 10]{a=123}
test.eq(f.get(), 133)

--cytest: pass
12 changes: 12 additions & 0 deletions test/types/template_dep_param_type_error.cy
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
type Foo[T type, Value T] struct:
a T

var f = Foo[int, 10.0]{a=123}

--cytest: error
--CompileError: Expected type `int`, got `float`.
--
--main:4:18:
--var f = Foo[int, 10.0]{a=123}
-- ^
--

0 comments on commit d183885

Please sign in to comment.