Skip to content

Commit a093546

Browse files
committed
Add tensorflow basics test
1 parent e66ac96 commit a093546

File tree

4 files changed

+72
-2
lines changed

4 files changed

+72
-2
lines changed

.pylintrc

+5-2
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ contextmanager-decorators=contextlib.contextmanager
236236
# List of members which are set dynamically and missed by pylint inference
237237
# system, and so shouldn't trigger E1101 when accessed. Python regular
238238
# expressions are accepted.
239-
generated-members=numpy.*,torch.*,mpi4py.MPI.*,bluefog.torch.mpi_lib.*,tensorflow.*
239+
generated-members=numpy.*,torch.*,mpi4py.MPI.*,
240+
tensorflow.*,
241+
bluefog.torch.mpi_lib.*,
242+
bluefog.tensorflow.mpi_lib.*,
240243

241244
# Tells whether missing members accessed in mixin class should be ignored. A
242245
# mixin class is detected if its name ends with "mixin" (case insensitive).
@@ -263,7 +266,7 @@ ignored-classes=optparse.Values,thread._local,_thread._local
263266
# (useful for modules/projects where namespaces are manipulated during runtime
264267
# and thus existing member attributes cannot be deduced by static analysis. It
265268
# supports qualified module names, as well as Unix pattern matching.
266-
ignored-modules=
269+
ignored-modules=tensorflow
267270

268271
# Show a hint with possible names when a member name was not found. The aspect
269272
# of finding the hint is based on edit distance.

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ test_torch_basic:
99
test_torch_ops:
1010
${MPIRUN} -np 4 ${PYTEST} ./test/torch_ops_test.py
1111

12+
test_tensorflow_basic:
13+
${PYTEST} ./test/tensorflow_basics_test.py && ${MPIRUN} -np 4 ${PYTEST} ./test/tensorflow_basics_test.py
14+
1215
clean_build:
1316
rm -R build
1417

test/tensorflow_basics_test.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import unittest
6+
import warnings
7+
import numpy as np
8+
import networkx as nx
9+
import pytest
10+
import tensorflow as tf
11+
12+
from common import mpi_env_rank_and_size
13+
import bluefog.tensorflow as bf
14+
from bluefog.common.topology_util import PowerTwoRingGraph, BiRingGraph
15+
16+
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
17+
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
18+
19+
20+
class BasicsTests(tf.test.TestCase):
21+
"""
22+
Tests for basics.py
23+
"""
24+
25+
def __init__(self, *args, **kwargs):
26+
super(BasicsTests, self).__init__(*args, **kwargs)
27+
warnings.simplefilter("module")
28+
29+
def test_bluefog_rank(self):
30+
"""Test that the rank returned by bf.rank() is correct."""
31+
true_rank, _ = mpi_env_rank_and_size()
32+
bf.init()
33+
rank = bf.rank()
34+
# print("Rank: ", true_rank, rank)
35+
assert true_rank == rank
36+
37+
def test_bluefog_size(self):
38+
"""Test that the size returned by bf.size() is correct."""
39+
_, true_size = mpi_env_rank_and_size()
40+
bf.init()
41+
size = bf.size()
42+
# print("Size: ", true_size, size)
43+
assert true_size == size
44+
45+
def test_set_and_load_topology(self):
46+
_, size = mpi_env_rank_and_size()
47+
if size == 4:
48+
expected_topology = nx.DiGraph(np.array(
49+
[[0, 1, 1, 0], [0, 0, 1, 1], [1, 0, 0, 1], [1, 1, 0, 0]]
50+
))
51+
elif size == 1:
52+
expected_topology = nx.DiGraph(np.array([[0]]))
53+
else:
54+
expected_topology = PowerTwoRingGraph(size)
55+
bf.init()
56+
_, _, topology = bf.load_topology()
57+
assert isinstance(topology, nx.DiGraph)
58+
np.testing.assert_array_equal(
59+
nx.to_numpy_array(expected_topology), nx.to_numpy_array(topology))
60+
61+
62+
if __name__ == "__main__":
63+
tf.test.main()

test/torch_basics_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
1616
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
1717

18+
1819
class BasicsTests(unittest.TestCase):
1920
"""
2021
Tests for basics.py

0 commit comments

Comments
 (0)