Skip to content

Commit

Permalink
Kill time based heuristics no versions (#1285)
Browse files Browse the repository at this point in the history
* Remove the time-based heuristics
* Remove invocation time
* Remove type feedback version
  • Loading branch information
fikovnik authored Jun 5, 2024
1 parent be83f77 commit 2d36a73
Show file tree
Hide file tree
Showing 17 changed files with 107 additions and 271 deletions.
2 changes: 0 additions & 2 deletions rir/src/compiler/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,6 @@ rir::Function* Backend::doCompile(ClosureVersion* cls, ClosureLog& log) {
// here we only set the current version used to compile this function
auto feedback = rir::TypeFeedback::empty();
PROTECT(feedback->container());
feedback->version(
cls->optFunction->dispatchTable()->currentTypeFeedbackVersion());

function.finalize(body, signature, cls->context(), feedback);
for (auto& c : done)
Expand Down
4 changes: 1 addition & 3 deletions rir/src/compiler/native/builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ static FunctionSignature
deoptSentinelSig(FunctionSignature::Environment::CallerProvided,
FunctionSignature::OptimizationLevel::Optimized);
static Function* deoptSentinel;
static SEXP deoptSentinelContainer = []() {
SEXP deoptSentinelContainer = []() {
auto c = rir::Code::NewNative(0);
PROTECT(c->container());
SEXP store = Rf_allocVector(EXTERNALSXP, sizeof(Function));
Expand Down Expand Up @@ -1464,7 +1464,6 @@ static SEXP nativeCallTrampolineImpl(ArglistOrder::CallId callId, rir::Code* c,
R_ReturnedValue = R_NilValue; /* remove restart token */
fun->registerInvocation();
result = code->nativeCode()(code, args, env, callee);
fun->registerEndInvocation();
} else {
result = R_ReturnedValue;
}
Expand All @@ -1482,7 +1481,6 @@ static SEXP nativeCallTrampolineImpl(ArglistOrder::CallId callId, rir::Code* c,
ostack_popn(missing);

SLOWASSERT(t == R_BCNodeStackTop);
fun->registerEndInvocation();
return result;
}

Expand Down
78 changes: 31 additions & 47 deletions rir/src/compiler/native/lower_function_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ using namespace llvm;
extern "C" size_t R_NSize;
extern "C" size_t R_NodesInUse;

extern SEXP deoptSentinelContainer;

static_assert(sizeof(unsigned long) == sizeof(uint64_t),
"sizeof(unsigned long) and sizeof(uint64_t) should match");

Expand Down Expand Up @@ -3402,7 +3404,6 @@ void LowerFunctionLLVM::compile() {
auto calli = StaticCall::Cast(i);
calli->eachArg([](Value* v) { assert(!ExpandDots::Cast(v)); });
auto target = calli->tryDispatch();
auto bestTarget = calli->tryOptimisticDispatch();
std::vector<Value*> args;
calli->eachCallArg([&](Value* v) { args.push_back(v); });
Context asmpt = calli->inferAvailableAssumptions();
Expand All @@ -3424,32 +3425,34 @@ void LowerFunctionLLVM::compile() {
break;
}

if (target == bestTarget) {
auto callee = target->owner()->rirClosure();
auto dt = DispatchTable::check(BODY(callee));
rir::Function* nativeTarget = nullptr;
for (size_t i = 0; i < dt->size(); i++) {
auto entry = dt->get(i);
if (entry->context() == target->context() &&
entry->signature().numArguments >= args.size()) {
nativeTarget = entry;
}
auto callee = target->owner()->rirClosure();
auto dt = DispatchTable::check(BODY(callee));
rir::Function* nativeTarget = nullptr;
for (size_t i = 0; i < dt->size(); i++) {
auto entry = dt->get(i);
if (entry->context() == target->context() &&
entry->signature().numArguments >= args.size() &&
!entry->disabled()) {
nativeTarget = entry;
}
if (nativeTarget) {
assert(
asmpt.includes(Assumption::StaticallyArgmatched));
auto idx = Pool::makeSpace();
NativeBuiltins::targetCaches.push_back(idx);
Pool::patch(idx, nativeTarget->container());
auto missAsmptStore =
Rf_allocVector(RAWSXP, sizeof(Context));
auto missAsmptIdx = Pool::insert(missAsmptStore);
new (DATAPTR(missAsmptStore))
Context(nativeTarget->context() - asmpt);
assert(asmpt.smaller(nativeTarget->context()));
auto res = withCallFrame(args, [&]() {
return call(
NativeBuiltins::get(
}
SEXP container = deoptSentinelContainer;
if (nativeTarget) {
container = nativeTarget->container();
}

assert(asmpt.includes(Assumption::StaticallyArgmatched));
auto idx = Pool::makeSpace();
NativeBuiltins::targetCaches.push_back(idx);
Pool::patch(idx, container);
auto missAsmptStore = Rf_allocVector(RAWSXP, sizeof(Context));
auto missAsmptIdx = Pool::insert(missAsmptStore);
new (DATAPTR(missAsmptStore)) Context();
if (nativeTarget) {
assert(asmpt.smaller(nativeTarget->context()));
}
auto res = withCallFrame(args, [&]() {
return call(NativeBuiltins::get(
NativeBuiltins::Id::nativeCallTrampoline),
{
c(callId),
Expand All @@ -3462,27 +3465,8 @@ void LowerFunctionLLVM::compile() {
c(asmpt.toI()),
c(missAsmptIdx),
});
});
setVal(i, res);
break;
}
}

assert(asmpt.includes(Assumption::StaticallyArgmatched));
setVal(i, withCallFrame(args, [&]() -> llvm::Value* {
return call(
NativeBuiltins::get(NativeBuiltins::Id::call),
{
c(callId),
paramCode(),
c(calli->srcIdx),
builder.CreateIntToPtr(
c(calli->cls()->rirClosure()), t::SEXP),
loadSxp(calli->env()),
c(calli->nCallArgs()),
c(asmpt.toI()),
});
}));
});
setVal(i, res);
break;
}

Expand Down
13 changes: 8 additions & 5 deletions rir/src/compiler/rir2pir/rir2pir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,16 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos,
}

if (ti.taken != (size_t)-1 &&
insert.function->optFunction->invocationCount()) {
// the reason to take the baseline version is that we only
// increment the taken type feedback while running baseline
// FIXME: refactor
insert.function->owner()->rirFunction()->invocationCount()) {
if (auto c = CallInstruction::CastCall(top())) {
// invocation count is already incremented before calling jit
c->taken =
(double)ti.taken /
(double)(insert.function->optFunction->invocationCount() -
1);
c->taken = (double)ti.taken / (double)(insert.function->owner()
->rirFunction()
->invocationCount() -
1);
}
}
break;
Expand Down
7 changes: 0 additions & 7 deletions rir/src/interpreter/interp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,10 +818,6 @@ static void supplyMissingArgs(CallContext& call, const Function* fun) {

const unsigned pir::Parameter::PIR_WARMUP =
getenv("PIR_WARMUP") ? atoi(getenv("PIR_WARMUP")) : 100;
const unsigned pir::Parameter::PIR_OPT_TIME =
getenv("PIR_OPT_TIME") ? atoi(getenv("PIR_OPT_TIME")) : 3e6;
const unsigned pir::Parameter::PIR_REOPT_TIME =
getenv("PIR_REOPT_TIME") ? atoi(getenv("PIR_REOPT_TIME")) : 5e7;
const unsigned pir::Parameter::DEOPT_ABANDON =
getenv("PIR_DEOPT_ABANDON") ? atoi(getenv("PIR_DEOPT_ABANDON")) : 12;
const unsigned pir::Parameter::PIR_OPT_BC_SIZE =
Expand Down Expand Up @@ -1137,7 +1133,6 @@ SEXP doCall(CallContext& call, bool popArgs) {
assert(result);
if (popArgs)
ostack_popn(call.passedArgs - call.suppliedArgs);
fun->registerEndInvocation();
return result;
}
default:
Expand Down Expand Up @@ -3982,14 +3977,12 @@ SEXP rirEval(SEXP what, SEXP env) {
Function* fun = table->baseline();
fun->registerInvocation();
auto res = evalRirCodeExtCaller(fun->body(), env);
fun->registerEndInvocation();
return res;
}

if (auto fun = Function::check(what)) {
fun->registerInvocation();
auto res = evalRirCodeExtCaller(fun->body(), env);
fun->registerEndInvocation();
return res;
}

Expand Down
30 changes: 11 additions & 19 deletions rir/src/interpreter/interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ inline RCNTXT* findFunctionContextFor(SEXP e) {
return nullptr;
}

inline bool RecompileHeuristic(Function* fun,
Function* funMaybeDisabled = nullptr) {
inline bool RecompileHeuristic(Function* fun, Function* disabledFun = nullptr) {

auto flags = fun->flags;
if (flags.contains(Function::MarkOpt)) {
Expand All @@ -67,31 +66,24 @@ inline bool RecompileHeuristic(Function* fun,
if (flags.contains(Function::NotOptimizable))
return false;

if (!funMaybeDisabled)
funMaybeDisabled = fun;
if (!disabledFun)
disabledFun = fun;

auto abandon =
funMaybeDisabled->deoptCount() >= pir::Parameter::DEOPT_ABANDON;

auto wt = fun->isOptimized() ? pir::Parameter::PIR_REOPT_TIME
: pir::Parameter::PIR_OPT_TIME;
if (fun->invocationCount() >= 3 && fun->invocationTime() > wt) {
REC_HOOK(recording::recordInvocationCountTimeReason(
fun->invocationCount(), 3, fun->invocationTime(), wt));

fun->clearInvocationTime();
return !abandon;
}

if (abandon || fun->isOptimized())
if (disabledFun->deoptCount() >= pir::Parameter::DEOPT_ABANDON) {
return false;
}

auto wu = pir::Parameter::PIR_WARMUP;
if (wu == 0 || fun->invocationCount() == wu) {
if (wu == 0) {
REC_HOOK(recording::recordPirWarmupReason(wu));
return true;
}

if (fun->invocationCount() % wu == 0) {
REC_HOOK(recording::recordPirWarmupReason(fun->invocationCount()));
return true;
}

return false;
}

Expand Down
47 changes: 13 additions & 34 deletions rir/src/recording.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,6 @@
namespace rir {
namespace recording {

SEXP InvocationCountTimeReason::toSEXP() const {
auto vec = PROTECT(this->CompileReasonImpl::toSEXP());

size_t i = 0;
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(count));
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(minimalCount));
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(time));
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(minimalTime));

UNPROTECT(1);
return vec;
}

void InvocationCountTimeReason::fromSEXP(SEXP sexp){
this->CompileReasonImpl::fromSEXP(sexp);

size_t i = 0;
this->count = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
this->minimalCount = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
this->time = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
this->minimalTime = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
}

SEXP PirWarmupReason::toSEXP() const {
auto vec = PROTECT(this->CompileReasonImpl::toSEXP());

Expand All @@ -57,10 +34,11 @@ SEXP PirWarmupReason::toSEXP() const {
return vec;
}

void PirWarmupReason::fromSEXP(SEXP sexp){
void PirWarmupReason::fromSEXP(SEXP sexp) {
this->CompileReasonImpl::fromSEXP(sexp);

this->invocationCount = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, 0));
this->invocationCount =
serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, 0));
}

SEXP OSRLoopReason::toSEXP() const {
Expand All @@ -72,7 +50,7 @@ SEXP OSRLoopReason::toSEXP() const {
return vec;
}

void OSRLoopReason::fromSEXP(SEXP sexp){
void OSRLoopReason::fromSEXP(SEXP sexp) {
this->CompileReasonImpl::fromSEXP(sexp);

this->loopCount = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, 0));
Expand Down Expand Up @@ -155,8 +133,8 @@ void Record::recordSpeculativeContext(const Code* code,
}
}

std::pair<size_t, FunRecording&> Record::initOrGetRecording(const SEXP cls,
const std::string& name) {
std::pair<size_t, FunRecording&>
Record::initOrGetRecording(const SEXP cls, const std::string& name) {
assert(Rf_isFunction(cls));
auto& body = *BODY(cls);

Expand Down Expand Up @@ -329,11 +307,13 @@ std::ostream& operator<<(std::ostream& out, const FunRecording& that) {
return out;
}

const char* ClosureEvent::targetName(const std::vector<FunRecording>& mapping) const {
const char*
ClosureEvent::targetName(const std::vector<FunRecording>& mapping) const {
return mapping[closureIndex].name.c_str();
}

const char* DtEvent::targetName(const std::vector<FunRecording>& mapping) const {
const char*
DtEvent::targetName(const std::vector<FunRecording>& mapping) const {
return mapping[dispatchTableIndex].name.c_str();
}

Expand Down Expand Up @@ -455,19 +435,19 @@ void CompilationEvent::print(const std::vector<FunRecording>& mapping,
out << "\n";
}
out << " ],\n opt_reasons=[\n";
if(this->compile_reasons.heuristic){
if (this->compile_reasons.heuristic) {
out << " heuristic=";
this->compile_reasons.heuristic->print(out);
out << "\n";
}

if(this->compile_reasons.condition){
if (this->compile_reasons.condition) {
out << " condition=";
this->compile_reasons.condition->print(out);
out << "\n";
}

if(this->compile_reasons.osr){
if (this->compile_reasons.osr) {
out << " osr_reason=";
this->compile_reasons.osr->print(out);
out << "\n";
Expand Down Expand Up @@ -681,7 +661,6 @@ void InvocationEvent::print(const std::vector<FunRecording>& mapping,
out << " }";
}


std::string getEnvironmentName(SEXP env) {
if (env == R_GlobalEnv) {
return GLOBAL_ENV_NAME;
Expand Down
28 changes: 0 additions & 28 deletions rir/src/recording.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,34 +99,6 @@ struct MarkOptReason : public CompileReasonImpl<MarkOptReason, 0> {
virtual ~MarkOptReason() = default;
};

struct InvocationCountTimeReason
: public CompileReasonImpl<InvocationCountTimeReason, 4> {
static constexpr const char* NAME = "InvocationCountTime";
virtual ~InvocationCountTimeReason() = default;

InvocationCountTimeReason(size_t count, size_t minimalCount,
unsigned long time, unsigned long minimalTime)
: count(count), minimalCount(minimalCount), time(time),
minimalTime(minimalTime) {}

InvocationCountTimeReason() {}

size_t count = 0;
size_t minimalCount = 0;
unsigned long time = 0;
unsigned long minimalTime = 0;

virtual SEXP toSEXP() const override;
virtual void fromSEXP(SEXP sexp) override;

virtual void print(std::ostream& out) const override {
this->CompileReasonImpl::print(out);

out << ", count=" << count << ", minimalCount=" << minimalCount
<< ", time=" << time << ", minimalTime=" << minimalTime;
}
};

struct PirWarmupReason : public CompileReasonImpl<PirWarmupReason, 1> {
static constexpr const char* NAME = "PirWarmupReason";
virtual ~PirWarmupReason() = default;
Expand Down
Loading

0 comments on commit 2d36a73

Please sign in to comment.