Skip to content

Commit a70aa24

Browse files
committed
credentials are scoped
1 parent 089acdb commit a70aa24

File tree

10 files changed

+115
-66
lines changed

10 files changed

+115
-66
lines changed

hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import is.hail.shadedazure.com.azure.storage.blob.models.{
1212
BlobItem, BlobRange, BlobStorageException, ListBlobsOptions,
1313
}
1414
import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient
15+
import is.hail.utils.FastSeq
1516

1617
import scala.collection.JavaConverters._
1718
import scala.collection.mutable
@@ -58,6 +59,9 @@ object AzureStorageFS {
5859
private val AZURE_HTTPS_URI_REGEX =
5960
"^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r
6061

62+
val RequiredOAuthScopes: IndexedSeq[String] =
63+
FastSeq("https://storage.azure.com/.default")
64+
6165
def parseUrl(filename: String): AzureStorageFSURL = {
6266
AZURE_HTTPS_URI_REGEX
6367
.findFirstMatchIn(filename)

hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ object GoogleStorageFS {
4444

4545
private[this] val GCS_URI_REGEX = "^gs:\\/\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r
4646

47+
val RequiredOAuthScopes: IndexedSeq[String] =
48+
FastSeq("https://www.googleapis.com/auth/devstorage.read_write")
49+
4750
def parseUrl(filename: String): GoogleStorageFSURL = {
4851
val scheme = filename.split(":")(0)
4952
if (scheme == null || scheme != "gs") {

hail/src/main/scala/is/hail/io/fs/RouterFS.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,18 @@ object RouterFS {
5252
def buildRoutes(cloudConfig: CloudStorageFSConfig, env: Map[String, String] = sys.env): FS =
5353
new RouterFS(
5454
IndexedSeq.concat(
55-
cloudConfig.google.map { case GoogleStorageFSConfig(path, maybeRequesterPaysConfig) =>
56-
new GoogleStorageFS(GoogleCloudCredentials(path), maybeRequesterPaysConfig)
55+
cloudConfig.google.map { case GoogleStorageFSConfig(path, mRPConfig) =>
56+
new GoogleStorageFS(
57+
GoogleCloudCredentials(path, GoogleStorageFS.RequiredOAuthScopes, env),
58+
mRPConfig,
59+
)
5760
},
5861
cloudConfig.azure.map { case AzureStorageFSConfig(path) =>
59-
val cred = AzureCloudCredentials(path)
60-
if (env.contains("HAIL_TERRA")) new TerraAzureStorageFS(cred)
61-
else new AzureStorageFS(cred)
62+
if (env.contains("HAIL_TERRA")) {
63+
val creds = AzureCloudCredentials(path, TerraAzureStorageFS.RequiredOAuthScopes, env)
64+
new TerraAzureStorageFS(creds)
65+
} else
66+
new AzureStorageFS(AzureCloudCredentials(path, AzureStorageFS.RequiredOAuthScopes, env))
6267
},
6368
FastSeq(new HadoopFS(new SerializableHadoopConfiguration(new Configuration()))),
6469
)

hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ import org.json4s.jackson.JsonMethods
1515

1616
object TerraAzureStorageFS {
1717
private val TEN_MINUTES_IN_MS = 10 * 60 * 1000
18+
19+
val RequiredOAuthScopes: IndexedSeq[String] =
20+
FastSeq("https://management.azure.com/.default")
1821
}
1922

2023
class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorageFS(credential) {
@@ -55,8 +58,7 @@ class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorag
5558
val url =
5659
s"$workspaceManagerUrl/api/workspaces/v1/$workspaceId/resources/controlled/azure/storageContainer/$containerResourceId/getSasToken"
5760
val req = new HttpPost(url)
58-
val token = credential.accessToken(FastSeq("https://management.azure.com/.default"))
59-
req.addHeader("Authorization", s"Bearer $token")
61+
req.addHeader("Authorization", s"Bearer ${credential.accessToken}")
6062

6163
val tenHoursInSeconds = 10 * 3600
6264
val expiration = System.currentTimeMillis() + tenHoursInSeconds * 1000

hail/src/main/scala/is/hail/services/BatchClient.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package is.hail.services
22

33
import is.hail.expr.ir.ByteArrayBuilder
4-
import is.hail.services.requests.{BatchServiceRequester, Requester}
4+
import is.hail.services.oauth2.CloudCredentials
5+
import is.hail.services.requests.Requester
56
import is.hail.utils._
67

78
import scala.util.Random
89

10+
import java.net.URL
911
import java.nio.charset.StandardCharsets
1012
import java.nio.file.Path
1113

@@ -88,9 +90,29 @@ object JobGroupStates {
8890
}
8991

9092
object BatchClient {
93+
94+
private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] =
95+
env.get("HAIL_CLOUD") match {
96+
case Some("gcp") =>
97+
Array(
98+
"https://www.googleapis.com/auth/userinfo.profile",
99+
"https://www.googleapis.com/auth/userinfo.email",
100+
"openid",
101+
)
102+
case Some("azure") =>
103+
env.get("HAIL_AZURE_OAUTH_SCOPE").toArray
104+
case Some(cloud) =>
105+
throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.")
106+
case None =>
107+
throw new IllegalArgumentException(s"HAIL_CLOUD must be set.")
108+
}
109+
91110
def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env)
92111
: BatchClient =
93-
new BatchClient(BatchServiceRequester(deployConfig, credentialsFile, env))
112+
new BatchClient(Requester(
113+
new URL(deployConfig.baseUrl("batch")),
114+
CloudCredentials(credentialsFile, BatchServiceScopes(env), env),
115+
))
94116
}
95117

96118
case class BatchClient private (req: Requester) extends Logging with AutoCloseable {

hail/src/main/scala/is/hail/services/oauth2.scala

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package is.hail.services
22

33
import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials
44
import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials
5-
import is.hail.shadedazure.com.azure.core.credential.{TokenCredential, TokenRequestContext}
5+
import is.hail.shadedazure.com.azure.core.credential.{
6+
AccessToken, TokenCredential, TokenRequestContext,
7+
}
68
import is.hail.shadedazure.com.azure.identity.{
79
ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder,
810
}
@@ -12,6 +14,7 @@ import scala.collection.JavaConverters._
1214

1315
import java.io.Serializable
1416
import java.nio.file.{Files, Path}
17+
import java.time.OffsetDateTime
1518

1619
import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials}
1720
import org.json4s.Formats
@@ -20,38 +23,25 @@ import org.json4s.jackson.JsonMethods
2023
object oauth2 {
2124

2225
sealed trait CloudCredentials extends Product with Serializable {
23-
def accessToken(scopes: IndexedSeq[String]): String
26+
def accessToken: String
2427
}
2528

26-
def CloudCredentials(credentialsPath: Path, env: Map[String, String] = sys.env)
27-
: CloudCredentials =
29+
def CloudCredentials(
30+
keyPath: Path,
31+
scopes: IndexedSeq[String],
32+
env: Map[String, String] = sys.env,
33+
): CloudCredentials =
2834
env.get("HAIL_CLOUD") match {
29-
case Some("gcp") => GoogleCloudCredentials(Some(credentialsPath))
30-
case Some("azure") => AzureCloudCredentials(Some(credentialsPath))
35+
case Some("gcp") => GoogleCloudCredentials(Some(keyPath), scopes, env)
36+
case Some("azure") => AzureCloudCredentials(Some(keyPath), scopes, env)
3137
case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'")
3238
case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.")
3339
}
3440

35-
def CloudScopes(env: Map[String, String] = sys.env): Array[String] =
36-
env.get("HAIL_CLOUD") match {
37-
case Some("gcp") =>
38-
Array(
39-
"https://www.googleapis.com/auth/userinfo.profile",
40-
"https://www.googleapis.com/auth/userinfo.email",
41-
"openid",
42-
)
43-
case Some("azure") =>
44-
sys.env.get("HAIL_AZURE_OAUTH_SCOPE").toArray
45-
case Some(cloud) =>
46-
throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.")
47-
case None =>
48-
throw new IllegalArgumentException(s"HAIL_CLOUD must be set.")
49-
}
50-
5141
case class GoogleCloudCredentials(value: GoogleCredentials) extends CloudCredentials {
52-
override def accessToken(scopes: IndexedSeq[String]): String = {
42+
override def accessToken: String = {
5343
value.refreshIfExpired()
54-
value.createScoped(scopes.asJava).getAccessToken.getTokenValue
44+
value.getAccessToken.getTokenValue
5545
}
5646
}
5747

@@ -60,40 +50,68 @@ object oauth2 {
6050
val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS"
6151
}
6252

63-
def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): GoogleCloudCredentials =
64-
GoogleCloudCredentials(
65-
keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match {
66-
case Some(path) => using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream)
67-
case None => GoogleCredentials.getApplicationDefault
68-
}
69-
)
53+
def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env)
54+
: GoogleCloudCredentials =
55+
GoogleCloudCredentials {
56+
val creds: GoogleCredentials =
57+
keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match {
58+
case Some(path) =>
59+
using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream)
60+
case None =>
61+
GoogleCredentials.getApplicationDefault
62+
}
63+
64+
creds.createScoped(scopes: _*)
65+
}
7066
}
7167

7268
sealed trait AzureCloudCredentials extends CloudCredentials {
69+
7370
def value: TokenCredential
71+
def scopes: IndexedSeq[String]
72+
73+
@transient private[this] var token: AccessToken = _
74+
75+
override def accessToken: String = {
76+
refreshIfRequired()
77+
token.getToken
78+
}
79+
80+
private[this] def refreshIfRequired(): Unit =
81+
if (!isExpired) token.getToken
82+
else synchronized {
83+
if (isExpired) {
84+
token = value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava))
85+
}
86+
87+
token.getToken
88+
}
7489

75-
override def accessToken(scopes: IndexedSeq[String]): String =
76-
value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava)).getToken
90+
private[this] def isExpired: Boolean =
91+
token == null || OffsetDateTime.now.plusHours(1).isBefore(token.getExpiresAt)
7792
}
7893

