Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/cmlibs/utils/zinc/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

from enum import Enum

from cmlibs.utils.zinc.field import get_group_list
from cmlibs.utils.zinc.finiteelement import evaluate_mesh_centroid, evaluate_nearest_mesh_location, \
evaluate_field_nodeset_mean
from cmlibs.utils.zinc.general import ChangeManager, HierarchicalChangeManager
from cmlibs.zinc.element import Element
from cmlibs.zinc.field import Field, FieldGroup
from cmlibs.zinc.result import RESULT_OK
import logging


logger = logging.getLogger(__name__)


class GroupOperator(Enum):
Expand Down Expand Up @@ -487,3 +492,40 @@ def group_remove_group_local_contents(group, source_group):
nodeset_group = group.getNodesetGroup(nodeset)
if nodeset_group.isValid() and (nodeset_group.getSize() > 0):
nodeset_group.removeNodesConditional(source_group)


def match_fitting_group_names(data_fieldmodule, model_fieldmodule, log_diagnostics=False):
"""
Used for fitting problems. Rename any group names in the data fieldmodule that differ only in
case and whitespace from any in the model fieldmodule, to the accepted values from the model,
which are expected to be lower case without leading or trailing whitespace characters.
Note that internal whitespace must be exactly matched.
:param data_fieldmodule: Data Fieldmodule whose group names may be modified.
:param model_fieldmodule: Model Fieldmodule containing preferred group names.
:param log_diagnostics: Set to True to write diagonstic messages about name matches and changes to logging.
"""
# future: match with annotation terms
model_names = [group.getName() for group in get_group_list(model_fieldmodule)]
for data_group in get_group_list(data_fieldmodule):
data_name = data_group.getName()
compare_name = data_name.strip().casefold()
for model_name in model_names:
if model_name == data_name:
if log_diagnostics:
logger.info("Data group '" + data_name + "' found in model")
break
elif model_name.strip().casefold() == compare_name:
result = data_group.setName(model_name)
if result == RESULT_OK:
if log_diagnostics:
logger.info("Data group '" + data_name + "' found in model as '" +
model_name + "'. Renaming to match.")
else:
logger.error("Error: Data group '" + data_name + "' found in model as '" +
model_name + "'. Renaming to match FAILED.")
if fieldmodule.findFieldByName(model_name).isValid():
logger.error(" Reason: field of that name already exists.")
break
else:
if log_diagnostics:
logger.info("Data group '" + data_name + "' not found in model")
15 changes: 15 additions & 0 deletions src/cmlibs/utils/zinc/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,18 @@ def undefine_field(field):
mesh = fm.findMeshByDimension(i)
mesh_group = mesh
_undefine_field_on_elements(field, mesh_group)


def element_or_ancestor_is_in_mesh(element, mesh):
"""
Query whether element is in mesh or is from its tree of faces, lines etc.
:param element: Element to query.
:param mesh: Equal or higher dimension ancestor mesh or mesh group to check.
:return: True if element or any parent/ancestor is in mesh.
"""
if mesh.containsElement(element):
return True
for p in range(1, element.getNumberOfParents() + 1):
if element_or_ancestor_is_in_mesh(element.getParentElement(p), mesh):
return True
return False
42 changes: 36 additions & 6 deletions src/cmlibs/utils/zinc/region.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cmlibs.utils.zinc.finiteelement import get_identifiers, evaluate_field_nodeset_range
from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range, get_identifiers, get_maximum_node_identifier
from cmlibs.zinc.field import Field
from cmlibs.zinc.result import RESULT_OK

Expand All @@ -10,11 +10,25 @@ def _find_missing(lst):
for i in range(x + 1, y) if y - x > 1]


