From 66d17769f44c2a472dd7a36c3c7c07958a4c417c Mon Sep 17 00:00:00 2001 From: David Tao Date: Mon, 13 Aug 2018 15:17:22 -0400 Subject: [PATCH 01/11] skeleton for game logger --- textworld/envs/wrappers/game_logger.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 textworld/envs/wrappers/game_logger.py diff --git a/textworld/envs/wrappers/game_logger.py b/textworld/envs/wrappers/game_logger.py new file mode 100644 index 00000000..cd106381 --- /dev/null +++ b/textworld/envs/wrappers/game_logger.py @@ -0,0 +1,21 @@ +import sys +from typing import Tuple + +from textworld.core import Environment, GameState, Wrapper + + +class GameLogger(Wrapper): + def __init__(self, env: Environment) -> None: + """ + Wrap around a TextWorld environment to provide logging capabilities. + + Parameters + ---------- + :param env: + The TextWorld environment to wrap + """ + super().__init__(env) + self.activate_state_tracking() + + def step(self, command: str) -> Tuple[GameState, float, bool]: + game_state, score, done = super().step(command) From f8de962fe85d499a29e9b66adc82564e08c2488c Mon Sep 17 00:00:00 2001 From: Ruo Yu Tao Date: Mon, 13 Aug 2018 16:02:53 -0400 Subject: [PATCH 02/11] first iteration of game logger --- textworld/envs/wrappers/game_logger.py | 27 ++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/textworld/envs/wrappers/game_logger.py b/textworld/envs/wrappers/game_logger.py index cd106381..dac47c1e 100644 --- a/textworld/envs/wrappers/game_logger.py +++ b/textworld/envs/wrappers/game_logger.py @@ -1,5 +1,5 @@ import sys -from typing import Tuple +from typing import Tuple, List from textworld.core import Environment, GameState, Wrapper @@ -12,10 +12,33 @@ def __init__(self, env: Environment) -> None: Parameters ---------- :param env: - The TextWorld environment to wrap + The TextWorld environment to wrap. Has the correct knowledge base. """ super().__init__(env) self.activate_state_tracking() + self.serialized_game = env + + self.logs = [] + self.current_log = {} + def step(self, command: str) -> Tuple[GameState, float, bool]: + if self.current_log: + self.logs.append(self.current_log) + self.current_log = {} + + self.current_log['action_taken'] = command + game_state, score, done = super().step(command) + return game_state, score, done + + def log_action_distribution(self, actions: List, probabilities: List): + action_dist = {a: p for a, p in zip(actions, probabilities)} + self.current_log['action_distribution'] = action_dist + + return self.logs + + def log(self, to_log): + self.current_log['others'] = to_log + + return self.logs[:].append(self.current_log) From 336354b30a91dee4dd00bf0bc5bfc4deea8a6d9f Mon Sep 17 00:00:00 2001 From: David Tao Date: Tue, 14 Aug 2018 11:50:47 -0400 Subject: [PATCH 03/11] Change to GlulxLogger --- textworld/envs/glulx/__init__.py | 1 + textworld/envs/glulx/git_glulx_ml.py | 4 + textworld/envs/wrappers/game_logger.py | 44 ---------- textworld/envs/wrappers/glulx_logger.py | 103 ++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 44 deletions(-) delete mode 100644 textworld/envs/wrappers/game_logger.py create mode 100644 textworld/envs/wrappers/glulx_logger.py diff --git a/textworld/envs/glulx/__init__.py b/textworld/envs/glulx/__init__.py index e69de29b..8b137891 100644 --- a/textworld/envs/glulx/__init__.py +++ b/textworld/envs/glulx/__init__.py @@ -0,0 +1 @@ + diff --git a/textworld/envs/glulx/git_glulx_ml.py b/textworld/envs/glulx/git_glulx_ml.py index 5fab75c6..3e18f451 100644 --- a/textworld/envs/glulx/git_glulx_ml.py +++ b/textworld/envs/glulx/git_glulx_ml.py @@ -424,6 +424,10 @@ def compute_intermediate_reward(self) -> None: def __del__(self) -> None: self.close() + @property + def gamefile(self) -> str: + return self._gamefile + @property def game_running(self) -> bool: """ Determines if the game is still running. """ diff --git a/textworld/envs/wrappers/game_logger.py b/textworld/envs/wrappers/game_logger.py deleted file mode 100644 index dac47c1e..00000000 --- a/textworld/envs/wrappers/game_logger.py +++ /dev/null @@ -1,44 +0,0 @@ -import sys -from typing import Tuple, List - -from textworld.core import Environment, GameState, Wrapper - - -class GameLogger(Wrapper): - def __init__(self, env: Environment) -> None: - """ - Wrap around a TextWorld environment to provide logging capabilities. - - Parameters - ---------- - :param env: - The TextWorld environment to wrap. Has the correct knowledge base. - """ - super().__init__(env) - self.activate_state_tracking() - - self.serialized_game = env - - self.logs = [] - self.current_log = {} - - def step(self, command: str) -> Tuple[GameState, float, bool]: - if self.current_log: - self.logs.append(self.current_log) - self.current_log = {} - - self.current_log['action_taken'] = command - - game_state, score, done = super().step(command) - return game_state, score, done - - def log_action_distribution(self, actions: List, probabilities: List): - action_dist = {a: p for a, p in zip(actions, probabilities)} - self.current_log['action_distribution'] = action_dist - - return self.logs - - def log(self, to_log): - self.current_log['others'] = to_log - - return self.logs[:].append(self.current_log) diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py new file mode 100644 index 00000000..a8adc2a1 --- /dev/null +++ b/textworld/envs/wrappers/glulx_logger.py @@ -0,0 +1,103 @@ +import sys +from typing import Tuple, List, Optional, Iterable, Union, Sized, Any, Mapping + +from textworld.core import Environment, GameState, Wrapper +from textworld.envs.glulx.git_glulx_ml import GitGlulxMLEnvironment, GlulxGameState + + +class GlulxLogger(Wrapper): + def __init__(self, env: GitGlulxMLEnvironment) -> None: + """ + Wrap around a TextWorld GitGlulxML environment to provide logging capabilities. + + Parameters + ---------- + :param env: + The GitGlulxML environment to wrap. + """ + super().__init__(env) + self.activate_state_tracking() + + self.serialized_game = env.game.serialize() + self.gamefile = env.gamefile + + self._logs = [] + self.current_log = {'optional': []} + + def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: + """ + Take a step in the environment, save needed information. + :param command: + input string for taking an action + :return: + GlulxGameState, score and done. + """ + if self.current_log: + self._logs.append(self.current_log) + self.current_log = {'optional': []} + + self.current_log['command'] = command + + game_state, score, done = super().step(command) + self.current_log['feedback'] = game_state.feedback + self.current_log['score'] = score + self.current_log['done'] = done + self.current_log['action'] = game_state.action.serialize() + self.current_log['state'] = game_state.state.serialize() + + return game_state, score, done + + def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[float], Sized]]=None) -> None: + """ + Add custom commands to the logger. Optionally add scores for each command. + :param commands: + A list of commands. + :param scores: + scores for each command. Must be same size as commands if provided. + :return: + """ + command_mapping = commands + if scores is not None: + assert len(scores) == len(commands) + command_mapping = {a: p for a, p in zip(commands, scores)} + + self.current_log['command_distribution'] = command_mapping + + def add(self, info: Any) -> None: + """ + Add any additional information you want to log. + :param info: + Additional information to log for the current game state. + """ + self.current_log['optional'].append(info) + + @property + def logs(self) -> List[Mapping]: + """ + Get all logs + :return: List of all logs + """ + logs = self._logs[:] + logs.append(self.current_log) + return logs + + def __getitem__(self, index: int) -> Mapping: + """ + Get a certain log at a given index. + :param index: + index of log to get. + :return: + log at index. + """ + assert index <= len(self._logs) + + if index < len(self._logs) - 1: + return self._logs[index] + return self.current_log + + def serialize(self) -> List[Mapping]: + """ + Get serialized mappings of logs. + :return: List of serialized mappings. + """ + return self.logs \ No newline at end of file From 2522015777ab286bc1852c4bc2c8dec62f3db046 Mon Sep 17 00:00:00 2001 From: David Tao Date: Tue, 14 Aug 2018 11:52:48 -0400 Subject: [PATCH 04/11] add str method --- textworld/envs/wrappers/glulx_logger.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index a8adc2a1..c5edf5aa 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -95,6 +95,9 @@ def __getitem__(self, index: int) -> Mapping: return self._logs[index] return self.current_log + def __str__(self) -> Mapping: + return self.logs + def serialize(self) -> List[Mapping]: """ Get serialized mappings of logs. From 862cba1746437be3486e52453e885768cbb241ea Mon Sep 17 00:00:00 2001 From: David Tao Date: Tue, 14 Aug 2018 13:41:14 -0400 Subject: [PATCH 05/11] add initial unit tests --- textworld/envs/wrappers/__init__.py | 1 + textworld/envs/wrappers/glulx_logger.py | 58 +++++++++++++------ .../envs/wrappers/tests/test_glulx_logger.py | 51 ++++++++++++++++ 3 files changed, 91 insertions(+), 19 deletions(-) create mode 100644 textworld/envs/wrappers/tests/test_glulx_logger.py diff --git a/textworld/envs/wrappers/__init__.py b/textworld/envs/wrappers/__init__.py index 47ecd37b..ddd0330e 100644 --- a/textworld/envs/wrappers/__init__.py +++ b/textworld/envs/wrappers/__init__.py @@ -4,3 +4,4 @@ from textworld.envs.wrappers.viewer import HtmlViewer from textworld.envs.wrappers.recorder import Recorder +from textworld.envs.wrappers.glulx_logger import GlulxLogger diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index c5edf5aa..b0d799a1 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -1,4 +1,6 @@ -import sys +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + from typing import Tuple, List, Optional, Iterable, Union, Sized, Any, Mapping from textworld.core import Environment, GameState, Wrapper @@ -19,10 +21,8 @@ def __init__(self, env: GitGlulxMLEnvironment) -> None: self.activate_state_tracking() self.serialized_game = env.game.serialize() - self.gamefile = env.gamefile + self._gamefile = env.gamefile - self._logs = [] - self.current_log = {'optional': []} def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: """ @@ -32,21 +32,33 @@ def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: :return: GlulxGameState, score and done. """ - if self.current_log: - self._logs.append(self.current_log) - self.current_log = {'optional': []} + self._logs.append(self._current) + self._current = {'optional': []} - self.current_log['command'] = command + self._current['command'] = command game_state, score, done = super().step(command) - self.current_log['feedback'] = game_state.feedback - self.current_log['score'] = score - self.current_log['done'] = done - self.current_log['action'] = game_state.action.serialize() - self.current_log['state'] = game_state.state.serialize() + self._current['feedback'] = game_state.feedback + self._current['score'] = score + self._current['done'] = done + self._current['action'] = game_state.action.serialize() + self._current['state'] = game_state.state.serialize() return game_state, score, done + def reset(self) -> GameState: + """ + Reset the environment. + Also clears logs. + """ + self._logs = [] + game_state = super().reset() + self._current = {'optional': []} + self._current['done'] = False + self._current['state'] = game_state.state.serialize() + + return game_state + def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[float], Sized]]=None) -> None: """ Add custom commands to the logger. Optionally add scores for each command. @@ -61,7 +73,7 @@ def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[floa assert len(scores) == len(commands) command_mapping = {a: p for a, p in zip(commands, scores)} - self.current_log['command_distribution'] = command_mapping + self._current['command_distribution'] = command_mapping def add(self, info: Any) -> None: """ @@ -69,7 +81,11 @@ def add(self, info: Any) -> None: :param info: Additional information to log for the current game state. """ - self.current_log['optional'].append(info) + self._current['optional'].append(info) + + @property + def current(self) -> Mapping: + return self._current @property def logs(self) -> List[Mapping]: @@ -78,9 +94,13 @@ def logs(self) -> List[Mapping]: :return: List of all logs """ logs = self._logs[:] - logs.append(self.current_log) + logs.append(self._current) return logs + @property + def gamefile(self): + return self._gamefile + def __getitem__(self, index: int) -> Mapping: """ Get a certain log at a given index. @@ -93,10 +113,10 @@ def __getitem__(self, index: int) -> Mapping: if index < len(self._logs) - 1: return self._logs[index] - return self.current_log + return self._current - def __str__(self) -> Mapping: - return self.logs + def __str__(self) -> str: + return str(self.logs) def serialize(self) -> List[Mapping]: """ diff --git a/textworld/envs/wrappers/tests/test_glulx_logger.py b/textworld/envs/wrappers/tests/test_glulx_logger.py new file mode 100644 index 00000000..33dc3490 --- /dev/null +++ b/textworld/envs/wrappers/tests/test_glulx_logger.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import textworld +import numpy as np + +from textworld.envs.wrappers import GlulxLogger +from textworld.utils import make_temp_directory +from textworld.generator import compile_game +from textworld import g_rng + + +def test_glulx_logger(): + num_nodes = 3 + num_items = 10 + g_rng.set_seed(1234) + grammar_flags = {"theme": "house", "include_adj": True} + game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, grammar_flags=grammar_flags) + + game_name = "test_glulx_logger" + with make_temp_directory(prefix=game_name) as tmpdir: + game_file = compile_game(game, game_name, games_folder=tmpdir) + + env = textworld.start(game_file) + env = GlulxLogger(env) + env.activate_state_tracking() + game_state = env.reset() + + # test reset + assert hasattr(env.current, 'state') + + # test step + options = game_state.admissible_commands + game_state, score, done = env.step(options[0]) + assert len(env.logs) > 1 + assert hasattr(env.current, 'action') + assert hasattr(env.current, 'state') + assert hasattr(env.current, 'feedback') + + # test add_commands + option_scores = np.array([0.1] * len(options)) + env.add_commands(options, option_scores) + assert len(env.current['command_distribution'].values()) == len(options) + + # test add + additional_info = {'scores': option_scores} + env.add(additional_info) + assert len(env.current['optional']) > 0 + + + From ad262f86502a2af1b7f67c80e85e15fa63c6b749 Mon Sep 17 00:00:00 2001 From: David Tao Date: Wed, 15 Aug 2018 14:04:36 -0400 Subject: [PATCH 06/11] fix tests --- textworld/envs/wrappers/tests/test_glulx_logger.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/textworld/envs/wrappers/tests/test_glulx_logger.py b/textworld/envs/wrappers/tests/test_glulx_logger.py index 33dc3490..fb6f23d5 100644 --- a/textworld/envs/wrappers/tests/test_glulx_logger.py +++ b/textworld/envs/wrappers/tests/test_glulx_logger.py @@ -27,15 +27,15 @@ def test_glulx_logger(): game_state = env.reset() # test reset - assert hasattr(env.current, 'state') + assert 'state' in env.current # test step options = game_state.admissible_commands game_state, score, done = env.step(options[0]) assert len(env.logs) > 1 - assert hasattr(env.current, 'action') - assert hasattr(env.current, 'state') - assert hasattr(env.current, 'feedback') + assert 'action' in env.current + assert 'state' in env.current + assert 'feedback' in env.current # test add_commands option_scores = np.array([0.1] * len(options)) From a4ca1433a42c4e0dd7eb6a419048698212fe6822 Mon Sep 17 00:00:00 2001 From: David Tao Date: Mon, 20 Aug 2018 17:53:13 -0400 Subject: [PATCH 07/11] untested game logger --- textworld/envs/wrappers/__init__.py | 1 + textworld/envs/wrappers/game_log.py | 43 ++++++++++ textworld/envs/wrappers/glulx_logger.py | 107 ++++++++++++++++-------- 3 files changed, 114 insertions(+), 37 deletions(-) create mode 100644 textworld/envs/wrappers/game_log.py diff --git a/textworld/envs/wrappers/__init__.py b/textworld/envs/wrappers/__init__.py index ddd0330e..8f207ddb 100644 --- a/textworld/envs/wrappers/__init__.py +++ b/textworld/envs/wrappers/__init__.py @@ -5,3 +5,4 @@ from textworld.envs.wrappers.viewer import HtmlViewer from textworld.envs.wrappers.recorder import Recorder from textworld.envs.wrappers.glulx_logger import GlulxLogger +from textworld.envs.wrappers.game_log import GameLog \ No newline at end of file diff --git a/textworld/envs/wrappers/game_log.py b/textworld/envs/wrappers/game_log.py new file mode 100644 index 00000000..8c737512 --- /dev/null +++ b/textworld/envs/wrappers/game_log.py @@ -0,0 +1,43 @@ +import json + +class GameLog: + def __init__(self): + """ + GameLog object. Allows your to load and save previous game logs. + """ + self._logs = [[]] + self._current_game = self._logs[-1] + self._filename = '' + + def __getitem__(self, idx): + assert idx <= len(self._logs) + return self._logs[idx] + + def __len__(self): + return len(self._logs) + + @property + def current_game(self): + return self._current_game + + @property + def logs(self): + return self._logs + + def new_game(self): + self._logs.append([]) + self._current_game = self._logs[-1] + return self._current_game + + def save(self, filename): + self._filename = filename + try: + with open(filename, 'w') as outfile: + json.dump(self._logs, outfile) + except TypeError as e: + raise TypeError('Log not serializable') + + def load(self, filename): + self._filename = filename + with open(filename) as f: + self._logs= json.load(f) \ No newline at end of file diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index b0d799a1..c1b7c6f9 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -5,6 +5,7 @@ from textworld.core import Environment, GameState, Wrapper from textworld.envs.glulx.git_glulx_ml import GitGlulxMLEnvironment, GlulxGameState +from textworld.envs.wrappers import GameLog class GlulxLogger(Wrapper): @@ -12,10 +13,8 @@ def __init__(self, env: GitGlulxMLEnvironment) -> None: """ Wrap around a TextWorld GitGlulxML environment to provide logging capabilities. - Parameters - ---------- - :param env: - The GitGlulxML environment to wrap. + Args: + env: The GitGlulxML environment to wrap. """ super().__init__(env) self.activate_state_tracking() @@ -23,17 +22,21 @@ def __init__(self, env: GitGlulxMLEnvironment) -> None: self.serialized_game = env.game.serialize() self._gamefile = env.gamefile + self._logs = GameLog() + self._current_log = self._logs.current_game + self._current_log.append({}) + self._current = self._current_log[-1] + def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: """ - Take a step in the environment, save needed information. - :param command: - input string for taking an action - :return: + Take a step in the environment. + Args: + command: input string for taking an action + + Returns: GlulxGameState, score and done. """ - self._logs.append(self._current) - self._current = {'optional': []} self._current['command'] = command @@ -49,8 +52,11 @@ def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: def reset(self) -> GameState: """ Reset the environment. - Also clears logs. + Adds a new game into the logs. + Returns: + GameState """ + # if not self._current self._logs = [] game_state = super().reset() self._current = {'optional': []} @@ -62,65 +68,92 @@ def reset(self) -> GameState: def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[float], Sized]]=None) -> None: """ Add custom commands to the logger. Optionally add scores for each command. - :param commands: - A list of commands. - :param scores: - scores for each command. Must be same size as commands if provided. - :return: + Args: + commands: A list of commands. + scores: scores for each command. Must be same size as commands if provided. + + Returns: + """ - command_mapping = commands if scores is not None: - assert len(scores) == len(commands) - command_mapping = {a: p for a, p in zip(commands, scores)} + self._current['command_scores'] = scores - self._current['command_distribution'] = command_mapping + self._current['commands'] = commands def add(self, info: Any) -> None: """ Add any additional information you want to log. - :param info: - Additional information to log for the current game state. + Args: + info: Additional information to log for the current game state. """ self._current['optional'].append(info) @property def current(self) -> Mapping: + """ + Returns: + Current game state logs. + """ return self._current @property def logs(self) -> List[Mapping]: """ - Get all logs - :return: List of all logs + Returns: List of all logs from this game. """ - logs = self._logs[:] - logs.append(self._current) - return logs + return self._current_log @property - def gamefile(self): + def all_logs(self) -> GameLog: + """ + Returns: GameLog object containing all logs. + """ + return self._logs + + @property + def gamefile(self) -> str: + """ + Returns: + Game file currently loaded + """ return self._gamefile def __getitem__(self, index: int) -> Mapping: """ Get a certain log at a given index. - :param index: - index of log to get. - :return: - log at index. + Args: + index: index of log to get. + Returns: + log at index """ assert index <= len(self._logs) - if index < len(self._logs) - 1: - return self._logs[index] return self._current def __str__(self) -> str: - return str(self.logs) + return str(self._current_log) def serialize(self) -> List[Mapping]: """ Get serialized mappings of logs. - :return: List of serialized mappings. + Returns: + List of serialized mappings. + """ + return self._logs.logs + + def save(self, filename) -> None: + """ + Saves all logs given a filename + Returns: None + """ + self._logs.save(filename) + + def load(self, filename) -> None: + """ + Loads logs from a file + Args: + filename: + string representing file location + Returns: None """ - return self.logs \ No newline at end of file + self._logs.load(filename) From eb4e0c5ed7b6f63ad3c9fc23c28cddb529fc7f8a Mon Sep 17 00:00:00 2001 From: David Tao Date: Mon, 20 Aug 2018 19:06:31 -0400 Subject: [PATCH 08/11] working game logger with separate game log --- textworld/envs/wrappers/__init__.py | 1 - textworld/envs/wrappers/game_log.py | 43 ----- textworld/envs/wrappers/glulx_logger.py | 165 +++++++++++++++--- .../envs/wrappers/tests/test_glulx_logger.py | 6 +- 4 files changed, 141 insertions(+), 74 deletions(-) delete mode 100644 textworld/envs/wrappers/game_log.py diff --git a/textworld/envs/wrappers/__init__.py b/textworld/envs/wrappers/__init__.py index 8f207ddb..ddd0330e 100644 --- a/textworld/envs/wrappers/__init__.py +++ b/textworld/envs/wrappers/__init__.py @@ -5,4 +5,3 @@ from textworld.envs.wrappers.viewer import HtmlViewer from textworld.envs.wrappers.recorder import Recorder from textworld.envs.wrappers.glulx_logger import GlulxLogger -from textworld.envs.wrappers.game_log import GameLog \ No newline at end of file diff --git a/textworld/envs/wrappers/game_log.py b/textworld/envs/wrappers/game_log.py deleted file mode 100644 index 8c737512..00000000 --- a/textworld/envs/wrappers/game_log.py +++ /dev/null @@ -1,43 +0,0 @@ -import json - -class GameLog: - def __init__(self): - """ - GameLog object. Allows your to load and save previous game logs. - """ - self._logs = [[]] - self._current_game = self._logs[-1] - self._filename = '' - - def __getitem__(self, idx): - assert idx <= len(self._logs) - return self._logs[idx] - - def __len__(self): - return len(self._logs) - - @property - def current_game(self): - return self._current_game - - @property - def logs(self): - return self._logs - - def new_game(self): - self._logs.append([]) - self._current_game = self._logs[-1] - return self._current_game - - def save(self, filename): - self._filename = filename - try: - with open(filename, 'w') as outfile: - json.dump(self._logs, outfile) - except TypeError as e: - raise TypeError('Log not serializable') - - def load(self, filename): - self._filename = filename - with open(filename) as f: - self._logs= json.load(f) \ No newline at end of file diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index c1b7c6f9..b0b4b079 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -1,11 +1,118 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. +import json from typing import Tuple, List, Optional, Iterable, Union, Sized, Any, Mapping -from textworld.core import Environment, GameState, Wrapper +from textworld.core import GameState, Wrapper from textworld.envs.glulx.git_glulx_ml import GitGlulxMLEnvironment, GlulxGameState -from textworld.envs.wrappers import GameLog + + +class GameLog: + def __init__(self): + """ + GameLog object. Allows your to load and save previous game logs. + """ + self._logs = [[]] + self._current_game = self._logs[-1] + self._filename = '' + + def __getitem__(self, idx: int) -> list: + """ + Gets a particular game log at index idx. + Args: + idx: index to retrieve + Returns: + + """ + assert idx <= len(self._logs) + return self._logs[idx] + + def __len__(self) -> int: + return len(self._logs) + + @property + def current_game(self) -> list: + """ + Gets current game we're logging. + Returns: list of logs from current game. + """ + return self._current_game + + @property + def logs(self) -> list: + """ + Get all logs from all games. + Returns: All logs from all games. + """ + return self._logs + + def new_game(self): + """ + Start logs for a new game. + Returns: log object for current game. + """ + if len(self._current_game) > 0: + self._logs.append([]) + self._current_game = self._logs[-1] + return self._current_game + + def set(self, key: Any, value: Any) -> None: + """ + Sets value for latest game + Args: + key: Key to set + value: Value to set + + """ + current = self._current_game[-1] + current[key] = value + + def append_optional(self, value: Any) -> None: + """ + Appends optional information to current game + Args: + value: Value to append + + """ + current = self._current_game[-1] + if 'optional' not in current: + current['optional'] = [] + current['optional'].append(value) + + def add_log(self, log: Mapping): + """ + Adds a new log to our logs + Args: + log: Mapping of a log + + """ + self._current_game.append(log) + + def save(self, filename): + """ + Save current logs to specified file name + Args: + filename: File path to save to (should have JSON extension) + + """ + self._filename = filename + try: + with open(filename, 'w') as outfile: + json.dump(self._logs, outfile) + except TypeError as e: + raise TypeError('Log not serializable') + + def load(self, filename): + """ + Loads a JSON object as logs + Args: + filename: file path to load. + + """ + self._filename = filename + with open(filename) as f: + self._logs= json.load(f) class GlulxLogger(Wrapper): @@ -23,29 +130,27 @@ def __init__(self, env: GitGlulxMLEnvironment) -> None: self._gamefile = env.gamefile self._logs = GameLog() - self._current_log = self._logs.current_game - self._current_log.append({}) - self._current = self._current_log[-1] - def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: """ Take a step in the environment. Args: command: input string for taking an action - Returns: GlulxGameState, score and done. """ - - self._current['command'] = command + new_log = {} + new_log['optional'] = [] + new_log['command'] = command game_state, score, done = super().step(command) - self._current['feedback'] = game_state.feedback - self._current['score'] = score - self._current['done'] = done - self._current['action'] = game_state.action.serialize() - self._current['state'] = game_state.state.serialize() + new_log['feedback'] = game_state.feedback + new_log['score'] = score + new_log['done'] = done + new_log['description'] = game_state.description + new_log['inventory'] = game_state.inventory + new_log['state'] = game_state.state.serialize() + self._logs.add_log(new_log) return game_state, score, done @@ -56,12 +161,16 @@ def reset(self) -> GameState: Returns: GameState """ - # if not self._current - self._logs = [] + new_log = {} + self._logs.new_game() + game_state = super().reset() - self._current = {'optional': []} - self._current['done'] = False - self._current['state'] = game_state.state.serialize() + new_log['optional'] = [] + new_log['done'] = False + new_log['description'] = game_state.description + new_log['inventory'] = game_state.inventory + new_log['state'] = game_state.state.serialize() + self._logs.add_log(new_log) return game_state @@ -72,13 +181,11 @@ def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[floa commands: A list of commands. scores: scores for each command. Must be same size as commands if provided. - Returns: - """ if scores is not None: - self._current['command_scores'] = scores + self._logs.set('command_scores', scores) - self._current['commands'] = commands + self._logs.set('commands', commands) def add(self, info: Any) -> None: """ @@ -86,7 +193,7 @@ def add(self, info: Any) -> None: Args: info: Additional information to log for the current game state. """ - self._current['optional'].append(info) + self._logs.append_optional(info) @property def current(self) -> Mapping: @@ -94,14 +201,14 @@ def current(self) -> Mapping: Returns: Current game state logs. """ - return self._current + return self._logs.current_game[-1] @property def logs(self) -> List[Mapping]: """ Returns: List of all logs from this game. """ - return self._current_log + return self._logs.current_game @property def all_logs(self) -> GameLog: @@ -128,10 +235,10 @@ def __getitem__(self, index: int) -> Mapping: """ assert index <= len(self._logs) - return self._current + return self._logs.current_game[index] def __str__(self) -> str: - return str(self._current_log) + return str(self._logs.current_game) def serialize(self) -> List[Mapping]: """ @@ -157,3 +264,5 @@ def load(self, filename) -> None: Returns: None """ self._logs.load(filename) + + diff --git a/textworld/envs/wrappers/tests/test_glulx_logger.py b/textworld/envs/wrappers/tests/test_glulx_logger.py index fb6f23d5..23beee28 100644 --- a/textworld/envs/wrappers/tests/test_glulx_logger.py +++ b/textworld/envs/wrappers/tests/test_glulx_logger.py @@ -25,6 +25,8 @@ def test_glulx_logger(): env = GlulxLogger(env) env.activate_state_tracking() game_state = env.reset() + game_state = env.reset() + assert len(env.all_logs.logs) == 2 # test reset assert 'state' in env.current @@ -33,14 +35,14 @@ def test_glulx_logger(): options = game_state.admissible_commands game_state, score, done = env.step(options[0]) assert len(env.logs) > 1 - assert 'action' in env.current + assert 'command' in env.current assert 'state' in env.current assert 'feedback' in env.current # test add_commands option_scores = np.array([0.1] * len(options)) env.add_commands(options, option_scores) - assert len(env.current['command_distribution'].values()) == len(options) + assert len(env.current['commands']) == len(env.current['command_scores']) # test add additional_info = {'scores': option_scores} From 5c498d5d6df9312d81d6172337aa6938577ee862 Mon Sep 17 00:00:00 2001 From: David Tao Date: Tue, 21 Aug 2018 10:11:06 -0400 Subject: [PATCH 09/11] fix requested changes --- textworld/envs/glulx/__init__.py | 1 - textworld/envs/wrappers/glulx_logger.py | 60 ++++++++++++++++--------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/textworld/envs/glulx/__init__.py b/textworld/envs/glulx/__init__.py index 8b137891..e69de29b 100644 --- a/textworld/envs/glulx/__init__.py +++ b/textworld/envs/glulx/__init__.py @@ -1 +0,0 @@ - diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index b0b4b079..1d0a4061 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. + import json from typing import Tuple, List, Optional, Iterable, Union, Sized, Any, Mapping @@ -9,20 +10,21 @@ class GameLog: + """ + GameLog object. Allows you to load and save previous game logs. + """ def __init__(self): - """ - GameLog object. Allows your to load and save previous game logs. - """ self._logs = [[]] - self._current_game = self._logs[-1] self._filename = '' def __getitem__(self, idx: int) -> list: - """ + """A Gets a particular game log at index idx. + Args: idx: index to retrieve - Returns: + + Returns: list of logs at idx """ assert idx <= len(self._logs) @@ -35,14 +37,16 @@ def __len__(self) -> int: def current_game(self) -> list: """ Gets current game we're logging. + Returns: list of logs from current game. """ - return self._current_game + return self._logs[-1] @property def logs(self) -> list: """ Get all logs from all games. + Returns: All logs from all games. """ return self._logs @@ -50,32 +54,34 @@ def logs(self) -> list: def new_game(self): """ Start logs for a new game. + Returns: log object for current game. """ - if len(self._current_game) > 0: + if len(self.current_game) > 0: self._logs.append([]) - self._current_game = self._logs[-1] - return self._current_game + return self.current_game def set(self, key: Any, value: Any) -> None: """ Sets value for latest game + Args: key: Key to set value: Value to set """ - current = self._current_game[-1] + current = self.current_game[-1] current[key] = value def append_optional(self, value: Any) -> None: """ Appends optional information to current game + Args: - value: Value to append + value: Value to append. Must be JSON serializable. """ - current = self._current_game[-1] + current = self.current_game[-1] if 'optional' not in current: current['optional'] = [] current['optional'].append(value) @@ -83,15 +89,17 @@ def append_optional(self, value: Any) -> None: def add_log(self, log: Mapping): """ Adds a new log to our logs + Args: log: Mapping of a log """ - self._current_game.append(log) + self.current_game.append(log) def save(self, filename): """ Save current logs to specified file name + Args: filename: File path to save to (should have JSON extension) @@ -106,6 +114,7 @@ def save(self, filename): def load(self, filename): """ Loads a JSON object as logs + Args: filename: file path to load. @@ -116,13 +125,13 @@ def load(self, filename): class GlulxLogger(Wrapper): - def __init__(self, env: GitGlulxMLEnvironment) -> None: - """ - Wrap around a TextWorld GitGlulxML environment to provide logging capabilities. + """ + Wrap around a TextWorld GitGlulxML environment to provide logging capabilities. - Args: - env: The GitGlulxML environment to wrap. - """ + Args: + env: The GitGlulxML environment to wrap. + """ + def __init__(self, env: GitGlulxMLEnvironment) -> None: super().__init__(env) self.activate_state_tracking() @@ -134,8 +143,10 @@ def __init__(self, env: GitGlulxMLEnvironment) -> None: def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: """ Take a step in the environment. + Args: command: input string for taking an action + Returns: GlulxGameState, score and done. """ @@ -158,6 +169,7 @@ def reset(self) -> GameState: """ Reset the environment. Adds a new game into the logs. + Returns: GameState """ @@ -177,6 +189,7 @@ def reset(self) -> GameState: def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[float], Sized]]=None) -> None: """ Add custom commands to the logger. Optionally add scores for each command. + Args: commands: A list of commands. scores: scores for each command. Must be same size as commands if provided. @@ -190,6 +203,7 @@ def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[floa def add(self, info: Any) -> None: """ Add any additional information you want to log. + Args: info: Additional information to log for the current game state. """ @@ -228,8 +242,10 @@ def gamefile(self) -> str: def __getitem__(self, index: int) -> Mapping: """ Get a certain log at a given index. + Args: index: index of log to get. + Returns: log at index """ @@ -243,6 +259,7 @@ def __str__(self) -> str: def serialize(self) -> List[Mapping]: """ Get serialized mappings of logs. + Returns: List of serialized mappings. """ @@ -251,6 +268,7 @@ def serialize(self) -> List[Mapping]: def save(self, filename) -> None: """ Saves all logs given a filename + Returns: None """ self._logs.save(filename) @@ -258,9 +276,11 @@ def save(self, filename) -> None: def load(self, filename) -> None: """ Loads logs from a file + Args: filename: string representing file location + Returns: None """ self._logs.load(filename) From 2cd2d921ee268a9a3e922c39d436dc8dce0489f8 Mon Sep 17 00:00:00 2001 From: David Tao Date: Tue, 21 Aug 2018 10:16:20 -0400 Subject: [PATCH 10/11] fix comments --- textworld/envs/wrappers/glulx_logger.py | 63 ++++++++----------------- 1 file changed, 20 insertions(+), 43 deletions(-) diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index 1d0a4061..4bc9a5b6 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -51,7 +51,7 @@ def logs(self) -> list: """ return self._logs - def new_game(self): + def new_game(self) -> list: """ Start logs for a new game. @@ -96,7 +96,7 @@ def add_log(self, log: Mapping): """ self.current_game.append(log) - def save(self, filename): + def save(self, filename: str) -> None: """ Save current logs to specified file name @@ -111,7 +111,7 @@ def save(self, filename): except TypeError as e: raise TypeError('Log not serializable') - def load(self, filename): + def load(self, filename: str) -> None: """ Loads a JSON object as logs @@ -121,12 +121,12 @@ def load(self, filename): """ self._filename = filename with open(filename) as f: - self._logs= json.load(f) + self._logs = json.load(f) class GlulxLogger(Wrapper): """ - Wrap around a TextWorld GitGlulxML environment to provide logging capabilities. + Wrapper around a TextWorld GitGlulxML environment to provide logging capabilities. Args: env: The GitGlulxML environment to wrap. @@ -138,7 +138,7 @@ def __init__(self, env: GitGlulxMLEnvironment) -> None: self.serialized_game = env.game.serialize() self._gamefile = env.gamefile - self._logs = GameLog() + self._game_log = GameLog() def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: """ @@ -161,7 +161,7 @@ def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: new_log['description'] = game_state.description new_log['inventory'] = game_state.inventory new_log['state'] = game_state.state.serialize() - self._logs.add_log(new_log) + self._game_log.add_log(new_log) return game_state, score, done @@ -174,7 +174,7 @@ def reset(self) -> GameState: GameState """ new_log = {} - self._logs.new_game() + self._game_log.new_game() game_state = super().reset() new_log['optional'] = [] @@ -182,23 +182,22 @@ def reset(self) -> GameState: new_log['description'] = game_state.description new_log['inventory'] = game_state.inventory new_log['state'] = game_state.state.serialize() - self._logs.add_log(new_log) + self._game_log.add_log(new_log) return game_state - def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[float], Sized]]=None) -> None: + def add_commands(self, commands: List[str], scores: Optional[Iterable[float]]=None) -> None: """ Add custom commands to the logger. Optionally add scores for each command. Args: commands: A list of commands. scores: scores for each command. Must be same size as commands if provided. - """ if scores is not None: - self._logs.set('command_scores', scores) + self._game_log.set('command_scores', scores) - self._logs.set('commands', commands) + self._game_log.set('commands', commands) def add(self, info: Any) -> None: """ @@ -207,7 +206,7 @@ def add(self, info: Any) -> None: Args: info: Additional information to log for the current game state. """ - self._logs.append_optional(info) + self._game_log.append_optional(info) @property def current(self) -> Mapping: @@ -215,21 +214,21 @@ def current(self) -> Mapping: Returns: Current game state logs. """ - return self._logs.current_game[-1] + return self._game_log.current_game[-1] @property def logs(self) -> List[Mapping]: """ Returns: List of all logs from this game. """ - return self._logs.current_game + return self._game_log.current_game @property def all_logs(self) -> GameLog: """ Returns: GameLog object containing all logs. """ - return self._logs + return self._game_log @property def gamefile(self) -> str: @@ -249,12 +248,12 @@ def __getitem__(self, index: int) -> Mapping: Returns: log at index """ - assert index <= len(self._logs) + assert index <= len(self._game_log) - return self._logs.current_game[index] + return self._game_log.current_game[index] def __str__(self) -> str: - return str(self._logs.current_game) + return str(self._game_log.current_game) def serialize(self) -> List[Mapping]: """ @@ -263,26 +262,4 @@ def serialize(self) -> List[Mapping]: Returns: List of serialized mappings. """ - return self._logs.logs - - def save(self, filename) -> None: - """ - Saves all logs given a filename - - Returns: None - """ - self._logs.save(filename) - - def load(self, filename) -> None: - """ - Loads logs from a file - - Args: - filename: - string representing file location - - Returns: None - """ - self._logs.load(filename) - - + return self._game_log.logs From 1c995646439c430484ef11502dd1756a1d4f4405 Mon Sep 17 00:00:00 2001 From: David Tao Date: Tue, 21 Aug 2018 16:00:38 -0400 Subject: [PATCH 11/11] remove type --- textworld/envs/wrappers/glulx_logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/textworld/envs/wrappers/glulx_logger.py b/textworld/envs/wrappers/glulx_logger.py index 4bc9a5b6..a609f8a2 100644 --- a/textworld/envs/wrappers/glulx_logger.py +++ b/textworld/envs/wrappers/glulx_logger.py @@ -18,7 +18,7 @@ def __init__(self): self._filename = '' def __getitem__(self, idx: int) -> list: - """A + """ Gets a particular game log at index idx. Args: