Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
- os: win64
os-version: windows-2022
- install-mode: dev
python-version: "3.11" # choice of Python version is arbitrary among those in matrix
python-version: "3.10" # choice of Python version is arbitrary among those in matrix
coverage: "true"
- os: win64 # only run mpi, on windows, until GH starts working again.
install-mode: dev
Expand Down
27 changes: 15 additions & 12 deletions src/parameter_sweep/loop_tool/loop_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_working_dir():
class loopTool:
def __init__(
self,
loop_file,
loop_configuration,
solver=None,
build_function=None,
initialize_function=None,
Expand All @@ -56,7 +56,7 @@ def __init__(
Loop tool class that runs iterative paramter sweeps

Arguments:
loop_file : .yaml config file that contains iterative loops to run
loop_configuration : .yaml file or dict that contains iterative loops
solver : solver to use in model, default uses watertap solver
build_function : function to build unit model
initialize_function : function for initialization of the unit model
Expand All @@ -76,7 +76,7 @@ def __init__(
use as many workers as specified to run through all loop options in parallel.
"""

self.loop_file = loop_file
self.loop_configuration = loop_configuration
self.solver = solver

self.build_function = build_function
Expand Down Expand Up @@ -132,8 +132,10 @@ def build_run_dict(self):
"""
This builds the dict that will be used for simulations
"""

loop_dict = ParameterSweepReader()._yaml_to_dict(self.loop_file)
if isinstance(self.loop_configuration, dict):
loop_dict = self.loop_configuration
else:
loop_dict = ParameterSweepReader()._yaml_to_dict(self.loop_configuration)
self.sweep_directory = {}
for key, loop in loop_dict.items():
self.check_dict_keys(loop)
Expand Down Expand Up @@ -247,8 +249,8 @@ def get_loop_key(self, loop, loop_type):
def build_sweep_directories(
self, loop, loop_type, sweep_directory, cur_dir, cur_h5_dir
):
"""this creats the loop directory dict, which is then used to run
the paramter sweep"""
"""this creates the loop directory dict, which is then used to run
the parameter sweep"""
if loop_type != None:
loop_type_recursive = self.get_loop_type(loop)
loop_key_current = self.get_loop_key(loop, loop_type)
Expand Down Expand Up @@ -308,7 +310,6 @@ def build_sweep_directories(
"original_options_dict": copy.deepcopy(self.options),
}
}

return sweep_directory, cur_dir

def update_dir_path(self, cur_dir, key, value):
Expand Down Expand Up @@ -484,7 +485,7 @@ def find_execution_configs(
if key != "simulation_setup":
self.find_execution_configs(value)
else:
self.execution_list.append(value)
self.execution_list.append(copy.deepcopy(value))
Comment thread
avdudchenko marked this conversation as resolved.

def execute_param_sweep_run(self, value):
"""this executes the parameter sweep
Expand All @@ -508,7 +509,7 @@ def setup_param_sweep(self, value):
tool, resets any of prior options"""
self.init_sim_options()
self.options = value["original_options_dict"]
self.build_default = value["build_defaults"]
self.build_defaults = value["build_defaults"]
self.optimize_defaults = value["optimize_defaults"]

self.init_defaults = value["init_defaults"]
Expand Down Expand Up @@ -555,9 +556,11 @@ def build_sim_kwargs(self):
)
self.build_outputs_kwargs = self.options.get("build_outputs_kwargs", None)
# generated combined build kwargs (default + loop)
self.combined_build_defaults = {} # self.build_default
self.combined_build_defaults = {} # self.build_defaults
self.combined_build_defaults.update(self.options.get("build_defaults", {}))
self.combined_build_defaults.update(self.build_default)
print("build_defaults", self.build_defaults)
self.combined_build_defaults.update(self.build_defaults)
print(self.combined_build_defaults)
# generated combined optimize kwargs (default + loop)
self.combined_optimize_defaults = {}
self.combined_optimize_defaults.update(
Expand Down
42 changes: 42 additions & 0 deletions src/parameter_sweep/loop_tool/tests/test_loop_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# information, respectively. These files are also available online at the URL
# "https://github.com/watertap-org/watertap/"
#################################################################################
from parameter_sweep.reader import ParameterSweepReader
import pytest
import os
import numpy as np
Expand Down Expand Up @@ -99,6 +100,36 @@ def loop_sweep_setup():
return lp, expected_run_dict


@pytest.fixture()
def loop_sweep_setup_from_dict():
dict_setup = ParameterSweepReader()._yaml_to_dict(
_this_file_path + "/test_sweep.yaml"
)

lp = loopTool(
dict_setup,
build_function=ro_setup.ro_build,
initialize_function=ro_setup.ro_init,
optimize_function=ro_setup.ro_solve,
saving_dir=_this_file_path,
save_name="ro_with_erd",
execute_simulations=False,
number_of_subprocesses=1,
)
lp.build_run_dict()
""" used to generate test file"""
if has_mpi_peer_processes() == False or (
has_mpi_peer_processes() and get_mpi_comm_process().Get_rank() == 0
):
with open(
_this_file_path + "/test_expected_sweep_directory.yaml", "r"
) as infile:
expected_run_dict = yaml.safe_load(infile)
else:
expected_run_dict = None
return lp, expected_run_dict


@pytest.fixture()
def loop_sweep_setup_with_workers():
lp = loopTool(
Expand Down Expand Up @@ -208,6 +239,17 @@ def test_sweep_setup(loop_sweep_setup):
assert diff_dict_check(lp.sweep_directory, expected_run_dict)


@pytest.mark.component
def test_sweep_setup_from_dict(loop_sweep_setup_from_dict):
if has_mpi_peer_processes() == False or (
has_mpi_peer_processes() and get_mpi_comm_process().Get_rank() == 0
):
lp, expected_run_dict = loop_sweep_setup_from_dict
lp.build_run_dict()

assert diff_dict_check(lp.sweep_directory, expected_run_dict)


@pytest.mark.component
def test_diff_setup(loop_diff_setup):
if has_mpi_peer_processes() == False or (
Expand Down
Loading