def convert_nodes_to_datapoints(target_region, source_region):
def convert_nodes_to_datapoints(target_region, source_region, source_nodeset_type=Field.DOMAIN_TYPE_NODES,
destroy_after_conversion=True):
"""
Converts nodes in the source region to datapoints in the target region, renumbering any existing
datapoints in target region to not clash.
When the source nodeset type is Field.DOMAIN_TYPE_DATAPOINTS, then datapoints are transferred from the
source region to the target region.
:param target_region: Zinc Region to read data into. Existing data points are renumbered to avoid nodes.
:param source_region: Zinc Region containing nodes to transfer.
:param source_nodeset_type: Set to Field.DOMAIN_TYPE_DATAPOINTS or Field.DOMAIN_TYPE_NODES to transfer datapoints
or convert nodes. Datapoint transfer should only be to different regions [default: Field.DOMAIN_TYPE_NODES].
:param destroy_after_conversion: Set to True to destroy nodes that have been successfully converted, or False
to leave intact in source region [default: True].
"""
source_fieldmodule = source_region.getFieldmodule()
target_fieldmodule = target_region.getFieldmodule()
with ChangeManager(source_fieldmodule), ChangeManager(target_fieldmodule):
nodes = source_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)
# Could be nodes or datapoints.
nodes = source_fieldmodule.findNodesetByFieldDomainType(source_nodeset_type)
if nodes.getSize() > 0:
datapoints = target_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
if datapoints.getSize() > 0:
Expand Down Expand Up @@ -44,11 +58,27 @@ def convert_nodes_to_datapoints(target_region, source_region):
datapoint.setIdentifier(new_identifier)

# transfer nodes as datapoints to target_region
buffer = write_to_buffer(source_region, resource_domain_type=Field.DOMAIN_TYPE_NODES)
buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8"))
buffer = write_to_buffer(source_region, resource_domain_type=source_nodeset_type)
if source_nodeset_type == Field.DOMAIN_TYPE_NODES:
buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8"))
result = read_from_buffer(target_region, buffer)
assert result == RESULT_OK, "Failed to load nodes as datapoints"
nodes.destroyAllNodes()
if destroy_after_conversion:
# note this cannot destroy nodes in use by elements
nodes.destroyAllNodes()


def copy_fitting_data(target_region, source_region):
"""
Copy nodes and data points from source_region to target_region, converting nodes to data points and
offsetting data point identifiers to not clash. All groups and fields in use are transferred.
This is used for setting up fitting problems where data needs to be in datapoints only.
:param target_region: Zinc Region to read nodes/data into.
:param source_region: Zinc Region containing nodes/data to transfer. Unmodified.
"""
for domain_type in [Field.DOMAIN_TYPE_DATAPOINTS, Field.DOMAIN_TYPE_NODES]:
convert_nodes_to_datapoints(target_region, source_region, source_nodeset_type=domain_type,
destroy_after_conversion=False)


def copy_nodeset(region, nodeset):
Expand Down
47 changes: 42 additions & 5 deletions tests/test_zinc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import unittest
from cmlibs.utils.zinc.field import createFieldMeshIntegral, findOrCreateFieldCoordinates, \
findOrCreateFieldGroup
from cmlibs.utils.zinc.finiteelement import createCubeElement, createSquareElement, createNodes, \
createTriangleElements, evaluateFieldNodesetMean
from cmlibs.utils.zinc.field import (
createFieldMeshIntegral, findOrCreateFieldCoordinates, findOrCreateFieldGroup)
from cmlibs.utils.zinc.mesh import element_or_ancestor_is_in_mesh
from cmlibs.utils.zinc.finiteelement import (
createCubeElement, createSquareElement, createNodes, createTriangleElements, evaluateFieldNodesetMean)
from cmlibs.zinc.context import Context
from cmlibs.zinc.element import Element
from cmlibs.zinc.field import Field
from cmlibs.zinc.result import RESULT_OK
from utilities import assert_almost_equal_list
from utilities import assert_almost_equal_list, get_test_resource_name


class ZincUtilsTestCase(unittest.TestCase):
Expand Down Expand Up @@ -105,6 +107,41 @@ def test_create_nodes_and_elements(self):
self.assertEqual(RESULT_OK, result)
self.assertAlmostEqual(0.9, volume, delta=1.0E-7)

