Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ Testing
*.pdf
*.gv
*.json
*.gz
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ if(NEML2_DOC)
FetchContent_MakeAvailable(doxygen-awesome-css)
endif()

# ----------------------------------------------------------------------------
# zlib
# ----------------------------------------------------------------------------
find_package(ZLIB)

if(NOT ZLIB_FOUND)
message(WARNING "ZLIB not found, model packaging features will be disabled")
endif()

# ----------------------------------------------------------------------------
# base neml2 library
# ----------------------------------------------------------------------------
Expand Down
33 changes: 29 additions & 4 deletions include/neml2/base/Factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <filesystem>
#include <iostream>
#include <memory>

#include "neml2/misc/errors.h"
#include "neml2/base/InputFile.h"
Expand All @@ -41,6 +42,7 @@ class Data;
class Model;
class Driver;
class WorkScheduler;
class BundledModel;

/**
* @brief A convenient function to parse all options from an input file
Expand Down Expand Up @@ -73,7 +75,7 @@ class Factory
const InputFile & input_file() const { return _input_file; }

/// Global settings
const std::shared_ptr<Settings> & settings() const { return _input_file.settings(); }
const std::shared_ptr<Settings> & settings() const { return _settings; }

/// Check if an object with the given name exists under the given section.
bool has_object(const std::string & section, const std::string & name);
Expand Down Expand Up @@ -116,6 +118,17 @@ class Factory
template <class T = WorkScheduler>
std::shared_ptr<T> get_scheduler(const std::string & name);

/**
* @brief Serialize an object to an input file. The returned input file contains the exact
* information needed to reconstruct the object.
*
* @note Behind the scenes, this method calls the get_object method with \p force_create set to
* true, which has the side effect of creating the object if it does not already exist.
*/
std::unique_ptr<InputFile> serialize_object(const std::string & section,
const std::string & name,
const OptionSet & additional_options = OptionSet());

/// @brief Delete all factories and destruct all the objects.
void clear();

Expand All @@ -139,14 +152,26 @@ class Factory
/// Check if the options are compatible with the object
bool options_compatible(const std::shared_ptr<NEML2Object> & obj, const OptionSet & opts) const;

/// BundledModel will need to squeeze the unpacked model into the factory
friend class BundledModel;

/// The input file
InputFile _input_file;

/// Global settings of the input file
const std::shared_ptr<Settings> _settings;

/**
* Manufactured objects. The key of the outer map is the section name, and the key of the inner
* map is the object name.
*/
std::map<std::string, std::map<std::string, std::vector<std::shared_ptr<NEML2Object>>>> _objects;

/// Whether the factory is currently serializing an object
bool _serializing = false;

/// The output serialized input file (used by the serialize_object method)
std::unique_ptr<InputFile> _serialized_file;
};

