Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor usage of ShardingSphereDatabase.schemas #33855

Merged
merged 8 commits into from
Dec 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BetweenExpression;
Expand Down Expand Up @@ -63,9 +64,9 @@ public final class EncryptConditionEngine {

private static final Set<String> SUPPORTED_COMPARE_OPERATOR = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);

private final EncryptRule encryptRule;
private final EncryptRule rule;

private final Map<String, ShardingSphereSchema> schemas;
private final ShardingSphereDatabase database;

static {
LOGICAL_OPERATOR.add("AND");
Expand Down Expand Up @@ -96,7 +97,7 @@ public Collection<EncryptCondition> createEncryptConditions(final Collection<Whe
final SQLStatementContext sqlStatementContext, final String databaseName) {
Collection<EncryptCondition> result = new LinkedList<>();
String defaultSchema = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> schemas.get(defaultSchema));
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchema));
Map<String, String> expressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
for (WhereSegment each : whereSegments) {
Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(each.getExpr());
Expand All @@ -122,7 +123,7 @@ private void addEncryptConditions(final Collection<EncryptCondition> encryptCond
}
for (ColumnSegment each : ColumnExtractor.extract(expression)) {
String tableName = expressionTableNames.getOrDefault(each.getExpression(), "");
Optional<EncryptTable> encryptTable = encryptRule.findEncryptTable(tableName);
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
createEncryptCondition(expression, tableName).ifPresent(encryptConditions::add);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private Collection<EncryptCondition> createEncryptConditions(final EncryptRule r
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(
rule, sqlRewriteContext.getDatabase().getSchemas()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext, sqlRewriteContext.getDatabase().getName());
rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext, sqlRewriteContext.getDatabase().getName());
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@
import org.apache.shardingsphere.encrypt.rewrite.token.generator.fixture.EncryptGeneratorFixtureBuilder;
import org.apache.shardingsphere.infra.binder.context.statement.dml.UpdateStatementContext;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class EncryptPredicateRightValueTokenGeneratorTest {

Expand Down Expand Up @@ -61,7 +62,9 @@ void assertGenerateSQLTokenFromGenerateNewSQLToken() {
}

private Collection<EncryptCondition> getEncryptConditions(final UpdateStatementContext updateStatementContext) {
return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule(), Collections.singletonMap(DefaultDatabase.LOGIC_NAME, mock(ShardingSphereSchema.class)))
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule(), database)
.createEncryptConditions(updateStatementContext.getWhereSegments(), updateStatementContext.getColumnSegments(), updateStatementContext, DefaultDatabase.LOGIC_NAME);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void assertNextForDistinctShorthandResultSetsEmpty() throws SQLException {
when(table.getAllColumns()).thenReturn(Collections.emptyList());
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(schema);
when(database.getSchemas()).thenReturn(Collections.singletonMap(DefaultDatabase.LOGIC_NAME, schema));
when(database.getAllSchemas()).thenReturn(Collections.singleton(schema));
when(database.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
ShardingDQLResultMerger merger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = merger.merge(Arrays.asList(queryResult, queryResult, queryResult), createSelectStatementContext(database), database, mock(ConnectionContext.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;

import java.util.Collection;
import java.util.Map;
import java.util.Map.Entry;

/**
Expand All @@ -43,7 +42,7 @@ public final class SupportedSQLCheckEngine {
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public void checkSQL(final Collection<ShardingSphereRule> rules, final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database) {
ShardingSphereSchema currentSchema = getCurrentSchema(sqlStatementContext, database.getSchemas(), database.getName());
ShardingSphereSchema currentSchema = getCurrentSchema(sqlStatementContext, database);
for (Entry<ShardingSphereRule, SupportedSQLCheckersBuilder> entry : OrderedSPILoader.getServices(SupportedSQLCheckersBuilder.class, rules).entrySet()) {
Collection<SupportedSQLChecker> checkers = entry.getValue().getSupportedSQLCheckers();
for (SupportedSQLChecker each : checkers) {
Expand All @@ -54,8 +53,10 @@ public void checkSQL(final Collection<ShardingSphereRule> rules, final SQLStatem
}
}

private ShardingSphereSchema getCurrentSchema(final SQLStatementContext sqlStatementContext, final Map<String, ShardingSphereSchema> schemas, final String databaseName) {
ShardingSphereSchema defaultSchema = schemas.get(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName));
return sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElse(defaultSchema) : defaultSchema;
private ShardingSphereSchema getCurrentSchema(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database) {
ShardingSphereSchema defaultSchema = database.getSchema(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName()));
return sqlStatementContext instanceof TableAvailable
? ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(database::getSchema).orElse(defaultSchema)
: defaultSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ private static ResourceMetaData createResourceMetaData(final Map<StorageNode, Da
return new ResourceMetaData(dataSources, storageUnits);
}

/**
* Get all schemas.
*
* @return all schemas
*/
public Collection<ShardingSphereSchema> getAllSchemas() {
return schemas.values();
}

/**
* Judge contains schema from database or not.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

/**
Expand All @@ -38,15 +37,15 @@ public final class GenericSchemaManager {
/**
* Get to be added tables by schemas.
*
* @param reloadSchemas reload schemas
* @param currentSchemas current schemas
* @param reloadDatabase reload database
* @param currentDatabase current database
* @return To be added table meta data
*/
public static Map<String, ShardingSphereSchema> getToBeAddedTablesBySchemas(final Map<String, ShardingSphereSchema> reloadSchemas, final Map<String, ShardingSphereSchema> currentSchemas) {
Map<String, ShardingSphereSchema> result = new LinkedHashMap<>(currentSchemas.size(), 1F);
reloadSchemas.entrySet().stream().filter(entry -> !currentSchemas.containsKey(entry.getKey())).forEach(entry -> result.put(entry.getKey(), entry.getValue()));
reloadSchemas.entrySet().stream().filter(entry -> currentSchemas.containsKey(entry.getKey())).collect(Collectors.toMap(Entry::getKey, Entry::getValue))
.forEach((key, value) -> result.put(key, getToBeAddedTablesBySchema(value, currentSchemas.get(key))));
public static Collection<ShardingSphereSchema> getToBeAddedTablesBySchemas(final ShardingSphereDatabase reloadDatabase, final ShardingSphereDatabase currentDatabase) {
Collection<ShardingSphereSchema> result = new LinkedList<>();
reloadDatabase.getAllSchemas().stream().filter(each -> !currentDatabase.containsSchema(each.getName())).forEach(result::add);
reloadDatabase.getAllSchemas().stream().filter(each -> currentDatabase.containsSchema(each.getName())).collect(Collectors.toList())
.forEach(each -> result.add(getToBeAddedTablesBySchema(each, currentDatabase.getSchema(each.getName()))));
return result;
}

Expand All @@ -68,14 +67,14 @@ public static Collection<ShardingSphereTable> getToBeAddedTables(final ShardingS
/**
* Get to be dropped tables by schemas.
*
* @param reloadSchemas reload schemas
* @param currentSchemas current schemas
* @param reloadDatabase reload database
* @param currentDatabase current database
* @return to be dropped table
*/
public static Map<String, ShardingSphereSchema> getToBeDroppedTablesBySchemas(final Map<String, ShardingSphereSchema> reloadSchemas, final Map<String, ShardingSphereSchema> currentSchemas) {
Map<String, ShardingSphereSchema> result = new LinkedHashMap<>(currentSchemas.size(), 1F);
currentSchemas.entrySet().stream().filter(entry -> reloadSchemas.containsKey(entry.getKey())).collect(Collectors.toMap(Entry::getKey, Entry::getValue))
.forEach((key, value) -> result.put(key, getToBeDroppedTablesBySchema(reloadSchemas.get(key), value)));
public static Collection<ShardingSphereSchema> getToBeDroppedTablesBySchemas(final ShardingSphereDatabase reloadDatabase, final ShardingSphereDatabase currentDatabase) {
Collection<ShardingSphereSchema> result = new LinkedList<>();
currentDatabase.getAllSchemas().stream().filter(entry -> reloadDatabase.containsSchema(entry.getName())).collect(Collectors.toMap(ShardingSphereSchema::getName, each -> each))
.forEach((key, value) -> result.add(getToBeDroppedTablesBySchema(reloadDatabase.getSchema(key), value)));
return result;
}

Expand All @@ -97,11 +96,11 @@ public static Collection<ShardingSphereTable> getToBeDroppedTables(final Shardin
/**
* Get to be dropped schemas.
*
* @param reloadSchemas reload schemas
* @param currentSchemas current schemas
* @param reloadDatabase reload database
* @param currentDatabase current database
* @return to be dropped schemas
*/
public static Map<String, ShardingSphereSchema> getToBeDroppedSchemas(final Map<String, ShardingSphereSchema> reloadSchemas, final Map<String, ShardingSphereSchema> currentSchemas) {
return currentSchemas.entrySet().stream().filter(entry -> !reloadSchemas.containsKey(entry.getKey())).collect(Collectors.toMap(Entry::getKey, Entry::getValue));
public static Map<String, ShardingSphereSchema> getToBeDroppedSchemas(final ShardingSphereDatabase reloadDatabase, final ShardingSphereDatabase currentDatabase) {
return currentDatabase.getAllSchemas().stream().filter(each -> !reloadDatabase.containsSchema(each.getName())).collect(Collectors.toMap(ShardingSphereSchema::getName, each -> each));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

public final class MySQLShardingSphereStatisticsBuilder implements ShardingSphereStatisticsBuilder {

private static final String SHARDING_SPHERE = "shardingsphere";
private static final String SHARDINGSPHERE = "shardingsphere";

private static final String CLUSTER_INFORMATION = "cluster_information";

Expand All @@ -57,12 +57,12 @@ public ShardingSphereStatistics build(final ShardingSphereMetaData metaData) {
}

private void initSchemas(final ShardingSphereDatabase database, final ShardingSphereDatabaseData databaseData) {
for (Entry<String, ShardingSphereSchema> entry : database.getSchemas().entrySet()) {
if (SHARDING_SPHERE.equals(entry.getKey())) {
for (ShardingSphereSchema each : database.getAllSchemas()) {
if (SHARDINGSPHERE.equals(each.getName())) {
ShardingSphereSchemaData schemaData = new ShardingSphereSchemaData();
initClusterInformationTable(schemaData);
initShardingTableStatisticsTable(schemaData);
databaseData.putSchema(SHARDING_SPHERE, schemaData);
databaseData.putSchema(SHARDINGSPHERE, schemaData);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

public final class PostgreSQLShardingSphereStatisticsBuilder implements ShardingSphereStatisticsBuilder {

private static final String SHARDING_SPHERE = "shardingsphere";
private static final String SHARDINGSPHERE = "shardingsphere";

private static final String CLUSTER_INFORMATION = "cluster_information";

Expand All @@ -68,18 +68,16 @@ public ShardingSphereStatistics build(final ShardingSphereMetaData metaData) {
}

private void initSchemas(final ShardingSphereDatabase database, final ShardingSphereDatabaseData databaseData) {
for (Entry<String, ShardingSphereSchema> entry : database.getSchemas().entrySet()) {
if (SHARDING_SPHERE.equals(entry.getKey())) {
ShardingSphereSchemaData schemaData = new ShardingSphereSchemaData();
initClusterInformationTable(schemaData);
initShardingTableStatisticsTable(schemaData);
databaseData.putSchema(SHARDING_SPHERE, schemaData);
}
if (INIT_DATA_SCHEMA_TABLES.containsKey(entry.getKey())) {
ShardingSphereSchemaData schemaData = new ShardingSphereSchemaData();
initTables(entry.getValue(), INIT_DATA_SCHEMA_TABLES.get(entry.getKey()), schemaData);
databaseData.putSchema(entry.getKey(), schemaData);
}
if (null != database.getSchema(SHARDINGSPHERE)) {
ShardingSphereSchemaData schemaData = new ShardingSphereSchemaData();
initClusterInformationTable(schemaData);
initShardingTableStatisticsTable(schemaData);
databaseData.putSchema(SHARDINGSPHERE, schemaData);
}
for (String each : INIT_DATA_SCHEMA_TABLES.keySet()) {
ShardingSphereSchemaData schemaData = new ShardingSphereSchemaData();
initTables(database.getSchema(each), INIT_DATA_SCHEMA_TABLES.get(each), schemaData);
databaseData.putSchema(each, schemaData);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

/**
Expand All @@ -49,11 +48,9 @@ public final class PgClassTableCollector implements ShardingSphereStatisticsColl
public Optional<ShardingSphereTableData> collect(final String databaseName, final ShardingSphereTable table, final Map<String, ShardingSphereDatabase> databases,
final RuleMetaData globalRuleMetaData) throws SQLException {
ShardingSphereTableData result = new ShardingSphereTableData(PG_CLASS);
long oid = 0L;
for (Entry<String, ShardingSphereSchema> entry : databases.get(databaseName).getSchemas().entrySet()) {
if (PUBLIC_SCHEMA.equalsIgnoreCase(entry.getKey())) {
result.getRows().addAll(collectForSchema(oid++, PUBLIC_SCHEMA_OID, entry.getValue(), table));
}
ShardingSphereSchema publicSchema = databases.get(databaseName).getSchema(PUBLIC_SCHEMA);
if (null != publicSchema) {
result.getRows().addAll(collectForSchema(0L, PUBLIC_SCHEMA_OID, publicSchema, table));
}
return Optional.of(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

/**
Expand All @@ -49,8 +48,8 @@ public Optional<ShardingSphereTableData> collect(final String databaseName, fina
final RuleMetaData globalRuleMetaData) throws SQLException {
ShardingSphereTableData result = new ShardingSphereTableData(PG_NAMESPACE);
long oid = 1L;
for (Entry<String, ShardingSphereSchema> entry : databases.get(databaseName).getSchemas().entrySet()) {
result.getRows().add(new ShardingSphereRowData(getRow(PUBLIC_SCHEMA.equalsIgnoreCase(entry.getKey()) ? PUBLIC_SCHEMA_OID : oid++, entry.getKey(), table)));
for (ShardingSphereSchema each : databases.get(databaseName).getAllSchemas()) {
result.getRows().add(new ShardingSphereRowData(getRow(PUBLIC_SCHEMA.equalsIgnoreCase(each.getName()) ? PUBLIC_SCHEMA_OID : oid++, each.getName(), table)));
}
return Optional.of(result);
}
Expand Down
Loading
Loading