Skip to content

Commit ec4d31e

Browse files
cshentonsb2nov
authored andcommitted
Half Normal Distribution (and inverse error function) (tensorflow#14056)
* foldednormal docstring * folded __init__ method * prob, log_prob methods * rewrote halfnormal docstring * initial implementation of dist methods * halfnormal unit tests * registered HalfNormal to contrib.distributions * added erfinv function * unit tests for erfinv * registered erfinv symbol * cdf, pdf now deal with x < 0 correctly * pylint fixes * cuda_py test reference in BUILD * erfinv fixes * corrections to scipy reference tests * Added reference to entropy test case.
1 parent 3bf2f35 commit ec4d31e

File tree

6 files changed

+560
-0
lines changed

6 files changed

+560
-0
lines changed

tensorflow/contrib/distributions/BUILD

+18
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,24 @@ cuda_py_test(
204204
],
205205
)
206206

207+
cuda_py_test(
208+
name = "half_normal_test",
209+
size = "medium",
210+
srcs = ["python/kernel_tests/half_normal_test.py"],
211+
additional_deps = [
212+
":distributions_py",
213+
"//third_party/py/numpy",
214+
"//tensorflow/python:client",
215+
"//tensorflow/python:client_testlib",
216+
"//tensorflow/python:framework_for_generated_wrappers",
217+
"//tensorflow/python:framework_test_lib",
218+
"//tensorflow/python:gradients",
219+
"//tensorflow/python:nn_ops",
220+
"//tensorflow/python:platform_test",
221+
"//tensorflow/python:variables",
222+
],
223+
)
224+
207225
cuda_py_test(
208226
name = "inverse_gamma_test",
209227
srcs = ["python/kernel_tests/inverse_gamma_test.py"],

tensorflow/contrib/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
3737
from tensorflow.contrib.distributions.python.ops.estimator import *
3838
from tensorflow.contrib.distributions.python.ops.geometric import *
39+
from tensorflow.contrib.distributions.python.ops.half_normal import *
3940
from tensorflow.contrib.distributions.python.ops.independent import *
4041
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
4142
from tensorflow.contrib.distributions.python.ops.logistic import *
@@ -107,6 +108,7 @@
107108
'Gamma',
108109
'GammaWithSoftplusConcentrationRate',
109110
'Geometric',
111+
'HalfNormal',
110112
'Independent',
111113
'InverseGamma',
112114
'InverseGammaWithSoftplusConcentrationRate',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for initializers."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import importlib
22+
import numpy as np
23+
24+
from tensorflow.python.framework import constant_op
25+
from tensorflow.python.framework import dtypes
26+
from tensorflow.python.framework import ops
27+
from tensorflow.python.framework import tensor_shape
28+
from tensorflow.python.ops import array_ops
29+
from tensorflow.python.ops import gradients_impl
30+
from tensorflow.python.ops import variables
31+
from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib
32+
from tensorflow.python.platform import test
33+
from tensorflow.python.platform import tf_logging
34+
35+
36+
def try_import(name): # pylint: disable=invalid-name
37+
module = None
38+
try:
39+
module = importlib.import_module(name)
40+
except ImportError as e:
41+
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
42+
return module
43+
44+
stats = try_import("scipy.stats")
45+
46+
47+
class HalfNormalTest(test.TestCase):
48+
49+
def setUp(self):
50+
self._rng = np.random.RandomState(123)
51+
52+
def assertAllFinite(self, tensor):
53+
is_finite = np.isfinite(tensor.eval())
54+
all_true = np.ones_like(is_finite, dtype=np.bool)
55+
self.assertAllEqual(all_true, is_finite)
56+
57+
def _testParamShapes(self, sample_shape, expected):
58+
with self.test_session():
59+
param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape)
60+
scale_shape = param_shapes["scale"]
61+
self.assertAllEqual(expected, scale_shape.eval())
62+
scale = array_ops.ones(scale_shape)
63+
self.assertAllEqual(
64+
expected,
65+
array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval())
66+
67+
def _testParamStaticShapes(self, sample_shape, expected):
68+
param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape)
69+
scale_shape = param_shapes["scale"]
70+
self.assertEqual(expected, scale_shape)
71+
72+
def _testBatchShapes(self, dist, tensor):
73+
self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape)
74+
self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape)
75+
self.assertAllEqual(dist.batch_shape, tensor.shape)
76+
self.assertAllEqual(dist.batch_shape, tensor.eval().shape)
77+
78+
def testParamShapes(self):
79+
sample_shape = [10, 3, 4]
80+
self._testParamShapes(sample_shape, sample_shape)
81+
self._testParamShapes(constant_op.constant(sample_shape), sample_shape)
82+
83+
def testParamStaticShapes(self):
84+
sample_shape = [10, 3, 4]
85+
self._testParamStaticShapes(sample_shape, sample_shape)
86+
self._testParamStaticShapes(
87+
tensor_shape.TensorShape(sample_shape), sample_shape)
88+
89+
def testHalfNormalLogPDF(self):
90+
with self.test_session():
91+
batch_size = 6
92+
scale = constant_op.constant([3.0] * batch_size)
93+
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
94+
halfnorm = hn_lib.HalfNormal(scale=scale)
95+
96+
log_pdf = halfnorm.log_prob(x)
97+
self._testBatchShapes(halfnorm, log_pdf)
98+
99+
pdf = halfnorm.prob(x)
100+
self._testBatchShapes(halfnorm, pdf)
101+
102+
if not stats:
103+
return
104+
expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x)
105+
self.assertAllClose(expected_log_pdf, log_pdf.eval())
106+
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
107+
108+
def testHalfNormalLogPDFMultidimensional(self):
109+
with self.test_session():
110+
batch_size = 6
111+
scale = constant_op.constant([[3.0, 1.0]] * batch_size)
112+
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
113+
halfnorm = hn_lib.HalfNormal(scale=scale)
114+
115+
log_pdf = halfnorm.log_prob(x)
116+
self._testBatchShapes(halfnorm, log_pdf)
117+
118+
pdf = halfnorm.prob(x)
119+
self._testBatchShapes(halfnorm, pdf)
120+
121+
if not stats:
122+
return
123+
expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x)
124+
self.assertAllClose(expected_log_pdf, log_pdf.eval())
125+
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
126+
127+
def testHalfNormalCDF(self):
128+
with self.test_session():
129+
batch_size = 50
130+
scale = self._rng.rand(batch_size) + 1.0
131+
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
132+
halfnorm = hn_lib.HalfNormal(scale=scale)
133+
134+
cdf = halfnorm.cdf(x)
135+
self._testBatchShapes(halfnorm, cdf)
136+
137+
log_cdf = halfnorm.log_cdf(x)
138+
self._testBatchShapes(halfnorm, log_cdf)
139+
140+
if not stats:
141+
return
142+
expected_logcdf = stats.halfnorm(scale=scale).logcdf(x)
143+
self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0)
144+
self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0)
145+
146+
def testHalfNormalSurvivalFunction(self):
147+
with self.test_session():
148+
batch_size = 50
149+
scale = self._rng.rand(batch_size) + 1.0
150+
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
151+
halfnorm = hn_lib.HalfNormal(scale=scale)
152+
153+
sf = halfnorm.survival_function(x)
154+
self._testBatchShapes(halfnorm, sf)
155+
156+
log_sf = halfnorm.log_survival_function(x)
157+
self._testBatchShapes(halfnorm, log_sf)
158+
159+
if not stats:
160+
return
161+
expected_logsf = stats.halfnorm(scale=scale).logsf(x)
162+
self.assertAllClose(expected_logsf, log_sf.eval(), atol=0)
163+
self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0)
164+
165+
def testHalfNormalQuantile(self):
166+
with self.test_session():
167+
batch_size = 50
168+
scale = self._rng.rand(batch_size) + 1.0
169+
p = np.linspace(0., 1.0, batch_size).astype(np.float64)
170+
171+
halfnorm = hn_lib.HalfNormal(scale=scale)
172+
x = halfnorm.quantile(p)
173+
self._testBatchShapes(halfnorm, x)
174+
175+
if not stats:
176+
return
177+
expected_x = stats.halfnorm(scale=scale).ppf(p)
178+
self.assertAllClose(expected_x, x.eval(), atol=0)
179+
180+
def testFiniteGradients(self):
181+
for dtype in [np.float32, np.float64]:
182+
g = ops.Graph()
183+
with g.as_default():
184+
scale = variables.Variable(dtype(3.0))
185+
dist = hn_lib.HalfNormal(scale=scale)
186+
x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype)
187+
for func in [
188+
dist.cdf, dist.log_cdf, dist.survival_function,
189+
dist.log_prob, dist.prob, dist.log_survival_function,
190+
]:
191+
print(func.__name__)
192+
value = func(x)
193+
grads = gradients_impl.gradients(value, [scale])
194+
with self.test_session(graph=g):
195+
variables.global_variables_initializer().run()
196+
self.assertAllFinite(value)
197+
self.assertAllFinite(grads[0])
198+
199+
def testHalfNormalEntropy(self):
200+
with self.test_session():
201+
scale = np.array([[1.0, 2.0, 3.0]])
202+
halfnorm = hn_lib.HalfNormal(scale=scale)
203+
204+
# See https://en.wikipedia.org/wiki/Half-normal_distribution for the
205+
# entropy formula used here.
206+
expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5
207+
208+
entropy = halfnorm.entropy()
209+
self._testBatchShapes(halfnorm, entropy)
210+
self.assertAllClose(expected_entropy, entropy.eval())
211+
212+
def testHalfNormalMeanAndMode(self):
213+
with self.test_session():
214+
scale = np.array([11., 12., 13.])
215+
216+
halfnorm = hn_lib.HalfNormal(scale=scale)
217+
expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi)
218+
219+
self.assertAllEqual((3,), halfnorm.mean().eval().shape)
220+
self.assertAllEqual(expected_mean, halfnorm.mean().eval())
221+
222+
self.assertAllEqual((3,), halfnorm.mode().eval().shape)
223+
self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval())
224+
225+
def testHalfNormalVariance(self):
226+
with self.test_session():
227+
scale = np.array([7., 7., 7.])
228+
halfnorm = hn_lib.HalfNormal(scale=scale)
229+
expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi)
230+
231+
self.assertAllEqual((3,), halfnorm.variance().eval().shape)
232+
self.assertAllEqual(expected_variance, halfnorm.variance().eval())
233+
234+
def testHalfNormalStandardDeviation(self):
235+
with self.test_session():
236+
scale = np.array([7., 7., 7.])
237+
halfnorm = hn_lib.HalfNormal(scale=scale)
238+
expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi)
239+
240+
self.assertAllEqual((3,), halfnorm.stddev().shape)
241+
self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval())
242+
243+
def testHalfNormalSample(self):
244+
with self.test_session():
245+
scale = constant_op.constant(3.0)
246+
n = constant_op.constant(100000)
247+
halfnorm = hn_lib.HalfNormal(scale=scale)
248+
249+
sample = halfnorm.sample(n)
250+
251+
self.assertEqual(sample.eval().shape, (100000,))
252+
self.assertAllClose(sample.eval().mean(),
253+
3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1)
254+
255+
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
256+
tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval()))
257+
self.assertAllEqual(expected_shape, sample.shape)
258+
self.assertAllEqual(expected_shape, sample.eval().shape)
259+
260+
expected_shape_static = (tensor_shape.TensorShape(
261+
[n.eval()]).concatenate(halfnorm.batch_shape))
262+
self.assertAllEqual(expected_shape_static, sample.shape)
263+
self.assertAllEqual(expected_shape_static, sample.eval().shape)
264+
265+
def testHalfNormalSampleMultiDimensional(self):
266+
with self.test_session():
267+
batch_size = 2
268+
scale = constant_op.constant([[2.0, 3.0]] * batch_size)
269+
n = constant_op.constant(100000)
270+
halfnorm = hn_lib.HalfNormal(scale=scale)
271+
272+
sample = halfnorm.sample(n)
273+
self.assertEqual(sample.shape, (100000, batch_size, 2))
274+
self.assertAllClose(sample.eval()[:, 0, 0].mean(),
275+
2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1)
276+
self.assertAllClose(sample.eval()[:, 0, 1].mean(),
277+
3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1)
278+
279+
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
280+
tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval()))
281+
self.assertAllEqual(expected_shape, sample.shape)
282+
self.assertAllEqual(expected_shape, sample.eval().shape)
283+
284+
expected_shape_static = (tensor_shape.TensorShape(
285+
[n.eval()]).concatenate(halfnorm.batch_shape))
286+
self.assertAllEqual(expected_shape_static, sample.shape)
287+
self.assertAllEqual(expected_shape_static, sample.eval().shape)
288+
289+
def testNegativeSigmaFails(self):
290+
with self.test_session():
291+
halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
292+
with self.assertRaisesOpError("Condition x > 0 did not hold"):
293+
halfnorm.mean().eval()
294+
295+
def testHalfNormalShape(self):
296+
with self.test_session():
297+
scale = constant_op.constant([6.0] * 5)
298+
halfnorm = hn_lib.HalfNormal(scale=scale)
299+
300+
self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5])
301+
self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5]))
302+
self.assertAllEqual(halfnorm.event_shape_tensor().eval(), [])
303+
self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([]))
304+
305+
def testHalfNormalShapeWithPlaceholders(self):
306+
scale = array_ops.placeholder(dtype=dtypes.float32)
307+
halfnorm = hn_lib.HalfNormal(scale=scale)
308+
309+
with self.test_session() as sess:
310+
# get_batch_shape should return an "<unknown>" tensor.
311+
self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None))
312+
self.assertEqual(halfnorm.event_shape, ())
313+
self.assertAllEqual(halfnorm.event_shape_tensor().eval(), [])
314+
self.assertAllEqual(
315+
sess.run(halfnorm.batch_shape_tensor(),
316+
feed_dict={scale: [1.0, 2.0]}), [2])
317+
318+
319+
if __name__ == "__main__":
320+
test.main()

0 commit comments

Comments
 (0)