Skip to content

Commit d09f3ca

Browse files
authored
Merge pull request #51 from rchristie/transfer-data
Add common fit utils for transferring data and matching field names
2 parents 06562dd + 8902470 commit d09f3ca

File tree

5 files changed

+161
-12
lines changed

5 files changed

+161
-12
lines changed

src/cmlibs/utils/zinc/group.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44

55
from enum import Enum
66

7+
from cmlibs.utils.zinc.field import get_group_list
78
from cmlibs.utils.zinc.finiteelement import evaluate_mesh_centroid, evaluate_nearest_mesh_location, \
89
evaluate_field_nodeset_mean
910
from cmlibs.utils.zinc.general import ChangeManager, HierarchicalChangeManager
1011
from cmlibs.zinc.element import Element
1112
from cmlibs.zinc.field import Field, FieldGroup
1213
from cmlibs.zinc.result import RESULT_OK
14+
import logging
15+
16+
17+
logger = logging.getLogger(__name__)
1318

1419

1520
class GroupOperator(Enum):
@@ -487,3 +492,40 @@ def group_remove_group_local_contents(group, source_group):
487492
nodeset_group = group.getNodesetGroup(nodeset)
488493
if nodeset_group.isValid() and (nodeset_group.getSize() > 0):
489494
nodeset_group.removeNodesConditional(source_group)
495+
496+
497+
def match_fitting_group_names(data_fieldmodule, model_fieldmodule, log_diagnostics=False):
498+
"""
499+
Used for fitting problems. Rename any group names in the data fieldmodule that differ only in
500+
case and whitespace from any in the model fieldmodule, to the accepted values from the model,
501+
which are expected to be lower case without leading or trailing whitespace characters.
502+
Note that internal whitespace must be exactly matched.
503+
:param data_fieldmodule: Data Fieldmodule whose group names may be modified.
504+
:param model_fieldmodule: Model Fieldmodule containing preferred group names.
505+
:param log_diagnostics: Set to True to write diagonstic messages about name matches and changes to logging.
506+
"""
507+
# future: match with annotation terms
508+
model_names = [group.getName() for group in get_group_list(model_fieldmodule)]
509+
for data_group in get_group_list(data_fieldmodule):
510+
data_name = data_group.getName()
511+
compare_name = data_name.strip().casefold()
512+
for model_name in model_names:
513+
if model_name == data_name:
514+
if log_diagnostics:
515+
logger.info("Data group '" + data_name + "' found in model")
516+
break
517+
elif model_name.strip().casefold() == compare_name:
518+
result = data_group.setName(model_name)
519+
if result == RESULT_OK:
520+
if log_diagnostics:
521+
logger.info("Data group '" + data_name + "' found in model as '" +
522+
model_name + "'. Renaming to match.")
523+
else:
524+
logger.error("Error: Data group '" + data_name + "' found in model as '" +
525+
model_name + "'. Renaming to match FAILED.")
526+
if fieldmodule.findFieldByName(model_name).isValid():
527+
logger.error(" Reason: field of that name already exists.")
528+
break
529+
else:
530+
if log_diagnostics:
531+
logger.info("Data group '" + data_name + "' not found in model")

src/cmlibs/utils/zinc/mesh.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,18 @@ def undefine_field(field):
278278
mesh = fm.findMeshByDimension(i)
279279
mesh_group = mesh
280280
_undefine_field_on_elements(field, mesh_group)
281+
282+
283+
def element_or_ancestor_is_in_mesh(element, mesh):
284+
"""
285+
Query whether element is in mesh or is from its tree of faces, lines etc.
286+
:param element: Element to query.
287+
:param mesh: Equal or higher dimension ancestor mesh or mesh group to check.
288+
:return: True if element or any parent/ancestor is in mesh.
289+
"""
290+
if mesh.containsElement(element):
291+
return True
292+
for p in range(1, element.getNumberOfParents() + 1):
293+
if element_or_ancestor_is_in_mesh(element.getParentElement(p), mesh):
294+
return True
295+
return False

src/cmlibs/utils/zinc/region.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cmlibs.utils.zinc.finiteelement import get_identifiers, evaluate_field_nodeset_range
1+
from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range, get_identifiers, get_maximum_node_identifier
22
from cmlibs.zinc.field import Field
33
from cmlibs.zinc.result import RESULT_OK
44

@@ -10,11 +10,25 @@ def _find_missing(lst):
1010
for i in range(x + 1, y) if y - x > 1]
1111

1212

