Skip to content
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ set(NEML2_PYBIND OFF CACHE BOOL "Build NEML2 Python bindings")
set(NEML2_DOC OFF CACHE BOOL "Build NEML2 documentation (html)")
set(NEML2_WORK_DISPATCHER OFF CACHE BOOL "Enable NEML2 work dispatcher")
set(NEML2_JSON OFF CACHE BOOL "Enable JSON support")
set(AURORA_BUILD OFF CACHE BOOL "Build on ALCF Aurora machine")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer a more general flag like NEML2_IPEX. I suppose there's nothing else specific to Aurora.


# ----------------------------------------------------------------------------
# Dependencies and 3rd party packages
Expand Down Expand Up @@ -130,6 +131,13 @@ if(Torch_DOWNLOADED)
install(DIRECTORY ${Torch_LINK_DIR} DESTINATION . COMPONENT libneml2)
endif()

# ----------------------------------------------------------------------------
# Intel Extension for PyTorch
# ----------------------------------------------------------------------------
if (AURORA_BUILD)
find_library(IPEX_LIB intel-ext-pt-gpu PATHS ${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib NO_DEFAULT_PATH REQUIRED)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
find_library(IPEX_LIB intel-ext-pt-gpu PATHS ${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib NO_DEFAULT_PATH REQUIRED)
find_library(IPEX_LIB intel-ext-pt-gpu)

I suspect something like this will just work. You can then specify the path using -DIPEX_LIB_ROOT=....

endif()

# ----------------------------------------------------------------------------
# HIT
# ----------------------------------------------------------------------------
Expand Down
12 changes: 11 additions & 1 deletion include/neml2/drivers/ModelDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#pragma once

#include "neml2/drivers/Driver.h"
#include "neml2/misc/types.h"
#include "neml2/models/map_types.h"
#include "neml2/tensors/tensors.h"

Expand All @@ -37,6 +38,11 @@ namespace neml2
{
class Model;

/// Callback to dump a ton of information on model exectuion
void details_callback(const Model &,
const std::map<VariableName, std::unique_ptr<VariableBase>> &,
const std::map<VariableName, std::unique_ptr<VariableBase>> &);

/**
* @brief A general-purpose driver that does *something* with a model
*
Expand All @@ -55,18 +61,22 @@ class ModelDriver : public Driver

const Model & model() const { return *_model; }

void to(Device dev);

protected:
/// The model which the driver uses to perform constitutive updates.
const std::shared_ptr<Model> _model;
/// The device on which to evaluate the model
const Device _device;
Device _device;

/// Set to true to list all the model parameters at the beginning
const bool _show_params;
/// Set to true to show model's input axis at the beginning
const bool _show_input;
/// Set to true to show model's output axis at the beginning
const bool _show_output;
/// Set to output a ton of information on the model execution
const bool _log_details;

#ifdef NEML2_HAS_DISPATCHER
/// The work scheduler to use
Expand Down
20 changes: 19 additions & 1 deletion include/neml2/models/Model.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,25 @@
#include "neml2/models/Data.h"
#include "neml2/models/ParameterStore.h"
#include "neml2/models/VariableStore.h"
#include "neml2/models/Variable.h"
#include "neml2/solvers/NonlinearSystem.h"
#include "neml2/models/NonlinearParameter.h"

// These headers are not directly used by Model, but are included here so that derived classes do
// not have to include them separately. This is a convenience for the user, and is a reasonable
// choice since these headers are light and bring in little dependency.
#include "neml2/base/LabeledAxis.h"
#include "neml2/models/Variable.h"

namespace neml2
{
class Model;

/// typedef giving the call signature for a model callback
using ModelCallback =
std::function<void(const Model &,
const std::map<VariableName, std::unique_ptr<VariableBase>> &,
const std::map<VariableName, std::unique_ptr<VariableBase>> &)>;

/**
* @brief A convenient function to load an input file and get a model
*
Expand Down Expand Up @@ -148,6 +154,12 @@ class Model : public std::enable_shared_from_this<Model>,
/// Request to use AD to compute the second derivative of a variable
void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);

/// Register a callback to be called when the model is evaluated
void register_callback(const ModelCallback & callback);

/// Register a callback on this and all submodels
void register_callback_recursive(const ModelCallback & callback);

/// Forward operator without jit
void forward(bool out, bool dout, bool d2out);

Expand Down Expand Up @@ -291,6 +303,9 @@ class Model : public std::enable_shared_from_this<Model>,
std::vector<std::shared_ptr<Model>> _registered_models;

private:
/// Call the callbacks...
void call_callbacks() const;

template <typename T>
void forward_helper(T && in, bool out, bool dout, bool d2out)
{
Expand Down Expand Up @@ -365,6 +380,9 @@ class Model : public std::enable_shared_from_this<Model>,
/// Similar to _trace_functions, but for the forward operator of the nonlinear system
std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8>
_traced_functions_nl_sys;

/// List of callbacks
std::vector<ModelCallback> _callbacks;
};

std::ostream & operator<<(std::ostream & os, const Model & model);
Expand Down
8 changes: 6 additions & 2 deletions src/neml2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ add_library(neml2 INTERFACE)

# libneml2_misc
neml2_add_submodule(neml2_misc SHARED misc)
target_link_libraries(neml2_misc PUBLIC Torch::Torch)
if (AURORA_BUILD)
target_link_libraries(neml2_misc PUBLIC Torch::Torch ${IPEX_LIB})
else()
target_link_libraries(neml2_misc PUBLIC Torch::Torch)
endif()
set_target_properties(neml2_misc PROPERTIES INSTALL_RPATH_USE_LINK_PATH ON)

if(NEML2_JSON)
Expand Down Expand Up @@ -166,4 +170,4 @@ if(Git_FOUND)
DESTINATION .
COMPONENT libneml2
)
endif()
endif()
41 changes: 40 additions & 1 deletion src/neml2/drivers/ModelDriver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@

namespace neml2
{
void
details_callback(const Model & model,
const std::map<VariableName, std::unique_ptr<VariableBase>> & input,
const std::map<VariableName, std::unique_ptr<VariableBase>> & output)
{
std::cout << model.name() << std::endl;
std::cout << "\tInput" << std::endl;
for (const auto & pair : input)
{
std::cout << "\t\t" << pair.first.str() << " (" << pair.second->sizes() << ") -> "
<< at::norm(pair.second->tensor()).cpu().item<double>() << std::endl;
}
std::cout << "\tOutput" << std::endl;
for (const auto & pair : output)
{
std::cout << "\t\t" << pair.first.str() << " (" << pair.second->sizes() << ") -> "
<< at::norm(pair.second->tensor()).cpu().item<double>() << std::endl;
}
}

OptionSet
ModelDriver::expected_options()
{
Expand All @@ -55,6 +75,10 @@ ModelDriver::expected_options()
options.set<bool>("show_output_axis") = false;
options.set("show_output_axis").doc() = "Whether to show model output axis at the beginning";

options.set<bool>("log_details") = false;
options.set("log_details").doc() =
"If true attach a callback which outputs lots of information on the model execution";

#ifdef NEML2_HAS_DISPATCHER
options.set<std::string>("scheduler");
options.set("scheduler").doc() = "The work scheduler to use";
Expand All @@ -71,7 +95,8 @@ ModelDriver::ModelDriver(const OptionSet & options)
_device(options.get<std::string>("device")),
_show_params(options.get<bool>("show_parameters")),
_show_input(options.get<bool>("show_input_axis")),
_show_output(options.get<bool>("show_output_axis"))
_show_output(options.get<bool>("show_output_axis")),
_log_details(options.get<bool>("log_details"))
#ifdef NEML2_HAS_DISPATCHER
,
_scheduler(options.get("scheduler").user_specified() ? get_scheduler("scheduler") : nullptr),
Expand All @@ -85,6 +110,12 @@ ModelDriver::setup()
{
Driver::setup();

// Send model parameters and buffers to device
_model->to(_device);

if (_log_details)
_model->register_callback_recursive(details_callback);

#ifdef NEML2_HAS_DISPATCHER
if (_scheduler)
{
Expand Down Expand Up @@ -146,4 +177,12 @@ ModelDriver::diagnose() const
Driver::diagnose();
neml2::diagnose(*_model);
}

void
ModelDriver::to(Device dev)
{
_device = dev;
setup();
}

} // namespace neml2
27 changes: 26 additions & 1 deletion src/neml2/models/Model.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <c10/core/InferenceMode.h>

#include "neml2/misc/assertions.h"
#include "neml2/base/guards.h"
#include "neml2/base/Factory.h"
#include "neml2/base/Settings.h"
#include "neml2/jit/utils.h"
Expand Down Expand Up @@ -361,6 +360,21 @@ Model::forward_operator_index(bool out, bool dout, bool d2out) const
return (out ? 4 : 0) + (dout ? 2 : 0) + (d2out ? 1 : 0);
}

void
Model::register_callback(const ModelCallback & callback)
{
_callbacks.push_back(callback);
}

void
Model::register_callback_recursive(const ModelCallback & callback)
{
register_callback(callback);

for (auto & submodel : registered_models())
submodel->register_callback_recursive(callback);
}

void
Model::forward(bool out, bool dout, bool d2out)
{
Expand All @@ -387,6 +401,9 @@ Model::forward(bool out, bool dout, bool d2out)
if (dout || d2out)
extract_AD_derivatives(dout, d2out);

// Call the callbacks
call_callbacks();

return;
}

Expand Down Expand Up @@ -919,5 +936,13 @@ operator<<(std::ostream & os, const Model & model)

return os;
}

void
Model::call_callbacks() const
{
for (const auto & callback : _callbacks)
callback(*this, input_variables(), output_variables());
}

// LCOV_EXCL_STOP
} // namespace neml2
8 changes: 6 additions & 2 deletions src/neml2/tensors/functions/linalg/solve.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ namespace neml2::linalg
Tensor
solve(const Tensor & A, const Tensor & B)
{
return Tensor(at::linalg_solve(A.batch_expand_as(B), B, /*left=*/true), B.batch_sizes());
auto [LU, pivots] = at::linalg_lu_factor(A.batch_expand_as(B), true);
auto x =
Tensor(at::linalg_lu_solve(LU, pivots, B.unsqueeze(-1), true).squeeze(-1), B.batch_sizes());

return x;
}
} // namespace neml2::linalg
} // namespace neml2::linalg
7 changes: 6 additions & 1 deletion tests/dispatchers/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
file(GLOB_RECURSE srcs CONFIGURE_DEPENDS *.cxx)
add_executable(dispatcher_tests ${srcs})
target_link_libraries(dispatcher_tests PRIVATE testutils)
if (AURORA_BUILD)
find_package(Threads REQUIRED)
target_link_libraries(dispatcher_tests PRIVATE testutils Threads::Threads)
else()
target_link_libraries(dispatcher_tests PRIVATE testutils)
endif()
Comment on lines +3 to +8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this required?

target_compile_options(dispatcher_tests PRIVATE -Wall -Wextra -pedantic)
set_target_properties(dispatcher_tests PROPERTIES INSTALL_RPATH_USE_LINK_PATH ON)
2 changes: 2 additions & 0 deletions tests/include/TransientRegression.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class TransientRegression : public Driver

bool run() override;

void to(Device dev);

private:
/// The driver that will run the NEML2 model
const std::shared_ptr<TransientDriver> _driver;
Expand Down
8 changes: 7 additions & 1 deletion tests/src/TransientRegression.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ TransientRegression::diagnose() const
"destination file/path.");
}

void
TransientRegression::to(Device dev)
{
_driver->to(dev);
}

bool
TransientRegression::run()
{
Expand All @@ -91,7 +97,7 @@ diff(const jit::named_buffer_list & res,
{
std::map<std::string, ATensor> res_map;
for (auto item : res)
res_map.emplace(item.name, item.value);
res_map.emplace(item.name, item.value.to(kCPU));

std::map<std::string, ATensor> ref_map;
for (auto item : ref)
Expand Down