diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7fc43b4e38..a724e8a564 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - EvalCallback now works also for recurrent policies (@mily20001) +- Add minimal support for TF2 using tensorflow.compat.v1 while keeping support for TF1 Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index ce3e6e7674..48f53b9d2f 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -2,7 +2,7 @@ import gym import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines import logger from stable_baselines.common import explained_variance, tf_util, ActorCriticRLModel, SetVerbosity, TensorboardWriter diff --git a/stable_baselines/acer/acer_simple.py b/stable_baselines/acer/acer_simple.py index 160f48a3df..5f5a7a7ba1 100644 --- a/stable_baselines/acer/acer_simple.py +++ b/stable_baselines/acer/acer_simple.py @@ -2,7 +2,7 @@ import warnings import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import Discrete, Box from collections import deque diff --git a/stable_baselines/acktr/acktr.py b/stable_baselines/acktr/acktr.py index e8509bd59f..67ae9fdbff 100644 --- a/stable_baselines/acktr/acktr.py +++ b/stable_baselines/acktr/acktr.py @@ -1,7 +1,7 @@ import time import warnings -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import Box, Discrete from stable_baselines import logger diff --git a/stable_baselines/acktr/kfac.py b/stable_baselines/acktr/kfac.py index 4ab208056e..d62bd47763 100644 --- a/stable_baselines/acktr/kfac.py +++ b/stable_baselines/acktr/kfac.py @@ -1,7 +1,7 @@ import re from functools import reduce -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np from stable_baselines.acktr.kfac_utils import detect_min_val, factor_reshape, gmatmul diff --git a/stable_baselines/acktr/kfac_utils.py b/stable_baselines/acktr/kfac_utils.py index 512e21a239..e901a22294 100644 --- a/stable_baselines/acktr/kfac_utils.py +++ b/stable_baselines/acktr/kfac_utils.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf def gmatmul(tensor_a, tensor_b, transpose_a=False, transpose_b=False, reduce_dim=None): diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 994a539714..5326311d78 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -10,7 +10,7 @@ import gym import cloudpickle import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines.common.misc_util import set_global_seeds from stable_baselines.common.save_util import data_to_json, json_to_data, params_to_bytes, bytes_to_params diff --git a/stable_baselines/common/distributions.py b/stable_baselines/common/distributions.py index f38dd65767..b9cfb96223 100644 --- a/stable_baselines/common/distributions.py +++ b/stable_baselines/common/distributions.py @@ -1,5 +1,5 @@ import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from tensorflow.python.ops import math_ops from gym import spaces diff --git a/stable_baselines/common/input.py b/stable_baselines/common/input.py index e8cfa3c8b4..5b1b00e2c0 100644 --- a/stable_baselines/common/input.py +++ b/stable_baselines/common/input.py @@ -1,5 +1,5 @@ import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete diff --git a/stable_baselines/common/misc_util.py b/stable_baselines/common/misc_util.py index 2d8730c03e..840bfabb0a 100644 --- a/stable_baselines/common/misc_util.py +++ b/stable_baselines/common/misc_util.py @@ -2,7 +2,7 @@ import gym import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf def zipsame(*seqs): diff --git a/stable_baselines/common/mpi_adam.py b/stable_baselines/common/mpi_adam.py index 02c3de8bb2..b42dd3d012 100644 --- a/stable_baselines/common/mpi_adam.py +++ b/stable_baselines/common/mpi_adam.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np import mpi4py diff --git a/stable_baselines/common/mpi_running_mean_std.py b/stable_baselines/common/mpi_running_mean_std.py index 5e52129b7a..ae75223815 100644 --- a/stable_baselines/common/mpi_running_mean_std.py +++ b/stable_baselines/common/mpi_running_mean_std.py @@ -1,5 +1,5 @@ import mpi4py -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np import stable_baselines.common.tf_util as tf_util diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index a2f3cc7f1c..7372bb2263 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import Discrete from stable_baselines.common.tf_util import batch_to_seq, seq_to_batch diff --git a/stable_baselines/common/tf_layers.py b/stable_baselines/common/tf_layers.py index c35bf85261..5d7f734e61 100644 --- a/stable_baselines/common/tf_layers.py +++ b/stable_baselines/common/tf_layers.py @@ -1,5 +1,5 @@ import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf def ortho_init(scale=1.0): diff --git a/stable_baselines/common/tf_util.py b/stable_baselines/common/tf_util.py index ba78c042f0..6c40fb8c01 100644 --- a/stable_baselines/common/tf_util.py +++ b/stable_baselines/common/tf_util.py @@ -5,7 +5,7 @@ from typing import Set import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf def is_image(tensor): diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 5e806f3cdd..ab98765b69 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -7,7 +7,7 @@ import gym import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf import tensorflow.contrib as tc from mpi4py import MPI diff --git a/stable_baselines/ddpg/main.py b/stable_baselines/ddpg/main.py index 3e1232788c..742f9d3568 100644 --- a/stable_baselines/ddpg/main.py +++ b/stable_baselines/ddpg/main.py @@ -3,7 +3,7 @@ import os import gym -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np from mpi4py import MPI diff --git a/stable_baselines/ddpg/policies.py b/stable_baselines/ddpg/policies.py index 19ac6463ef..6d36a1f629 100644 --- a/stable_baselines/ddpg/policies.py +++ b/stable_baselines/ddpg/policies.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import Box from stable_baselines.common.policies import BasePolicy, nature_cnn, register_policy diff --git a/stable_baselines/deepq/build_graph.py b/stable_baselines/deepq/build_graph.py index 51453ec6e5..0266b2aaa5 100644 --- a/stable_baselines/deepq/build_graph.py +++ b/stable_baselines/deepq/build_graph.py @@ -62,7 +62,7 @@ Q' is set to Q once every 10000 updates training steps. """ -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import MultiDiscrete from stable_baselines.common import tf_util diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index 8deaf5a090..a9c4ee8f81 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -1,6 +1,6 @@ from functools import partial -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np import gym diff --git a/stable_baselines/deepq/policies.py b/stable_baselines/deepq/policies.py index 3a2dfec16d..1fdf057df3 100644 --- a/stable_baselines/deepq/policies.py +++ b/stable_baselines/deepq/policies.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf import tensorflow.contrib.layers as tf_layers import numpy as np from gym.spaces import Discrete diff --git a/stable_baselines/gail/adversary.py b/stable_baselines/gail/adversary.py index 7c6cb63c68..fcaf0e1dd1 100644 --- a/stable_baselines/gail/adversary.py +++ b/stable_baselines/gail/adversary.py @@ -3,7 +3,7 @@ I follow the architecture from the official repository """ import gym -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np from stable_baselines.common.mpi_running_mean_std import RunningMeanStd diff --git a/stable_baselines/logger.py b/stable_baselines/logger.py index 0cfcde68db..55ee818e8d 100644 --- a/stable_baselines/logger.py +++ b/stable_baselines/logger.py @@ -9,7 +9,7 @@ from collections import defaultdict from typing import Optional -import tensorflow as tf +import tensorflow.compat.v1 as tf from tensorflow.python import pywrap_tensorflow from tensorflow.core.util import event_pb2 from tensorflow.python.util import compat @@ -715,7 +715,7 @@ def read_tb(path): import numpy as np from glob import glob # from collections import defaultdict - import tensorflow as tf + import tensorflow.compat.v1 as tf if os.path.isdir(path): fnames = glob(os.path.join(path, "events.*")) elif os.path.basename(path).startswith("events."): diff --git a/stable_baselines/ppo1/pposgd_simple.py b/stable_baselines/ppo1/pposgd_simple.py index b570df26e3..7c0898c5ea 100644 --- a/stable_baselines/ppo1/pposgd_simple.py +++ b/stable_baselines/ppo1/pposgd_simple.py @@ -3,7 +3,7 @@ import gym import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from mpi4py import MPI from stable_baselines.common import Dataset, explained_variance, fmt_row, zipsame, ActorCriticRLModel, SetVerbosity, \ diff --git a/stable_baselines/ppo2/ppo2.py b/stable_baselines/ppo2/ppo2.py index 8a2ddff9f7..db187c904c 100644 --- a/stable_baselines/ppo2/ppo2.py +++ b/stable_baselines/ppo2/ppo2.py @@ -2,7 +2,7 @@ import gym import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines import logger from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter diff --git a/stable_baselines/sac/policies.py b/stable_baselines/sac/policies.py index 7cc61e7dd9..35b7cf9984 100644 --- a/stable_baselines/sac/policies.py +++ b/stable_baselines/sac/policies.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np from gym.spaces import Box diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index f466af76d3..1b74bf0756 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -2,7 +2,7 @@ import warnings import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter from stable_baselines.common.vec_env import VecEnv diff --git a/stable_baselines/td3/policies.py b/stable_baselines/td3/policies.py index bc68890012..e83ca696ec 100644 --- a/stable_baselines/td3/policies.py +++ b/stable_baselines/td3/policies.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf from gym.spaces import Box from stable_baselines.common.policies import BasePolicy, nature_cnn, register_policy diff --git a/stable_baselines/td3/td3.py b/stable_baselines/td3/td3.py index 5f65d8bbc3..95b1fd78a2 100644 --- a/stable_baselines/td3/td3.py +++ b/stable_baselines/td3/td3.py @@ -2,7 +2,7 @@ import warnings import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines import logger from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter diff --git a/stable_baselines/trpo_mpi/trpo_mpi.py b/stable_baselines/trpo_mpi/trpo_mpi.py index 7d95356db8..fc796a3d91 100644 --- a/stable_baselines/trpo_mpi/trpo_mpi.py +++ b/stable_baselines/trpo_mpi/trpo_mpi.py @@ -4,7 +4,7 @@ import gym from mpi4py import MPI -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np import stable_baselines.common.tf_util as tf_util diff --git a/tests/test_a2c_conv.py b/tests/test_a2c_conv.py index 59de1e04d6..38d740791a 100644 --- a/tests/test_a2c_conv.py +++ b/tests/test_a2c_conv.py @@ -1,6 +1,6 @@ import gym import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines.common.tf_layers import conv from stable_baselines.common.input import observation_input diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 725d88ffeb..c74d05b5ce 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -2,7 +2,7 @@ import gym import pytest -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines import A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO, SAC, DDPG from stable_baselines.common.policies import FeedForwardPolicy diff --git a/tests/test_distri.py b/tests/test_distri.py index d3be362617..4e055f4b94 100644 --- a/tests/test_distri.py +++ b/tests/test_distri.py @@ -1,5 +1,5 @@ import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf import stable_baselines.common.tf_util as tf_util from stable_baselines.common.distributions import DiagGaussianProbabilityDistributionType,\ diff --git a/tests/test_math_util.py b/tests/test_math_util.py index 584ba98d47..3221322e58 100644 --- a/tests/test_math_util.py +++ b/tests/test_math_util.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np from gym.spaces.box import Box diff --git a/tests/test_tf_util.py b/tests/test_tf_util.py index d71374da03..553e59c99c 100644 --- a/tests/test_tf_util.py +++ b/tests/test_tf_util.py @@ -1,6 +1,6 @@ # tests for tf_util import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from stable_baselines.common.tf_util import function, initialize, single_threaded_session, is_image