Skip to content

Commit a0ce020

Browse files
committed
[query] read hail-generated cloud credentails
1 parent 43f1be9 commit a0ce020

File tree

12 files changed

+153
-96
lines changed

12 files changed

+153
-96
lines changed

hail/hail/src/is/hail/backend/driver/BackendRpc.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@ import is.hail.io.plink.LoadPlink
99
import is.hail.io.vcf.LoadVCF
1010
import is.hail.types.virtual.{Kind, TFloat64}
1111
import is.hail.types.virtual.Kinds._
12-
import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer, FastSeq}
12+
import is.hail.utils.{jsonToBytes, using, BoxedArrayBuilder, ExecutionTimer, FastSeq}
1313
import is.hail.utils.ExecutionTimer.Timings
1414
import is.hail.variant.ReferenceGenome
1515

1616
import scala.util.control.NonFatal
1717

1818
import java.io.ByteArrayOutputStream
19-
import java.nio.charset.StandardCharsets
2019

2120
import org.json4s.{DefaultFormats, Extraction, Formats, JArray, JValue}
22-
import org.json4s.jackson.JsonMethods
2321

2422
case class SerializedIRFunction(
2523
name: String,
@@ -146,9 +144,6 @@ trait BackendRpc {
146144
case NonFatal(error) => R.failure(env, error)
147145
}
148146

149-
def jsonToBytes(v: JValue): Array[Byte] =
150-
JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8)
151-
152147
private[this] def withRegisterSerializedFns[A](
153148
ctx: ExecuteContext,
154149
serializedFns: Array[SerializedIRFunction],

hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import is.hail.expr.ir.lowering.IrMetadata
99
import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS}
1010
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1111
import is.hail.services._
12+
import is.hail.services.oauth2.CloudCredentials
1213
import is.hail.types.virtual.Kinds._
1314
import is.hail.utils._
1415
import is.hail.utils.ExecutionTimer.Timings
@@ -187,7 +188,7 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
187188
name,
188189
BatchClient(
189190
DeployConfig.fromConfigFile("/deploy-config/deploy-config.json"),
190-
Path.of(scratchDir, "secrets/gsa-key/key.json"),
191+
CloudCredentials(Some(Path.of(scratchDir, "secrets/gsa-key/key.json"))),
191192
),
192193
JarUrl(jarLocation),
193194
BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")),

hail/hail/src/is/hail/io/fs/AzureStorageFS.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import is.hail.shadedazure.com.azure.storage.blob.models.{
1212
BlobItem, BlobRange, BlobStorageException, ListBlobsOptions,
1313
}
1414
import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient
15-
import is.hail.utils.FastSeq
1615

