Skip to content

Commit

Permalink
add --dotcmds
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamilcuk committed Nov 29, 2024
1 parent d56fe18 commit 53b25af
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 40 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ L_lib:
xdot:
L_bash_profile analyze profile.txt --dot profile.dot --dotlimit 3
xdot profile.dot
xdot2 xdot_L_argparse:
L_bash_profile analyze profile.txt --dot profile.dot --dotfunction L_argparse --dotcmds --dotlimit 6
xdot profile.dot
xdot_L_asa_has:
L_bash_profile analyze profile.txt --dot profile.dot --dotcmds --dotfunction L_asa_has
xdot profile.do
snakeviz:
L_bash_profile analyze profile.txt --pstats profile.pstats
snakeviz profile.pstats
196 changes: 156 additions & 40 deletions src/L_bash_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import hashlib
import io
import marshal
import multiprocessing
Expand All @@ -16,7 +17,7 @@
from dataclasses import astuple, dataclass, field
from datetime import timedelta
from functools import cached_property
from typing import Iterable, List, Optional, TypeVar
from typing import Iterable, List, Optional, TypeVar, Union, cast

import click
import clickdc
Expand All @@ -26,7 +27,15 @@
###############################################################################

T = TypeVar("T")
V1 = TypeVar("V1")
V = TypeVar("V")


def md5sum(data: str) -> str:
return hashlib.md5(data.encode("utf-8")).hexdigest()


def clamp(n, minn, maxn):
return max(min(maxn, n), minn)


def dots_trim(v: str, width: int = 50) -> str:
Expand Down Expand Up @@ -54,7 +63,7 @@ def maybe_take_n(generator: Iterable[T], n: Optional[int]) -> Iterable[T]:
return generator


def getdefault(e: list[T], idx: int, default: V1 = None) -> T | V1:
def getdefault(e: list[T], idx: int, default: V = None) -> T | V:
try:
return e[idx]
except (KeyError, IndexError):
Expand All @@ -72,13 +81,29 @@ def fmtus(us: int) -> str:
###############################################################################


@dataclass
class RedGreenHue:
elems: int

def color(self, idx: int) -> Optional[str]:
if self.elems == 0:
return None
val = int(0xFF * 2 / self.elems * idx)
color = "#%02x%02x%02x" % (
0xFF - val if 0 <= val < 0xFF else 0x00,
val - 0xFF if 0xFF <= val else 0x00,
0x00,
)
return color


@dataclass(frozen=True, order=True)
class FunctionKey:
"""To uniquely identify a function."""

filename: str = "~"
lineno: int = 0
funcname: str = ""
funcname: str = "~"

def __str__(self):
return f"{self.filename}:{self.lineno}({self.funcname})"
Expand Down Expand Up @@ -116,14 +141,19 @@ def sum_spent_us(self):
return sum(x.spent_us for x in self)


@dataclass
class CmdStats:
cmd: str = ""
callcount: int = 0
totaltime: int = 0


@dataclass
class CallgraphNode:
"""Single node in the callgraph tree"""

function: FunctionKey = field(default_factory=FunctionKey)
"""An index to the function to unique identify the node"""
parent: Optional[CallgraphNode] = None
"""guess"""
callcount: int = 0
"""How many times this function was called from the parent?"""
recursivecallcount: int = 0
Expand All @@ -132,13 +162,33 @@ class CallgraphNode:
"""How much time was spent in this node excluding subcalls"""
childtime: int = 0
"""How much time was spent in this node only in subcalls"""
childs: dict[FunctionKey, CallgraphNode] = field(default_factory=dict)
children: dict[FunctionKey, CallgraphNode] = field(default_factory=dict)
"""functions called by this function"""
cmdstats: dict[str, CmdStats] = field(default_factory=dict)
"""The commands executed by the function"""

def add_record(self, r: Record):
s = self.cmdstats.setdefault(r.cmd, CmdStats(r.cmd))
s.callcount += 1
s.totaltime += r.spent_us

@property
def totaltime(self):
return self.inlinetime + self.childtime

def merge(self, o: CallgraphNode):
assert self.function == o.function
self.callcount += o.callcount
self.recursivecallcount += o.recursivecallcount
self.inlinetime += o.inlinetime
self.childtime += o.childtime
for k, v in o.cmdstats.items():
s = self.cmdstats.setdefault(k, CmdStats(k))
s.callcount += v.callcount
s.totaltime += v.totaltime
for k, v in o.children.items():
self.children.setdefault(k, CallgraphNode(k)).merge(v)


@dataclass
class Pstatsnocallers:
Expand Down Expand Up @@ -184,6 +234,13 @@ class AnalyzeArgs:
pstats: Optional[str] = clickdc.option(
help="TODO: Generate python pstats file just like python cProfile file"
)
dotfunction: Optional[str] = clickdc.option(
help="""
The callgraph is generated with functions matching given regex as roots.
Implies --filterfunction
""",
)
dotcmds: Optional[bool] = clickdc.option(help="Add commands to dot graph nodes")
profilefile: Optional[io.TextIOBase] = clickdc.argument(
type=click.File(errors="replace", lazy=True), required=False
)
Expand Down Expand Up @@ -286,7 +343,6 @@ class Analyzer:
records: list[Record] = field(default_factory=list)
functions: dict[str, FunctionData] = field(default_factory=dict)
commands: dict[str, CommandData] = field(default_factory=dict)
callgraph: Optional[CallgraphNode] = None

