## Tracing JIT and real world Python ### aka: what we can learn from PyPy --- # Motivation - CPython's JIT has a lot in common with PyPy - "Optimize for PyPy" ==> my job for ~7 years - Real world code != pyperformance - Challenges & lesson learned --- # Assumption - The JIT revolutionizes performance characteristics - CPython perf will look like PyPy's - ==> Some results are surprising --- # Context - High Frequency Trading firm (sport betting) * every ms counts - Python 2.7 - Multi process system: stateful server + dispatcher + stateless workers (long running processes) - "big" messages passed around --- ## PyPy JIT 101 - Interpreter written in RPython - RPython -> `*.c` -> gcc -> `./pypy` - RPython -> "jit codegen" -> "jitcodes" (~RPython IR) - RPython jitcodes ~= CPython microops * Slightly higher level than C - Tracing means *executing jitcodes* * we have an interpreter for that, super slow --- ### Problem 1: trace blockers ```python [1-100] def get_pi(): """ Compute an approximation of PI using the Leibniz series """ tol = 0.0000001 pi_approx = 0.0 k = 0 term = 1.0 # Initial term to enter the loop while abs(term) > tol: if k % 2 == 0: term = 1.0 / (2 * k + 1) else: term = -1 * 1.0 / (2 * k + 1) pi_approx = pi_approx + term k = k + 1 return 4 * pi_approx ``` --- ### Problem 1: trace blockers ```python [18] def get_pi(): """ Compute an approximation of PI using the Leibniz series """ tol = 0.0000001 pi_approx = 0.0 k = 0 term = 1.0 # Initial term to enter the loop while abs(term) > tol: if k % 2 == 0: term = 1.0 / (2 * k + 1) else: term = -1 * 1.0 / (2 * k + 1) pi_approx = pi_approx + term k = k + 1 hic_sunt_leones() # the JIT cannot enter here return 4 * pi_approx ``` --- ## Hic sunt leones ``` def empty(): pass # the JIT cannot enter here def hic_sunt_leones(): pypyjit.residual_call(empty) ``` - Any call to non-traceable function - C builtins, C extensions - (for PyPy): RPython instructions not understood by the JIT --- ``` ❯ python3.13 pi.py 2.1712 secs, pi = 3.1415928535897395 ❯ pypy pi.py 0.0518 secs, pi = 3.1415928535897395 ❯ # with "hic_sunt_leones()" ❯ pypy pi.py 1.1808 secs, pi = 3.1415928535897395 ``` --- ### Problem 2: data driven control flow ``` def fn(v=None, a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None): "Random nonsense computation generated by ChatGPT" if v is None: v = 0 if a is None: a = 1.25 if b is None: b = -0.75 [...] y = a * v + b if y < f: y = f [...] return y def main(): [...] for row in DATA: acc += fn(*row) ``` --- ### Problem 2: data driven control flow ``` ❯ python3.13 data_driven.py 0.1274 secs ❯ pypy --jit off data_driven.py 0.2953 secs ❯ pypy data_driven.py 1.6414 secs ``` --- ## Exponential tracing - Every combination of "`None`ness" must be compiled separately ``` ❯ PYPYLOG=jit-summary:- pypy jit_explosion.py 1.6387 secs [a625ea04910] {jit-summary ... Total # of loops: 11 Total # of bridges: 527 ... [a625ea507bc] jit-summary} ``` --- ## Exponential tracing - Mitigation: branchless code ``` if x < 0: x = 100 # ===> x = (x < 0)*100 + (x >= 0)*x ``` - Ugly, unreadable, not always possible - Never found a good solution - Happens quite a lot - *Fundamental problem of tracing JITs*? --- ### Problem 3: generators (and async?) ``` def count_triples_loop(P): """ Counts how many integer right triangles (Pythagorean triples) have perimeter <= P. """ m_max = int(math.isqrt(2 * P)) # loose but safe upper bound for m count = 0 for m in range(1, m_max + 1): for n in range(1, m_max + 1): if ((m - n) & 1) and math.gcd(m, n) == 1: p0 = 2 * m * (m + n) # a+b+c if p0 > P: continue count += P // p0 return count ``` --- ### Problem 3: generators (and async?) ``` def range_product(a, b): for i in range(*a): for j in range(*b): yield i, j def count_triples_gen(P): m_max = int((math.isqrt(2 * P))) count = 0 for m, n in range_product((1, m_max + 1), (1, m_max + 1)): if ((m - n) & 1) and math.gcd(m, n) == 1: p0 = 2 * m * (m + n) # a+b+c if p0 > P: continue count += P // p0 return count ``` --- ### Problem 3: generators (and async?) ``` class RangeProductIter: def __init__(self, a, b): self.i, self.n = a self.j, self.m = b def __iter__(self): return self def __next__(self): if self.i >= self.n: raise StopIteration value = (self.i, self.j) self.j += 1 if self.j >= self.m: self.j = 0 self.i += 1 return value ``` --- ### Problem 3: generators (and async?) ``` ❯ python3.13 pythagorean.py loop: 0.4560 secs (1x) gen: 0.5884 secs (1.29x) iter: 1.0126 secs (2.22x) ❯ pypy pythagorean.py loop: 0.1199 secs (1x) gen: 0.1550 secs (1.29x) iter: 0.1264 secs (1.05x) ``` - Generators force to create a frame - The JIT cannot see "through" generators - In real code, much worse slowdowns --- ## Other misc problems - Tooling, profilers - Warmup - Performance instability (link to paper?) - Long tail of jitting ``` for n in itertools.count(): job = accept_job() do(job) if n > 12345: pypyjit.disable() ``` --- ## Bonus slides ### (Avoid) allocations is all your need --- ### Task - Compute baricenter of a series of triangles serialized according to a binary protocol - Simulate protobuf, capnproto, etc. ``` struct Point { double x; double y; }; struct Triangle { Point a; Point b; Point c; }; ``` --- ### Bare loop ``` def read_loop(): fmt = 'dddddd' size = struct.calcsize(fmt) tot_x = 0 tot_y = 0 n = 0 with open('poly.bin', 'rb') as f: while True: buf = f.read(size) if not buf: break points = struct.unpack_from(fmt, buf) ax, ay, bx, by, cx, cy = points tot_x += (ax + bx + cx) tot_y += (ay + by + cy) n += 1 print(n) x = tot_x/n y = tot_y/n return x, y ``` --- ### Schema-aware protocol ``` class Triangle: def __init__(self, buf, offset): self.buf = buf self.offset = offset @property def a(self): return Point(self.buf, 0) [...] class Point: def __init__(self, buf, offset): self.buf = buf self.offset = offset @property def x(self): return struct.unpack_from('d', self.buf, self.offset)[0] ``` --- ### Schema-aware protocol ``` while True: buf = f.read(size) if not buf: break t = Triangle(buf, 0) tot_x += t.a.x + t.b.x + t.c.x tot_y += t.a.y + t.b.y + t.c.y n += 1 ``` --- ``` ❯ python3.13 readpoly.py read_loop: 0.5444 secs read_proto: 3.0307 secs ❯ pypy readpoly.py read_loop: 0.2945 secs read_proto: 0.1183 secs ```