13-
def convert_nodes_to_datapoints(target_region, source_region):
13+
def convert_nodes_to_datapoints(target_region, source_region, source_nodeset_type=Field.DOMAIN_TYPE_NODES,
14+
destroy_after_conversion=True):
15+
"""
16+
Converts nodes in the source region to datapoints in the target region, renumbering any existing
17+
datapoints in target region to not clash.
18+
When the source nodeset type is Field.DOMAIN_TYPE_DATAPOINTS, then datapoints are transferred from the
19+
source region to the target region.
20+
:param target_region: Zinc Region to read data into. Existing data points are renumbered to avoid nodes.
21+
:param source_region: Zinc Region containing nodes to transfer.
22+
:param source_nodeset_type: Set to Field.DOMAIN_TYPE_DATAPOINTS or Field.DOMAIN_TYPE_NODES to transfer datapoints
23+
or convert nodes. Datapoint transfer should only be to different regions [default: Field.DOMAIN_TYPE_NODES].
24+
:param destroy_after_conversion: Set to True to destroy nodes that have been successfully converted, or False
25+
to leave intact in source region [default: True].
26+
"""
1427
source_fieldmodule = source_region.getFieldmodule()
1528
target_fieldmodule = target_region.getFieldmodule()
1629
with ChangeManager(source_fieldmodule), ChangeManager(target_fieldmodule):
17-
nodes = source_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)
30+
# Could be nodes or datapoints.
31+
nodes = source_fieldmodule.findNodesetByFieldDomainType(source_nodeset_type)
1832
if nodes.getSize() > 0:
1933
datapoints = target_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
2034
if datapoints.getSize() > 0:
@@ -44,11 +58,27 @@ def convert_nodes_to_datapoints(target_region, source_region):
4458
datapoint.setIdentifier(new_identifier)
4559

4660
# transfer nodes as datapoints to target_region
47-
buffer = write_to_buffer(source_region, resource_domain_type=Field.DOMAIN_TYPE_NODES)
48-
buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8"))
61+
buffer = write_to_buffer(source_region, resource_domain_type=source_nodeset_type)
62+
if source_nodeset_type == Field.DOMAIN_TYPE_NODES:
63+
buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8"))
4964
result = read_from_buffer(target_region, buffer)
5065
assert result == RESULT_OK, "Failed to load nodes as datapoints"
51-
nodes.destroyAllNodes()
66+
if destroy_after_conversion:
67+
# note this cannot destroy nodes in use by elements
68+
nodes.destroyAllNodes()
69+
70+
71+
def copy_fitting_data(target_region, source_region):
72+
"""
73+
Copy nodes and data points from source_region to target_region, converting nodes to data points and
74+
offsetting data point identifiers to not clash. All groups and fields in use are transferred.
75+
This is used for setting up fitting problems where data needs to be in datapoints only.
76+
:param target_region: Zinc Region to read nodes/data into.
77+
:param source_region: Zinc Region containing nodes/data to transfer. Unmodified.
78+
"""
79+
for domain_type in [Field.DOMAIN_TYPE_DATAPOINTS, Field.DOMAIN_TYPE_NODES]:
80+
convert_nodes_to_datapoints(target_region, source_region, source_nodeset_type=domain_type,
81+
destroy_after_conversion=False)
5282

5383

5484
def copy_nodeset(region, nodeset):

tests/test_zinc.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import unittest
2-
from cmlibs.utils.zinc.field import createFieldMeshIntegral, findOrCreateFieldCoordinates, \
3-
findOrCreateFieldGroup
4-
from cmlibs.utils.zinc.finiteelement import createCubeElement, createSquareElement, createNodes, \
5-
createTriangleElements, evaluateFieldNodesetMean
2+
from cmlibs.utils.zinc.field import (
3+
createFieldMeshIntegral, findOrCreateFieldCoordinates, findOrCreateFieldGroup)
4+
from cmlibs.utils.zinc.mesh import element_or_ancestor_is_in_mesh
5+
from cmlibs.utils.zinc.finiteelement import (
6+
createCubeElement, createSquareElement, createNodes, createTriangleElements, evaluateFieldNodesetMean)
67
from cmlibs.zinc.context import Context
8+
from cmlibs.zinc.element import Element
79
from cmlibs.zinc.field import Field
810
from cmlibs.zinc.result import RESULT_OK
9-
from utilities import assert_almost_equal_list
11+
from utilities import assert_almost_equal_list, get_test_resource_name
1012

1113

1214
class ZincUtilsTestCase(unittest.TestCase):
@@ -105,6 +107,41 @@ def test_create_nodes_and_elements(self):
105107
self.assertEqual(RESULT_OK, result)
106108
self.assertAlmostEqual(0.9, volume, delta=1.0E-7)
107109

