Skip to content

Commit 9c868fd

Browse files
authored
Merge pull request #19 from sethsec/fix_download_from_aws_error
Fix: Handle new AWS API format with hashed service IDs
2 parents fb62aea + 9763d2f commit 9c868fd

3 files changed

Lines changed: 226 additions & 33 deletions

File tree

awsservicemap.go

Lines changed: 133 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ import (
1313
var awsJson embed.FS
1414

1515
type serviceEntry struct {
16-
ID string `json:"id"`
16+
ID string `json:"id"`
17+
Attributes *serviceAttributes `json:"attributes,omitempty"`
18+
}
19+
20+
type serviceAttributes struct {
21+
Region string `json:"aws:region"`
22+
ServiceName string `json:"aws:serviceName"`
23+
ServiceURL string `json:"aws:serviceUrl"`
1724
}
1825

1926
type regionalServiceData struct {
@@ -30,6 +37,28 @@ func contains(element string, array []string) bool {
3037
return false
3138
}
3239

40+
// extractServiceSlug extracts the service identifier from AWS service URL
41+
// Example: "https://aws.amazon.com/ec2/" -> "ec2"
42+
// "https://aws.amazon.com/rds/mysql/" -> "rds"
43+
// "https://aws.amazon.com/systems-manager/" -> "systems-manager"
44+
func extractServiceSlug(serviceURL string) string {
45+
if serviceURL == "" {
46+
return ""
47+
}
48+
49+
// Remove trailing slash
50+
serviceURL = strings.TrimSuffix(serviceURL, "/")
51+
52+
// Split by / and get the last non-empty part that's not the domain
53+
parts := strings.Split(serviceURL, "/")
54+
for i := len(parts) - 1; i >= 0; i-- {
55+
if parts[i] != "" && !strings.Contains(parts[i], "aws.amazon.com") && !strings.Contains(parts[i], "http") {
56+
return parts[i]
57+
}
58+
}
59+
return ""
60+
}
61+
3362
type AwsServiceMap struct {
3463
JsonFileSource JsonFileSource
3564
cachedData *regionalServiceData // Cache the parsed data to avoid repeated HTTP requests
@@ -91,14 +120,18 @@ func (m *AwsServiceMap) parseJson() (regionalServiceData, error) {
91120
if err != nil {
92121
return serviceData, err
93122
}
94-
json.Unmarshal([]byte(body), &serviceData)
123+
124+
err = json.Unmarshal([]byte(body), &serviceData)
95125
if err != nil {
96126
return serviceData, err
97127
}
98128

99129
} else {
100130
jsonFile, err := awsJson.ReadFile("data/aws-service-regions.json")
101-
json.Unmarshal([]byte(jsonFile), &serviceData)
131+
if err != nil {
132+
return serviceData, err
133+
}
134+
err = json.Unmarshal([]byte(jsonFile), &serviceData)
102135
if err != nil {
103136
return serviceData, err
104137
}
@@ -130,80 +163,148 @@ func (m *AwsServiceMap) GetAllRegions() ([]string, error) {
130163

131164
// Returns a slice of strings that represent all regions that support the specific service
132165
func (m *AwsServiceMap) GetRegionsForService(reqService string) ([]string, error) {
133-
regionsForServiceMap := map[string][]string{}
166+
regionsForService := []string{}
134167

135168
serviceData, err := m.parseJson()
136169
if err != nil {
137-
return regionsForServiceMap[reqService], err
170+
return regionsForService, err
138171
}
139-
for _, id := range serviceData.ServiceEntries {
140-
service := strings.Split(id.ID, ":")[0]
141-
if _, ok := regionsForServiceMap[service]; !ok {
142-
regionsForServiceMap[service] = nil
172+
173+
for _, entry := range serviceData.ServiceEntries {
174+
idParts := strings.Split(entry.ID, ":")
175+
if len(idParts) != 2 {
176+
continue
143177
}
144-
region := strings.Split(id.ID, ":")[1]
145-
if _, ok := regionsForServiceMap[service]; ok {
146-
regionsForServiceMap[service] = append(regionsForServiceMap[service], region)
178+
179+
serviceHash := idParts[0]
180+
region := idParts[1]
181+
182+
// Check if this is the service we're looking for
183+
var matches bool
184+
185+
// Try old format first (direct match)
186+
if serviceHash == reqService {
187+
matches = true
188+
} else if entry.Attributes != nil && entry.Attributes.ServiceURL != "" {
189+
// New format: extract slug from URL
190+
serviceSlug := extractServiceSlug(entry.Attributes.ServiceURL)
191+
if serviceSlug == reqService {
192+
matches = true
193+
}
194+
}
195+
196+
if matches && !contains(region, regionsForService) {
197+
regionsForService = append(regionsForService, region)
147198
}
148199
}
149-
return regionsForServiceMap[reqService], err
150200

201+
return regionsForService, nil
151202
}
152203

153204
// Returns a slice of strings that represent all observed services
154-
155205
func (m *AwsServiceMap) GetAllServices() ([]string, error) {
156206
totalServices := []string{}
157207
serviceData, err := m.parseJson()
158208
if err != nil {
159209
return totalServices, err
160210
}
161-
for _, id := range serviceData.ServiceEntries {
162-
service := strings.Split(id.ID, ":")[0]
163-
if !contains(service, totalServices) {
164-
totalServices = append(totalServices, service)
211+
212+
for _, entry := range serviceData.ServiceEntries {
213+
var serviceName string
214+
215+
if entry.Attributes != nil && entry.Attributes.ServiceURL != "" {
216+
// New format: extract from URL
217+
serviceName = extractServiceSlug(entry.Attributes.ServiceURL)
218+
} else {
219+
// Fallback to hash from ID (old format)
220+
idParts := strings.Split(entry.ID, ":")
221+
if len(idParts) > 0 {
222+
serviceName = idParts[0]
223+
}
224+
}
225+
226+
if serviceName != "" && !contains(serviceName, totalServices) {
227+
totalServices = append(totalServices, serviceName)
165228
}
166229
}
167-
return totalServices, err
168230

231+
return totalServices, nil
169232
}
170233

171234
// Returns a slice of strings that represent all service supported in a specific region
172235
func (m *AwsServiceMap) GetServicesForRegion(reqRegion string) ([]string, error) {
173-
servicesForRegionMap := map[string][]string{}
236+
servicesForRegion := []string{}
174237
serviceData, err := m.parseJson()
175238
if err != nil {
176-
return servicesForRegionMap[reqRegion], err
239+
return servicesForRegion, err
177240
}
178241

179-
for _, id := range serviceData.ServiceEntries {
242+
for _, entry := range serviceData.ServiceEntries {
243+
idParts := strings.Split(entry.ID, ":")
244+
if len(idParts) != 2 {
245+
continue
246+
}
180247

181-
region := strings.Split(id.ID, ":")[1]
182-
if _, ok := servicesForRegionMap[region]; !ok {
183-
servicesForRegionMap[region] = nil
248+
serviceHash := idParts[0]
249+
region := idParts[1]
250+
251+
if region != reqRegion {
252+
continue
184253
}
185254

186-
service := strings.Split(id.ID, ":")[0]
187-
if _, ok := servicesForRegionMap[region]; ok {
188-
servicesForRegionMap[region] = append(servicesForRegionMap[region], service)
255+
// Determine service name
256+
var serviceName string
257+
if entry.Attributes != nil && entry.Attributes.ServiceURL != "" {
258+
serviceName = extractServiceSlug(entry.Attributes.ServiceURL)
259+
} else {
260+
// Fallback to hash (old format)
261+
serviceName = serviceHash
262+
}
263+
264+
if serviceName != "" && !contains(serviceName, servicesForRegion) {
265+
servicesForRegion = append(servicesForRegion, serviceName)
189266
}
190267
}
191-
return servicesForRegionMap[reqRegion], err
268+
269+
return servicesForRegion, nil
192270
}
193271

194272
// Is a specific service supported in a specific region. Returns true/false
273+
// Handles both old format (ec2:us-east-1) and new format (hash:us-east-1 with attributes)
195274
func (m *AwsServiceMap) IsServiceInRegion(reqService string, reqRegion string) (bool, error) {
196275
serviceData, err := m.parseJson()
197276
if err != nil {
198277
return false, err
199278
}
200279

280+
// Try direct match first (old format or if hash is provided)
201281
reqPair := serviceEntry{ID: fmt.Sprintf("%s:%s", reqService, reqRegion)}
202282
if serviceEntryContains(reqPair, serviceData.ServiceEntries) {
203-
return true, err
204-
} else {
205-
return false, err
283+
return true, nil
284+
}
285+
286+
// New format: check if any entry with matching region has the service slug
287+
for _, entry := range serviceData.ServiceEntries {
288+
// Check if region matches
289+
idParts := strings.Split(entry.ID, ":")
290+
if len(idParts) != 2 {
291+
continue
292+
}
293+
294+
if idParts[1] != reqRegion {
295+
continue
296+
}
297+
298+
// Extract service slug from URL if attributes exist
299+
if entry.Attributes != nil && entry.Attributes.ServiceURL != "" {
300+
serviceSlug := extractServiceSlug(entry.Attributes.ServiceURL)
301+
if serviceSlug == reqService {
302+
return true, nil
303+
}
304+
}
206305
}
306+
307+
return false, nil
207308
}
208309

209310
func serviceEntryContains(element serviceEntry, array []serviceEntry) bool {

cmd/examples/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ func main() {
1616
// With the new caching feature, the data is fetched only once per instance regardless of how many method calls are made.
1717

1818
servicemap := &awsservicemap.AwsServiceMap{
19-
JsonFileSource: "EMBEDDED_IN_PACKAGE",
19+
//JsonFileSource: "EMBEDDED_IN_PACKAGE",
20+
JsonFileSource: "DOWNLOAD_FROM_AWS",
2021
}
2122

2223
// Example of how you can also use the constructor pattern to simulate "instantiating" a new service map "object"

test/test_new_format.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
7+
"github.com/bishopfox/awsservicemap"
8+
)
9+
10+
func main() {
11+
log.SetFlags(log.LstdFlags | log.Lshortfile)
12+
13+
servicemap := &awsservicemap.AwsServiceMap{
14+
JsonFileSource: "DOWNLOAD_FROM_AWS",
15+
}
16+
17+
fmt.Println("=== Testing awsservicemap with new AWS API format ===\n")
18+
19+
// Test IsServiceInRegion
20+
fmt.Println("1. Testing IsServiceInRegion for EC2:")
21+
testRegions := []string{"us-east-1", "us-west-2", "eu-west-1", "ap-southeast-1"}
22+
for _, region := range testRegions {
23+
res, err := servicemap.IsServiceInRegion("ec2", region)
24+
if err != nil {
25+
fmt.Printf(" ERROR checking %s: %v\n", region, err)
26+
} else {
27+
fmt.Printf(" EC2 in %s: %v\n", region, res)
28+
}
29+
}
30+
31+
// Test GetRegionsForService
32+
fmt.Println("\n2. Testing GetRegionsForService for EC2:")
33+
regions, err := servicemap.GetRegionsForService("ec2")
34+
if err != nil {
35+
fmt.Printf(" ERROR: %v\n", err)
36+
} else {
37+
fmt.Printf(" Found EC2 in %d regions\n", len(regions))
38+
fmt.Printf(" First 5 regions: %v\n", regions[:min(5, len(regions))])
39+
}
40+
41+
// Test GetAllServices
42+
fmt.Println("\n3. Testing GetAllServices:")
43+
services, err := servicemap.GetAllServices()
44+
if err != nil {
45+
fmt.Printf(" ERROR: %v\n", err)
46+
} else {
47+
fmt.Printf(" Found %d total services\n", len(services))
48+
fmt.Printf(" First 10 services: %v\n", services[:min(10, len(services))])
49+
}
50+
51+
// Test GetServicesForRegion
52+
fmt.Println("\n4. Testing GetServicesForRegion for us-east-1:")
53+
usEast1Services, err := servicemap.GetServicesForRegion("us-east-1")
54+
if err != nil {
55+
fmt.Printf(" ERROR: %v\n", err)
56+
} else {
57+
fmt.Printf(" Found %d services in us-east-1\n", len(usEast1Services))
58+
fmt.Printf(" First 10 services: %v\n", usEast1Services[:min(10, len(usEast1Services))])
59+
}
60+
61+
// Test GetAllRegions
62+
fmt.Println("\n5. Testing GetAllRegions:")
63+
allRegions, err := servicemap.GetAllRegions()
64+
if err != nil {
65+
fmt.Printf(" ERROR: %v\n", err)
66+
} else {
67+
fmt.Printf(" Found %d total regions\n", len(allRegions))
68+
fmt.Printf(" All regions: %v\n", allRegions)
69+
}
70+
71+
// Test a few more services
72+
fmt.Println("\n6. Testing other services:")
73+
otherServices := []string{"rds", "lambda", "s3", "iam", "eks"}
74+
for _, svc := range otherServices {
75+
res, err := servicemap.IsServiceInRegion(svc, "us-east-1")
76+
if err != nil {
77+
fmt.Printf(" ERROR checking %s: %v\n", svc, err)
78+
} else {
79+
fmt.Printf(" %s in us-east-1: %v\n", svc, res)
80+
}
81+
}
82+
83+
fmt.Println("\n=== All tests completed successfully! ===")
84+
}
85+
86+
func min(a, b int) int {
87+
if a < b {
88+
return a
89+
}
90+
return b
91+
}

0 commit comments

Comments
 (0)