Skip to content

Commit

Permalink
Merge pull request #33613 from terrymanu/dev
Browse files Browse the repository at this point in the history
Refactor EncryptParameterRewriterBuilder
  • Loading branch information
iamhucong authored Nov 11, 2024
2 parents 0afafe0 + 7fa23a5 commit 9d9af4d
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ public void decorate(final EncryptRule rule, final ConfigurationProperties props
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlRewriteContext);
String databaseName = sqlRewriteContext.getDatabase().getName();
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters =
new EncryptParameterRewriterBuilder(rule, databaseName, sqlRewriteContext.getDatabase().getSchemas(), sqlStatementContext, encryptConditions).getParameterRewriters();
Collection<ParameterRewriter> parameterRewriters = new EncryptParameterRewriterBuilder(rule, databaseName, sqlStatementContext, encryptConditions).getParameterRewriters();
rewriteParameters(sqlRewriteContext, parameterRewriters);
}
SQLTokenGeneratorBuilder sqlTokenGeneratorBuilder = new EncryptTokenGenerateBuilder(rule, sqlStatementContext, encryptConditions, databaseName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,19 @@
package org.apache.shardingsphere.encrypt.rewrite.parameter;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseTypeAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptConditionsAware;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptAssignmentParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertPredicateParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertValueParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptPredicateParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertPredicateParameterRewriter;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriterBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;

import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;

/**
* Parameter rewriter builder for encrypt.
Expand All @@ -49,43 +42,24 @@ public final class EncryptParameterRewriterBuilder implements ParameterRewriterB

private final String databaseName;

private final Map<String, ShardingSphereSchema> schemas;

private final SQLStatementContext sqlStatementContext;

private final Collection<EncryptCondition> encryptConditions;

@Override
public Collection<ParameterRewriter> getParameterRewriters() {
Collection<ParameterRewriter> result = new LinkedList<>();
addParameterRewriter(result, new EncryptAssignmentParameterRewriter(rule));
addParameterRewriter(result, new EncryptPredicateParameterRewriter(rule));
addParameterRewriter(result, new EncryptInsertPredicateParameterRewriter(rule));
addParameterRewriter(result, new EncryptInsertValueParameterRewriter(rule));
addParameterRewriter(result, new EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter(rule));
addParameterRewriter(result, new EncryptAssignmentParameterRewriter(rule, databaseName));
addParameterRewriter(result, new EncryptPredicateParameterRewriter(rule, databaseName, encryptConditions));
addParameterRewriter(result, new EncryptInsertPredicateParameterRewriter(rule, databaseName, encryptConditions));
addParameterRewriter(result, new EncryptInsertValueParameterRewriter(rule, databaseName));
addParameterRewriter(result, new EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter(rule, databaseName));
return result;
}

private void addParameterRewriter(final Collection<ParameterRewriter> paramRewriters, final ParameterRewriter toBeAddedParamRewriter) {
if (toBeAddedParamRewriter.isNeedRewrite(sqlStatementContext)) {
setUpParameterRewriter(toBeAddedParamRewriter);
paramRewriters.add(toBeAddedParamRewriter);
}
}

private void setUpParameterRewriter(final ParameterRewriter toBeAddedParamRewriter) {
if (toBeAddedParamRewriter instanceof SchemaMetaDataAware) {
((SchemaMetaDataAware) toBeAddedParamRewriter).setSchemas(schemas);
((SchemaMetaDataAware) toBeAddedParamRewriter).setDefaultSchema(schemas.get(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName)));
}
if (toBeAddedParamRewriter instanceof EncryptConditionsAware) {
((EncryptConditionsAware) toBeAddedParamRewriter).setEncryptConditions(encryptConditions);
}
if (toBeAddedParamRewriter instanceof DatabaseNameAware) {
((DatabaseNameAware) toBeAddedParamRewriter).setDatabaseName(databaseName);
}
if (toBeAddedParamRewriter instanceof DatabaseTypeAware) {
((DatabaseTypeAware) toBeAddedParamRewriter).setDatabaseType(sqlStatementContext.getDatabaseType());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
Expand Down Expand Up @@ -49,12 +47,11 @@
* Assignment parameter rewriter for encrypt.
*/
@RequiredArgsConstructor
@Setter
public final class EncryptAssignmentParameterRewriter implements ParameterRewriter, DatabaseNameAware {
public final class EncryptAssignmentParameterRewriter implements ParameterRewriter {

private final EncryptRule rule;

private String databaseName;
private final String databaseName;

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter;

import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.OnDuplicateUpdateContext;
Expand All @@ -39,12 +37,11 @@
* Insert on duplicate key update parameter rewriter for encrypt.
*/
@RequiredArgsConstructor
@Setter
public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter implements ParameterRewriter, DatabaseNameAware {
public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter implements ParameterRewriter {

private final EncryptRule rule;

private String databaseName;
private final String databaseName;

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
package org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter;

import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptConditionsAware;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
Expand All @@ -35,14 +32,13 @@
* Insert predicate parameter rewriter for encrypt.
*/
@RequiredArgsConstructor
@Setter
public final class EncryptInsertPredicateParameterRewriter implements ParameterRewriter, EncryptConditionsAware, DatabaseNameAware {
public final class EncryptInsertPredicateParameterRewriter implements ParameterRewriter {

private final EncryptRule rule;

private Collection<EncryptCondition> encryptConditions;
private final String databaseName;

private String databaseName;
private final Collection<EncryptCondition> encryptConditions;

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
Expand All @@ -52,9 +48,7 @@ public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {

@Override
public void rewrite(final ParameterBuilder paramBuilder, final SQLStatementContext sqlStatementContext, final List<Object> params) {
EncryptPredicateParameterRewriter rewriter = new EncryptPredicateParameterRewriter(rule);
rewriter.setEncryptConditions(encryptConditions);
rewriter.setDatabaseName(databaseName);
EncryptPredicateParameterRewriter rewriter = new EncryptPredicateParameterRewriter(rule, databaseName, encryptConditions);
rewriter.rewrite(paramBuilder, ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext(), params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter;

import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
Expand All @@ -44,12 +42,11 @@
* Insert value parameter rewriter for encrypt.
*/
@RequiredArgsConstructor
@Setter
public final class EncryptInsertValueParameterRewriter implements ParameterRewriter, DatabaseNameAware {
public final class EncryptInsertValueParameterRewriter implements ParameterRewriter {

private final EncryptRule rule;

private String databaseName;
private final String databaseName;

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
package org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter;

import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptConditionsAware;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.encrypt.rewrite.condition.impl.EncryptBinaryCondition;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
Expand All @@ -43,14 +40,13 @@
* Predicate parameter rewriter for encrypt.
*/
@RequiredArgsConstructor
@Setter
public final class EncryptPredicateParameterRewriter implements ParameterRewriter, EncryptConditionsAware, DatabaseNameAware {
public final class EncryptPredicateParameterRewriter implements ParameterRewriter {

private final EncryptRule rule;

private Collection<EncryptCondition> encryptConditions;
private final String databaseName;

private String databaseName;
private final Collection<EncryptCondition> encryptConditions;

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.junit.jupiter.api.Test;

Expand All @@ -31,32 +30,20 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class EncryptParameterRewriterBuilderTest {

@Test
void assertGetParameterRewritersWhenPredicateIsNeedRewrite() {
EncryptRule encryptRule = mock(EncryptRule.class, RETURNS_DEEP_STUBS);
when(encryptRule.findEncryptTable("t_order").isPresent()).thenReturn(true);
void assertGetParameterRewriters() {
EncryptRule rule = mock(EncryptRule.class, RETURNS_DEEP_STUBS);
when(rule.findEncryptTable("foo_tbl").isPresent()).thenReturn(true);
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("t_order"));
Collection<ParameterRewriter> actual = new EncryptParameterRewriterBuilder(
encryptRule, DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)), sqlStatementContext, Collections.emptyList()).getParameterRewriters();
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("foo_tbl"));
Collection<ParameterRewriter> actual = new EncryptParameterRewriterBuilder(rule, DefaultDatabase.LOGIC_NAME, sqlStatementContext, Collections.emptyList()).getParameterRewriters();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), instanceOf(EncryptPredicateParameterRewriter.class));
}

@Test
void assertGetParameterRewritersWhenPredicateIsNotNeedRewrite() {
EncryptRule encryptRule = mock(EncryptRule.class, RETURNS_DEEP_STUBS);
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("t_order"));
when(sqlStatementContext.getWhereSegments()).thenReturn(Collections.emptyList());
assertTrue(new EncryptParameterRewriterBuilder(encryptRule,
DefaultDatabase.LOGIC_NAME, Collections.singletonMap("test", mock(ShardingSphereSchema.class)), sqlStatementContext, Collections.emptyList()).getParameterRewriters().isEmpty());
}
}

0 comments on commit 9d9af4d

Please sign in to comment.