Chudnovsky RAM optimization in python3
I search RAM optimizations for a chudnovsky pi approximation. My code is very fast (for 1e9, the program takes 11m34s, for 10e6, 0.21 s, and for 10e8, 50s) but with 16 giga of RAM, I'm quickly out of memory for largest calcs. If you can find memory optimizations without too much speed impact...
There is the code (sorry for the english comments, I'm not a native speaker) :
import gmpy2
from gmpy2 import mpz, isqrt
from multiprocessing import Pool, cpu_count
from math import ceil
import time
# chudnovsky constants (mpz for avoid conversions)
A = mpz(13591409)
B = mpz(545140134)
C = mpz(640320)
C3 = (C*C*C) // 24
# optimized binary splitting
def bs(a, b):
A_loc = A
B_loc = B
C3_loc = C3
if b - a == 1:
if a == 0:
return mpz(1), mpz(1), mpz(A_loc)
a = mpz(a)
a3 = a * a * a
P = (6*a - 5) * (2*a - 1) * (6*a - 1)
Q = a3 * C3_loc
T = P * (A_loc + B_loc * a)
if int(a) & 1:
T = -T
return P, Q, T
m = (a + b) >> 1
P1, Q1, T1 = bs(a, m)
P2, Q2, T2 = bs(m, b)
P = P1 * P2
Q = Q1 * Q2
T = Q2 * T1 + P1 * T2
return P, Q, T
# multiprocessing worker
def worker(args):
return bs(*args)
def compute_pi(digits):
# number of necessary terms
N = int(ceil(digits / 14.181647462))
cores = cpu_count()
# splitting into balanced blocks
step = N // cores
ranges = []
start = 0
for i in range(cores):
end = start + step
if i == cores - 1:
end = N
ranges.append((start, end))
start = end
# parallel calculation
with Pool(cores) as p:
results = p.map(worker, ranges)
# balanced merge
while len(results) > 1:
new_results = []
for i in range(0, len(results) - 1, 2):
P1, Q1, T1 = results[i]
P2, Q2, T2 = results[i+1]
P = P1 * P2
Q = Q1 * Q2
T = Q2 * T1 + P1 * T2
new_results.append((P, Q, T))
if len(results) % 2:
new_results.append(results[-1])
results = new_results
P, Q, T = results[0]
# calcul entier de sqrt(10005)*10^digits
one = mpz(10) ** digits
sqrtC = isqrt(mpz(10005) * one * one)
# calcul final
pi = (mpz(426880) * sqrtC * Q) // T
s = pi.digits()
return s[0] + "." + s[1:]
# --- main pipeline ---
digits = int(input("Décimales : "))
t0 = time.time()
pi = compute_pi(digits)
t1 = time.time()
print("Time :", t1 - t0)
print(pi[:200])
with open("pi.txt", "w") as f:
f.write(pi)