Skip to content

Commit

Permalink
Add sql bind logic for create table statement and check simple table …
Browse files Browse the repository at this point in the history
…binder (#34074)

* Add sql bind logic for create table statement

* update release note

* fix unit test

* Pass metadata to PipelineDDLGenerator

* fix checkstyle

* Add it test for create table statement

* fix unit test

* fix unit test
  • Loading branch information
strongduanmu authored Dec 17, 2024
1 parent 8c33423 commit 694c6bf
Show file tree
Hide file tree
Showing 61 changed files with 598 additions and 139 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
1. Proxy Native: Support Seata AT integration under Proxy Native in GraalVM Native Image - [#33889](https://github.com/apache/shardingsphere/pull/33889)
1. Agent: Simplify the use of Agent's Docker Image - [#33356](https://github.com/apache/shardingsphere/pull/33356)
1. Metadata: Add load-table-metadata-batch-size props to concurrent load table metadata - [#34009](https://github.com/apache/shardingsphere/pull/34009)
1. SQL Binder: Add sql bind logic for create table statement - [#34074](https://github.com/apache/shardingsphere/pull/34074)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public Stream<? extends Arguments> provideArguments(final ExtensionContext exten
Arguments.of("update t_warehouse set warehouse_name = ? where id = ?", Arrays.asList("foo", 1), true, Collections.singletonList(1)),
Arguments.of("delete from t_warehouse where id = ?", Collections.singletonList(1), true, Collections.singletonList(0)));
Collection<? extends Arguments> nonCacheableCases = Arrays.asList(
Arguments.of("create table t_warehouse (id int4 not null primary key)", Collections.emptyList(), false, Collections.emptyList()),
Arguments.of("create table t_warehouse_for_create (id int4 not null primary key)", Collections.emptyList(), false, Collections.emptyList()),
Arguments.of("insert into t_warehouse (id) select warehouse_id from t_order", Collections.emptyList(), false, Collections.emptyList()),
Arguments.of("insert into t_warehouse (id) values (?), (?)", Arrays.asList(1, 2), false, Collections.emptyList()),
Arguments.of("insert into t_non_sharding_table (id) values (?)", Collections.singletonList(1), false, Collections.emptyList()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class ShardingCreateFunctionSupportedCheckerTest {
void assertCheckCreateFunctionForMySQL() {
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item"))));
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down Expand Up @@ -104,7 +105,8 @@ void assertCheckCreateFunctionWithNoSuchTableForMySQL() {

@Test
void assertCheckCreateFunctionWithTableExistsForMySQL() {
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class ShardingCreateProcedureSupportedCheckerTest {
void assertCheckForMySQL() {
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item"))));
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down Expand Up @@ -105,7 +106,8 @@ void assertCheckWithNoSuchTableForMySQL() {

@Test
void assertCheckWithTableExistsForMySQL() {
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class ShardingCreateTableSupportedCheckerTest {

@Test
void assertCheckForMySQL() {
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertThrows(TableExistsException.class, () -> assertCheck(sqlStatement));
}
Expand All @@ -63,7 +64,8 @@ void assertCheckForOracle() {

@Test
void assertCheckForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertThrows(TableExistsException.class, () -> assertCheck(sqlStatement));
}
Expand Down Expand Up @@ -92,14 +94,16 @@ private void assertCheck(final CreateTableStatement sqlStatement) {

@Test
void assertCheckIfNotExistsForMySQL() {
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement(true);
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement();
sqlStatement.setIfNotExists(true);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertCheckIfNotExists(sqlStatement);
}

@Test
void assertCheckIfNotExistsForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(true);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(true);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertCheckIfNotExists(sqlStatement);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class ShardingCreateTableRouteContextCheckerTest {

@Test
void assertCheckWithSameRouteResultShardingTableForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
when(shardingRule.isShardingTable("t_order")).thenReturn(true);
when(shardingRule.getShardingTable("t_order")).thenReturn(new ShardingTable(Arrays.asList("ds_0", "ds_1"), "t_order"));
Expand All @@ -78,7 +79,8 @@ void assertCheckWithSameRouteResultShardingTableForPostgreSQL() {

@Test
void assertCheckWithDifferentRouteResultShardingTableForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
when(shardingRule.isShardingTable("t_order")).thenReturn(true);
when(shardingRule.getShardingTable("t_order")).thenReturn(new ShardingTable(Arrays.asList("ds_0", "ds_1"), "t_order"));
Expand All @@ -92,7 +94,8 @@ void assertCheckWithDifferentRouteResultShardingTableForPostgreSQL() {

@Test
void assertCheckWithSameRouteResultBroadcastTableForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_config"))));
when(queryContext.getSqlStatementContext()).thenReturn(new CreateTableStatementContext(sqlStatement));
assertDoesNotThrow(() -> new ShardingCreateTableRouteContextChecker().check(shardingRule, queryContext, database, mock(ConfigurationProperties.class), routeContext));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
*/
public enum SegmentType {

PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, SET_ASSIGNMENT, VALUES, INSERT_COLUMNS
PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, SET_ASSIGNMENT, VALUES, INSERT_COLUMNS, DEFINITION_COLUMNS
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.column;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
import org.apache.shardingsphere.infra.binder.engine.segment.expression.type.ColumnSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.type.SimpleTableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.column.ColumnDefinitionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;

/**
* Column definition segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ColumnDefinitionSegmentBinder {

/**
* Bind column definition segment.
*
* @param segment column definition segment
* @param binderContext SQL statement binder context
* @param tableBinderContexts table binder contexts
* @return bound column definition segment
*/
public static ColumnDefinitionSegment bind(final ColumnDefinitionSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
ColumnSegment boundColumnSegment = ColumnSegmentBinder.bind(segment.getColumnName(), SegmentType.DEFINITION_COLUMNS, binderContext, tableBinderContexts, LinkedHashMultimap.create());
ColumnDefinitionSegment result =
new ColumnDefinitionSegment(segment.getStartIndex(), segment.getStopIndex(), boundColumnSegment, segment.getDataType(), segment.isPrimaryKey(), segment.isNotNull(), segment.getText());
copy(segment, result);
segment.getReferencedTables().forEach(each -> result.getReferencedTables().add(SimpleTableSegmentBinder.bind(each, binderContext, tableBinderContexts)));
return result;
}

private static void copy(final ColumnDefinitionSegment result, final ColumnDefinitionSegment segment) {
result.setAutoIncrement(segment.isAutoIncrement());
result.setRef(segment.isRef());
result.setCharsetName(segment.getCharsetName());
result.setCollateName(segment.getCollateName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static CombineSegment bind(final CombineSegment segment, final SQLStateme
private static SubquerySegment bindSubquerySegment(final SubquerySegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveMap.CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
SubquerySegment result = new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), segment.getText());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSelect());
subqueryBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
result.setSelect(new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), subqueryBinderContext));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private static Optional<ColumnSegment> findInputColumnSegment(final ColumnSegmen
}
}
if (!isFindInputColumn) {
result = findInputColumnSegmentByVariables(segment, binderContext.getVariableNames()).orElse(null);
result = findInputColumnSegmentByVariables(segment, binderContext.getSqlStatement().getVariableNames()).orElse(null);
isFindInputColumn = null != result;
}
if (!isFindInputColumn) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public final class SubquerySegmentBinder {
*/
public static SubquerySegment bind(final SubquerySegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSelect());
selectBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
SelectStatement boundSelectStatement = new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), selectBinderContext);
return new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), boundSelectStatement, segment.getText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ public static JoinTableSegment bind(final JoinTableSegment segment, final SQLSta
result.setDerivedUsing(bindUsingColumns(derivedUsingColumns, tableBinderContexts));
result.getDerivedUsing().forEach(each -> binderContext.getUsingColumnNames().add(each.getIdentifier().getValue()));
}
result.getDerivedJoinTableProjectionSegments().addAll(getDerivedJoinTableProjectionSegments(result, binderContext.getDatabaseType(), usingColumnsByNaturalJoin, tableBinderContexts));
result.getDerivedJoinTableProjectionSegments()
.addAll(getDerivedJoinTableProjectionSegments(result, binderContext.getSqlStatement().getDatabaseType(), usingColumnsByNaturalJoin, tableBinderContexts));
binderContext.getJoinTableProjectionSegments().addAll(result.getDerivedJoinTableProjectionSegments());
return result;
}
Expand Down
Loading

0 comments on commit 694c6bf

Please sign in to comment.