Skip to content

Commit 44f66cb

Browse files
committed
Optimize fft performance (pyscf#2276)
* Use scipy.fft module by default * Optimize cache utilization of fft * lint error * fix flake F841 warning * disable flake8
1 parent 2ba39a3 commit 44f66cb

File tree

3 files changed

+63
-56
lines changed

3 files changed

+63
-56
lines changed

.github/workflows/run_tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export PYTHONPATH=$(pwd):$PYTHONPATH
44
ulimit -s 20000
55

66
mkdir -p pyscftmpdir
7-
echo 'pbc_tools_pbc_fft_engine = "NUMPY"' > .pyscf_conf.py
7+
echo 'pbc_tools_pbc_fft_engine = "NUMPY+BLAS"' > .pyscf_conf.py
88
echo "dftd3_DFTD3PATH = './pyscf/lib/deps/lib'" >> .pyscf_conf.py
99
echo "scf_hf_SCF_mute_chkfile = True" >> .pyscf_conf.py
1010
echo 'TMPDIR = "./pyscftmpdir"' >> .pyscf_conf.py

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ license = { text = "Apache-2.0" }
3636

3737
dependencies = [
3838
'numpy>=1.13,!=1.16,!=1.17',
39-
'scipy!=1.5.0,!=1.5.1',
39+
'scipy>=1.6.0',
4040
'h5py>=2.7',
4141
'setuptools',
4242
]

pyscf/pbc/tools/pbc.py

+61-54
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,63 @@
1616
import warnings
1717
import ctypes
1818
import numpy as np
19+
import scipy
1920
import scipy.linalg
2021
from pyscf import lib
2122
from pyscf.lib import logger
2223
from pyscf.gto import ATM_SLOTS, BAS_SLOTS, ATOM_OF, PTR_COORD
2324
from pyscf.pbc.lib.kpts_helper import get_kconserv, get_kconserv3 # noqa
2425
from pyscf import __config__
2526

26-
FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'BLAS')
27+
FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'NUMPY+BLAS')
2728

2829
def _fftn_blas(f, mesh):
29-
Gx = np.fft.fftfreq(mesh[0])
30-
Gy = np.fft.fftfreq(mesh[1])
31-
Gz = np.fft.fftfreq(mesh[2])
32-
expRGx = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[0]), Gx))
33-
expRGy = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[1]), Gy))
34-
expRGz = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[2]), Gz))
35-
out = np.empty(f.shape, dtype=np.complex128)
36-
buf = np.empty(mesh, dtype=np.complex128)
37-
for i, fi in enumerate(f):
38-
buf[:] = fi.reshape(mesh)
39-
g = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, c=out[i].reshape(-1,mesh[0]))
40-
g = lib.dot(g.reshape(mesh[1],-1).T, expRGy, c=buf.reshape(-1,mesh[1]))
41-
g = lib.dot(g.reshape(mesh[2],-1).T, expRGz, c=out[i].reshape(-1,mesh[2]))
42-
return out.reshape(-1, *mesh)
30+
assert f.ndim == 4
31+
mx, my, mz = mesh
32+
expRGx = np.exp(-2j*np.pi*np.arange(mx)[:,None] * np.fft.fftfreq(mx))
33+
expRGy = np.exp(-2j*np.pi*np.arange(my)[:,None] * np.fft.fftfreq(my))
34+
expRGz = np.exp(-2j*np.pi*np.arange(mz)[:,None] * np.fft.fftfreq(mz))
35+
blksize = max(int(1e5 / (mx * my * mz)), 8) * 4
36+
n = f.shape[0]
37+
out = np.empty((n, mx*my*mz), dtype=np.complex128)
38+
buf = np.empty((blksize, mx*my*mz), dtype=np.complex128)
39+
for i0, i1 in lib.prange(0, n, blksize):
40+
ni = i1 - i0
41+
buf1 = buf[:ni]
42+
out1 = out[i0:i1]
43+
g = lib.transpose(f[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni))
44+
g = lib.dot(g.reshape(mx,-1).T, expRGx, c=out1.reshape(-1,mx))
45+
g = lib.dot(g.reshape(my,-1).T, expRGy, c=buf1.reshape(-1,my))
46+
g = lib.dot(g.reshape(mz,-1).T, expRGz, c=out1.reshape(-1,mz))
47+
return out.reshape(n, *mesh)
4348