def run(self):
with self.timeit(f"Reading {self.args.profilefile}"):
Expand Down Expand Up @@ -424,51 +480,111 @@ def print_top_longest_functions(self):
print(tabulate(longest_functions, headers="keys"))
print()

@cached_property
def get_callgraph(self):
# Create callgraph
if self.callgraph is None:
self.callgraph = CallgraphNode()
prevlevel = 0
for record in self.records:
call = self.callgraph
for t in reversed(record.trace):
call.childtime += record.spent_us
call = call.childs.setdefault(t, CallgraphNode(t, call))
if record.level > prevlevel:
call.callcount += 1
if len(record.trace) > 2 and record.trace[1] == record.trace[0]:
call.recursivecallcount += 1
prevlevel = record.level
call.inlinetime += record.spent_us
return self.callgraph
callgraph = CallgraphNode()
prevlevel = 0
for record in self.records:
call = callgraph
for t in reversed(record.trace):
call.childtime += record.spent_us
call = call.children.setdefault(t, CallgraphNode(t))
if record.level > prevlevel:
call.callcount += 1
if len(record.trace) > 2 and record.trace[1] == record.trace[0]:
call.recursivecallcount += 1
else:
if call.inlinetime != 0: # trickery!
call.add_record(record)
call.inlinetime += record.spent_us
prevlevel = record.level

if self.args.dotfunction:
# Filter dotfunction by merging the trees from the top node that matches the regex.
dotfunctionrgx = re.compile(self.args.dotfunction)
newcallgraph = CallgraphNode()

def walk(node: CallgraphNode):
if dotfunctionrgx.match(node.function.funcname):
newcallgraph.childtime += node.totaltime
newcallgraph.children.setdefault(
node.function, CallgraphNode(node.function)
).merge(node)
else:
for c in node.children.values():
walk(c)

walk(callgraph)
callgraph = newcallgraph

return callgraph

def extract_callgraph(self, outputfile: str):
callgraph = self.get_callgraph()
dot = Digraph()

def callgraph_printer(parents: str, x: CallgraphNode, color: str = "#ffffff"):
def callgraph_printer(
parents: str, x: CallgraphNode, color: Optional[str] = None
):
me = f"{parents}_{x.function.funcname}"
dot.node(
me,
f"{x.function.funcname}\ncalls={x.callcount} total={x.totaltime:_}us\ninline={x.inlinetime:_}us childs={x.childtime:_}us",
"\n".join(
[
f"{x.function.funcname}",
(
f"calls={x.callcount} total={x.totaltime:_}us percall={int(x.totaltime / (x.callcount or 1)):_}us"
if x.callcount
else f"total={x.totaltime:_}us"
),
" ".join(
([f"inline={x.inlinetime:_}us"] if x.inlinetime else [])
+ ([f"childs={x.childtime:_}us"] if x.childtime else [])
),
]
),
color=color,
)
inc = 255 / len(x.childs) if x.childs else 0
for idx, child in enumerate(
nodechildren = list(x.children.values())
children: list[Union[CallgraphNode, CmdStats]] = cast(
list[Union[CallgraphNode, CmdStats]], nodechildren
)
if self.args.dotcmds:
children.extend(list(x.cmdstats.values()))
children = list(
maybe_take_n(
sorted(x.childs.values(), key=lambda x: -x.totaltime),
sorted(children, key=lambda x: -x.totaltime),
self.args.dotcallgraphlimit,
)
):
color = "#%02x%02x%02x" % (0xFF - int(inc * idx), 0x00, 0x00)
dot.edge(
me,
f"{me}_{child.function.funcname}",
color=color,
)
callgraph_printer(me, child, color)
)
redgreenhue = RedGreenHue(len(children))
for idx, child in enumerate(children):
# print(val, inc, idx, len(x.childs), color)
color = redgreenhue.color(idx)
if isinstance(child, CallgraphNode):
dot.edge(
me,
f"{me}_{child.function.funcname}",
color=color,
)
callgraph_printer(me, child, color)
else:
childname = f"{me}_{md5sum(child.cmd)}"
dot.edge(me, childname, color=color)
dot.node(
childname,
"\n".join(
[
repr(child.cmd),
f"calls={child.callcount} spent={child.totaltime:_}us",
f"percall={int(child.totaltime / child.callcount):_}us",
]
),
color=color,
shape="box",
)

callgraph_printer("", callgraph)
callgraph_printer("", self.get_callgraph)
with open(outputfile, "w") as f:
print(dot.source, file=f)

Expand All @@ -478,7 +594,7 @@ def create_python_pstats_file(self, file: str):
https://github.com/python/cpython/blob/main/Lib/cProfile.py#L63
"""
# Extract function calls
callgraph = self.get_callgraph()
callgraph = self.get_callgraph
statsroot: dict[FunctionKey, Pstats] = {}

def fillstats(node: CallgraphNode):
Expand All @@ -487,7 +603,7 @@ def fillstats(node: CallgraphNode):
nodestats.primitivecallcount += node.callcount - node.recursivecallcount
nodestats.totaltime += us2s(node.totaltime)
nodestats.inlinetime += us2s(node.inlinetime)
for child in node.childs.values():
for child in node.children.values():
fillstats(child)
childstats = statsroot.setdefault(
child.function, Pstats()
Expand Down

0 comments on commit 53b25af

Please sign in to comment.