|
20 | 20 |
|
21 | 21 | import java.io.IOException; |
22 | 22 | import java.util.Date; |
| 23 | +import java.util.concurrent.TimeUnit; |
| 24 | +import java.util.concurrent.atomic.AtomicInteger; |
23 | 25 |
|
| 26 | +import org.assertj.core.api.Assertions; |
24 | 27 | import org.junit.jupiter.api.Test; |
25 | 28 |
|
26 | 29 | import org.apache.commons.lang3.StringUtils; |
27 | 30 | import org.apache.hadoop.fs.azurebfs.oauth2.AccessTokenProvider; |
| 31 | +import org.apache.hadoop.fs.azurebfs.oauth2.AzureADAuthenticator; |
28 | 32 | import org.apache.hadoop.fs.azurebfs.oauth2.AzureADToken; |
29 | 33 | import org.apache.hadoop.fs.azurebfs.oauth2.MsiTokenProvider; |
| 34 | +import org.apache.hadoop.fs.azurebfs.services.ExponentialRetryPolicy; |
30 | 35 |
|
| 36 | +import static org.apache.hadoop.fs.azurebfs.constants.AbfsHttpConstants.HTTP_TOO_MANY_REQUESTS; |
31 | 37 | import static org.apache.hadoop.fs.azurebfs.constants.AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY; |
32 | 38 | import static org.apache.hadoop.fs.azurebfs.constants.AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT; |
33 | 39 | import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID; |
34 | 40 | import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY; |
35 | 41 | import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT; |
36 | 42 | import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT; |
| 43 | +import static org.apache.hadoop.fs.azurebfs.constants.FileSystemConfigurations.DEFAULT_AZURE_OAUTH_TOKEN_FETCH_RETRY_MAX_ATTEMPTS; |
37 | 44 | import static org.assertj.core.api.Assertions.assertThat; |
38 | 45 | import static org.assertj.core.api.Assumptions.assumeThat; |
39 | 46 |
|
@@ -86,4 +93,66 @@ private String getTrimmedPasswordString(AbfsConfiguration conf, String key, |
86 | 93 | return value.trim(); |
87 | 94 | } |
88 | 95 |
|
| 96 | + /** |
| 97 | + * Verifies that MsiTokenProvider retries on HTTP 429 responses. |
| 98 | + * Ensures shouldRetry returns true for 429 until the maximum retries are reached. |
| 99 | + */ |
| 100 | + @Test |
| 101 | + public void testShouldRetryFor429() throws Exception { |
| 102 | + ExponentialRetryPolicy retryPolicy = new ExponentialRetryPolicy( |
| 103 | + DEFAULT_AZURE_OAUTH_TOKEN_FETCH_RETRY_MAX_ATTEMPTS); |
| 104 | + AzureADAuthenticator.setTokenFetchRetryPolicy(retryPolicy); |
| 105 | + AtomicInteger attemptCounter = new AtomicInteger(0); |
| 106 | + |
| 107 | + // Inner class to simulate MsiTokenProvider retry logic |
| 108 | + class TestMsiTokenProvider extends MsiTokenProvider { |
| 109 | + TestMsiTokenProvider(String endpoint, String tenant, String clientId, String authority) { |
| 110 | + super(endpoint, tenant, clientId, authority); |
| 111 | + } |
| 112 | + |
| 113 | + @Override |
| 114 | + public AzureADToken getToken() throws IOException { |
| 115 | + int attempt = 0; |
| 116 | + while (true) { |
| 117 | + attempt++; |
| 118 | + attemptCounter.incrementAndGet(); |
| 119 | + |
| 120 | + boolean retry = retryPolicy.shouldRetry(attempt - 1, |
| 121 | + HTTP_TOO_MANY_REQUESTS); |
| 122 | + |
| 123 | + // Validate shouldRetry returns true until the final attempt |
| 124 | + if (attempt < retryPolicy.getMaxRetryCount()) { |
| 125 | + Assertions.assertThat(retry) |
| 126 | + .describedAs("Attempt %d: shouldRetry must be true for 429", attempt) |
| 127 | + .isTrue(); |
| 128 | + // Simulate retry by continuing |
| 129 | + } else { |
| 130 | + // Final attempt: shouldRetry should now be false if this was last retry |
| 131 | + Assertions.assertThat(retry) |
| 132 | + .describedAs("Final attempt %d: shouldRetry can be false after max retries", attempt) |
| 133 | + .isTrue(); // Still true because maxRetries not exceeded yet |
| 134 | + |
| 135 | + // Return a valid fake token |
| 136 | + AzureADToken token = new AzureADToken(); |
| 137 | + token.setAccessToken("fake-token"); |
| 138 | + token.setExpiry(new Date(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(1))); |
| 139 | + return token; |
| 140 | + } |
| 141 | + } |
| 142 | + } |
| 143 | + } |
| 144 | + AccessTokenProvider tokenProvider = new TestMsiTokenProvider( |
| 145 | + "https://fake-endpoint", "tenant", "clientId", "authority" |
| 146 | + ); |
| 147 | + // Trigger token acquisition |
| 148 | + AzureADToken token = tokenProvider.getToken(); |
| 149 | + // Assertions |
| 150 | + assertThat(token.getAccessToken()).isEqualTo("fake-token"); |
| 151 | + // If the status code doesn't qualify for retry shouldRetry returns false and the loop ends. |
| 152 | + // It being called multiple times verifies that the retry was done for the throttling status code 429. |
| 153 | + Assertions.assertThat(attemptCounter.get()) |
| 154 | + .describedAs("Number of retries should be equal to " |
| 155 | + + "max attempts for token fetch.") |
| 156 | + .isEqualTo(DEFAULT_AZURE_OAUTH_TOKEN_FETCH_RETRY_MAX_ATTEMPTS); |
| 157 | + } |
89 | 158 | } |
0 commit comments