def test_element_or_ancestor_is_in_mesh(self):
exf_file = get_test_resource_name('two_element_cube.exf')

context = Context("test")
source_region = context.createRegion()
result = source_region.readFile(exf_file)
self.assertTrue(result == RESULT_OK)

fm = source_region.getFieldmodule()
mesh3d = fm.findMeshByDimension(3)
mesh2d = fm.findMeshByDimension(2)
mesh1d = fm.findMeshByDimension(1)
element2 = mesh3d.findElementByIdentifier(2)
self.assertTrue(element2.isValid())
self.assertTrue(element_or_ancestor_is_in_mesh(element2, mesh3d))
self.assertFalse(element_or_ancestor_is_in_mesh(element2, mesh2d))
self.assertFalse(element_or_ancestor_is_in_mesh(element2, mesh1d))
face4 = mesh2d.findElementByIdentifier(4)
self.assertTrue(face4.isValid())
self.assertTrue(element_or_ancestor_is_in_mesh(face4, mesh3d))
self.assertTrue(element_or_ancestor_is_in_mesh(face4, mesh2d))
self.assertFalse(element_or_ancestor_is_in_mesh(face4, mesh1d))
line5 = mesh1d.findElementByIdentifier(5)
self.assertTrue(line5.isValid())
self.assertTrue(element_or_ancestor_is_in_mesh(line5, mesh3d))
self.assertTrue(element_or_ancestor_is_in_mesh(line5, mesh2d))
self.assertTrue(element_or_ancestor_is_in_mesh(line5, mesh1d))

elementtemplate = mesh1d.createElementtemplate()
elementtemplate.setElementShapeType(Element.SHAPE_TYPE_LINE)
new_line = mesh1d.createElement(-1, elementtemplate)
self.assertFalse(element_or_ancestor_is_in_mesh(new_line, mesh3d))
self.assertFalse(element_or_ancestor_is_in_mesh(new_line, mesh2d))
self.assertTrue(element_or_ancestor_is_in_mesh(new_line, mesh1d))


if __name__ == "__main__":
unittest.main()
27 changes: 26 additions & 1 deletion tests/test_zinc_group.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import unittest
from cmlibs.utils.zinc.field import find_or_create_field_group
from cmlibs.utils.zinc.group import (
group_add_group_local_contents, group_evaluate_centroid, group_remove_group_local_contents,
group_evaluate_representative_point, groups_have_same_local_contents)
group_evaluate_representative_point, groups_have_same_local_contents, match_fitting_group_names)
from cmlibs.zinc.context import Context
from cmlibs.zinc.element import Element
from cmlibs.zinc.field import Field
Expand Down Expand Up @@ -134,3 +135,27 @@ def test_group_add_compare_group_local_contents(self):
self.assertEqual(0, group1.getMeshGroup(mesh2d).getSize())
self.assertEqual(0, group1.getMeshGroup(mesh1d).getSize())
self.assertEqual(0, group1.getNodesetGroup(nodes).getSize())


def test_match_fitting_group_names(self):
"""
Test utility functions for adding and comparing group local contents.
"""
context = Context("test")
model_region = context.createRegion()
model_fieldmodule = model_region.getFieldmodule()
find_or_create_field_group(model_fieldmodule, "bob", managed=True)
find_or_create_field_group(model_fieldmodule, "fred", managed=True)
find_or_create_field_group(model_fieldmodule, "two names", managed=True)

data_region = context.createRegion()
data_fieldmodule = data_region.getFieldmodule()
data_group_bob = find_or_create_field_group(data_fieldmodule, " Bob")
data_group_fred = find_or_create_field_group(data_fieldmodule, " fRed\t")
data_group_two_names = find_or_create_field_group(data_fieldmodule, "\t two NAMES ")

match_fitting_group_names(data_fieldmodule, model_fieldmodule, log_diagnostics=True)

self.assertEqual(data_group_bob.getName(), "bob")
self.assertEqual(data_group_fred.getName(), "fred")
self.assertEqual(data_group_two_names.getName(), "two names")