Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3179] Add GloVe implementation #2201

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions scripts/builtin/glove.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# The implementation is based on
# (1) https://github.com/roamanalytics/mittens/blob/master/mittens/
# (2) https://github.com/stanfordnlp/GloVe/blob/master/src/glove.c
#-------------------------------------------------------------

source("scripts/builtin/cooccurrenceMatrix.dml") as cooc

init = function(matrix[double] cooc_matrix, double x_max, double alpha)
return(matrix[double] weights, matrix[double] log_cooc_matrix){
E = 2.718281828;
bounded = pmin(cooc_matrix, x_max);
weights = pmin(1, (bounded / x_max) ^ alpha);
log_cooc_matrix = ifelse(cooc_matrix > 0, log(cooc_matrix, E), 0);
}

gloveWithCoocMatrix = function(matrix[double] cooc_matrix, frame[Unknown] cooc_index, int seed, int vector_size, double alpha, double eta, double x_max, double tol, int iterations,int print_loss_it)
return (frame[Unknown] G){
/*
* Computes the vector embeddings for words by analyzing their co-occurrence statistics in a large text corpus.
*
* Inputs:
* - cooc_matrix: Precomputed co-occurrence matrix of shape (N, N).
* - cooc_index: Index file mapping words to their positions in the co-occurrence matrix.
* The second column should contain the word list in the same order as the matrix.
* - seed: Random seed for reproducibility.
* - vector_size: Dimensionality of word vectors, V.
* - eta: Learning rate for optimization, recommended value: 0.05.
* - alpha: Weighting function parameter, recommended value: 0.75.
* - x_max: Maximum co-occurrence value as per the GloVe paper: 100.
* - tol: Tolerance value to avoid overfitting, recommended value: 1e-4.
* - iterations: Total number of training iterations.
* - print_loss_it: Interval (in iterations) for printing the loss.
*
* Outputs:
* - G: frame of the word indices and their word vectors, of shape (N, V). Each represented as a vector, of shape (1,V)
*/

vocab_size = nrow(cooc_matrix);
W = (rand(rows=vocab_size, cols=vector_size, min=0, max=1, seed=seed)-0.5)/vector_size;
C = (rand(rows=vocab_size, cols=vector_size, min=0, max=1, seed=seed+1)-0.5)/vector_size;
bw = (rand(rows=vocab_size, cols=1, min=0, max=1, seed=seed+2)-0.5)/vector_size;
bc = (rand(rows=vocab_size, cols=1, min=0, max=1, seed=seed+3)-0.5)/vector_size;
[weights, log_cooc_matrix] = init(cooc_matrix, x_max, alpha);

momentum_W = 1e-8 + 0.1 * matrix(1, nrow(W), ncol(W));
momentum_C = 1e-8 + 0.1 * matrix(1, nrow(C), ncol(C));
momentum_bw = 1e-8 + 0.1 * matrix(1, nrow(bw), ncol(bw));
momentum_bc = 1e-8 + 0.1 * matrix(1, nrow(bc), ncol(bc));

error = 0;
iter = 0;
tolerance = tol;
previous_error = 1e10;
conti = TRUE;

while (conti) {

# compute predictions for all co-occurring word pairs at once
predictions = W %*% t(C) + bw + t(bc);
diffs = predictions - log_cooc_matrix;
weighted_diffs = weights * diffs;

# compute gradients
wgrad = weighted_diffs %*% C;
cgrad = t(weighted_diffs) %*% W;
bwgrad = rowSums(weighted_diffs);
bcgrad = matrix(colSums(weighted_diffs), nrow(bc), ncol(bc));

error = sum(0.5 * (weights * (diffs ^ 2)));
iter = iter + 1;


if (abs(previous_error - error) >= tolerance) {
if(iter <= iterations){

# get steps and update
momentum_W = momentum_W + (wgrad ^ 2);
momentum_C = momentum_C + (cgrad ^ 2);
momentum_bw = momentum_bw + (bwgrad ^ 2);
momentum_bc = momentum_bc + (bcgrad ^ 2);

W = W - (eta * wgrad / (sqrt(momentum_W) + 1e-8));
C = C - (eta * cgrad / (sqrt(momentum_C) + 1e-8));
bw = bw - (eta * bwgrad / (sqrt(momentum_bw) + 1e-8));
bc = bc - (eta * bcgrad / (sqrt(momentum_bc) + 1e-8));

G = W + C;

previous_error = error;

final_iter = iter;
} else {
conti = FALSE;
}
} else {
conti = FALSE;
}

if (iter - floor(iter / print_loss_it) * print_loss_it == 0) {
print("iteration: " + iter + " error: " + error);
}
}

# add the word index to the word vectors
print("Given " + iterations + " iterations, " + "stopped (or converged) at the " + final_iter + " iteration / error: " + error);
G = cbind(cooc_index[,2], as.frame(G));
}

