Skip to content

Commit 69f93b9

Browse files
authored
Merge pull request #3161 from actiontech/fix/3140
fix: mismatching between pipeline permissions and actual obtained values
2 parents 04fd0c3 + 1f1d64a commit 69f93b9

File tree

3 files changed

+110
-29
lines changed

3 files changed

+110
-29
lines changed

sqle/api/controller/v1/pipeline.go

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package v1
33
import (
44
"context"
55
"fmt"
6-
v1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
7-
"github.com/actiontech/sqle/sqle/errors"
86
"net/http"
97
"strconv"
108

9+
"github.com/actiontech/sqle/sqle/errors"
10+
1111
"github.com/actiontech/sqle/sqle/api/controller"
1212
"github.com/actiontech/sqle/sqle/dms"
1313
"github.com/actiontech/sqle/sqle/server/pipeline"
@@ -234,19 +234,10 @@ func GetPipelines(c echo.Context) error {
234234
if err != nil {
235235
return errors.New(errors.ConnectStorageError, fmt.Errorf("check get pipelines failed: %v", err))
236236
}
237-
userId := ""
238-
if !userPermission.CanViewProject() {
239-
userId = user.GetIDStr()
240-
}
241-
rangeDatasourceIds := make([]string, 0)
242-
viewPipelinePermission := userPermission.GetOnePermission(v1.OpPermissionViewPipeline)
243-
if viewPipelinePermission != nil {
244-
userId = ""
245-
rangeDatasourceIds = viewPipelinePermission.RangeUids
246-
}
247-
// 4. 获取存储对象并查询流水线列表
237+
238+
// 3. 获取存储对象并查询流水线列表
248239
var pipelineSvc pipeline.PipelineSvc
249-
count, pipelineList, err := pipelineSvc.GetPipelineList(limit, offset, req.FuzzySearchNameDesc, projectUid, userId, rangeDatasourceIds)
240+
count, pipelineList, err := pipelineSvc.GetPipelineListWithPermission(limit, offset, req.FuzzySearchNameDesc, projectUid, userPermission, user.GetIDStr())
250241
if err != nil {
251242
return controller.JSONBaseErrorReq(c, err)
252243
}

sqle/model/pipline.go

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,50 @@ func isValidAuditMethod(a string) bool {
101101
return false
102102
}
103103

104-
func (s *Storage) GetPipelineList(projectID ProjectUID, fuzzySearchContent string, limit, offset uint32, userId string, rangeDatasourceIds []string) ([]*Pipeline, uint64, error) {
104+
func (s *Storage) GetPipelineList(projectID ProjectUID, fuzzySearchContent string, limit, offset uint32, userId string, rangeDatasourceIds []string, canViewAll bool) ([]*Pipeline, uint64, error) {
105105
var count int64
106106
var pipelines []*Pipeline
107107
query := s.db.Model(&Pipeline{}).Where("project_uid = ?", projectID)
108-
if userId != "" {
109-
query = query.Where("create_user_id = ? OR create_user_id IS NULL", userId)
110-
}
108+
109+
// 1. 模糊搜索
111110
if fuzzySearchContent != "" {
112111
query = query.Where("name LIKE ? OR description LIKE ?", "%"+fuzzySearchContent+"%", "%"+fuzzySearchContent+"%")
113112
}
114-
if len(rangeDatasourceIds) > 0 {
115-
query = query.Joins("JOIN pipeline_nodes ON pipelines.id = pipeline_nodes.pipeline_id").
116-
Where("pipeline_nodes.instance_id IN (?)", rangeDatasourceIds).
117-
Group("pipelines.id")
113+
114+
// 2. 权限过滤
115+
if !canViewAll {
116+
if len(rangeDatasourceIds) > 0 {
117+
// 有数据源权限的用户可以看到:
118+
// 1. 包含权限范围内数据源的流水线(通过LEFT JOIN匹配)
119+
// 2. 自己创建的所有流水线
120+
// 3. 所有节点都是离线节点的流水线(通过NOT EXISTS检查)
121+
query = query.
122+
Joins("LEFT JOIN pipeline_nodes ON pipelines.id = pipeline_nodes.pipeline_id").
123+
Where(`
124+
pipeline_nodes.instance_id IN (?) OR
125+
pipelines.create_user_id = ? OR
126+
NOT EXISTS (
127+
SELECT 1 FROM pipeline_nodes pn2
128+
WHERE pn2.pipeline_id = pipelines.id
129+
AND pn2.instance_id != 0
130+
)`, rangeDatasourceIds, userId).
131+
Group("pipelines.id") // 去重,因为LEFT JOIN可能产生重复记录
132+
} else if userId != "" {
133+
// 普通用户只能看到:
134+
// 1. 自己创建的流水线
135+
// 2. 所有节点都是离线节点的流水线
136+
query = query.Where(`
137+
create_user_id = ? OR
138+
NOT EXISTS (
139+
SELECT 1 FROM pipeline_nodes pn
140+
WHERE pn.pipeline_id = pipelines.id
141+
AND pn.instance_id != 0
142+
)`, userId)
143+
}
118144
}
145+
// canViewAll = true 时不添加任何过滤条件
119146

147+
// 3. 统计和分页查询
120148
err := query.Count(&count).Error
121149
if err != nil {
122150
return pipelines, uint64(count), errors.New(errors.ConnectStorageError, err)
@@ -169,6 +197,27 @@ func (s *Storage) GetPipelineNodesByInstanceId(instanceID uint64) ([]*PipelineNo
169197
return nodes, nil
170198
}
171199

200+
// GetPipelineNodesInBatch 批量获取多个流水线的节点
201+
func (s *Storage) GetPipelineNodesInBatch(pipelineIDs []uint) (map[uint][]*PipelineNode, error) {
202+
if len(pipelineIDs) == 0 {
203+
return make(map[uint][]*PipelineNode), nil
204+
}
205+
206+
var nodes []*PipelineNode
207+
err := s.db.Model(PipelineNode{}).Where("pipeline_id IN (?)", pipelineIDs).Find(&nodes).Error
208+
if err != nil {
209+
return nil, errors.New(errors.ConnectStorageError, err)
210+
}
211+
212+
// 按pipeline_id分组
213+
nodeMap := make(map[uint][]*PipelineNode)
214+
for _, node := range nodes {
215+
nodeMap[node.PipelineID] = append(nodeMap[node.PipelineID], node)
216+
}
217+
218+
return nodeMap, nil
219+
}
220+
172221
func (s *Storage) CreatePipeline(pipeline *Pipeline, nodes []*PipelineNode) error {
173222
return s.Tx(func(txDB *gorm.DB) error {
174223
// 4.1 保存 Pipeline 到数据库

sqle/server/pipeline/pipeline.go

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/actiontech/sqle/sqle/errors"
1111

12+
v1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
1213
dmsCommonJwt "github.com/actiontech/dms/pkg/dms-common/api/jwt"
1314
"github.com/actiontech/sqle/sqle/api/controller"
1415
scannerCmd "github.com/actiontech/sqle/sqle/cmd/scannerd/command"
@@ -235,19 +236,59 @@ func (svc PipelineSvc) GetPipeline(projectUID string, pipelineID uint) (*Pipelin
235236
return svc.toPipeline(modelPipeline, modelPiplineNodes), nil
236237
}
237238

238-
func (svc PipelineSvc) GetPipelineList(limit, offset uint32, fuzzySearchNameDesc string, projectUID string, userId string, rangeDatasourceIds []string) (count uint64, pipelines []*Pipeline, err error) {
239+
// GetPipelineListWithPermission 根据用户权限获取流水线列表
240+
func (svc PipelineSvc) GetPipelineListWithPermission(limit, offset uint32, fuzzySearchNameDesc string, projectUID string, userPermission *dms.UserPermission, userId string) (count uint64, pipelines []*Pipeline, err error) {
239241
s := model.GetStorage()
240-
modelPipelines, count, err := s.GetPipelineList(model.ProjectUID(projectUID), fuzzySearchNameDesc, limit, offset, userId, rangeDatasourceIds)
242+
243+
// 根据用户权限确定查询参数
244+
var queryUserId string
245+
var rangeDatasourceIds []string
246+
var canViewAll bool
247+
248+
// 权限判断逻辑
249+
if userPermission.IsAdmin() || userPermission.IsProjectAdmin() {
250+
// 超级管理员或项目管理员:可以查看所有流水线
251+
canViewAll = true
252+
} else if viewPipelinePermission := userPermission.GetOnePermission(v1.OpPermissionViewPipeline); viewPipelinePermission != nil {
253+
// 拥有"查看流水线"权限的普通用户:可以查看指定数据源相关的流水线 + 自己创建的所有流水线
254+
queryUserId = userId
255+
rangeDatasourceIds = viewPipelinePermission.RangeUids
256+
canViewAll = false
257+
} else {
258+
// 普通用户:只能查看自己创建的流水线
259+
queryUserId = userId
260+
rangeDatasourceIds = nil
261+
canViewAll = false
262+
}
263+
264+
// 执行数据库查询
265+
modelPipelines, count, err := s.GetPipelineList(model.ProjectUID(projectUID), fuzzySearchNameDesc, limit, offset, queryUserId, rangeDatasourceIds, canViewAll)
241266
if err != nil {
242267
return 0, nil, err
243268
}
269+
270+
// 转换为服务层对象
244271
pipelines = make([]*Pipeline, 0, len(modelPipelines))
272+
if len(modelPipelines) == 0 {
273+
return count, pipelines, nil
274+
}
275+
276+
// 收集所有pipeline ID
277+
pipelineIDs := make([]uint, 0, len(modelPipelines))
278+
for _, mp := range modelPipelines {
279+
pipelineIDs = append(pipelineIDs, mp.ID)
280+
}
281+
282+
// 批量获取所有节点
283+
nodesMap, err := s.GetPipelineNodesInBatch(pipelineIDs)
284+
if err != nil {
285+
return 0, nil, err
286+
}
287+
288+
// 组装结果
245289
for _, modelPipeline := range modelPipelines {
246-
modelPiplineNodes, err := s.GetPipelineNodes(modelPipeline.ID)
247-
if err != nil {
248-
return 0, nil, err
249-
}
250-
pipelines = append(pipelines, svc.toPipeline(modelPipeline, modelPiplineNodes))
290+
nodes := nodesMap[modelPipeline.ID]
291+
pipelines = append(pipelines, svc.toPipeline(modelPipeline, nodes))
251292
}
252293
return count, pipelines, nil
253294
}

0 commit comments

Comments
 (0)