4449
def _ifftn_blas(g, mesh):
45-
Gx = np.fft.fftfreq(mesh[0])
46-
Gy = np.fft.fftfreq(mesh[1])
47-
Gz = np.fft.fftfreq(mesh[2])
48-
expRGx = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[0]), Gx))
49-
expRGy = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[1]), Gy))
50-
expRGz = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[2]), Gz))
51-
out = np.empty(g.shape, dtype=np.complex128)
52-
buf = np.empty(mesh, dtype=np.complex128)
53-
for i, gi in enumerate(g):
54-
buf[:] = gi.reshape(mesh)
55-
f = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, 1./mesh[0], c=out[i].reshape(-1,mesh[0]))
56-
f = lib.dot(f.reshape(mesh[1],-1).T, expRGy, 1./mesh[1], c=buf.reshape(-1,mesh[1]))
57-
f = lib.dot(f.reshape(mesh[2],-1).T, expRGz, 1./mesh[2], c=out[i].reshape(-1,mesh[2]))
58-
return out.reshape(-1, *mesh)
50+
assert g.ndim == 4
51+
mx, my, mz = mesh
52+
expRGx = np.exp(2j*np.pi*np.fft.fftfreq(mx)[:,None] * np.arange(mx))
53+
expRGy = np.exp(2j*np.pi*np.fft.fftfreq(my)[:,None] * np.arange(my))
54+
expRGz = np.exp(2j*np.pi*np.fft.fftfreq(mz)[:,None] * np.arange(mz))
55+
blksize = max(int(1e5 / (mx * my * mz)), 8) * 4
56+
n = g.shape[0]
57+
out = np.empty((n, mx*my*mz), dtype=np.complex128)
58+
buf = np.empty((blksize, mx*my*mz), dtype=np.complex128)
59+
for i0, i1 in lib.prange(0, n, blksize):
60+
ni = i1 - i0
61+
buf1 = buf[:ni]
62+
out1 = out[i0:i1]
63+
f = lib.transpose(g[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni))
64+
f = lib.dot(f.reshape(mx,-1).T, expRGx, 1./mx, c=out1.reshape(-1,mx))
65+
f = lib.dot(f.reshape(my,-1).T, expRGy, 1./my, c=buf1.reshape(-1,my))
66+
f = lib.dot(f.reshape(mz,-1).T, expRGz, 1./mz, c=out1.reshape(-1,mz))
67+
return out.reshape(n, *mesh)
68+
69+
nproc = lib.num_threads()
70+
71+
def _fftn_wrapper(a): # noqa
72+
return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc)
73+
74+
def _ifftn_wrapper(a): # noqa
75+
return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc)
5976

6077
if FFT_ENGINE == 'FFTW':
6178
try:
@@ -88,60 +105,50 @@ def _complex_fftn_fftw(f, mesh, func):
88105
ctypes.c_int(rank))
89106
return out
90107

91-
def _fftn_wrapper(a):
108+
def _fftn_wrapper(a): # noqa
92109
mesh = a.shape[1:]
93110
return _complex_fftn_fftw(a, mesh, 'fft')
94-
def _ifftn_wrapper(a):
111+
def _ifftn_wrapper(a): # noqa
95112
mesh = a.shape[1:]
96113
return _complex_fftn_fftw(a, mesh, 'ifft')
97114

98115
elif FFT_ENGINE == 'PYFFTW':
99-
# pyfftw is slower than np.fft in most cases
116+
# Note: pyfftw is likely slower than scipy.fft in multi-threading environments
100117
try:
101118
import pyfftw
119+
pyfftw.config.PLANNER_EFFORT = 'FFTW_MEASURE'
102120
pyfftw.interfaces.cache.enable()
103-
nproc = lib.num_threads()
104-
def _fftn_wrapper(a):
121+
def _fftn_wrapper(a): # noqa
105122
return pyfftw.interfaces.numpy_fft.fftn(a, axes=(1,2,3), threads=nproc)
106-
def _ifftn_wrapper(a):
123+
def _ifftn_wrapper(a): # noqa
107124
return pyfftw.interfaces.numpy_fft.ifftn(a, axes=(1,2,3), threads=nproc)
108125
except ImportError:
109-
def _fftn_wrapper(a):
110-
return np.fft.fftn(a, axes=(1,2,3))
111-
def _ifftn_wrapper(a):
112-
return np.fft.ifftn(a, axes=(1,2,3))
113-
114-
elif FFT_ENGINE == 'NUMPY':
115-
def _fftn_wrapper(a):
116-
return np.fft.fftn(a, axes=(1,2,3))
117-
def _ifftn_wrapper(a):
118-
return np.fft.ifftn(a, axes=(1,2,3))
126+
print('PyFFTW not installed. SciPy fft module will be used.')
119127

120128
elif FFT_ENGINE == 'NUMPY+BLAS':
121129
_EXCLUDE = [17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79,
122130
83, 89, 97,101,103,107,109,113,127,131,137,139,149,151,157,163,
123131
167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,
124132
257,263,269,271,277,281,283,293]
125-
_EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE] + [n*3 for n in _EXCLUDE])
126-
def _fftn_wrapper(a):
133+
_EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE[:30]] + [n*3 for n in _EXCLUDE[:20]])
134+
def _fftn_wrapper(a): # noqa
127135
mesh = a.shape[1:]
128136
if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE:
129137
return _fftn_blas(a, mesh)
130138
else:
131-
return np.fft.fftn(a, axes=(1,2,3))
132-
def _ifftn_wrapper(a):
139+
return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc)
140+
def _ifftn_wrapper(a): # noqa
133141
mesh = a.shape[1:]
134142
if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE:
135143
return _ifftn_blas(a, mesh)
136144
else:
137-
return np.fft.ifftn(a, axes=(1,2,3))
145+
return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc)
138146

139-
#?elif: # 'FFTW+BLAS'
140-
else: # 'BLAS'
141-
def _fftn_wrapper(a):
147+
elif FFT_ENGINE == 'BLAS':
148+
def _fftn_wrapper(a): # noqa
142149
mesh = a.shape[1:]
143150
return _fftn_blas(a, mesh)
144-
def _ifftn_wrapper(a):
151+
def _ifftn_wrapper(a): # noqa
145152
mesh = a.shape[1:]
146153
return _ifftn_blas(a, mesh)
147154

0 commit comments

Comments
 (0)