Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extensions/spark/kyuubi-spark-authz/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AccessResource private (val objectType: ObjectType, val catalog: Option[St
extends RangerAccessResourceImpl {
implicit def asString(obj: Object): String = if (obj != null) obj.asInstanceOf[String] else null
def getDatabase: String = getValue("database")
def getUdf: String = getValue("udf")
def getTable: String = getValue("table")
def getColumn: String = getValue("column")
def getColumns: Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class RangerSparkExtension extends (SparkSessionExtensions => Unit) {

override def apply(v1: SparkSessionExtensions): Unit = {
v1.injectCheckRule(AuthzConfigurationChecker)
v1.injectCheckRule(RuleFunctionAuthorization)
v1.injectResolutionRule(_ => RuleReplaceShowObjectCommands)
v1.injectResolutionRule(_ => RuleApplyPermanentViewMarker)
v1.injectResolutionRule(_ => RuleApplyTypeOfMarker)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.plugin.spark.authz.ranger

import scala.collection.mutable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import org.apache.kyuubi.plugin.spark.authz._
import org.apache.kyuubi.plugin.spark.authz.ranger.AccessType.AccessType
import org.apache.kyuubi.plugin.spark.authz.ranger.SparkRangerAdminPlugin._
import org.apache.kyuubi.plugin.spark.authz.util.AuthZUtils._

case class RuleFunctionAuthorization(spark: SparkSession) extends (LogicalPlan => Unit) {
override def apply(plan: LogicalPlan): Unit = {
val auditHandler = new SparkRangerAuditHandler
val ugi = getAuthzUgi(spark.sparkContext)
val (inputs, _, opType) = PrivilegesBuilder.buildFunctions(plan, spark)

// Use a HashSet to deduplicate the same AccessResource and AccessType, the requests will be all
// the non-duplicate requests and in the same order as the input requests.
val requests = new mutable.ArrayBuffer[AccessRequest]()
val requestsSet = new mutable.HashSet[(AccessResource, AccessType)]()

def addAccessRequest(objects: Iterable[PrivilegeObject], isInput: Boolean): Unit = {
objects.foreach { obj =>
val resource = AccessResource(obj, opType)
val accessType = ranger.AccessType(obj, opType, isInput)
if (accessType != AccessType.NONE && !requestsSet.contains((resource, accessType))) {
requests += AccessRequest(resource, ugi, opType, accessType)
requestsSet.add(resource, accessType)
}
}
}

addAccessRequest(inputs, isInput = true)
checkPrivileges(requests, auditHandler)
}

def checkPrivileges(
requests: mutable.ArrayBuffer[AccessRequest],
auditHandler: SparkRangerAuditHandler): Unit = {
val requestArrays = requests.map(Seq(_))
if (authorizeInSingleCall) {
verify(requestArrays.flatten, auditHandler)
} else {
requestArrays.flatten.foreach { req =>
verify(Seq(req), auditHandler)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

package org.apache.kyuubi.plugin.spark.authz

import scala.collection.mutable

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
// scalastyle:off
import org.scalatest.funsuite.AnyFunSuite

import org.apache.kyuubi.plugin.spark.authz.OperationType.QUERY
import org.apache.kyuubi.plugin.spark.authz.ranger.AccessType
import org.apache.kyuubi.plugin.spark.authz.ranger.{AccessRequest, AccessResource, AccessType, RuleFunctionAuthorization, SparkRangerAuditHandler}

abstract class FunctionPrivilegesBuilderSuite extends AnyFunSuite
with SparkSessionProvider with BeforeAndAfterAll with BeforeAndAfterEach {
Expand Down Expand Up @@ -193,4 +199,56 @@ class HiveFunctionPrivilegesBuilderSuite extends FunctionPrivilegesBuilderSuite
}
}

test("Built in and UDF Function Call Query") {
val plan = sql(s"SELECT kyuubi_fun_0('TESTSTRING'), " +
s"kyuubi_fun_0(value)," +
s"abs(key)," +
s"abs(-100)," +
s"lower(value)," +
s"lower('TESTSTRING') " +
s"FROM $reusedTable").queryExecution.analyzed
val (inputs, _, _) = PrivilegesBuilder.buildFunctions(plan, spark)
assert(inputs.size === 2)
inputs.foreach { po =>
assert(po.actionType === PrivilegeObjectActionType.OTHER)
assert(po.privilegeObjectType === PrivilegeObjectType.FUNCTION)
assert(po.dbname startsWith reusedDb.toLowerCase)
assert(po.objectName startsWith functionNamePrefix.toLowerCase)
val accessType = ranger.AccessType(po, QUERY, isInput = true)
assert(accessType === AccessType.SELECT)
}
}

test("[KYUUBI #7186] Introduce RuleFunctionAuthorization") {

val ruleFunc = Mockito.spy[RuleFunctionAuthorization](RuleFunctionAuthorization(spark))
Mockito.doAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val requests = invocation.getArgument[mutable.ArrayBuffer[AccessRequest]](0)
requests.foreach { request =>
// deny udf `reusedDb.kyuubi_fun_0`
var database: String = request.getResource.asInstanceOf[AccessResource].getDatabase
var udf: String = request.getResource.asInstanceOf[AccessResource].getUdf
if (database.equalsIgnoreCase(reusedDb) && udf.equalsIgnoreCase("kyuubi_fun_0")) {
throw new AccessControlException("Access denied")
}
}
}
}).when(ruleFunc).checkPrivileges(
any[mutable.ArrayBuffer[AccessRequest]](),
any[SparkRangerAuditHandler]())

val query1 = sql(s"SELECT " +
s"${reusedDb}.kyuubi_fun_0('KYUUBI_STRING')," +
s"${reusedDb}.kyuubi_fun_1('KYUUBI_STRING') ").queryExecution.analyzed
intercept[AccessControlException] { ruleFunc.apply(query1) }

val query2 = sql(s"SELECT " +
s"${reusedDb}.kyuubi_fun_0('KYUUBI_STRING')").queryExecution.analyzed
intercept[AccessControlException] { ruleFunc.apply(query2) }

val query3 = sql(s"SELECT " +
s"${reusedDb}.kyuubi_fun_1('KYUUBI_STRING')").queryExecution.analyzed
ruleFunc.apply(query3)
}
}
Loading