glove = function(
Frame[Unknown] input,
int seed, int vector_size,
double alpha, double eta,
double x_max,
double tol,
int iterations,
int print_loss_it,
Int maxTokens,
Int windowSize,
Boolean distanceWeighting,
Boolean symmetric)
return (frame[Unknown] G){

/*
* Main function to Computes the vector embeddings for words in a large text corpus.
* INPUT:
* ------------------------------------------------------------------------------
* - input (Frame[Unknown]): 1DInput corpus in CSV format.
* - seed: Random seed for reproducibility.
* - vector_size: Dimensionality of word vectors, V.
* - eta: Learning rate for optimization, recommended value: 0.05.
* - alpha: Weighting function parameter, recommended value: 0.75.
* - x_max: Maximum co-occurrence value as per the GloVe paper: 100.
* - tol: Tolerance value to avoid overfitting, recommended value: 1e-4.
* - iterations: Total number of training iterations.
* - print_loss_it: Interval (in iterations) for printing the loss.
* - maxTokens (Int): Maximum number of tokens per text entry.
* - windowSize (Int): Context window size.
* - distanceWeighting (Boolean): Whether to apply distance-based weighting.
* - symmetric (Boolean): Determines if the matrix is symmetric (TRUE) or asymmetric (FALSE).
* ------------------------------------------------------------------------------
* OUTPUT:
* ------------------------------------------------------------------------------
* G (Frame[Unknown]): The word indices and their word vectors, of shape (N, V). Each represented as a vector, of shape (1,V)
* ------------------------------------------------------------------------------
*/

[cooc_matrix, cooc_index] = cooc::cooccurrenceMatrix(input, maxTokens, windowSize, distanceWeighting, symmetric);
G = gloveWithCoocMatrix(cooc_matrix, cooc_index, seed, vector_size, alpha, eta, x_max, tol, iterations, print_loss_it);
}

#input = read("src/test/resources/datasets/20news/20news_subset_untokenized.csv", data_type="frame", format="csv", sep=",");
#G = glove(input[,4], seed, vector_size, alpha, eta, x_max, tol, iterations, print_loss_it, 2600, 15, TRUE, TRUE);
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public enum Builtins {
GET_ACCURACY("getAccuracy", true),
GLM("glm", true),
GLM_PREDICT("glmPredict", true),
GLOVE("glove", true),
GMM("gmm", true),
GMM_PREDICT("gmmPredict", true),
GNMF("gnmf", true),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.builtin.part1;

import java.io.IOException;
import java.util.Objects;

import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.junit.Test;

public class BuiltinGloVeTest extends AutomatedTestBase {

private static final String TEST_NAME = "glove";
private static final String TEST_DIR = "functions/builtin/";
private static final String RESOURCE_DIRECTORY = "./src/test/resources/datasets/";
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGloVeTest.class.getSimpleName() + "/";

private final static Types.ValueType[] schema = {Types.ValueType.STRING};

private static final int TOP_K = 5;
private static final double ACCURACY_THRESHOLD = 0.85;

private static final double seed = 45;
private static final double vector_size = 100;
private static final double alpha = 0.75;
private static final double eta = 0.05;
private static final double x_max = 100;
private static final double tol = 1e-4;
private static final double iterations = 10000;
private static final double print_loss_it = 100;
private static final double maxTokens = 2600;
private static final double windowSize = 15;
private static final String distanceWeighting = "TRUE";
private static final String symmetric = "TRUE";

@Override
public void setUp() {
addTestConfiguration(TEST_NAME,
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"out_result"}));
}

@Test
public void gloveTest() throws IOException{
// Using top-5 words for similarity comparison
runGloVe(TOP_K);

// Read the computed similarity results from SystemDS
FrameBlock computedSimilarity = readDMLFrameFromHDFS("out_result", FileFormat.CSV);

// Load expected results (precomputed in Python)
FrameBlock expectedSimilarity = readDMLFrameFromHDFS(RESOURCE_DIRECTORY + "/GloVe/gloveExpectedTop10.csv", FileFormat.CSV, false);

// Compute accuracy by comparing computed and expected results
double accuracy = computeAccuracy(computedSimilarity, expectedSimilarity, TOP_K);

System.out.println("Computed Accuracy: " + accuracy);

// Ensure accuracy is above a reasonable threshold
assert accuracy > ACCURACY_THRESHOLD : "Accuracy too low! Expected > 85% match.";
}

public void runGloVe(int topK) {
// Load test configuration
Types.ExecMode platformOld = setExecMode(Types.ExecType.CP);
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));

String HOME = SCRIPT_DIR + TEST_DIR;

fullDMLScriptName = HOME + TEST_NAME + ".dml";

programArgs = new String[] {
"-nvargs",
"input=" + RESOURCE_DIRECTORY + "20news/20news_subset_untokenized.csv",
"seed=" + seed,
"vector_size=" + vector_size,
"alpha=" + alpha,
"eta=" + eta,
"x_max=" + x_max,
"tol=" + tol,
"iterations=" + iterations,
"print_loss_it=" + print_loss_it,
"maxTokens=" + maxTokens,
"windowSize=" + windowSize,
"distanceWeighting=" + distanceWeighting,
"symmetric=" + symmetric,
"topK=" + topK,
"out_result=" + output("out_result")
};

System.out.println("Running DML script...");
runTest(true, false, null, -1);
System.out.println("Test execution completed.");
} finally {
rtplatform = platformOld;
}
}

/**
* Computes accuracy by comparing top-K similar words between computed and expected results.
*/
private double computeAccuracy(FrameBlock computed, FrameBlock expected, int k) {
int count = 0;
for (int i = 0; i < computed.getNumRows(); i++) {
int matchCount = 0;
for (int j = 1; j < k; j++) {
String word1 = computed.getString(i, j);
for (int m = 0; m < k; m++) {
if (Objects.equals(word1, expected.getString(i, m))) {
matchCount++;
break;
}
}
}
if (matchCount > 0) {
count++;
}
}
return (double) count / computed.getNumRows();
}
}
Loading
Loading