1716
import scala.collection.JavaConverters._
1817
import scala.collection.mutable
@@ -59,8 +58,8 @@ object AzureStorageFS {
5958
private val AZURE_HTTPS_URI_REGEX =
6059
"^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r
6160

62-
val RequiredOAuthScopes: IndexedSeq[String] =
63-
FastSeq("https://storage.azure.com/.default")
61+
val RequiredOAuthScopes: Array[String] =
62+
Array("https://storage.azure.com/.default")
6463

6564
def parseUrl(filename: String): AzureStorageFSURL = {
6665
AZURE_HTTPS_URI_REGEX

hail/hail/src/is/hail/io/fs/GoogleStorageFS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ object GoogleStorageFS {
4444

4545
private[this] val GCS_URI_REGEX = "^gs:\\/\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r
4646

47-
val RequiredOAuthScopes: IndexedSeq[String] =
48-
FastSeq("https://www.googleapis.com/auth/devstorage.read_write")
47+
val RequiredOAuthScopes: Array[String] =
48+
Array("https://www.googleapis.com/auth/devstorage.read_write")
4949

5050
def parseUrl(filename: String): GoogleStorageFSURL = {
5151
val scheme = filename.split(":")(0)

hail/hail/src/is/hail/io/fs/RouterFS.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,23 @@ object RouterFS {
5252
def buildRoutes(cloudConfig: CloudStorageFSConfig, env: Map[String, String] = sys.env): FS =
5353
new RouterFS(
5454
IndexedSeq.concat(
55-
cloudConfig.google.map { case GoogleStorageFSConfig(path, mRPConfig) =>
55+
cloudConfig.google.map { case GoogleStorageFSConfig(path, rpConfig) =>
5656
new GoogleStorageFS(
57-
GoogleCloudCredentials(path, GoogleStorageFS.RequiredOAuthScopes, env),
58-
mRPConfig,
57+
GoogleCloudCredentials(path).scoped(GoogleStorageFS.RequiredOAuthScopes),
58+
rpConfig,
5959
)
6060
},
6161
cloudConfig.azure.map { case AzureStorageFSConfig(path) =>
62-
if (env.contains("HAIL_TERRA")) {
63-
val creds = AzureCloudCredentials(path, TerraAzureStorageFS.RequiredOAuthScopes, env)
64-
new TerraAzureStorageFS(creds)
65-
} else
66-
new AzureStorageFS(AzureCloudCredentials(path, AzureStorageFS.RequiredOAuthScopes, env))
62+
if (env.contains("HAIL_TERRA"))
63+
new TerraAzureStorageFS(
64+
AzureCloudCredentials(path, env)
65+
.scoped(TerraAzureStorageFS.RequiredOAuthScopes)
66+
)
67+
else
68+
new AzureStorageFS(
69+
AzureCloudCredentials(path, env)
70+
.scoped(AzureStorageFS.RequiredOAuthScopes)
71+
)
6772
},
6873
FastSeq(new HadoopFS(new SerializableHadoopConfiguration(new Configuration()))),
6974
)

hail/hail/src/is/hail/io/fs/TerraAzureStorageFS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ import org.json4s.jackson.JsonMethods
1616
object TerraAzureStorageFS {
1717
private val TEN_MINUTES_IN_MS = 10 * 60 * 1000
1818

19-
val RequiredOAuthScopes: IndexedSeq[String] =
20-
FastSeq("https://management.azure.com/.default")
19+
val RequiredOAuthScopes: Array[String] =
20+
Array("https://management.azure.com/.default")
2121
}
2222

2323
class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorageFS(credential) {

hail/hail/src/is/hail/services/BatchClient.scala

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import scala.util.Random
1515

1616
import java.net.{URL, URLEncoder}
1717
import java.nio.charset.StandardCharsets.UTF_8
18-
import java.nio.file.Path
1918

2019
import org.apache.http.entity.ByteArrayEntity
2120
import org.apache.http.entity.ContentType.APPLICATION_JSON
@@ -116,31 +115,38 @@ case class JobListEntry(
116115
)
117116

118117
object BatchClient {
118+
object RequiredOAuth2Scopes {
119+
private[this] val Google: Array[String] =
120+
Array(
121+
"https://www.googleapis.com/auth/userinfo.profile",
122+
"https://www.googleapis.com/auth/userinfo.email",
123+
"openid",
124+
)
125+
126+
private[this] val Microsoft: Array[String] =
127+
Array(".default")
128+
129+
def apply(env: Map[String, String] = sys.env): Array[String] =
130+
env.get("HAIL_CLOUD") match {
131+
case None | Some("gcp") => Google
132+
case Some("azure") => env.get("HAIL_AZURE_OAUTH_SCOPE").map(Array(_)).getOrElse(Microsoft)
133+
case None => throw new IllegalArgumentException(s"'HAIL_CLOUD' must be set.")
134+
}
135+
}
119136

120137
val BunchMaxSizeBytes: Int = 1024 * 1024
121138

122-
def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env)
123-
: BatchClient =
124-
new BatchClient(Requester(
125-
new URL(deployConfig.baseUrl("batch")),
126-
CloudCredentials(credentialsFile, BatchServiceScopes(env), env),
127-
))
128-
129-
def BatchServiceScopes(env: Map[String, String]): Array[String] =
130-
env.get("HAIL_CLOUD") match {
131-
case Some("gcp") =>
132-
Array(
133-
"https://www.googleapis.com/auth/userinfo.profile",
134-
"https://www.googleapis.com/auth/userinfo.email",
135-
"openid",
136-
)
137-
case Some("azure") =>
138-
env.get("HAIL_AZURE_OAUTH_SCOPE").toArray
139-
case Some(cloud) =>
140-
throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.")
141-
case None =>
142-
throw new IllegalArgumentException(s"HAIL_CLOUD must be set.")
143-
}
139+
def apply(
140+
deployConfig: DeployConfig,
141+
credentials: CloudCredentials,
142+
env: Map[String, String] = sys.env,
143+
): BatchClient =
144+
new BatchClient(
145+
Requester(
146+
new URL(deployConfig.baseUrl("batch")),
147+
credentials.scoped(RequiredOAuth2Scopes(env)),
148+
)
149+
)
144150

145151
object JobProcessRequestSerializer extends CustomSerializer[JobProcess](implicit fmts =>
146152
(

hail/hail/src/is/hail/services/oauth2.scala

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,92 @@ package is.hail.services
22

33
import is.hail.services.oauth2.AzureCloudCredentials.AzureTokenRefreshMinutes
44
import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials
5-
import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials
65
import is.hail.shadedazure.com.azure.core.credential.{
76
AccessToken, TokenCredential, TokenRequestContext,
87
}
98
import is.hail.shadedazure.com.azure.identity.{
109
ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder,
1110
}
12-
import is.hail.utils.{defaultJSONFormats, using}
11+
import is.hail.utils.{jsonToBytes, using}
1312

1413
import scala.collection.JavaConverters._
1514

16-
import java.io.Serializable
15+
import java.io.{ByteArrayInputStream, Serializable}
1716
import java.nio.file.{Files, Path}
1817
import java.time.OffsetDateTime
1918

20-
import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials}
21-
import org.json4s.Formats
19+
import com.google.auth.oauth2.GoogleCredentials
20+
import org.json4s.{DefaultFormats, Formats, JValue}
2221
import org.json4s.jackson.JsonMethods
2322

2423
object oauth2 {
2524

2625
sealed trait CloudCredentials extends Product with Serializable {
2726
def accessToken: String
27+
def scoped(scopes: Array[String]): CloudCredentials
2828
}
2929

30-
def CloudCredentials(
31-
keyPath: Path,
32-
scopes: IndexedSeq[String],
33-
env: Map[String, String] = sys.env,
34-
): CloudCredentials =
30+
implicit lazy val fmts: Formats = DefaultFormats
31+
32+
def HailCredentials(env: Map[String, String] = sys.env): Option[CloudCredentials] =
33+
for {
34+
config <-
35+
env
36+
.get("XDG_CONFIG_HOME")
37+
.map(Path.of(_))
38+
.orElse(env.get("HOME").map(Path.of(_, ".config")))
39+
40+
identity = config.resolve("hail/identity.json").toFile
41+
if identity.exists()
42+
43+
jvalue <- JsonMethods.parseOpt(identity)
44+
} yield (jvalue \ "idp").extract[String] match {
45+
case "Google" => GoogleCloudCredentials.fromJson(jvalue \ "credentials")
46+
case "Microsoft" => AzureCloudCredentials.fromJson(jvalue \ "credentials")
47+
case other => throw new IllegalArgumentException(s"Unknown identity provider: '$other'")
48+
}
49+
50+
def CloudCredentials(keyPath: Option[Path], env: Map[String, String] = sys.env)
51+
: CloudCredentials =
3552
env.get("HAIL_CLOUD") match {
36-
case Some("gcp") => GoogleCloudCredentials(Some(keyPath), scopes, env)
37-
case Some("azure") => AzureCloudCredentials(Some(keyPath), scopes, env)
53+
case None | Some("gcp") => GoogleCloudCredentials(keyPath)
54+
case Some("azure") => AzureCloudCredentials(keyPath, env)
3855
case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'")
39-
case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.")
4056
}
4157

4258
case class GoogleCloudCredentials(value: GoogleCredentials) extends CloudCredentials {
4359
override def accessToken: String = {
4460
value.refreshIfExpired()
4561
value.getAccessToken.getTokenValue
4662
}
63+
64+
override def scoped(scopes: Array[String]): GoogleCloudCredentials =
65+
GoogleCloudCredentials(value.createScoped(scopes: _*))
4766
}
4867

4968
object GoogleCloudCredentials {
50-
object EnvVars {
51-
val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS"
52-
}
53-
54-
def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env)
55-
: GoogleCloudCredentials =
69+
def fromJson(jv: JValue): GoogleCloudCredentials =
5670
GoogleCloudCredentials {
57-
val creds: GoogleCredentials =
58-
keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match {
59-
case Some(path) =>
60-
using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream)
61-
case None =>
62-
GoogleCredentials.getApplicationDefault
63-
}
71+
GoogleCredentials.fromStream(
72+
new ByteArrayInputStream(jsonToBytes(jv))
73+
)
74+
}
6475

65-
creds.createScoped(scopes: _*)
76+
def apply(keyPath: Option[Path]): GoogleCloudCredentials =
77+
GoogleCloudCredentials {
78+
keyPath match {
79+
case Some(path) =>
80+
using(Files.newInputStream(path))(GoogleCredentials.fromStream)
81+
case None =>
82+
GoogleCredentials.getApplicationDefault
83+
}
6684
}
6785
}
6886

6987
sealed trait AzureCloudCredentials extends CloudCredentials {
7088

7189
def value: TokenCredential
72-
def scopes: IndexedSeq[String]
90+
def scopes: Array[String]
7391

7492
@transient private[this] var token: AccessToken = _
7593

@@ -78,11 +96,13 @@ object oauth2 {
7896
token.getToken
7997
}
8098

99+
override def scoped(scopes: Array[String]): AzureCloudCredentials
100+
81101
private[this] def refreshIfRequired(): Unit =
82102
if (!isExpired) token.getToken
83103
else synchronized {
84104
if (isExpired) {
85-
token = value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava))
105+
token = value.getTokenSync(new TokenRequestContext().setScopes(scopes.toSeq.asJava))
86106
}
87107

88108
token.getToken: Unit
@@ -99,33 +119,52 @@ object oauth2 {
99119
val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS"
100120
}
101121

122+
val DefaultOAuth2Scopes: Array[String] =
123+
Array(".default")
124+
102125
private[AzureCloudCredentials] val AzureTokenRefreshMinutes = 5
103126

104-
def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env)
105-
: AzureCloudCredentials =
127+
def fromJson(jv: JValue, scopes: Array[String] = DefaultOAuth2Scopes): AzureCloudCredentials =
128+
AzureClientSecretCredentials(
129+
clientId = (jv \ "appId").extract[String],
130+
tenantId = (jv \ "tenant").extract[String],
131+
secret = (jv \ "password").extract[String],
132+
scopes = scopes,
133+
)
134+
135+
def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): AzureCloudCredentials =
106136
keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match {
107-
case Some(path) => AzureClientSecretCredentials(path, scopes)
108-
case None => AzureDefaultCredentials(scopes)
137+
case Some(path) =>
138+
using(Files.newInputStream(path)) { in =>
139+
fromJson(JsonMethods.parse(in), DefaultOAuth2Scopes)
140+
}
141+
case None =>
142+
AzureDefaultCredentials(DefaultOAuth2Scopes)
109143
}
110144
}
111145

112-
private case class AzureDefaultCredentials(scopes: IndexedSeq[String])
113-
extends AzureCloudCredentials {
146+
private case class AzureDefaultCredentials(scopes: Array[String]) extends AzureCloudCredentials {
114147
@transient override lazy val value: TokenCredential =
115148
new DefaultAzureCredentialBuilder().build()
149+
150+
override def scoped(scopes: Array[String]): AzureDefaultCredentials =
151+
copy(scopes)
116152
}
117153

118-
private case class AzureClientSecretCredentials(path: Path, scopes: IndexedSeq[String])
119-
extends AzureCloudCredentials {
154+
private case class AzureClientSecretCredentials(
155+
clientId: String,
156+
tenantId: String,
157+
secret: String,
158+
scopes: Array[String],
159+
) extends AzureCloudCredentials {
120160
@transient override lazy val value: TokenCredential =
121-
using(Files.newInputStream(path)) { is =>
122-
implicit val fmts: Formats = defaultJSONFormats
123-
val kvs = JsonMethods.parse(is)
124-
new ClientSecretCredentialBuilder()
125-
.clientId((kvs \ "appId").extract[String])
126-
.clientSecret((kvs \ "password").extract[String])
127-
.tenantId((kvs \ "tenant").extract[String])
128-
.build()
129-
}
161+
new ClientSecretCredentialBuilder()
162+
.clientId(clientId)
163+
.clientSecret(secret)
164+
.tenantId(tenantId)
165+
.build()
166+
167+
override def scoped(scopes: Array[String]): AzureClientSecretCredentials =
168+
copy(scopes = scopes)
130169
}
131170
}

0 commit comments

Comments
 (0)