110+
def test_element_or_ancestor_is_in_mesh(self):
111+
exf_file = get_test_resource_name('two_element_cube.exf')
112+
113+
context = Context("test")
114+
source_region = context.createRegion()
115+
result = source_region.readFile(exf_file)
116+
self.assertTrue(result == RESULT_OK)
117+
118+
fm = source_region.getFieldmodule()
119+
mesh3d = fm.findMeshByDimension(3)
120+
mesh2d = fm.findMeshByDimension(2)
121+
mesh1d = fm.findMeshByDimension(1)
122+
element2 = mesh3d.findElementByIdentifier(2)
123+
self.assertTrue(element2.isValid())
124+
self.assertTrue(element_or_ancestor_is_in_mesh(element2, mesh3d))
125+
self.assertFalse(element_or_ancestor_is_in_mesh(element2, mesh2d))
126+
self.assertFalse(element_or_ancestor_is_in_mesh(element2, mesh1d))
127+
face4 = mesh2d.findElementByIdentifier(4)
128+
self.assertTrue(face4.isValid())
129+
self.assertTrue(element_or_ancestor_is_in_mesh(face4, mesh3d))
130+
self.assertTrue(element_or_ancestor_is_in_mesh(face4, mesh2d))
131+
self.assertFalse(element_or_ancestor_is_in_mesh(face4, mesh1d))
132+
line5 = mesh1d.findElementByIdentifier(5)
133+
self.assertTrue(line5.isValid())
134+
self.assertTrue(element_or_ancestor_is_in_mesh(line5, mesh3d))
135+
self.assertTrue(element_or_ancestor_is_in_mesh(line5, mesh2d))
136+
self.assertTrue(element_or_ancestor_is_in_mesh(line5, mesh1d))
137+
138+
elementtemplate = mesh1d.createElementtemplate()
139+
elementtemplate.setElementShapeType(Element.SHAPE_TYPE_LINE)
140+
new_line = mesh1d.createElement(-1, elementtemplate)
141+
self.assertFalse(element_or_ancestor_is_in_mesh(new_line, mesh3d))
142+
self.assertFalse(element_or_ancestor_is_in_mesh(new_line, mesh2d))
143+
self.assertTrue(element_or_ancestor_is_in_mesh(new_line, mesh1d))
144+
108145

109146
if __name__ == "__main__":
110147
unittest.main()

tests/test_zinc_group.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import unittest
3+
from cmlibs.utils.zinc.field import find_or_create_field_group
34
from cmlibs.utils.zinc.group import (
45
group_add_group_local_contents, group_evaluate_centroid, group_remove_group_local_contents,
5-
group_evaluate_representative_point, groups_have_same_local_contents)
6+
group_evaluate_representative_point, groups_have_same_local_contents, match_fitting_group_names)
67
from cmlibs.zinc.context import Context
78
from cmlibs.zinc.element import Element
89
from cmlibs.zinc.field import Field
@@ -134,3 +135,27 @@ def test_group_add_compare_group_local_contents(self):
134135
self.assertEqual(0, group1.getMeshGroup(mesh2d).getSize())
135136
self.assertEqual(0, group1.getMeshGroup(mesh1d).getSize())
136137
self.assertEqual(0, group1.getNodesetGroup(nodes).getSize())
138+
139+
140+
def test_match_fitting_group_names(self):
141+
"""
142+
Test utility functions for adding and comparing group local contents.
143+
"""
144+
context = Context("test")
145+
model_region = context.createRegion()
146+
model_fieldmodule = model_region.getFieldmodule()
147+
find_or_create_field_group(model_fieldmodule, "bob", managed=True)
148+
find_or_create_field_group(model_fieldmodule, "fred", managed=True)
149+
find_or_create_field_group(model_fieldmodule, "two names", managed=True)
150+
151+
data_region = context.createRegion()
152+
data_fieldmodule = data_region.getFieldmodule()
153+
data_group_bob = find_or_create_field_group(data_fieldmodule, " Bob")
154+
data_group_fred = find_or_create_field_group(data_fieldmodule, " fRed\t")
155+
data_group_two_names = find_or_create_field_group(data_fieldmodule, "\t two NAMES ")
156+
157+
match_fitting_group_names(data_fieldmodule, model_fieldmodule, log_diagnostics=True)
158+
159+
self.assertEqual(data_group_bob.getName(), "bob")
160+
self.assertEqual(data_group_fred.getName(), "fred")
161+
self.assertEqual(data_group_two_names.getName(), "two names")

0 commit comments

Comments
 (0)