diff --git a/aws/aws.go b/aws/aws.go index cfc42c03..8d6a8bb5 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -15,6 +15,7 @@ import ( "fmt" "io/ioutil" "os" + "regexp" "github.com/vaughan0/go-ini" ) @@ -289,6 +290,83 @@ func GetMetaData(path string) (contents []byte, err error) { return []byte(body), err } +// GetRegionFromEnv retrieves the region specified by an environment variable +// AWS_REGION or aws_region +func EnvRegion() (region Region, err error) { + r := os.Getenv("AWS_REGION") + if r == "" { + r = os.Getenv("aws_region") + } + + if r == "" { + err = fmt.Errorf("AWS_REGION or aws_region not found in environment") + return + } + + region = Regions[r] + + if region.Name == "" { + err = fmt.Errorf("%v region not found", r) + } + + return +} + +// GetRegionFromInstance retrieves the region from the instance metadata service +func GetInstanceRegion() (region Region, err error) { + + regionMatch := regexp.MustCompile(`^(\w+-\w+-\d+)`) + regionPath := "placement/availability-zone/" + + // Get the instance region plus zone + resp, err := GetMetaData(regionPath) + + if err != nil { + return + } + + // extract region from availability-zone + extracted := regionMatch.FindStringSubmatch(string(resp)) + + if extracted == nil { + err = fmt.Errorf("invalid region from metadata service - availability-zone") + return + } + + r := extracted[1] + + region, ok := Regions[r] + if ok == false{ + err = fmt.Errorf("cannot find region %s", r) + } + + return +} + +// GetRegion tires to get the region from either environment variables or metadata service if available +func GetRegion(r string) (region Region, err error) { + + region, ok := Regions[r] + if ok == true { + return + } + + //if not passed in, check ENV + region, err = EnvRegion() + if err == nil { + return + } + + region, err = GetInstanceRegion() + if err == nil { + return + } + + err = errors.New("Cloud not find a valid AWS region") + + return +} + func getInstanceCredentials() (cred credentials, err error) { credentialPath := "iam/security-credentials/" diff --git a/aws/aws_test.go b/aws/aws_test.go index 78cbbaf0..7a9b9fd9 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -1,12 +1,14 @@ package aws_test import ( - "github.com/mitchellh/goamz/aws" - . "github.com/motain/gocheck" + "fmt" "io/ioutil" "os" "strings" "testing" + + "github.com/mitchellh/goamz/aws" + . "github.com/motain/gocheck" ) func Test(t *testing.T) { @@ -201,3 +203,46 @@ func (s *S) TestRegionsAreNamed(c *C) { c.Assert(n, Equals, r.Name) } } + +func (s *S) TestEnvRegionNoRegion(c *C) { + _, err := aws.EnvRegion() + c.Assert(err, ErrorMatches, "AWS_REGION or aws_region not found in environment") +} + +func (s *S) TestEnvRegion(c *C) { + os.Clearenv() + os.Setenv("AWS_REGION", "eu-west-1") + region, err := aws.EnvRegion() + c.Assert(err, IsNil) + c.Assert(region.Name, Equals, "eu-west-1") +} + +func (s *S) TestEnvRegionAlt(c *C) { + os.Clearenv() + os.Setenv("aws_region", "eu-west-1") + region, err := aws.EnvRegion() + c.Assert(err, IsNil) + c.Assert(region.Name, Equals, "eu-west-1") +} + +func (s *S) TestEnvRegionInvalid(c *C) { + os.Clearenv() + os.Setenv("AWS_REGION", "eu-west-never") + _, err := aws.EnvRegion() + errorString := fmt.Sprintf("%v region not found", os.Getenv("AWS_REGION")) + c.Assert(err, ErrorMatches, errorString) +} + +func (s *S) TestGetRegionStatic(c *C) { + region, err := aws.GetRegion("eu-west-1") + c.Assert(err, IsNil) + c.Assert(region.Name, Equals, "eu-west-1") +} + +func (s *S) TestGetRegionEnv(c *C) { + os.Clearenv() + os.Setenv("AWS_REGION", "eu-west-1") + region, err := aws.GetRegion("") + c.Assert(err, IsNil) + c.Assert(region.Name, Equals, "eu-west-1") +}