Skip to content

Commit

Permalink
Support for SPIR-V code generation (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottslaughter authored Jun 19, 2024
1 parent e824949 commit dff6c92
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 30 deletions.
81 changes: 55 additions & 26 deletions src/tcompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,34 +762,36 @@ struct CCallingConv {
lua_State *L;
terra_CompilerState *C;
Types *Ty;
bool pass_struct_as_exploded_values;
bool return_empty_struct_as_void;
bool wasm_cconv;
bool aarch64_cconv;
bool amdgpu_cconv;
bool ppc64_cconv;
int ppc64_float_limit;
int ppc64_int_limit;
bool ppc64_count_used;
bool spirv_cconv;
bool wasm_cconv;

CCallingConv(TerraCompilationUnit *CU_, Types *Ty_)
: CU(CU_),
T(CU_->T),
L(CU_->T->L),
C(CU_->T->C),
Ty(Ty_),
pass_struct_as_exploded_values(false),
return_empty_struct_as_void(false),
wasm_cconv(false),
aarch64_cconv(false),
amdgpu_cconv(false),
ppc64_cconv(false),
ppc64_float_limit(0),
ppc64_int_limit(0),
ppc64_count_used(false) {
ppc64_count_used(false),
spirv_cconv(false),
wasm_cconv(false) {
auto Triple = CU->TT->tm->getTargetTriple();
switch (Triple.getArch()) {
case Triple::ArchType::amdgcn: {
return_empty_struct_as_void = true;
pass_struct_as_exploded_values = true;
amdgpu_cconv = true;
} break;
case Triple::ArchType::aarch64:
case Triple::ArchType::aarch64_be: {
Expand All @@ -806,16 +808,27 @@ struct CCallingConv {
ppc64_int_limit = 8;
ppc64_count_used = true;
} break;
#if LLVM_VERSION >= 150
case Triple::ArchType::spirv32:
case Triple::ArchType::spirv64: {
return_empty_struct_as_void = true;
spirv_cconv = true;
} break;
#endif
case Triple::ArchType::wasm32:
case Triple::ArchType::wasm64: {
wasm_cconv = true;
} break;
default:
break;
}

switch (Triple.getOS()) {
case Triple::OSType::Win32: {
return_empty_struct_as_void = true;
} break;
default:
break;
}
}

Expand Down Expand Up @@ -1088,11 +1101,11 @@ struct CCallingConv {
return Argument(C_PRIMITIVE, t, usei1 ? Type::getInt1Ty(*CU->TT->ctx) : NULL);
}

if (wasm_cconv && !WasmIsSingletonOrEmpty(t->type)) {
if ((wasm_cconv && !WasmIsSingletonOrEmpty(t->type)) || spirv_cconv) {
return Argument(C_AGGREGATE_MEM, t);
}

if (pass_struct_as_exploded_values) {
if (amdgpu_cconv) {
return Argument(C_AGGREGATE_REG, t, t->type);
}

Expand Down Expand Up @@ -1128,7 +1141,7 @@ struct CCallingConv {

return Argument(C_AGGREGATE_REG, t, StructType::get(*CU->TT->ctx, elements));
}
void Classify(Obj *ftype, Obj *params, Classification *info) {
void Classify(Obj *ftype, CallingConv::ID cconv, Obj *params, Classification *info) {
Obj fparams, returntype;
ftype->obj("parameters", &fparams);
ftype->obj("returntype", &returntype);
Expand Down Expand Up @@ -1161,13 +1174,13 @@ struct CCallingConv {
CreateFunctionType(info, fparams.size(), ftype->boolean("isvararg"));
}

Classification *ClassifyFunction(Obj *fntyp) {
Classification *ClassifyFunction(Obj *fntyp, CallingConv::ID cconv) {
Classification *info = (Classification *)CU->symbols->getud(fntyp);
if (!info) {
info = new Classification(); // TODO: fix leak
Obj params;
fntyp->obj("parameters", &params);
Classify(fntyp, &params, info);
Classify(fntyp, cconv, &params, info);
CU->symbols->setud(fntyp, info);
}
return info;
Expand Down Expand Up @@ -1279,8 +1292,9 @@ struct CCallingConv {
}
}

Function *CreateFunction(Module *M, Obj *ftype, const Twine &name) {
Classification *info = ClassifyFunction(ftype);
Function *CreateFunction(Module *M, Obj *ftype, CallingConv::ID cconv,
const Twine &name) {
Classification *info = ClassifyFunction(ftype, cconv);
Function *fn = Function::Create(info->fntype, Function::InternalLinkage, name, M);
AttributeFnOrCall(fn, info);
return fn;
Expand Down Expand Up @@ -1311,7 +1325,7 @@ struct CCallingConv {
}
void EmitEntry(IRBuilder<> *B, Obj *ftype, Function *func,
std::vector<Value *> *variables) {
Classification *info = ClassifyFunction(ftype);
Classification *info = ClassifyFunction(ftype, func->getCallingConv());
assert(info->paramtypes.size() == variables->size());
Function::arg_iterator ai = func->arg_begin();
if (info->returntype.kind == C_AGGREGATE_MEM)
Expand Down Expand Up @@ -1359,7 +1373,7 @@ struct CCallingConv {
}
}
void EmitReturn(IRBuilder<> *B, Obj *ftype, Function *function, Value *result) {
Classification *info = ClassifyFunction(ftype);
Classification *info = ClassifyFunction(ftype, function->getCallingConv());
ArgumentKind kind = info->returntype.kind;

if (C_AGGREGATE_REG == kind &&
Expand Down Expand Up @@ -1420,10 +1434,10 @@ struct CCallingConv {
}
}

Value *EmitCall(IRBuilder<> *B, Obj *ftype, Obj *paramtypes, Value *callee,
std::vector<Value *> *actuals) {
Value *EmitCall(IRBuilder<> *B, Obj *ftype, CallingConv::ID cconv, Obj *paramtypes,
Value *callee, std::vector<Value *> *actuals) {
Classification info;
Classify(ftype, paramtypes, &info);
Classify(ftype, cconv, paramtypes, &info);

std::vector<Value *> arguments;

Expand Down Expand Up @@ -1867,16 +1881,25 @@ struct FunctionEmitter {
if (fstate->func) return fstate;
}

CallingConv::ID callingconv = CallingConv::MaxID;
if (funcobj->hasfield("callingconv")) {
callingconv = ParseCallingConv(funcobj->string("callingconv"));
}

Obj ftype;
funcobj->obj("type", &ftype);
// function name is $+name so that it can't conflict with any symbols imported
// from the C namespace
fstate->func = CC->CreateFunction(
M, &ftype, Twine(StringRef((isextern) ? "" : "$"), name));
fstate->func =
CC->CreateFunction(M, &ftype, callingconv,
Twine(StringRef((isextern) ? "" : "$"), name));
if (isextern) {
// Set external linkage for extern functions.
fstate->func->setLinkage(GlobalValue::ExternalLinkage);
}
if (callingconv != CallingConv::MaxID) {
fstate->func->setCallingConv(callingconv);
}

if (funcobj->hasfield("alwaysinline")) {
if (funcobj->boolean("alwaysinline")) {
Expand All @@ -1891,10 +1914,6 @@ struct FunctionEmitter {
fstate->func->addFnAttr(Attribute::NoInline);
}
}
if (funcobj->hasfield("callingconv")) {
const char *callingconv = funcobj->string("callingconv");
fstate->func->setCallingConv(ParseCallingConv(callingconv));
}
if (funcobj->hasfield("noreturn")) {
if (funcobj->boolean("noreturn")) {
fstate->func->addFnAttr(Attribute::NoReturn);
Expand Down Expand Up @@ -2810,7 +2829,12 @@ struct FunctionEmitter {
#if LLVM_VERSION < 170
return B->CreateBitCast(v, toT->type);
#else
return v;
if (fromT->type->getPointerAddressSpace() !=
toT->type->getPointerAddressSpace()) {
return B->CreateAddrSpaceCast(v, toT->type);
} else {
return v;
}
#endif
} else {
assert(toT->type->isIntegerTy());
Expand Down Expand Up @@ -3205,6 +3229,11 @@ struct FunctionEmitter {

call->obj("value", &func);

CallingConv::ID callingconv = CallingConv::MaxID;
if (func.hasfield("callingconv")) {
callingconv = ParseCallingConv(func.string("callingconv"));
}

Value *fn = emitExp(&func);

Obj fnptrtyp;
Expand All @@ -3220,7 +3249,7 @@ struct FunctionEmitter {
setInsertBlock(bb);
deferred.push_back(bb);
}
Value *r = CC->EmitCall(B, &fntyp, &paramtypes, fn, &actuals);
Value *r = CC->EmitCall(B, &fntyp, callingconv, &paramtypes, fn, &actuals);
setInsertBlock(cur); // defer may have changed it
return r;
}
Expand Down
13 changes: 9 additions & 4 deletions src/terralib.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,12 @@ do
if status then return r end
end
return self.name
elseif self:ispointer() then return "&"..tostring(self.type)
elseif self:ispointer() then
if not self.addressspace or self.addressspace == 0 then
return "&"..tostring(self.type)
else
return "pointer("..tostring(self.type)..","..tostring(self.addressspace)..")"
end
elseif self:isvector() then return "vector("..tostring(self.type)..","..tostring(self.N)..")"
elseif self:isfunction() then return mkstring(self.parameters,"{",",",self.isvararg and " ...}" or "}").." -> "..tostring(self.returntype)
elseif self:isarray() then
Expand Down Expand Up @@ -2455,15 +2460,15 @@ function typecheck(topexp,luaenv,simultaneousdefinitions)
end
local function ascompletepointer(exp) --convert pointer like things into pointers to _complete_ types
exp.type.type:tcomplete(exp)
return (insertcast(exp,terra.types.pointer(exp.type.type))) --parens are to truncate to 1 argument
return (insertcast(exp,terra.types.pointer(exp.type.type, exp.type.addressspace))) --parens are to truncate to 1 argument
end
-- subtracting 2 pointers
if pointerlike(l.type) and pointerlike(r.type) and l.type.type == r.type.type and e.operator == tokens["-"] then
return e:copy { operands = List {ascompletepointer(l),ascompletepointer(r)} }:withtype(terra.types.ptrdiff)
elseif pointerlike(l.type) and r.type:isintegral() then -- adding or subtracting a int to a pointer
return e:copy {operands = List {ascompletepointer(l),r} }:withtype(terra.types.pointer(l.type.type))
return e:copy {operands = List {ascompletepointer(l),r} }:withtype(terra.types.pointer(l.type.type, l.type.addressspace))
elseif l.type:isintegral() and pointerlike(r.type) then
return e:copy {operands = List {ascompletepointer(r),l} }:withtype(terra.types.pointer(r.type.type))
return e:copy {operands = List {ascompletepointer(r),l} }:withtype(terra.types.pointer(r.type.type, r.type.addressspace))
else
return meetbinary(e,"isarithmeticorvector",l,r)
end
Expand Down
26 changes: 26 additions & 0 deletions tests/addressspace.t
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Tests of pointers with address spaces.

-- The exact meaning of this depends on the target, but at least basic
-- code compilation should work.

local function ptr1(ty)
-- A pointer in address space 1.
return terralib.types.pointer(ty, 1)
end

terra test(x : &int, y : ptr1(int))
-- Should be able to do math on pointers with non-zero address spaces:
var a = [ptr1(int8)](y)
var b = a + 8
var c = [ptr1(int)](b)
var d = c - y
y = c

-- Casts should work:
y = [ptr1(int)](x)
x = [&int](y)

return d
end
test:compile()
print(test)

0 comments on commit dff6c92

Please sign in to comment.