-
Notifications
You must be signed in to change notification settings - Fork 123
Generate embeddings for DC core schema #6046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # Copyright 2024 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Build the schema embeddings input CSV from core_schema.mcf.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where does this file come from? maybe add comment about where to find it |
||
|
|
||
| import csv | ||
| import os | ||
| import re | ||
|
|
||
| from absl import app | ||
| from absl import flags | ||
|
|
||
| FLAGS = flags.FLAGS | ||
|
|
||
| flags.DEFINE_string('schema_path', 'core_schema.mcf', | ||
| 'Path to the core_schema.mcf file.') | ||
| flags.DEFINE_string('output_path', 'input/schema/schema_nodes.csv', | ||
| 'Path to the output CSV file.') | ||
|
|
||
|
|
||
| def parse_mcf(file_path): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we install and use the mcf lib for parsing like we do here instead of writing a new parse function?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 but also, should that be published as an actual library in pypi owned by datacommons?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Absolutely, it'd be great if there was one! |
||
| """Parses an MCF file and yields nodes.""" | ||
| # Read nodes for one file at a time | ||
| files = file_path.split(',') | ||
| for file in files: | ||
| with open(file.strip(), 'r') as f: | ||
| content = f.read() | ||
|
|
||
| # Split by empty lines to get blocks, but handle multiple newlines | ||
| blocks = re.split(r'\n\s*\n', content) | ||
|
|
||
| for i, block in enumerate(blocks): | ||
| node = {} | ||
| lines = block.strip().split('\n') | ||
| for line in lines: | ||
| if not line: | ||
| continue | ||
| # Handle multi-line descriptions if necessary, but for now assume single line or simple structure | ||
| # A simple regex to capture key: value | ||
| match = re.match(r'^([^:]+):\s*(.*)$', line) | ||
| if match: | ||
| key = match.group(1).strip() | ||
| value = match.group(2).strip() | ||
| # Remove quotes if present | ||
| if value.startswith('"') and value.endswith('"'): | ||
| value = value[1:-1] | ||
|
|
||
| # Handle multiple values for same key (e.g. typeOf) | ||
| if key in node: | ||
| if isinstance(node[key], list): | ||
| node[key].append(value) | ||
| else: | ||
| node[key] = [node[key], value] | ||
| else: | ||
| node[key] = value | ||
| # else: | ||
| # print(f"DEBUG: No match for line: '{line}'") | ||
|
|
||
| if 'Node' in node: | ||
| yield node | ||
|
|
||
|
Comment on lines
+31
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To improve robustness, I recommend updating the parser to handle multi-line values. This would typically involve iterating through lines of a block and concatenating indented lines to the value of the preceding line. |
||
|
|
||
| def main(_): | ||
| schema_path = FLAGS.schema_path | ||
| output_path = FLAGS.output_path | ||
|
|
||
| print(f"Reading schema from: {schema_path}") | ||
| print(f"Writing CSV to: {output_path}") | ||
|
|
||
| # Ensure output directory exists | ||
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | ||
|
|
||
| with open(output_path, 'w', newline='') as csvfile: | ||
| fieldnames = ['dcid', 'sentence'] | ||
| writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | ||
| writer.writeheader() | ||
|
|
||
| count = 0 | ||
| for node in parse_mcf(schema_path): | ||
| dcid = node.get('Node') | ||
| if not dcid: | ||
| continue | ||
|
|
||
| # Remove 'dcid:' prefix if present | ||
| if dcid.startswith('dcid:'): | ||
| dcid = dcid[5:] | ||
|
|
||
| type_of = node.get('typeOf') | ||
| if not type_of: | ||
| continue | ||
|
|
||
| # Normalize type_of to list | ||
| if isinstance(type_of, str): | ||
| type_of = [type_of] | ||
|
|
||
| # Check if it's a Class, Property, or Enum | ||
| is_schema_node = False | ||
| for t in type_of: | ||
| if t in ['dcid:Class', 'dcid:Property', 'Class', 'Property']: | ||
| is_schema_node = True | ||
| break | ||
| if 'Enum' in t: # Heuristic | ||
| is_schema_node = True | ||
|
|
||
| # Also check subClassOf for Enumeration | ||
| sub_class_of = node.get('subClassOf') | ||
| if isinstance(sub_class_of, str): | ||
| sub_class_of = [sub_class_of] | ||
| if sub_class_of: | ||
| for s in sub_class_of: | ||
| if s in ['dcid:Enumeration', 'Enumeration']: | ||
| is_schema_node = True | ||
| break | ||
|
Comment on lines
+106
to
+123
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for identifying schema nodes, particularly for enumerations, is based on a fragile heuristic ( A more robust approach would be to first perform a pass over the MCF data to explicitly identify all enumeration classes (i.e., nodes with |
||
|
|
||
| if not is_schema_node: | ||
| continue | ||
|
|
||
| description = node.get('description') | ||
| name = node.get('name') | ||
|
|
||
| sentences = [] | ||
| if name: | ||
| sentences.append(name) | ||
| if description: | ||
| # Description might be a list if multiple lines were treated as such, or just a string | ||
| if isinstance(description, list): | ||
| sentences.extend(description) | ||
| else: | ||
| sentences.append(description) | ||
|
|
||
| if not sentences: | ||
| continue | ||
|
|
||
| # Join with semi-colon | ||
| sentence_str = '; '.join(sentences) | ||
| # Convert camelCase to space-separated words | ||
| sentence_str = re.sub(r'([a-z0-9])([A-Z]+)', r'\1 \2', sentence_str) | ||
|
|
||
| writer.writerow({'dcid': dcid, 'sentence': sentence_str}) | ||
| count += 1 | ||
|
|
||
| print(f"Processed {count} nodes.") | ||
|
|
||
| if __name__ == '__main__': | ||
| app.run(main) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| # Copyright 2024 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Tests for build_schema_csv.py.""" | ||
|
|
||
| import csv | ||
| import os | ||
| import tempfile | ||
| import unittest | ||
| from absl import flags | ||
| from absl.testing import absltest | ||
|
|
||
| # Import the module to be tested. | ||
| # Assuming this test is run from the repo root or with appropriate PYTHONPATH. | ||
| from website.tools.nl.embeddings import build_schema_csv | ||
|
|
||
| FLAGS = flags.FLAGS | ||
|
|
||
| class BuildSchemaCsvTest(absltest.TestCase): | ||
|
|
||
| def setUp(self): | ||
| super().setUp() | ||
| self.test_dir = tempfile.TemporaryDirectory() | ||
| self.mcf_path = os.path.join(self.test_dir.name, 'test.mcf') | ||
| self.output_path = os.path.join(self.test_dir.name, 'output.csv') | ||
|
|
||
| def tearDown(self): | ||
| self.test_dir.cleanup() | ||
| super().tearDown() | ||
|
|
||
| def test_parse_mcf(self): | ||
| content = """ | ||
| Node: dcid:TestNode | ||
| typeOf: dcid:Class | ||
| name: "Test Node" | ||
| description: "A test node." | ||
|
|
||
| Node: dcid:TestProperty | ||
| typeOf: dcid:Property | ||
| name: "testProperty" | ||
| domainIncludes: dcid:TestNode | ||
| """ | ||
| with open(self.mcf_path, 'w') as f: | ||
| f.write(content) | ||
|
|
||
| nodes = list(build_schema_csv.parse_mcf(self.mcf_path)) | ||
| self.assertLen(nodes, 2) | ||
|
|
||
| self.assertEqual(nodes[0]['Node'], 'dcid:TestNode') | ||
| self.assertEqual(nodes[0]['typeOf'], 'dcid:Class') | ||
| self.assertEqual(nodes[0]['name'], 'Test Node') | ||
| self.assertEqual(nodes[0]['description'], 'A test node.') | ||
|
|
||
| self.assertEqual(nodes[1]['Node'], 'dcid:TestProperty') | ||
| self.assertEqual(nodes[1]['typeOf'], 'dcid:Property') | ||
| self.assertEqual(nodes[1]['name'], 'testProperty') | ||
|
|
||
| def test_parse_mcf_multiple_files(self): | ||
| mcf1 = os.path.join(self.test_dir.name, 'test1.mcf') | ||
| mcf2 = os.path.join(self.test_dir.name, 'test2.mcf') | ||
|
|
||
| with open(mcf1, 'w') as f: | ||
| f.write("Node: dcid:Node1\ntypeOf: Class\n") | ||
| with open(mcf2, 'w') as f: | ||
| f.write("Node: dcid:Node2\ntypeOf: Class\n") | ||
|
|
||
| nodes = list(build_schema_csv.parse_mcf(f"{mcf1},{mcf2}")) | ||
| self.assertLen(nodes, 2) | ||
| self.assertEqual(nodes[0]['Node'], 'dcid:Node1') | ||
| self.assertEqual(nodes[1]['Node'], 'dcid:Node2') | ||
|
|
||
| def test_main_logic(self): | ||
| # This test simulates the logic in main() but we can't easily call main() | ||
| # because it parses flags. We can test the processing logic if we extract it, | ||
| # or we can just set FLAGS and call main if we are careful. | ||
| # For now, let's verify the processing logic by replicating the loop | ||
| # or by mocking FLAGS. | ||
|
|
||
| content = """ | ||
| Node: dcid:TestClass | ||
| typeOf: dcid:Class | ||
| name: "TestClass" | ||
| description: "A test class." | ||
|
|
||
| Node: dcid:TestEnum | ||
| typeOf: dcid:Class | ||
| subClassOf: dcid:Enumeration | ||
| name: "TestEnum" | ||
|
|
||
| Node: dcid:TestEnumValue | ||
| typeOf: dcid:TestEnum | ||
| name: "TestEnumValue" | ||
|
|
||
| Node: dcid:IgnoreMe | ||
| typeOf: dcid:SomethingElse | ||
| name: "IgnoreMe" | ||
| """ | ||
| with open(self.mcf_path, 'w') as f: | ||
| f.write(content) | ||
|
|
||
| # Set flags | ||
| FLAGS.schema_path = self.mcf_path | ||
| FLAGS.output_path = self.output_path | ||
|
|
||
| # Run main | ||
| build_schema_csv.main([]) | ||
|
|
||
| # Check output | ||
| self.assertTrue(os.path.exists(self.output_path)) | ||
|
|
||
| with open(self.output_path, 'r') as f: | ||
| reader = csv.DictReader(f) | ||
| rows = list(reader) | ||
|
|
||
| # We expect TestClass, TestEnum (it is a Class subClassOf Enumeration), | ||
| # and TestEnumValue (typeOf TestEnum -> heuristic 'Enum' in type). | ||
| # IgnoreMe should be ignored. | ||
|
|
||
| # Wait, the heuristic for Enum value in the script is: | ||
| # if 'Enum' in t: is_schema_node = True | ||
| # TestEnumValue has typeOf: dcid:TestEnum, so it matches. | ||
|
|
||
| self.assertLen(rows, 3) | ||
|
|
||
| # Check CamelCase splitting | ||
| # TestClass -> Test Class | ||
| # TestEnum -> Test Enum | ||
| # TestEnumValue -> Test Enum Value | ||
|
|
||
| dcids = {row['dcid']: row['sentence'] for row in rows} | ||
|
|
||
| self.assertIn('TestClass', dcids) | ||
| self.assertIn('Test Class', dcids['TestClass']) | ||
|
|
||
| self.assertIn('TestEnum', dcids) | ||
| self.assertIn('Test Enum', dcids['TestEnum']) | ||
|
|
||
| self.assertIn('TestEnumValue', dcids) | ||
| self.assertIn('Test Enum Value', dcids['TestEnumValue']) | ||
|
|
||
| if __name__ == '__main__': | ||
| absltest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we separate the PR into 2? One for yaml + csv files and another for the code? The former can be submitted right away while we iterate on the latter.