Skip to content
2 changes: 1 addition & 1 deletion sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ ALTER TABLE exist_db.exist_tb_1 Add index idx_2 (id,id);
ALTER TABLE exist_db.exist_tb_1 Add index (id,id);
`,
newTestResult().add(driver.RuleLevelError, DuplicateIndexedColumnMessage, "(匿名)",
"id").addResult(rulepkg.DDLCheckIndexPrefix, "idx_"),
"id").addResult(rulepkg.DDLCheckIndexPrefix, "idx_").addResult(rulepkg.DDLCheckIndexNameExisted),
)
}

Expand Down
332 changes: 319 additions & 13 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ const (
DDLNotAllowRenaming = "ddl_not_allow_renaming"
DDLCheckObjectNameIsUpperAndLowerLetterMixed = "ddl_check_object_name_is_upper_and_lower_letter_mixed"
DDLCheckFieldNotNUllMustContainDefaultValue = "ddl_check_field_not_null_must_contain_default_value"
DDLCheckIndexNameExisted = "ddl_check_index_name_existed"
DDLCheckTableRowLength = "ddl_check_table_row_length"
)

// inspector DML rules
Expand Down Expand Up @@ -169,6 +171,8 @@ const (
ConfigOptimizeIndexEnabled = "optimize_index_enabled"
ConfigDMLExplainPreCheckEnable = "dml_enable_explain_pre_check"
ConfigSQLIsExecuted = "sql_is_executed"
ConfigAvoidSet = "config_avoid_set"
ConfigCheckEventScheduler = "config_check_event_scheduler"
)

type RuleHandlerInput struct {
Expand Down Expand Up @@ -1881,6 +1885,62 @@ var RuleHandlers = []RuleHandler{
Message: "禁止使用rename或change对表名字段名进行修改",
Func: ddlNotAllowRenaming,
},
{
Rule: driver.Rule{
Name: DDLCheckIndexNameExisted,
Desc: "索引必须设置索引名",
Annotation: "普通索引定义索引名,且名称遵循固定的命名规范、避免特殊字符的使用,可以提高代码的可读性、可维护性,并减少潜在的兼容性和语法问题。",
Level: driver.RuleLevelNormal,
Category: RuleTypeNamingConvention,
},
AllowOffline: true,
Message: "索引必须设置索引名",
Func: checkIndexNameExisted,
},
{
Rule: driver.Rule{
Name: DDLCheckTableRowLength,
Desc: "表设计做到行不跨页",
Annotation: "在表设计时,应该尽量确保每一行数据都不会跨越数据页(Page)的边界,以提高数据的读取和写入性能,减少物理I/O操作,并优化存储空间的利用率。",
Level: driver.RuleLevelWarn,
Category: RuleTypeDDLConvention,
Params: params.Params{
&params.Param{
Key: DefaultSingleParamKeyName,
Value: "65535",
Desc: "最大行长 (byte)",
Type: params.ParamTypeInt,
},
},
},
AllowOffline: true,
Message: "表设计做到行不跨页",
Func: checkTableRowLength,
},
{
Rule: driver.Rule{
Name: ConfigAvoidSet,
Desc: "不允许使用SET操作",
Annotation: "禁止使用SET命令来修改MySQL的系统参数,以确保数据库的稳定性、一致性和安全性。",
Level: driver.RuleLevelError,
Category: RuleTypeGlobalConfig,
},
AllowOffline: true,
Message: "不允许使用SET操作",
Func: avoidSet,
},
{
Rule: driver.Rule{
Name: ConfigCheckEventScheduler,
Desc: "禁止使用event scheduler",
Annotation: "禁用MySQL的事件调度器(event_scheduler),以提高数据库的安全性、稳定性和可控性,避免非预期的事件执行对系统造成影响。",
Level: driver.RuleLevelError,
Category: RuleTypeGlobalConfig,
},
AllowOffline: true,
Message: "禁止使用event schedule",
Func: checkEventScheduler,
},
}

func checkFieldNotNUllMustContainDefaultValue(input *RuleHandlerInput) error {
Expand Down Expand Up @@ -4352,12 +4412,13 @@ var createTriggerReg1 = regexp.MustCompile(`(?i)create[\s]+trigger[\s]+[\S\s]+be
var createTriggerReg2 = regexp.MustCompile(`(?i)create[\s]+[\s\S]+[\s]+trigger[\s]+[\S\s]+before|after`)

