diff --git a/src/terralib.lua b/src/terralib.lua index 7fb4320f..d2e1d6f6 100644 --- a/src/terralib.lua +++ b/src/terralib.lua @@ -2785,6 +2785,39 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return stmts end + local function hasmetacopyassignment(typ) + if typ and typ:isstruct() and typ.metamethods.__copy then + return true + end + return false + end + + local function checkmetacopyassignment(anchor, from, to) + --if neither `from` or `to` are a struct then return + if not (hasmetacopyassignment(from.type) or hasmetacopyassignment(to.type)) then + return + end + --if `to` is an allocvar then set type and turn into corresponding `var` + if to:is "allocvar" then + local typ = from.type or terra.types.error + to:settype(typ) + to = newobject(anchor,T.var,to.name,to.symbol):setlvalue(true):withtype(to.type) + end + --list of overloaded __copy metamethods + local overloads = terra.newlist() + local function checkoverload(v) + if hasmetacopyassignment(v.type) then + overloads:insert(asterraexpression(anchor, v.type.metamethods.__copy, "luaobject")) + end + end + --add overloaded methods based on left- and right-hand-side of the assignment + checkoverload(from) + checkoverload(to) + if #overloads > 0 then + return checkcall(anchor, overloads, terralib.newlist{from, to}, "all", true, "expression") + end + end + local function checkmethod(exp, location) local methodname = checklabel(exp.name,true).value assert(type(methodname) == "string" or terra.islabel(methodname)) @@ -3211,7 +3244,28 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return newobject(anchor,T.letin, stmts, List {}, true):withtype(terra.types.unit) end + --divide assignment into regular assignments and copy assignments + local function assignmentkinds(lhs, rhs) + local regular = {lhs = terralib.newlist(), rhs = terralib.newlist()} + local byfcall = {lhs = terralib.newlist(), rhs = terralib.newlist()} + for i=1,#lhs do + local rhstype = rhs[i] and rhs[i].type + local lhstype = lhs[i].type + if rhstype and (hasmetacopyassignment(lhstype) or hasmetacopyassignment(rhstype)) then + --add assignment by __copy call + byfcall.lhs:insert(lhs[i]) + byfcall.rhs:insert(rhs[i]) + else + --default to regular assignment + regular.lhs:insert(lhs[i]) + regular.rhs:insert(rhs[i]) + end + end + return regular, byfcall + end + local function createassignment(anchor,lhs,rhs) + --special case where a rhs struct is unpacked if #lhs > #rhs and #rhs > 0 then local last = rhs[#rhs] if last.type:isstruct() and last.type.convertible == "tuple" and #last.type.entries + #rhs - 1 == #lhs then @@ -3231,20 +3285,73 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return createstatementlist(anchor, List {a1, a2}) end end - local vtypes = lhs:map(function(v) return v.type or "passthrough" end) - rhs = insertcasts(anchor,vtypes,rhs) - for i,v in ipairs(lhs) do - local rhstype = rhs[i] and rhs[i].type or terra.types.error - if v:is "setteru" then - local rv,r = allocvar(v,rhstype,"") - lhs[i] = newobject(v,T.setter, rv,v.setter(r)) - elseif v:is "allocvar" then - v:settype(rhstype) + + if #lhs < #rhs then + --an error may be reported later during type-checking: 'expected #lhs parameters (...), but found #rhs (...)' + local vtypes = lhs:map(function(v) return v.type or "passthrough" end) + rhs = insertcasts(anchor, vtypes, rhs) + for i,v in ipairs(lhs) do + local rhstype = rhs[i] and rhs[i].type or terra.types.error + if v:is "setteru" then + local rv,r = allocvar(v,rhstype,"") + lhs[i] = newobject(v,T.setter, rv,v.setter(r)) + elseif v:is "allocvar" then + v:settype(rhstype) + else + ensurelvalue(v) + end + end + return newobject(anchor,T.assignment,lhs,rhs) + else + --standard case #lhs == #rhs + --first take care of regular assignments + local regular, byfcall = assignmentkinds(lhs, rhs) + local vtypes = regular.lhs:map(function(v) return v.type or "passthrough" end) + regular.rhs = insertcasts(anchor, vtypes, regular.rhs) + for i,v in ipairs(regular.lhs) do + local rhstype = regular.rhs[i] and regular.rhs[i].type or terra.types.error + if v:is "setteru" then + local rv,r = allocvar(v,rhstype,"") + regular.lhs[i] = newobject(v,T.setter, rv,v.setter(r)) + elseif v:is "allocvar" then + v:settype(rhstype) + else + ensurelvalue(v) + end + end + --take care of copy assignments using metamethods.__copy + local stmts = terralib.newlist() + for i,v in ipairs(byfcall.lhs) do + local rhstype = byfcall.rhs[i] and byfcall.rhs[i].type or terra.types.error + if v:is "setteru" then + local rv,r = allocvar(v,rhstype,"") + stmts:insert(checkmetacopyassignment(anchor, byfcall.rhs[i], r)) + stmts:insert(newobject(v,T.setter, rv, v.setter(r))) + elseif v:is "allocvar" then + v:settype(rhstype) + stmts:insert(v) + local init = checkmetainit(anchor, v) + if init then + stmts:insert(init) + end + stmts:insert(checkmetacopyassignment(anchor, byfcall.rhs[i], v)) + else + ensurelvalue(v) + stmts:insert(checkmetacopyassignment(anchor, byfcall.rhs[i], v)) + end + end + if #stmts==0 then + --standard case, no meta-copy-assignments + return newobject(anchor,T.assignment, regular.lhs, regular.rhs) else - ensurelvalue(v) + --managed case using meta-copy-assignments + --the calls to `__copy` are in `stmts` + if #regular.lhs>0 then + stmts:insert(newobject(anchor,T.assignment, regular.lhs, regular.rhs)) + end + return createstatementlist(anchor, stmts) end end - return newobject(anchor,T.assignment,lhs,rhs) end local function checkmetadtors(anchor, stats)