template <class T>
Expand All @@ -160,7 +185,7 @@ Factory::get_object(const std::string & section,
throw FactoryException("The input file is empty.");

// Easy if it already exists
if (!force_create)
if (!force_create && !_serializing)
if (_objects.count(section) && _objects.at(section).count(name))
for (const auto & neml2_obj : _objects[section][name])
{
Expand All @@ -183,8 +208,8 @@ Factory::get_object(const std::string & section,
if (options.first == name)
{
auto new_options = options.second;
new_options.set<Factory *>("_factory") = this;
new_options.set<std::shared_ptr<Settings>>("_settings") = settings();
new_options.set<Factory *>("factory") = this;
new_options.set<std::shared_ptr<Settings>>("settings") = settings();
new_options += additional_options;
create_object(section, new_options);
break;
Expand Down
14 changes: 10 additions & 4 deletions include/neml2/base/HITParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
namespace neml2
{
class OptionSet;
class OptionBase;

/**
* @copydoc neml2::Parser
Expand All @@ -47,11 +48,14 @@ class HITParser : public Parser
~HITParser() override = default;

/// Parse a HIT input file from a filename.
InputFile parse(const std::filesystem::path & filename,
const std::string & additional_input = "") const override;
InputFile parse_from_string(const std::string & input,
const std::string & additional_input = "") const override;

/// Parse a HIT input file from a root node.
InputFile parse(hit::Node * root) const;
InputFile parse_from_hit_node(hit::Node * root) const;

/// Serialize an input file to a string.
std::string serialize(const InputFile & inp) const override;

private:
/**
Expand All @@ -61,10 +65,12 @@ class HITParser : public Parser
* @param section The current section node.
* @return OptionSet The options of the object.
*/
virtual OptionSet extract_object_options(hit::Node * object, hit::Node * section) const;
OptionSet extract_object_options(hit::Node * object, hit::Node * section) const;

void extract_options(hit::Node * object, OptionSet & options) const;
void extract_option(hit::Node * node, OptionSet & options) const;

void serialize_options(hit::Node * node, const OptionSet & options) const;
};

} // namespace neml2
8 changes: 3 additions & 5 deletions include/neml2/base/InputFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,16 @@

namespace neml2
{
class Settings;

/**
* @brief A data structure that holds options of multiple objects.
*/
class InputFile
{
public:
InputFile(const OptionSet & settings);
InputFile(OptionSet settings);

/// Get global settings
const std::shared_ptr<Settings> & settings() const { return _settings; }
const OptionSet & settings() const { return _settings; }

/// Get all the object options under a specific section.
std::map<std::string, OptionSet> & operator[](const std::string & section);
Expand All @@ -55,7 +53,7 @@ class InputFile

private:
/// Global settings specified under the [Settings] section
const std::shared_ptr<Settings> _settings;
const OptionSet _settings;

/// Collection of options for all manufacturable objects
std::map<std::string, std::map<std::string, OptionSet>> _data;
Expand Down
23 changes: 19 additions & 4 deletions include/neml2/base/Option.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#pragma once

#include <cstddef>
#include <vector>

#include "neml2/base/OptionBase.h"
Expand All @@ -49,6 +50,9 @@ template <typename P>
void _print_helper(std::ostream & os, const std::vector<P> *);
template <typename P>
void _print_helper(std::ostream & os, const std::vector<std::vector<P>> *);
/// bool
template <>
void _print_helper(std::ostream & os, const bool *);
/// The evil vector of bool :/
template <>
void _print_helper(std::ostream & os, const std::vector<bool> *);
Expand All @@ -58,6 +62,9 @@ void _print_helper(std::ostream & os, const char *);
/// Specialization so that we don't print out unprintable characters
template <>
void _print_helper(std::ostream & os, const unsigned char *);
/// Specialization for tensor shape
template <>
void _print_helper(std::ostream & os, const TensorShape *);
///@}
}

Expand Down Expand Up @@ -108,16 +115,24 @@ template <typename P>
void
_print_helper(std::ostream & os, const std::vector<P> * option)
{
for (const auto & p : *option)
os << p << " ";
for (std::size_t i = 0; i < option->size(); i++)
{
if (i > 0)
os << " ";
_print_helper(os, &(*option)[i]);
}
}

template <typename P>
void
_print_helper(std::ostream & os, const std::vector<std::vector<P>> * option)
{
for (const auto & pv : *option)
_print_helper(os, &pv);
for (std::size_t i = 0; i < option->size(); i++)
{
if (i > 0)
os << "; ";
_print_helper(os, &(*option)[i]);
}
}
} // namespace details
// LCOV_EXCL_STOP
Expand Down
12 changes: 9 additions & 3 deletions include/neml2/base/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ class Parser
* @return InputFile The extracted object options.
*/
virtual InputFile parse(const std::filesystem::path & filename,
const std::string & additional_input = "") const = 0;
const std::string & additional_input = "") const;

virtual InputFile parse_from_string(const std::string & input,
const std::string & additional_input = "") const = 0;

/// @brief Serialize an input file to a string
virtual std::string serialize(const InputFile & inp) const = 0;
};

namespace utils
Expand Down Expand Up @@ -117,7 +123,7 @@ parse_vector_(std::vector<T> & vals, const std::string & raw_str)
vals.resize(tokens.size(), kCPU);
else
vals.resize(tokens.size());
for (size_t i = 0; i < tokens.size(); i++)
for (std::size_t i = 0; i < tokens.size(); i++)
{
auto success = parse_<T>(vals[i], tokens[i]);
if (!success)
Expand All @@ -144,7 +150,7 @@ parse_vector_vector_(std::vector<std::vector<T>> & vals, const std::string & raw
{
auto token_vecs = split(raw_str, ";");
vals.resize(token_vecs.size());
for (size_t i = 0; i < token_vecs.size(); i++)
for (std::size_t i = 0; i < token_vecs.size(); i++)
{
auto success = parse_vector_<T>(vals[i], token_vecs[i]);
if (!success)
Expand Down
3 changes: 3 additions & 0 deletions include/neml2/base/Registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class Registry
/// Load registry from a dynamic library
static void load(const std::filesystem::path &);

/// Check if an object is registered in the registry.
static bool is_registered(const std::string & name);

/// Get information of all registered objects.
static const std::map<std::string, NEML2ObjectInfo> & info();

Expand Down
8 changes: 2 additions & 6 deletions include/neml2/drivers/ModelDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,8 @@ class ModelDriver : public Driver
/// The device on which to evaluate the model
const Device _device;

/// Set to true to list all the model parameters at the beginning
const bool _show_params;
/// Set to true to show model's input axis at the beginning
const bool _show_input;
/// Set to true to show model's output axis at the beginning
const bool _show_output;
/// Set to true to show model summary at the beginning
const bool _show_model;

#ifdef NEML2_HAS_DISPATCHER
/// The work scheduler to use
Expand Down
80 changes: 80 additions & 0 deletions include/neml2/models/BundledModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2024, UChicago Argonne, LLC
// All Rights Reserved
// Software Name: NEML2 -- the New Engineering material Model Library, version 2
// By: Argonne National Laboratory
// OPEN SOURCE LICENSE (MIT)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#pragma once

#include <filesystem>

#ifdef NEML2_HAS_ZLIB
#include <zlib.h>
#endif

#ifdef NEML2_HAS_JSON
#include "nlohmann/json.hpp"
#endif

#define NEML2_CAN_BUNDLE_MODEL defined(NEML2_HAS_ZLIB) && defined(NEML2_HAS_JSON)

#include "neml2/models/Model.h"

namespace neml2
{
#ifdef NEML2_CAN_BUNDLE_MODEL
void bundle_model(const std::string & file,
const std::string & name,
const std::string & cliargs = "",
const nlohmann::json & config = nlohmann::json(),
std::filesystem::path output_path = std::filesystem::path());

std::pair<std::shared_ptr<Model>, nlohmann::json> unbundle_model(const std::filesystem::path & pkg,
NEML2Object * host = nullptr);
#endif // NEML2_CAN_BUNDLE_MODEL

class BundledModel : public Model
{
public:
static OptionSet expected_options();

BundledModel(const OptionSet & options);

const nlohmann::json & config() const { return _config; }

///@{
/// Methods for retrieving descriptions
std::string description() const;
std::string input_description(const VariableName & name) const;
std::string output_description(const VariableName & name) const;
std::string param_description(const std::string & name) const;
std::string buffer_description(const std::string & name) const;
///@}

protected:
void link_output_variables() override;
void set_value(bool, bool, bool) override;

std::shared_ptr<Model> _bundled_model;

nlohmann::json _config;
};
} // namespace neml2
10 changes: 5 additions & 5 deletions include/neml2/models/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ class Data : public NEML2Object, public BufferStore
T & register_data(const std::string & name)
{
OptionSet extra_opts;
extra_opts.set<NEML2Object *>("_host") = host();
if (!host()->factory())
throw SetupException("Internal error: Host object '" + host()->name() +
"' does not have a factory set.");
auto data = host()->factory()->get_object<T>("Data", name, extra_opts, /*force_create=*/false);
extra_opts.set<NEML2Object *>("host") = host();
if (!factory())
throw SetupException("Internal error: Object '" + this->name() +
"' does not have a factory.");
auto data = factory()->get_object<T>("Data", name, extra_opts, /*force_create=*/false);

if (std::find(_registered_data.begin(), _registered_data.end(), data) != _registered_data.end())
throw SetupException("Data named '" + name + "' has already been registered.");
Expand Down
Loading
Loading