7994
object AzureCloudCredentials {
8095
object EnvVars {
8196
val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS"
8297
}
8398

84-
def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): AzureCloudCredentials =
99+
def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env)
100+
: AzureCloudCredentials =
85101
keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match {
86-
case Some(path) => AzureClientSecretCredentials(path)
87-
case None => AzureDefaultCredentials
102+
case Some(path) => AzureClientSecretCredentials(path, scopes)
103+
case None => AzureDefaultCredentials(scopes)
88104
}
89105
}
90106

91-
private case object AzureDefaultCredentials extends AzureCloudCredentials {
107+
private case class AzureDefaultCredentials(scopes: IndexedSeq[String])
108+
extends AzureCloudCredentials {
92109
@transient override lazy val value: TokenCredential =
93110
new DefaultAzureCredentialBuilder().build()
94111
}
95112

96-
private case class AzureClientSecretCredentials(path: Path) extends AzureCloudCredentials {
113+
private case class AzureClientSecretCredentials(path: Path, scopes: IndexedSeq[String])
114+
extends AzureCloudCredentials {
97115
@transient override lazy val value: TokenCredential =
98116
using(Files.newInputStream(path)) { is =>
99117
implicit val fmts: Formats = defaultJSONFormats

hail/src/main/scala/is/hail/services/requests.scala

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
package is.hail.services
22

3-
import is.hail.services.oauth2.{CloudCredentials, CloudScopes}
3+
import is.hail.services.oauth2.CloudCredentials
44
import is.hail.utils.{log, _}
55

66
import java.net.URL
7-
import java.nio.file.Path
87

98
import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest}
109
import org.apache.http.client.config.RequestConfig
@@ -30,15 +29,7 @@ object requests {
3029

3130
private[this] val TIMEOUT_MS = 5 * 1000
3231

33-
def BatchServiceRequester(conf: DeployConfig, keyFile: Path, env: Map[String, String] = sys.env)
34-
: Requester =
35-
Requester(
36-
new URL(conf.baseUrl("batch")),
37-
CloudCredentials(keyFile, env),
38-
CloudScopes(env),
39-
)
40-
41-
def Requester(baseUrl: URL, cred: CloudCredentials, scopes: IndexedSeq[String]): Requester = {
32+
def Requester(baseUrl: URL, cred: CloudCredentials): Requester = {
4233

4334
val httpClient: CloseableHttpClient = {
4435
log.info("creating HttpClient")
@@ -67,7 +58,7 @@ object requests {
6758

6859
def request(req: HttpUriRequest, body: Option[HttpEntity] = None): JValue = {
6960
log.info(s"request ${req.getMethod} ${req.getURI}")
70-
req.addHeader("Authorization", s"Bearer ${cred.accessToken(scopes)}")
61+
req.addHeader("Authorization", s"Bearer ${cred.accessToken}")
7162
body.foreach(entity => req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(entity))
7263
retryTransientErrors {
7364
using(httpClient.execute(req)) { resp =>

hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ class AzureStorageFSSuite extends FSSuite {
1616
}
1717
}
1818

19-
lazy val fs = new AzureStorageFS(AzureCloudCredentials(None))
19+
override lazy val fs: FS =
20+
new AzureStorageFS(AzureCloudCredentials(None, AzureStorageFS.RequiredOAuthScopes))
2021

2122
@Test def testMakeQualified(): Unit = {
2223
val qualifiedFileName = "https://account.blob.core.windows.net/container/path"

hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite {
1717
}
1818
}
1919

20-
lazy val fs = new GoogleStorageFS(GoogleCloudCredentials(None), None)
20+
override lazy val fs: FS =
21+
new GoogleStorageFS(GoogleCloudCredentials(None, GoogleStorageFS.RequiredOAuthScopes), None)
2122

2223
@Test def testMakeQualified(): Unit = {
2324
val qualifiedFileName = "gs://bucket/path"

hail/src/test/scala/is/hail/services/BatchClientSuite.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,31 @@ import java.nio.file.Path
66

77
import org.scalatestplus.testng.TestNGSuite
88
import org.testng.annotations.Test
9+
import sourcecode.FullName
910

1011
class BatchClientSuite extends TestNGSuite {
1112
@Test def testBasic(): Unit =
12-
using(BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json"))) { client =>
13+
using(BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json"))) { client =>
1314
val jobGroup = client.run(
1415
BatchRequest(
1516
billing_project = "test",
16-
n_jobs = 1,
17+
n_jobs = 0,
1718
token = tokenUrlSafe,
19+
attributes = Map("name" -> s"Test ${implicitly[FullName].value}"),
1820
),
1921
JobGroupRequest(
20-
job_group_id = 0,
22+
job_group_id = 1,
2123
absolute_parent_id = 0,
2224
),
2325
FastSeq(
2426
JobRequest(
25-
job_id = 0,
27+
job_id = 1,
2628
always_run = false,
2729
in_update_job_group_id = 0,
2830
in_update_parent_ids = Array(),
2931
process = BashJob(
3032
image = "ubuntu:22.04",
31-
command = Array("/bin/bash", "-c", "'hello, world!"),
33+
command = Array("/bin/bash", "-c", "echo 'hello, hail!'"),
3234
),
3335
)
3436
),

0 commit comments

Comments
 (0)