|
16 | 16 | import warnings
|
17 | 17 | import ctypes
|
18 | 18 | import numpy as np
|
| 19 | +import scipy |
19 | 20 | import scipy.linalg
|
20 | 21 | from pyscf import lib
|
21 | 22 | from pyscf.lib import logger
|
22 | 23 | from pyscf.gto import ATM_SLOTS, BAS_SLOTS, ATOM_OF, PTR_COORD
|
23 | 24 | from pyscf.pbc.lib.kpts_helper import get_kconserv, get_kconserv3 # noqa
|
24 | 25 | from pyscf import __config__
|
25 | 26 |
|
26 |
| -FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'BLAS') |
| 27 | +FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'NUMPY+BLAS') |
27 | 28 |
|
28 | 29 | def _fftn_blas(f, mesh):
|
29 |
| - Gx = np.fft.fftfreq(mesh[0]) |
30 |
| - Gy = np.fft.fftfreq(mesh[1]) |
31 |
| - Gz = np.fft.fftfreq(mesh[2]) |
32 |
| - expRGx = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[0]), Gx)) |
33 |
| - expRGy = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[1]), Gy)) |
34 |
| - expRGz = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[2]), Gz)) |
35 |
| - out = np.empty(f.shape, dtype=np.complex128) |
36 |
| - buf = np.empty(mesh, dtype=np.complex128) |
37 |
| - for i, fi in enumerate(f): |
38 |
| - buf[:] = fi.reshape(mesh) |
39 |
| - g = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, c=out[i].reshape(-1,mesh[0])) |
40 |
| - g = lib.dot(g.reshape(mesh[1],-1).T, expRGy, c=buf.reshape(-1,mesh[1])) |
41 |
| - g = lib.dot(g.reshape(mesh[2],-1).T, expRGz, c=out[i].reshape(-1,mesh[2])) |
42 |
| - return out.reshape(-1, *mesh) |
| 30 | + assert f.ndim == 4 |
| 31 | + mx, my, mz = mesh |
| 32 | + expRGx = np.exp(-2j*np.pi*np.arange(mx)[:,None] * np.fft.fftfreq(mx)) |
| 33 | + expRGy = np.exp(-2j*np.pi*np.arange(my)[:,None] * np.fft.fftfreq(my)) |
| 34 | + expRGz = np.exp(-2j*np.pi*np.arange(mz)[:,None] * np.fft.fftfreq(mz)) |
| 35 | + blksize = max(int(1e5 / (mx * my * mz)), 8) * 4 |
| 36 | + n = f.shape[0] |
| 37 | + out = np.empty((n, mx*my*mz), dtype=np.complex128) |
| 38 | + buf = np.empty((blksize, mx*my*mz), dtype=np.complex128) |
| 39 | + for i0, i1 in lib.prange(0, n, blksize): |
| 40 | + ni = i1 - i0 |
| 41 | + buf1 = buf[:ni] |
| 42 | + out1 = out[i0:i1] |
| 43 | + g = lib.transpose(f[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni)) |
| 44 | + g = lib.dot(g.reshape(mx,-1).T, expRGx, c=out1.reshape(-1,mx)) |
| 45 | + g = lib.dot(g.reshape(my,-1).T, expRGy, c=buf1.reshape(-1,my)) |
| 46 | + g = lib.dot(g.reshape(mz,-1).T, expRGz, c=out1.reshape(-1,mz)) |
| 47 | + return out.reshape(n, *mesh) |
43 | 48 |
|
44 | 49 | def _ifftn_blas(g, mesh):
|
45 |
| - Gx = np.fft.fftfreq(mesh[0]) |
46 |
| - Gy = np.fft.fftfreq(mesh[1]) |
47 |
| - Gz = np.fft.fftfreq(mesh[2]) |
48 |
| - expRGx = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[0]), Gx)) |
49 |
| - expRGy = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[1]), Gy)) |
50 |
| - expRGz = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[2]), Gz)) |
51 |
| - out = np.empty(g.shape, dtype=np.complex128) |
52 |
| - buf = np.empty(mesh, dtype=np.complex128) |
53 |
| - for i, gi in enumerate(g): |
54 |
| - buf[:] = gi.reshape(mesh) |
55 |
| - f = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, 1./mesh[0], c=out[i].reshape(-1,mesh[0])) |
56 |
| - f = lib.dot(f.reshape(mesh[1],-1).T, expRGy, 1./mesh[1], c=buf.reshape(-1,mesh[1])) |
57 |
| - f = lib.dot(f.reshape(mesh[2],-1).T, expRGz, 1./mesh[2], c=out[i].reshape(-1,mesh[2])) |
58 |
| - return out.reshape(-1, *mesh) |
| 50 | + assert g.ndim == 4 |
| 51 | + mx, my, mz = mesh |
| 52 | + expRGx = np.exp(2j*np.pi*np.fft.fftfreq(mx)[:,None] * np.arange(mx)) |
| 53 | + expRGy = np.exp(2j*np.pi*np.fft.fftfreq(my)[:,None] * np.arange(my)) |
| 54 | + expRGz = np.exp(2j*np.pi*np.fft.fftfreq(mz)[:,None] * np.arange(mz)) |
| 55 | + blksize = max(int(1e5 / (mx * my * mz)), 8) * 4 |
| 56 | + n = g.shape[0] |
| 57 | + out = np.empty((n, mx*my*mz), dtype=np.complex128) |
| 58 | + buf = np.empty((blksize, mx*my*mz), dtype=np.complex128) |
| 59 | + for i0, i1 in lib.prange(0, n, blksize): |
| 60 | + ni = i1 - i0 |
| 61 | + buf1 = buf[:ni] |
| 62 | + out1 = out[i0:i1] |
| 63 | + f = lib.transpose(g[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni)) |
| 64 | + f = lib.dot(f.reshape(mx,-1).T, expRGx, 1./mx, c=out1.reshape(-1,mx)) |
| 65 | + f = lib.dot(f.reshape(my,-1).T, expRGy, 1./my, c=buf1.reshape(-1,my)) |
| 66 | + f = lib.dot(f.reshape(mz,-1).T, expRGz, 1./mz, c=out1.reshape(-1,mz)) |
| 67 | + return out.reshape(n, *mesh) |
| 68 | + |
| 69 | +nproc = lib.num_threads() |
| 70 | + |
| 71 | +def _fftn_wrapper(a): # noqa |
| 72 | + return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc) |
| 73 | + |
| 74 | +def _ifftn_wrapper(a): # noqa |
| 75 | + return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc) |
59 | 76 |
|
60 | 77 | if FFT_ENGINE == 'FFTW':
|
61 | 78 | try:
|
@@ -88,60 +105,50 @@ def _complex_fftn_fftw(f, mesh, func):
|
88 | 105 | ctypes.c_int(rank))
|
89 | 106 | return out
|
90 | 107 |
|
91 |
| - def _fftn_wrapper(a): |
| 108 | + def _fftn_wrapper(a): # noqa |
92 | 109 | mesh = a.shape[1:]
|
93 | 110 | return _complex_fftn_fftw(a, mesh, 'fft')
|
94 |
| - def _ifftn_wrapper(a): |
| 111 | + def _ifftn_wrapper(a): # noqa |
95 | 112 | mesh = a.shape[1:]
|
96 | 113 | return _complex_fftn_fftw(a, mesh, 'ifft')
|
97 | 114 |
|
98 | 115 | elif FFT_ENGINE == 'PYFFTW':
|
99 |
| - # pyfftw is slower than np.fft in most cases |
| 116 | + # Note: pyfftw is likely slower than scipy.fft in multi-threading environments |
100 | 117 | try:
|
101 | 118 | import pyfftw
|
| 119 | + pyfftw.config.PLANNER_EFFORT = 'FFTW_MEASURE' |
102 | 120 | pyfftw.interfaces.cache.enable()
|
103 |
| - nproc = lib.num_threads() |
104 |
| - def _fftn_wrapper(a): |
| 121 | + def _fftn_wrapper(a): # noqa |
105 | 122 | return pyfftw.interfaces.numpy_fft.fftn(a, axes=(1,2,3), threads=nproc)
|
106 |
| - def _ifftn_wrapper(a): |
| 123 | + def _ifftn_wrapper(a): # noqa |
107 | 124 | return pyfftw.interfaces.numpy_fft.ifftn(a, axes=(1,2,3), threads=nproc)
|
108 | 125 | except ImportError:
|
109 |
| - def _fftn_wrapper(a): |
110 |
| - return np.fft.fftn(a, axes=(1,2,3)) |
111 |
| - def _ifftn_wrapper(a): |
112 |
| - return np.fft.ifftn(a, axes=(1,2,3)) |
113 |
| - |
114 |
| -elif FFT_ENGINE == 'NUMPY': |
115 |
| - def _fftn_wrapper(a): |
116 |
| - return np.fft.fftn(a, axes=(1,2,3)) |
117 |
| - def _ifftn_wrapper(a): |
118 |
| - return np.fft.ifftn(a, axes=(1,2,3)) |
| 126 | + print('PyFFTW not installed. SciPy fft module will be used.') |
119 | 127 |
|
120 | 128 | elif FFT_ENGINE == 'NUMPY+BLAS':
|
121 | 129 | _EXCLUDE = [17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79,
|
122 | 130 | 83, 89, 97,101,103,107,109,113,127,131,137,139,149,151,157,163,
|
123 | 131 | 167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,
|
124 | 132 | 257,263,269,271,277,281,283,293]
|
125 |
| - _EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE] + [n*3 for n in _EXCLUDE]) |
126 |
| - def _fftn_wrapper(a): |
| 133 | + _EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE[:30]] + [n*3 for n in _EXCLUDE[:20]]) |
| 134 | + def _fftn_wrapper(a): # noqa |
127 | 135 | mesh = a.shape[1:]
|
128 | 136 | if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE:
|
129 | 137 | return _fftn_blas(a, mesh)
|
130 | 138 | else:
|
131 |
| - return np.fft.fftn(a, axes=(1,2,3)) |
132 |
| - def _ifftn_wrapper(a): |
| 139 | + return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc) |
| 140 | + def _ifftn_wrapper(a): # noqa |
133 | 141 | mesh = a.shape[1:]
|
134 | 142 | if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE:
|
135 | 143 | return _ifftn_blas(a, mesh)
|
136 | 144 | else:
|
137 |
| - return np.fft.ifftn(a, axes=(1,2,3)) |
| 145 | + return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc) |
138 | 146 |
|
139 |
| -#?elif: # 'FFTW+BLAS' |
140 |
| -else: # 'BLAS' |
141 |
| - def _fftn_wrapper(a): |
| 147 | +elif FFT_ENGINE == 'BLAS': |
| 148 | + def _fftn_wrapper(a): # noqa |
142 | 149 | mesh = a.shape[1:]
|
143 | 150 | return _fftn_blas(a, mesh)
|
144 |
| - def _ifftn_wrapper(a): |
| 151 | + def _ifftn_wrapper(a): # noqa |
145 | 152 | mesh = a.shape[1:]
|
146 | 153 | return _ifftn_blas(a, mesh)
|
147 | 154 |
|
|
0 commit comments