|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import unittest |
| 6 | +import warnings |
| 7 | +import numpy as np |
| 8 | +import networkx as nx |
| 9 | +import pytest |
| 10 | +import tensorflow as tf |
| 11 | + |
| 12 | +from common import mpi_env_rank_and_size |
| 13 | +import bluefog.tensorflow as bf |
| 14 | +from bluefog.common.topology_util import PowerTwoRingGraph, BiRingGraph |
| 15 | + |
| 16 | +warnings.filterwarnings("ignore", message="numpy.dtype size changed") |
| 17 | +warnings.filterwarnings("ignore", message="numpy.ufunc size changed") |
| 18 | + |
| 19 | + |
| 20 | +class BasicsTests(tf.test.TestCase): |
| 21 | + """ |
| 22 | + Tests for basics.py |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, *args, **kwargs): |
| 26 | + super(BasicsTests, self).__init__(*args, **kwargs) |
| 27 | + warnings.simplefilter("module") |
| 28 | + |
| 29 | + def test_bluefog_rank(self): |
| 30 | + """Test that the rank returned by bf.rank() is correct.""" |
| 31 | + true_rank, _ = mpi_env_rank_and_size() |
| 32 | + bf.init() |
| 33 | + rank = bf.rank() |
| 34 | + # print("Rank: ", true_rank, rank) |
| 35 | + assert true_rank == rank |
| 36 | + |
| 37 | + def test_bluefog_size(self): |
| 38 | + """Test that the size returned by bf.size() is correct.""" |
| 39 | + _, true_size = mpi_env_rank_and_size() |
| 40 | + bf.init() |
| 41 | + size = bf.size() |
| 42 | + # print("Size: ", true_size, size) |
| 43 | + assert true_size == size |
| 44 | + |
| 45 | + def test_set_and_load_topology(self): |
| 46 | + _, size = mpi_env_rank_and_size() |
| 47 | + if size == 4: |
| 48 | + expected_topology = nx.DiGraph(np.array( |
| 49 | + [[0, 1, 1, 0], [0, 0, 1, 1], [1, 0, 0, 1], [1, 1, 0, 0]] |
| 50 | + )) |
| 51 | + elif size == 1: |
| 52 | + expected_topology = nx.DiGraph(np.array([[0]])) |
| 53 | + else: |
| 54 | + expected_topology = PowerTwoRingGraph(size) |
| 55 | + bf.init() |
| 56 | + _, _, topology = bf.load_topology() |
| 57 | + assert isinstance(topology, nx.DiGraph) |
| 58 | + np.testing.assert_array_equal( |
| 59 | + nx.to_numpy_array(expected_topology), nx.to_numpy_array(topology)) |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + tf.test.main() |
0 commit comments