@@ -2,7 +2,9 @@ package is.hail.services
22
33import is .hail .services .oauth2 .AzureCloudCredentials .EnvVars .AzureApplicationCredentials
44import is .hail .services .oauth2 .GoogleCloudCredentials .EnvVars .GoogleApplicationCredentials
5- import is .hail .shadedazure .com .azure .core .credential .{TokenCredential , TokenRequestContext }
5+ import is .hail .shadedazure .com .azure .core .credential .{
6+ AccessToken , TokenCredential , TokenRequestContext ,
7+ }
68import is .hail .shadedazure .com .azure .identity .{
79 ClientSecretCredentialBuilder , DefaultAzureCredentialBuilder ,
810}
@@ -12,6 +14,7 @@ import scala.collection.JavaConverters._
1214
1315import java .io .Serializable
1416import java .nio .file .{Files , Path }
17+ import java .time .OffsetDateTime
1518
1619import com .google .auth .oauth2 .{GoogleCredentials , ServiceAccountCredentials }
1720import org .json4s .Formats
@@ -20,38 +23,25 @@ import org.json4s.jackson.JsonMethods
2023object oauth2 {
2124
2225 sealed trait CloudCredentials extends Product with Serializable {
23- def accessToken ( scopes : IndexedSeq [ String ]) : String
26+ def accessToken : String
2427 }
2528
26- def CloudCredentials (credentialsPath : Path , env : Map [String , String ] = sys.env)
27- : CloudCredentials =
29+ def CloudCredentials (
30+ keyPath : Path ,
31+ scopes : IndexedSeq [String ],
32+ env : Map [String , String ] = sys.env,
33+ ): CloudCredentials =
2834 env.get(" HAIL_CLOUD" ) match {
29- case Some (" gcp" ) => GoogleCloudCredentials (Some (credentialsPath) )
30- case Some (" azure" ) => AzureCloudCredentials (Some (credentialsPath) )
35+ case Some (" gcp" ) => GoogleCloudCredentials (Some (keyPath), scopes, env )
36+ case Some (" azure" ) => AzureCloudCredentials (Some (keyPath), scopes, env )
3137 case Some (cloud) => throw new IllegalArgumentException (s " Unknown cloud: ' $cloud' " )
3238 case None => throw new IllegalArgumentException (s " HAIL_CLOUD must be set. " )
3339 }
3440
35- def CloudScopes (env : Map [String , String ] = sys.env): Array [String ] =
36- env.get(" HAIL_CLOUD" ) match {
37- case Some (" gcp" ) =>
38- Array (
39- " https://www.googleapis.com/auth/userinfo.profile" ,
40- " https://www.googleapis.com/auth/userinfo.email" ,
41- " openid" ,
42- )
43- case Some (" azure" ) =>
44- sys.env.get(" HAIL_AZURE_OAUTH_SCOPE" ).toArray
45- case Some (cloud) =>
46- throw new IllegalArgumentException (s " Unknown cloud: ' $cloud'. " )
47- case None =>
48- throw new IllegalArgumentException (s " HAIL_CLOUD must be set. " )
49- }
50-
5141 case class GoogleCloudCredentials (value : GoogleCredentials ) extends CloudCredentials {
52- override def accessToken ( scopes : IndexedSeq [ String ]) : String = {
42+ override def accessToken : String = {
5343 value.refreshIfExpired()
54- value.createScoped(scopes.asJava). getAccessToken.getTokenValue
44+ value.getAccessToken.getTokenValue
5545 }
5646 }
5747
@@ -60,40 +50,68 @@ object oauth2 {
6050 val GoogleApplicationCredentials = " GOOGLE_APPLICATION_CREDENTIALS"
6151 }
6252
63- def apply (keyPath : Option [Path ], env : Map [String , String ] = sys.env): GoogleCloudCredentials =
64- GoogleCloudCredentials (
65- keyPath.orElse(env.get(GoogleApplicationCredentials ).map(Path .of(_))) match {
66- case Some (path) => using(Files .newInputStream(path))(ServiceAccountCredentials .fromStream)
67- case None => GoogleCredentials .getApplicationDefault
68- }
69- )
53+ def apply (keyPath : Option [Path ], scopes : IndexedSeq [String ], env : Map [String , String ] = sys.env)
54+ : GoogleCloudCredentials =
55+ GoogleCloudCredentials {
56+ val creds : GoogleCredentials =
57+ keyPath.orElse(env.get(GoogleApplicationCredentials ).map(Path .of(_))) match {
58+ case Some (path) =>
59+ using(Files .newInputStream(path))(ServiceAccountCredentials .fromStream)
60+ case None =>
61+ GoogleCredentials .getApplicationDefault
62+ }
63+
64+ creds.createScoped(scopes : _* )
65+ }
7066 }
7167
7268 sealed trait AzureCloudCredentials extends CloudCredentials {
69+
7370 def value : TokenCredential
71+ def scopes : IndexedSeq [String ]
72+
73+ @ transient private [this ] var token : AccessToken = _
74+
75+ override def accessToken : String = {
76+ refreshIfRequired()
77+ token.getToken
78+ }
79+
80+ private [this ] def refreshIfRequired (): Unit =
81+ if (! isExpired) token.getToken
82+ else synchronized {
83+ if (isExpired) {
84+ token = value.getTokenSync(new TokenRequestContext ().setScopes(scopes.asJava))
85+ }
86+
87+ token.getToken
88+ }
7489
75- override def accessToken ( scopes : IndexedSeq [ String ]) : String =
76- value.getTokenSync( new TokenRequestContext ().setScopes(scopes.asJava)).getToken
90+ private [ this ] def isExpired : Boolean =
91+ token == null || OffsetDateTime .now.plusHours( 1 ).isBefore(token.getExpiresAt)
7792 }
7893
7994 object AzureCloudCredentials {
8095 object EnvVars {
8196 val AzureApplicationCredentials = " AZURE_APPLICATION_CREDENTIALS"
8297 }
8398
84- def apply (keyPath : Option [Path ], env : Map [String , String ] = sys.env): AzureCloudCredentials =
99+ def apply (keyPath : Option [Path ], scopes : IndexedSeq [String ], env : Map [String , String ] = sys.env)
100+ : AzureCloudCredentials =
85101 keyPath.orElse(env.get(AzureApplicationCredentials ).map(Path .of(_))) match {
86- case Some (path) => AzureClientSecretCredentials (path)
87- case None => AzureDefaultCredentials
102+ case Some (path) => AzureClientSecretCredentials (path, scopes )
103+ case None => AzureDefaultCredentials (scopes)
88104 }
89105 }
90106
91- private case object AzureDefaultCredentials extends AzureCloudCredentials {
107+ private case class AzureDefaultCredentials (scopes : IndexedSeq [String ])
108+ extends AzureCloudCredentials {
92109 @ transient override lazy val value : TokenCredential =
93110 new DefaultAzureCredentialBuilder ().build()
94111 }
95112
96- private case class AzureClientSecretCredentials (path : Path ) extends AzureCloudCredentials {
113+ private case class AzureClientSecretCredentials (path : Path , scopes : IndexedSeq [String ])
114+ extends AzureCloudCredentials {
97115 @ transient override lazy val value : TokenCredential =
98116 using(Files .newInputStream(path)) { is =>
99117 implicit val fmts : Formats = defaultJSONFormats
0 commit comments