// CREATE
// [DEFINER = user]
// TRIGGER trigger_name
// trigger_time trigger_event
// ON tbl_name FOR EACH ROW
// [trigger_order]
// trigger_body
//
// [DEFINER = user]
// TRIGGER trigger_name
// trigger_time trigger_event
// ON tbl_name FOR EACH ROW
// [trigger_order]
// trigger_body
//
// ref:https://dev.mysql.com/doc/refman/8.0/en/create-trigger.html
//
Expand All @@ -4378,10 +4439,11 @@ var createFunctionReg1 = regexp.MustCompile(`(?i)create[\s]+function[\s]+[\S\s]+
var createFunctionReg2 = regexp.MustCompile(`(?i)create[\s]+[\s\S]+[\s]+function[\s]+[\S\s]+returns`)

// CREATE
// [DEFINER = user]
// FUNCTION sp_name ([func_parameter[,...]])
// RETURNS type
// [characteristic ...] routine_body
//
// [DEFINER = user]
// FUNCTION sp_name ([func_parameter[,...]])
// RETURNS type
// [characteristic ...] routine_body
//
// ref: https://dev.mysql.com/doc/refman/5.7/en/create-procedure.html
// For now, we do character matching for CREATE FUNCTION Statement. Maybe we need
Expand All @@ -4401,9 +4463,10 @@ var createProcedureReg1 = regexp.MustCompile(`(?i)create[\s]+procedure[\s]+[\S\s
var createProcedureReg2 = regexp.MustCompile(`(?i)create[\s]+[\s\S]+[\s]+procedure[\s]+[\S\s]+`)

// CREATE
// [DEFINER = user]
// PROCEDURE sp_name ([proc_parameter[,...]])
// [characteristic ...] routine_body
//
// [DEFINER = user]
// PROCEDURE sp_name ([proc_parameter[,...]])
// [characteristic ...] routine_body
//
// ref: https://dev.mysql.com/doc/refman/8.0/en/create-procedure.html
// For now, we do character matching for CREATE PROCEDURE Statement. Maybe we need
Expand Down Expand Up @@ -5097,3 +5160,246 @@ func ddlNotAllowRenaming(input *RuleHandlerInput) error {
}
return nil
}

func checkIndexNameExisted(input *RuleHandlerInput) error {
indexNameNotExisted := false
switch stmt := input.Node.(type) {
case *ast.CreateTableStmt:
for _, constraint := range stmt.Constraints {
switch constraint.Tp {
case ast.ConstraintIndex, ast.ConstraintUniqIndex, ast.ConstraintKey, ast.ConstraintUniqKey:
if constraint.Name == "" {
indexNameNotExisted = true
}
default:
return nil
}
}
case *ast.AlterTableStmt:
for _, spec := range stmt.Specs {
if spec.Tp == ast.AlterTableAddConstraint && IsIndexConstraint(spec.Constraint.Tp) {
// 遍历Keys
if spec.Constraint.Name == "" {
indexNameNotExisted = true
}
}
}
default:
return nil
}
if indexNameNotExisted {
addResult(input.Res, input.Rule, DDLCheckIndexNameExisted)
}
return nil
}

func IsIndexConstraint(constraintType ast.ConstraintType) bool {
return constraintType == ast.ConstraintIndex || constraintType == ast.ConstraintUniqIndex || constraintType == ast.ConstraintKey || constraintType == ast.ConstraintUniqKey
}

func checkTableRowLength(input *RuleHandlerInput) error {
var rowLengthLimit = input.Rule.Params.GetParam(DefaultSingleParamKeyName).Int()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

若表实际的页长小于人工设置的页长,则该规则失效

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充进文档

rowLength := 0
switch stmt := input.Node.(type) {
case *ast.CreateTableStmt:
charsetNum := GetTableCharsetNum(stmt.Options)
for _, col := range stmt.Cols {
colCharsetNum := MappingCharsetLength(col.Tp.Charset)
// 可能会设置列级别的字符串
if charsetNum != colCharsetNum {
charsetNum = colCharsetNum
}
oneColumnLength := ComputeOneColumnLength(col, charsetNum)
rowLength += oneColumnLength
}
case *ast.AlterTableStmt:
// 获取在线表信息
tableStmt, tableExist, err := input.Ctx.GetCreateTableStmt(stmt.Table)
if !tableExist || err != nil {
return err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果table不存在,那error会不等于nil吗?
这是两种情况,都用返回error处理可以吗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果表不存在,返回error,会不会导致审核阻塞?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 这里有问题

}
charsetNum := GetTableCharsetNum(tableStmt.Options)
columnLengthMap := make(map[string]int, len(tableStmt.Cols))
// 计算原表的长度
for _, col := range tableStmt.Cols {
colCharsetNum := MappingCharsetLength(col.Tp.Charset)
if charsetNum != colCharsetNum {
charsetNum = colCharsetNum
}
oneColumnLength := ComputeOneColumnLength(col, charsetNum)
rowLength += oneColumnLength
columnLengthMap[col.Name.String()] = oneColumnLength
}
// 计算alter语句修改列之后的长度
for _, alteredSpec := range stmt.Specs {
for _, alterCol := range alteredSpec.NewColumns {
if alterCol.Tp == nil {
// 不是对于列类型相关的变更
continue
}
// 可能会设置列级别的字符串
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不理解这个注释的含义
注释+代码,需要能够说明这里为什么这么做,做了什么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加注释

colCharsetNum := MappingCharsetLength(alterCol.Tp.Charset)
if charsetNum != colCharsetNum {
charsetNum = colCharsetNum
}
if alteredSpec.Tp == ast.AlterTableAddColumns {
rowLength += ComputeOneColumnLength(alterCol, charsetNum)
}
if alteredSpec.Tp == ast.AlterTableModifyColumn {
// 如果是修改某个字段,减去原来字段的长度,使用新的字段长度
rowLength -= columnLengthMap[alterCol.Name.String()]
rowLength += ComputeOneColumnLength(alterCol, charsetNum)
}
}
}
default:
return nil
}
if rowLength > rowLengthLimit {
addResult(input.Res, input.Rule, DDLCheckTableRowLength)
}
return nil
}

func GetTableCharsetNum(options []*ast.TableOption) int {
charsetNum := 4
for _, opt := range options {
if opt.Tp == ast.TableOptionCharset {
charsetNum = MappingCharsetLength(opt.StrValue)
}
}
return charsetNum
}

// ComputeOneColumnLength 计算一个列的长度
Copy link
Collaborator

@winfredLIN winfredLIN Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

贴一下长度计算的依据:

  1. 官方文档
  2. 用ai使用mysql的风格格式绘制一个表格放在这里

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了注释

func ComputeOneColumnLength(columnDef *ast.ColumnDef, charsetNum int) int {
oneColumnLength := 0
switch columnDef.Tp.Tp {
case mysql.TypeVarchar:
// 0~255 长度需要一个字节存储长度
lLength := 1
if columnDef.Tp.Flen > 255 {
// > 255 需要两个字节来存储长度
lLength = 2
}
// length * charsetNum + notNull + lLength
oneColumnLength = columnDef.Tp.Flen*charsetNum + OptionNotNullLength(columnDef.Options) + lLength
case mysql.TypeString:
oneColumnLength = columnDef.Tp.Flen*charsetNum + OptionNotNullLength(columnDef.Options)
case mysql.TypeYear, mysql.TypeTiny:
oneColumnLength = 1 + OptionNotNullLength(columnDef.Options)
case mysql.TypeDate, mysql.TypeInt24:
// DATE MEDIUMINT
oneColumnLength = 3 + OptionNotNullLength(columnDef.Options)
case mysql.TypeDuration:
// TIME
oneColumnLength = 3 + OptionNotNullLength(columnDef.Options) + typeTimePrecisionLength(columnDef.Tp.Decimal)
case mysql.TypeDatetime:
oneColumnLength = 5 + OptionNotNullLength(columnDef.Options) + typeTimePrecisionLength(columnDef.Tp.Decimal)
case mysql.TypeTimestamp:
oneColumnLength = 4 + OptionNotNullLength(columnDef.Options) + typeTimePrecisionLength(columnDef.Tp.Decimal)
case mysql.TypeShort:
// SMALLINT
oneColumnLength = 2 + OptionNotNullLength(columnDef.Options)
case mysql.TypeLong, mysql.TypeFloat:
// INT FLOAT
oneColumnLength = 4 + OptionNotNullLength(columnDef.Options)
case mysql.TypeLonglong:
// BIGINT
oneColumnLength = 8 + OptionNotNullLength(columnDef.Options)
case mysql.TypeDouble:
// BIGINT DOUBLE REAL
oneColumnLength = 8 + OptionNotNullLength(columnDef.Options)
case mysql.TypeNewDecimal:
// 整数部分
partition := (columnDef.Tp.Flen - columnDef.Tp.Decimal) / 9
oneColumnLength += partition * 4
oneColumnLength += decimalLeftoverLength((columnDef.Tp.Flen - columnDef.Tp.Decimal) % 9)
// 小数部分
decimalPartition := columnDef.Tp.Decimal / 9
oneColumnLength += decimalPartition * 4
oneColumnLength += decimalLeftoverLength((columnDef.Tp.Decimal) % 9)
}
return oneColumnLength
}

// typeTimePrecisionLength 时间类型会根据精度的不同有不同的存储大小
// decimal bytes
// 0 0
// 1,2 1
// 3,4 2
// 5,6 3
func typeTimePrecisionLength(decimal int) int {
if decimal < 0 {
return 0
} else if decimal < 3 {
return 1
} else if decimal < 5 {
return 2
} else if decimal < 7 {
return 3
}
return 0
}

// decimalLeftoverLength decimal被9整除后的部分,根据位数使用相印字节数
// leftover bytes
// 1-2 1
// 3-4 2
// 5-6 3
// 7-9 4
func decimalLeftoverLength(leftover int) int {
if leftover < 0 {
return 0
} else if leftover < 3 {
return 1
} else if leftover < 5 {
return 2
} else if leftover < 7 {
return 3
} else if leftover < 10 {
return 4
}
return 0
}

// OptionNotNullLength 当有not null 约束时会占用一个字节
func OptionNotNullLength(columnOptions []*ast.ColumnOption) int {
for _, option := range columnOptions {
if option.Tp == ast.ColumnOptionNotNull {
return 0
}
}
return 1
}

// MappingCharsetLength 不同的字符集会用不同数量表示一个字符
func MappingCharsetLength(charset string) int {
charNum := 4
switch charset {
case "utf8mb4", "utf16", "utf16le", "utf32":
charNum = 4
case "utf8":
charNum = 3
default:
charNum = 4
}
return charNum
}

func avoidSet(input *RuleHandlerInput) error {
switch input.Node.(type) {
case *ast.SetStmt:
addResult(input.Res, input.Rule, ConfigAvoidSet)
default:
return nil
}
return nil
}

func checkEventScheduler(input *RuleHandlerInput) error {
if utils.IsOpenEventScheduler(input.Node.Text()) {
addResult(input.Res, input.Rule, input.Rule.Name)
}
return nil
}
Loading