From 9e9cda6bb11ad819c474b7cf8d37bfa05d710298 Mon Sep 17 00:00:00 2001 From: Lu Lu Date: Sat, 30 Jul 2022 19:52:52 -0400 Subject: [PATCH] Improve point sampling --- deepxde/data/func_constraint.py | 8 +++---- deepxde/geometry/geometry_nd.py | 10 ++++----- deepxde/geometry/sampler.py | 39 +++++++++++++++++++-------------- docs/requirements.txt | 2 +- requirements.txt | 2 +- 5 files changed, 33 insertions(+), 28 deletions(-) diff --git a/deepxde/data/func_constraint.py b/deepxde/data/func_constraint.py index 6c54313d1..f04336138 100644 --- a/deepxde/data/func_constraint.py +++ b/deepxde/data/func_constraint.py @@ -43,14 +43,14 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): @run_if_any_none("train_x", "train_y") def train_next_batch(self, batch_size=None): - if self.dist_train == "log uniform": + if self.dist_train == "uniform": + self.train_x = self.geom.uniform_points(self.num_train, False) + elif self.dist_train == "log uniform": self.train_x = self.geom.log_uniform_points(self.num_train, False) - elif self.dist_train == "random": + else: self.train_x = self.geom.random_points( self.num_train, random=self.dist_train ) - else: - self.train_x = self.geom.uniform_points(self.num_train, False) if self.anchors is not None: self.train_x = np.vstack((self.anchors, self.train_x)) self.train_y = self.func(self.train_x) diff --git a/deepxde/geometry/geometry_nd.py b/deepxde/geometry/geometry_nd.py index 834152204..ef1768650 100644 --- a/deepxde/geometry/geometry_nd.py +++ b/deepxde/geometry/geometry_nd.py @@ -118,7 +118,7 @@ def on_boundary(self, x): return np.isclose(np.linalg.norm(x - self.center, axis=-1), self.radius) def distance2boundary_unitdirn(self, x, dirn): - """https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection""" + # https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection xc = x - self.center ad = np.dot(xc, dirn) return -ad + (ad ** 2 - np.sum(xc * xc, axis=-1) + self._r2) ** 0.5 @@ -136,24 +136,24 @@ def boundary_normal(self, x): return _n def random_points(self, n, random="pseudo"): - """https://math.stackexchange.com/questions/87230/picking-random-points-in-the-volume-of-sphere-with-uniform-probability""" + # https://math.stackexchange.com/questions/87230/picking-random-points-in-the-volume-of-sphere-with-uniform-probability if random == "pseudo": U = np.random.rand(n, 1) X = np.random.normal(size=(n, self.dim)) else: rng = sample(n, self.dim + 1, random) - U, X = rng[:, 0:1], rng[:, 1:] + U, X = rng[:, 0:1], rng[:, 1:] # Error if X = [0, 0, ...] X = stats.norm.ppf(X) X = preprocessing.normalize(X) X = U ** (1 / self.dim) * X return self.radius * X + self.center def random_boundary_points(self, n, random="pseudo"): - """http://mathworld.wolfram.com/HyperspherePointPicking.html""" + # http://mathworld.wolfram.com/HyperspherePointPicking.html if random == "pseudo": X = np.random.normal(size=(n, self.dim)).astype(config.real(np)) else: - U = sample(n, self.dim, random) + U = sample(n, self.dim, random) # Error for [0, 0, ...] or [0.5, 0.5, ...] X = stats.norm.ppf(U) X = preprocessing.normalize(X) return self.radius * X + self.center diff --git a/deepxde/geometry/sampler.py b/deepxde/geometry/sampler.py index aad44b5cb..7b384eddc 100644 --- a/deepxde/geometry/sampler.py +++ b/deepxde/geometry/sampler.py @@ -1,7 +1,5 @@ __all__ = ["sample"] -from distutils.version import LooseVersion - import numpy as np import skopt @@ -35,24 +33,31 @@ def pseudorandom(n_samples, dimension): def quasirandom(n_samples, dimension, sampler): + # Certain points should be removed: + # - Boundary points such as [..., 0, ...] + # - Special points [0, 0, 0, ...] and [0.5, 0.5, 0.5, ...], which cause error in + # Hypersphere.random_points() and Hypersphere.random_boundary_points() + skip = 0 if sampler == "LHS": - sampler = skopt.sampler.Lhs( - lhs_type="centered", criterion="maximin", iterations=1000 - ) + sampler = skopt.sampler.Lhs() elif sampler == "Halton": - sampler = skopt.sampler.Halton(min_skip=-1, max_skip=-1) + # 1st point: [0, 0, ...] + sampler = skopt.sampler.Halton(min_skip=1, max_skip=1) elif sampler == "Hammersley": - sampler = skopt.sampler.Hammersly(min_skip=-1, max_skip=-1) + # 1st point: [0, 0, ...] + if dimension == 1: + sampler = skopt.sampler.Hammersly(min_skip=1, max_skip=1) + else: + sampler = skopt.sampler.Hammersly() + skip = 1 elif sampler == "Sobol": - # Remove the first point [0, 0, ...] and the second point [0.5, 0.5, ...], which - # are too special and may cause some error. - if LooseVersion(skopt.__version__) < LooseVersion("0.9"): - sampler = skopt.sampler.Sobol(min_skip=2, max_skip=2, randomize=False) + # 1st point: [0, 0, ...], 2nd point: [0.5, 0.5, ...] + sampler = skopt.sampler.Sobol(randomize=False) + if dimension < 3: + skip = 1 else: - sampler = skopt.sampler.Sobol(skip=0, randomize=False) - space = [(0.0, 1.0)] * dimension - return np.asarray( - sampler.generate(space, n_samples + 2)[2:], dtype=config.real(np) - ) + skip = 2 space = [(0.0, 1.0)] * dimension - return np.asarray(sampler.generate(space, n_samples), dtype=config.real(np)) + return np.asarray( + sampler.generate(space, n_samples + skip)[skip:], dtype=config.real(np) + ) diff --git a/docs/requirements.txt b/docs/requirements.txt index 8484f776d..71b94f83b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,7 @@ matplotlib numpy scikit-learn -scikit-optimize +scikit-optimize>=0.9.0 scipy docutils<0.18 # https://github.com/readthedocs/readthedocs.org/issues/8616 diff --git a/requirements.txt b/requirements.txt index 4798085f8..38608b103 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ matplotlib numpy scikit-learn -scikit-optimize +scikit-optimize>=0.9.0 scipy