Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public NodesEdges parseArcRow(String row, Counter mcfNodesWithoutTypeCounter) {
Node.builder()
.subjectId(nodeId)
.value(nodeValue)
.bytes(bytes)
.bytes(bytes != null ? bytes.toByteArray() : new byte[0])
.name(entity.getName())
.types(types)
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ public class Edge implements Serializable {
private String objectId;
private String provenance;

@SuppressWarnings("unused")
private Edge() {}

// Private constructor to enforce use of Builder
private Edge(Builder builder) {
this.subjectId = builder.subjectId;
Expand Down
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/gemini review

Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
package org.datacommons.ingestion.data;

import com.google.cloud.ByteArray;
import com.google.cloud.spanner.Mutation;
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.datacommons.Storage.Observations;
import org.datacommons.ingestion.spanner.SpannerClient;
import org.datacommons.pipeline.util.PipelineUtils;
Expand All @@ -25,13 +30,128 @@
import org.datacommons.proto.Mcf.McfStatVarObsSeries.StatVarObs;
import org.datacommons.proto.Mcf.ValueType;
import org.datacommons.util.GraphUtils;
import org.datacommons.util.McfUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GraphReader implements Serializable {
private static final Logger LOGGER = LoggerFactory.getLogger(GraphReader.class);
// Maximum size for a single column value in Spanner (10MB)
private static final String DC_AGGREGATE = "dcAggregate/";
private static final String DATCOM_AGGREGATE = "DataCommonsAggregate";
private static final String IMPORT_METADATA_FILE = "import_metadata_mcf.mcf";

public static PCollection<Node> combineNodes(PCollection<Node> nodes) {
return nodes
.apply(
"MapNodesToKV",
ParDo.of(
new DoFn<Node, KV<String, Node>>() {
@ProcessElement
public void processElement(
@Element Node node, OutputReceiver<KV<String, Node>> receiver) {
receiver.output(KV.of(node.getSubjectId(), node));
}
}))
.apply(
"CombineNodes",
Combine.perKey(
new Combine.CombineFn<Node, List<Node>, Node>() {
@Override
public List<Node> createAccumulator() {
return new ArrayList<>();
}

@Override
public List<Node> addInput(List<Node> accumulator, Node input) {
accumulator.add(input);
return accumulator;
}

@Override
public List<Node> mergeAccumulators(Iterable<List<Node>> accumulators) {
List<Node> merged = new ArrayList<>();
for (List<Node> acc : accumulators) {
merged.addAll(acc);
}
return merged;
}

@Override
public Node extractOutput(List<Node> accumulator) {
if (accumulator.isEmpty()) return null;
Node first = accumulator.get(0);
Node.Builder builder =
Node.builder()
.subjectId(first.getSubjectId())
.value(first.getValue())
.name(first.getName())
.types(first.getTypes())
.bytes(first.getBytes());

Set<String> types = new java.util.TreeSet<>();
for (Node n : accumulator) {
types.addAll(n.getTypes());
if (!n.getValue().isEmpty()) {
builder.value(n.getValue());
}
if (!n.getName().isEmpty()) {
builder.name(n.getName());
}
if (n.getBytes().length > 0) {
builder.bytes(n.getBytes());
}
}
if (types.size() > 1 && types.contains("ProvisionalNode")) {
types.remove("ProvisionalNode");
}
builder.types(new ArrayList<>(types));
return builder.build();
}
}))
.apply(
"ExtractNodes",
ParDo.of(
new DoFn<KV<String, Node>, Node>() {
@ProcessElement
public void processElement(
@Element KV<String, Node> element, OutputReceiver<Node> receiver) {
receiver.output(element.getValue());
}
}));
}

public static PCollection<Mutation> nodeToMutations(
PCollection<Node> nodes, SpannerClient spannerClient) {
return nodes.apply(
"NodesToMutations",
ParDo.of(
new DoFn<Node, Mutation>() {
@ProcessElement
public void processElement(@Element Node node, OutputReceiver<Mutation> receiver) {
Mutation mutation = spannerClient.toNodeMutation(node);
if (mutation != null) {
receiver.output(mutation);
}
}
}));
}

public static PCollection<Mutation> edgeToMutations(
PCollection<Edge> edges, SpannerClient spannerClient) {
return edges.apply(
"EdgesToMutations",
ParDo.of(
new DoFn<Edge, Mutation>() {
@ProcessElement
public void processElement(@Element Edge edge, OutputReceiver<Mutation> receiver) {
Mutation mutation = spannerClient.toEdgeMutation(edge);
if (mutation != null) {
receiver.output(mutation);
}
}
}));
}

public static List<Node> graphToNodes(McfGraph graph, Counter mcfNodesWithoutTypeCounter) {
List<Node> nodes = new ArrayList<>();
Expand All @@ -42,9 +162,12 @@ public static List<Node> graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp
// Generate corresponding node
Map<String, McfGraph.Values> pv = pvs.getPvsMap();
Node.Builder node = Node.builder();
node.subjectId(nodeEntry.getKey());
node.value(nodeEntry.getKey());
String dcid = GraphUtils.getPropertyValue(pv, "dcid");
String subjectId = !dcid.isEmpty() ? dcid : McfUtil.stripNamespace(nodeEntry.getKey());
node.subjectId(subjectId);
node.value(subjectId);
node.name(GraphUtils.getPropertyValue(pv, "name"));

List<String> types = GraphUtils.getPropertyValues(pv, "typeOf");
if (types.isEmpty()) {
types = List.of(PipelineUtils.TYPE_THING);
Expand All @@ -55,16 +178,25 @@ public static List<Node> graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp
nodes.add(node.build());

// Generate any leaf nodes
for (Map.Entry<String, McfGraph.Values> entry : pv.entrySet()) { // Iterate over properties
for (Map.Entry<String, McfGraph.Values> entry : pv.entrySet()) {
for (TypedValue val : entry.getValue().getTypedValuesList()) {
if (val.getType() != ValueType.RESOLVED_REF) {
int valSize = val.getValue().getBytes(StandardCharsets.UTF_8).length;
if (valSize > SpannerClient.MAX_SPANNER_COLUMN_SIZE) {
LOGGER.warn(
"Dropping node from {} because value size {} exceeds max size.",
subjectId,
valSize);
continue;
}
node = Node.builder();
node.subjectId(PipelineUtils.generateObjectValueKey(val.getValue()));
if (PipelineUtils.storeValueAsBytes(entry.getKey())) {
node.bytes(ByteArray.copyFrom(PipelineUtils.compressString(val.getValue())));
node.bytes(PipelineUtils.compressString(val.getValue()));
} else {
node.value(val.getValue());
}
node.types(List.of(ValueType.TEXT.toString()));
nodes.add(node.build());
}
}
Expand All @@ -74,22 +206,54 @@ public static List<Node> graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp
return nodes;
}

public static PCollection<McfGraph> getProvenanceMcf(
String bucketName, String importName, String latestVersion, Pipeline p) {
String provenanceFile = "gs://" + bucketName + "/" + "provenance/" + importName + ".mcf";
String metadataFile = latestVersion + "/" + IMPORT_METADATA_FILE;
LOGGER.info("Reading provenance mcf from {} {}", provenanceFile, metadataFile);
List<McfGraph> mcfList = new ArrayList<>();
String defaultProvenance =
"Node: dcid:dc/base/" + importName + "\n" + "typeOf: dcid:Provenance\n";
mcfList.add(GraphUtils.convertToGraph(defaultProvenance));
// try {
// mcfList.add(GraphUtils.convertToGraph(PipelineUtils.getGCSFileContent(metadataFile)));
// } catch (IOException e) {
// LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage());
// }
try {
mcfList.add(GraphUtils.convertToGraph(PipelineUtils.getGCSFileContent(provenanceFile)));
} catch (IOException e) {
LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage());
}
return p.apply(Create.of(mcfList).withType(TypeDescriptor.of(McfGraph.class)));
}

public static List<Edge> graphToEdges(McfGraph graph, String provenance) {
List<Edge> edges = new ArrayList<>();
for (Map.Entry<String, PropertyValues> nodeEntry : graph.getNodesMap().entrySet()) {
PropertyValues pvs = nodeEntry.getValue();
if (!GraphUtils.isObservation(pvs)) {
Map<String, McfGraph.Values> pv = pvs.getPvsMap();
// String provenance = GraphUtils.getPropertyValue(pv, "provenance");
String subjectId = nodeEntry.getKey(); // Use the map key as the subjectId
for (Map.Entry<String, McfGraph.Values> entry : pv.entrySet()) { // Iterate over properties
String dcid = GraphUtils.getPropertyValue(pv, "dcid");
String subjectId = !dcid.isEmpty() ? dcid : McfUtil.stripNamespace(nodeEntry.getKey());
for (Map.Entry<String, McfGraph.Values> entry : pv.entrySet()) {
for (TypedValue val : entry.getValue().getTypedValuesList()) {
if (val.getType() != ValueType.RESOLVED_REF) {
int valSize = val.getValue().getBytes(StandardCharsets.UTF_8).length;
if (valSize > SpannerClient.MAX_SPANNER_COLUMN_SIZE) {
LOGGER.warn(
"Dropping edge from {} because value size {} exceeds max size.",
subjectId,
valSize);
continue;
}
}
Edge.Builder edge = Edge.builder();
edge.subjectId(subjectId);
edge.predicate(entry.getKey());
edge.provenance(provenance);
if (val.getType() == ValueType.RESOLVED_REF) {
edge.objectId(val.getValue());
edge.objectId(McfUtil.stripNamespace(val.getValue()));
} else {
edge.objectId(PipelineUtils.generateObjectValueKey(val.getValue()));
}
Expand Down Expand Up @@ -164,6 +328,42 @@ public void processElement(
}));
}

public static PCollection<Node> mcfToNodes(
PCollection<McfGraph> graph, Counter nodeCounter, Counter mcfNodesWithoutTypeCounter) {
return graph.apply(
"McfToNodes",
ParDo.of(
new DoFn<McfGraph, Node>() {
@ProcessElement
public void processElement(@Element McfGraph element, OutputReceiver<Node> receiver) {
List<Node> nodes = graphToNodes(element, mcfNodesWithoutTypeCounter);
for (Node node : nodes) {
// LOGGER.info("Node: {}", node.toString());
receiver.output(node);
}
nodeCounter.inc(nodes.size());
}
}));
}

public static PCollection<Edge> mcfToEdges(
PCollection<McfGraph> graph, String provenance, Counter edgeCounter) {
return graph.apply(
"McfToEdges",
ParDo.of(
new DoFn<McfGraph, Edge>() {
@ProcessElement
public void processElement(@Element McfGraph element, OutputReceiver<Edge> receiver) {
List<Edge> edges = graphToEdges(element, provenance);
for (Edge edge : edges) {
receiver.output(edge);
// LOGGER.info("Edge : {}", edge.toString());
}
edgeCounter.inc(edges.size());
}
}));
}

public static PCollection<KV<String, Mutation>> graphToNodes(
PCollection<McfGraph> graph,
SpannerClient spannerClient,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.datacommons.ingestion.data;

import com.google.cloud.ByteArray;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.beam.sdk.coders.DefaultCoder;
Expand All @@ -16,10 +16,13 @@ public class Node implements Serializable {

private String subjectId;
private String value;
private ByteArray bytes;
private byte[] bytes;
private String name;
private List<String> types;

@SuppressWarnings("unused")
private Node() {}

// Private constructor to enforce use of Builder
private Node(Builder builder) {
this.subjectId = builder.subjectId;
Expand All @@ -41,7 +44,7 @@ public String getValue() {
return value;
}

public ByteArray getBytes() {
public byte[] getBytes() {
return bytes;
}

Expand All @@ -60,7 +63,7 @@ public boolean equals(Object o) {
Node node = (Node) o;
return Objects.equals(subjectId, node.subjectId)
&& Objects.equals(value, node.value)
&& Objects.equals(bytes, node.bytes)
&& Arrays.equals(bytes, node.bytes)
&& Objects.equals(name, node.name)
&& Objects.equals(types, node.types);
}
Expand All @@ -74,13 +77,13 @@ public int hashCode() {
public String toString() {
return String.format(
"Node{subjectId='%s', value='%s', bytes='%s', name='%s', types=%s}",
subjectId, value, bytes, name, types);
subjectId, value, Arrays.toString(bytes), name, types);
}

public static class Builder {
private String subjectId = "";
private String value = "";
private ByteArray bytes = null;
private byte[] bytes = new byte[0];
private String name = "";
private List<String> types = List.of();

Expand All @@ -96,7 +99,7 @@ public Builder value(String value) {
return this;
}

public Builder bytes(ByteArray bytes) {
public Builder bytes(byte[] bytes) {
this.bytes = bytes;
return this;
}
Expand Down
Loading