diff --git a/contrib/admin/mypy-with-ignore.py b/contrib/admin/mypy-with-ignore.py index b940f40507..acefbe9e46 100755 --- a/contrib/admin/mypy-with-ignore.py +++ b/contrib/admin/mypy-with-ignore.py @@ -91,6 +91,14 @@ def main(): 'src/toil/lib/encryption/conftest.py', 'src/toil/lib/encryption/__init__.py', 'src/toil/lib/aws/__init__.py', + 'src/toil/lib/aws/ec2.py', + 'src/toil/lib/aws/s3.py', + 'src/toil/lib/aws/iam.py', + 'src/toil/lib/aws/config.py', + 'src/toil/lib/aws/utils.py', + 'src/toil/lib/aws/util.py', + 'src/toil/lib/checksum.py', + 'src/toil/lib/pipes.py', 'src/toil/server/utils.py', 'src/toil/utils/toilStats.py' ]] diff --git a/src/toil/batchSystems/awsBatch.py b/src/toil/batchSystems/awsBatch.py index 5789d1ca82..1784969b2a 100644 --- a/src/toil/batchSystems/awsBatch.py +++ b/src/toil/batchSystems/awsBatch.py @@ -47,7 +47,7 @@ from toil.bus import MessageBus, MessageOutbox, JobAnnotationMessage from toil.common import Config, Toil from toil.job import JobDescription -from toil.lib.aws import get_current_aws_region, zone_to_region +from toil.lib.aws.util import get_current_aws_region, zone_to_region from toil.lib.aws.session import establish_boto3_session from toil.lib.conversions import b_to_mib, mib_to_b from toil.lib.misc import slow_down, utc_now, unix_now_ms diff --git a/src/toil/common.py b/src/toil/common.py index 6b4e48daed..71dcd64a72 100644 --- a/src/toil/common.py +++ b/src/toil/common.py @@ -68,7 +68,7 @@ ClusterSizeMessage, ClusterDesiredSizeMessage) from toil.fileStores import FileID -from toil.lib.aws import zone_to_region +from toil.lib.aws.util import zone_to_region from toil.lib.compatibility import deprecated from toil.lib.conversions import bytes2human, human2bytes from toil.lib.retry import retry diff --git a/src/toil/jobStores/aws/jobStore.py b/src/toil/jobStores/aws/jobStore.py index 69e4489383..006f617b3e 100644 --- a/src/toil/jobStores/aws/jobStore.py +++ b/src/toil/jobStores/aws/jobStore.py @@ -64,6 +64,7 @@ retry_s3, retryable_s3_errors ) +from toil.lib.aws.s3 import list_multipart_uploads, delete_bucket from toil.lib.compatibility import compat_bytes from toil.lib.aws.session import establish_boto3_session from toil.lib.ec2nodes import EC2Regions @@ -1211,14 +1212,15 @@ def readFrom(self, readable): parts = [] logger.debug('Multipart upload started as %s', uploadId) - for attempt in retry_s3(): with attempt: for i in range(CONSISTENCY_TICKS): - # Sometimes we can create a multipart upload and not see it. Wait around for it. - response = client.list_multipart_uploads(Bucket=bucket_name, - MaxUploads=1, - Prefix=compat_bytes(info.fileID)) + # Sometimes we can create a multipart upload and can not see it. Wait around for it. + response = list_multipart_uploads( + s3_resource=store.s3_resource, + bucket=bucket_name, + prefix=compat_bytes(info.fileID) + ) if len(response['Uploads']) != 0 and response['Uploads'][0]['UploadId'] == uploadId: logger.debug('Multipart upload visible as %s', uploadId) break @@ -1627,27 +1629,7 @@ def _delete_domain(self, domain): @staticmethod def _delete_bucket(bucket): - """ - :param bucket: S3.Bucket - """ - for attempt in retry_s3(): - with attempt: - try: - uploads = s3_boto3_client.list_multipart_uploads(Bucket=bucket.name).get('Uploads') - if uploads: - for u in uploads: - s3_boto3_client.abort_multipart_upload(Bucket=bucket.name, - Key=u["Key"], - UploadId=u["UploadId"]) - - bucket.objects.all().delete() - bucket.object_versions.delete() - bucket.delete() - except s3_boto3_client.exceptions.NoSuchBucket: - pass - except ClientError as e: - if get_error_status(e) != 404: - raise + delete_bucket(s3_boto3_resource, bucket.name) aRepr = reprlib.Repr() diff --git a/src/toil/jobStores/utils.py b/src/toil/jobStores/utils.py index 0b6409ebf3..b61fd47295 100644 --- a/src/toil/jobStores/utils.py +++ b/src/toil/jobStores/utils.py @@ -366,7 +366,7 @@ def generate_locator( from toil.jobStores.aws.jobStore import AWSJobStore # noqa # Find a region - from toil.lib.aws import get_current_aws_region + from toil.lib.aws.util import get_current_aws_region region = get_current_aws_region() diff --git a/src/toil/lib/aws/__init__.py b/src/toil/lib/aws/__init__.py index 206dce5173..471fe7cd32 100644 --- a/src/toil/lib/aws/__init__.py +++ b/src/toil/lib/aws/__init__.py @@ -11,158 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections -import inspect -import json -import logging -import os -import re -import socket -import threading -from functools import lru_cache -from urllib.request import urlopen -from urllib.error import URLError - -from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union - -logger = logging.getLogger(__name__) - -# This file isn't allowed to import anything that depends on Boto or Boto3, -# which may not be installed, because it has to be importable everywhere. - -def get_current_aws_region() -> Optional[str]: - """ - Return the AWS region that the currently configured AWS zone (see - get_current_aws_zone()) is in. - """ - # Try to derive it from the zone. - aws_zone = get_current_aws_zone() - return zone_to_region(aws_zone) if aws_zone else None - -def get_aws_zone_from_environment() -> Optional[str]: - """ - Get the AWS zone from TOIL_AWS_ZONE if set. - """ - return os.environ.get('TOIL_AWS_ZONE', None) - -def get_aws_zone_from_metadata() -> Optional[str]: - """ - Get the AWS zone from instance metadata, if on EC2 and the boto module is - available. Otherwise, gets the AWS zone from ECS task metadata, if on ECS. - """ - - # When running on ECS, we also appear to be running on EC2, but the EC2 - # metadata service doesn't seem to be contactable. So we check ECS first. - - if running_on_ecs(): - # Use the ECS metadata service - logger.debug("Fetch AZ from ECS metadata") - try: - resp = json.load(urlopen(os.environ['ECS_CONTAINER_METADATA_URI_V4'] + '/task', timeout=1)) - logger.debug("ECS metadata: %s", resp) - if isinstance(resp, dict): - # We found something. Go with that. - return resp.get('AvailabilityZone') - except (json.decoder.JSONDecodeError, KeyError, URLError) as e: - # We're on ECS but can't get the metadata. That's odd. - logger.warning("Skipping ECS metadata due to error: %s", e) - if running_on_ec2(): - # On EC2 alone, or on ECS but we couldn't get ahold of the ECS - # metadata. - try: - # Use the EC2 metadata service - import boto - from boto.utils import get_instance_metadata - logger.debug("Fetch AZ from EC2 metadata") - return get_instance_metadata()['placement']['availability-zone'] - except ImportError: - # This is expected to happen a lot - logger.debug("No boto to fetch ECS metadata") - except (KeyError, URLError) as e: - # We're on EC2 but can't get the metadata. That's odd. - logger.warning("Skipping EC2 metadata due to error: %s", e) - return None - -def get_aws_zone_from_boto() -> Optional[str]: - """ - Get the AWS zone from the Boto config file, if it is configured and the - boto module is available. - """ - try: - import boto - zone = boto.config.get('Boto', 'ec2_region_name') - if zone is not None: - zone += 'a' # derive an availability zone in the region - return zone - except ImportError: - pass - return None - -def get_aws_zone_from_environment_region() -> Optional[str]: - """ - Pick an AWS zone in the region defined by TOIL_AWS_REGION, if it is set. - """ - aws_region = os.environ.get('TOIL_AWS_REGION') - if aws_region is not None: - # If a region is specified, use the first zone in the region. - return aws_region + 'a' - # Otherwise, don't pick a region and let us fall back on the next method. - return None - -def get_current_aws_zone() -> Optional[str]: - """ - Get the currently configured or occupied AWS zone to use. - - Reports the TOIL_AWS_ZONE environment variable if set. - - Otherwise, if we have boto and are running on EC2, or if we are on ECS, - reports the zone we are running in. - - Otherwise, if we have the TOIL_AWS_REGION variable set, chooses a zone in - that region. - - Finally, if we have boto2, and a default region is configured in Boto 2, - chooses a zone in that region. - - Returns None if no method can produce a zone to use. - """ - return get_aws_zone_from_environment() or \ - get_aws_zone_from_metadata() or \ - get_aws_zone_from_environment_region() or \ - get_aws_zone_from_boto() - -def zone_to_region(zone: str) -> str: - """Get a region (e.g. us-west-2) from a zone (e.g. us-west-1c).""" - # re.compile() caches the regex internally so we don't have to - availability_zone = re.compile(r'^([a-z]{2}-[a-z]+-[1-9][0-9]*)([a-z])$') - m = availability_zone.match(zone) - if not m: - raise ValueError(f"Can't extract region from availability zone '{zone}'") - return m.group(1) - -def running_on_ec2() -> bool: - """ - Return True if we are currently running on EC2, and false otherwise. - """ - # TODO: Move this to toil.lib.ec2 and make toil.lib.ec2 importable without boto? - def file_begins_with(path, prefix): - with open(path) as f: - return f.read(len(prefix)) == prefix - - hv_uuid_path = '/sys/hypervisor/uuid' - if os.path.exists(hv_uuid_path) and file_begins_with(hv_uuid_path, 'ec2'): - return True - # Some instances do not have the /sys/hypervisor/uuid file, so check the identity document instead. - # See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html - try: - urlopen('http://169.254.169.254/latest/dynamic/instance-identity/document', timeout=1) - return True - except (URLError, socket.timeout): - return False - -def running_on_ecs() -> bool: - """ - Return True if we are currently running on Amazon ECS, and false otherwise. - """ - # We only care about relatively current ECS - return 'ECS_CONTAINER_METADATA_URI_V4' in os.environ diff --git a/src/toil/lib/aws/config.py b/src/toil/lib/aws/config.py new file mode 100644 index 0000000000..4943f67077 --- /dev/null +++ b/src/toil/lib/aws/config.py @@ -0,0 +1,30 @@ +import math + +from toil.lib.units import MIB, TB + +S3_PARALLELIZATION_FACTOR = 8 +S3_PART_SIZE = 16 * MIB + +# AWS Defined Limits +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html +AWS_MAX_CHUNK_SIZE = 5 * TB +# Files must be larger than this before we consider multipart uploads. +AWS_MIN_CHUNK_SIZE = 64 * MIB +# Convenience variable for Boto3 TransferConfig(multipart_threshold=). +MULTIPART_THRESHOLD = AWS_MIN_CHUNK_SIZE + 1 +# Maximum number of parts allowed in a multipart upload. This is a limitation imposed by S3. +AWS_MAX_MULTIPART_COUNT = 10000 +# Note: There is no minimum size limit on the last part of a multipart upload. + +# The chunk size we chose arbitrarily, but it must be consistent for etags +DEFAULT_AWS_CHUNK_SIZE = 128 * MIB +assert AWS_MAX_CHUNK_SIZE > DEFAULT_AWS_CHUNK_SIZE > AWS_MIN_CHUNK_SIZE + + +def get_s3_multipart_chunk_size(file_size: int) -> int: + if file_size >= AWS_MAX_CHUNK_SIZE * AWS_MAX_MULTIPART_COUNT: + return AWS_MAX_CHUNK_SIZE + elif file_size <= DEFAULT_AWS_CHUNK_SIZE * AWS_MAX_MULTIPART_COUNT: + return DEFAULT_AWS_CHUNK_SIZE + else: + return math.ceil(file_size / AWS_MAX_MULTIPART_COUNT) diff --git a/src/toil/lib/aws/ec2.py b/src/toil/lib/aws/ec2.py new file mode 100644 index 0000000000..b1efca1096 --- /dev/null +++ b/src/toil/lib/aws/ec2.py @@ -0,0 +1,39 @@ +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AWS functions dealing with ec2.""" +import os +import socket +from urllib.error import URLError +from urllib.request import urlopen + + +def running_on_ec2() -> bool: + """ + Return True if we are currently running on EC2, and false otherwise. + """ + # TODO: Move this to toil.lib.ec2 and make toil.lib.ec2 importable without boto? + def file_begins_with(path, prefix): + with open(path) as f: + return f.read(len(prefix)) == prefix + + hv_uuid_path = '/sys/hypervisor/uuid' + if os.path.exists(hv_uuid_path) and file_begins_with(hv_uuid_path, 'ec2'): + return True + # Some instances do not have the /sys/hypervisor/uuid file, so check the identity document instead. + # See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html + try: + urlopen('http://169.254.169.254/latest/dynamic/instance-identity/document', timeout=1) + return True + except (URLError, socket.timeout): + return False diff --git a/src/toil/lib/aws/ecs.py b/src/toil/lib/aws/ecs.py new file mode 100644 index 0000000000..c3e71a3daa --- /dev/null +++ b/src/toil/lib/aws/ecs.py @@ -0,0 +1,9 @@ +import os + + +def running_on_ecs() -> bool: + """ + Return True if we are currently running on Amazon ECS, and false otherwise. + """ + # We only care about relatively current ECS + return 'ECS_CONTAINER_METADATA_URI_V4' in os.environ diff --git a/src/toil/lib/aws/iam.py b/src/toil/lib/aws/iam.py index 57e56540e5..7c25798b07 100644 --- a/src/toil/lib/aws/iam.py +++ b/src/toil/lib/aws/iam.py @@ -1,19 +1,39 @@ - +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import boto3 import fnmatch import json -from toil.lib.aws import zone_to_region -from toil.provisioners.aws import get_best_aws_zone -from functools import lru_cache -from typing import Any, Optional, List, Dict, Set, cast +from functools import lru_cache +from typing import Optional, List, Dict, cast, Any from mypy_boto3_iam import IAMClient from mypy_boto3_sts import STSClient from mypy_boto3_iam.type_defs import AttachedPolicyTypeDef from toil.lib.aws.session import client as get_client from collections import defaultdict +from toil.lib.retry import retry +from toil.lib.aws.session import client, resource + +try: + from boto.exception import BotoServerError +except ImportError: + # AWS/boto extra is not installed + BotoServerError = None # type: ignore + + logger = logging.getLogger(__name__) #TODO Make this comprehensive @@ -51,8 +71,43 @@ "ec2:TerminateInstances", ] + +@retry(errors=[BotoServerError]) +def delete_iam_role(role_name: str, region: Optional[str] = None, display_type='print') -> None: + display = print if display_type == 'print' else logger.debug + from boto.iam.connection import IAMConnection + iam_client = client('iam', region_name=region) + iam_resource = resource('iam', region_name=region) + boto_iam_connection = IAMConnection() + role = iam_resource.Role(role_name) + # normal policies + for attached_policy in role.attached_policies.all(): + display(f'Now dissociating policy: {attached_policy.name} from role {role.name}') + role.detach_policy(PolicyName=attached_policy.name) + # inline policies + for attached_policy in role.policies.all(): + display(f'Deleting inline policy: {attached_policy.name} from role {role.name}') + # couldn't find an easy way to remove inline policies with boto3; use boto + boto_iam_connection.delete_role_policy(role.name, attached_policy.name) + iam_client.delete_role(RoleName=role_name) + display(f'Role {role_name} successfully deleted.') + + +@retry(errors=[BotoServerError]) +def delete_iam_instance_profile(instance_profile_name: str, region: Optional[str] = None, display_type='print') -> None: + display = print if display_type == 'print' else logger.debug + iam_resource = resource('iam', region_name=region) + instance_profile = iam_resource.InstanceProfile(instance_profile_name) + for role in instance_profile.roles: + display(f'Now dissociating role: {role.name} from instance profile {instance_profile_name}') + instance_profile.remove_role(RoleName=role.name) + instance_profile.delete() + display(f'Instance profile "{instance_profile_name}" successfully deleted.') + + AllowedActionCollection = Dict[str, Dict[str, List[str]]] + def init_action_collection() -> AllowedActionCollection: ''' Initialization of an action collection, an action collection contains allowed Actions and NotActions @@ -64,6 +119,7 @@ def init_action_collection() -> AllowedActionCollection: ''' return defaultdict(lambda: {'Action': [], 'NotAction': []}) + def add_to_action_collection(a: AllowedActionCollection, b: AllowedActionCollection) -> AllowedActionCollection: ''' Combines two action collections @@ -80,8 +136,6 @@ def add_to_action_collection(a: AllowedActionCollection, b: AllowedActionCollect return to_return - - def policy_permissions_allow(given_permissions: AllowedActionCollection, required_permissions: List[str] = []) -> bool: """ Check whether given set of actions are a subset of another given set of actions, returns true if they are @@ -124,6 +178,7 @@ def permission_matches_any(perm: str, list_perms: List[str]) -> bool: return True return False + def get_actions_from_policy_document(policy_doc: Dict[str, Any]) -> AllowedActionCollection: ''' Given a policy document, go through each statement and create an AllowedActionCollection representing the @@ -148,6 +203,8 @@ def get_actions_from_policy_document(policy_doc: Dict[str, Any]) -> AllowedActio allowed_actions[resource][key].append(statement[key]) return allowed_actions + + def allowed_actions_attached(iam: IAMClient, attached_policies: List[AttachedPolicyTypeDef]) -> AllowedActionCollection: """ Go through all attached policy documents and create an AllowedActionCollection representing granted permissions. @@ -189,6 +246,7 @@ def allowed_actions_roles(iam: IAMClient, policy_names: List[str], role_name: st return allowed_actions + def allowed_actions_users(iam: IAMClient, policy_names: List[str], user_name: str) -> AllowedActionCollection: """ Gets all allowed actions for a user given by user_name, returns a dictionary, keyed by resource, @@ -210,6 +268,7 @@ def allowed_actions_users(iam: IAMClient, policy_names: List[str], user_name: st return allowed_actions + def get_policy_permissions(region: str) -> AllowedActionCollection: """ Returns an action collection containing lists of all permission grant patterns keyed by resource @@ -217,10 +276,9 @@ def get_policy_permissions(region: str) -> AllowedActionCollection: :param zone: AWS zone to connect to """ - iam: IAMClient = cast(IAMClient, get_client('iam', region)) sts: STSClient = cast(STSClient, get_client('sts', region)) - #TODO Condider effect: deny at some point + #TODO Consider effect: deny at some point allowed_actions: AllowedActionCollection = defaultdict(lambda: {'Action': [], 'NotAction': []}) try: # If successful then we assume we are operating as a user, and grab the associated permissions @@ -257,4 +315,4 @@ def get_aws_account_num() -> Optional[str]: """ Returns AWS account num """ - return boto3.client('sts').get_caller_identity().get('Account') \ No newline at end of file + return boto3.client('sts').get_caller_identity().get('Account') diff --git a/src/toil/lib/aws/s3.py b/src/toil/lib/aws/s3.py new file mode 100644 index 0000000000..e46caba2d5 --- /dev/null +++ b/src/toil/lib/aws/s3.py @@ -0,0 +1,450 @@ +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +import itertools +import urllib.parse +import logging +import os + +from io import BytesIO +from typing import Tuple, Optional, Union, Iterable, Callable, Iterator, ContextManager +from datetime import timedelta +from contextlib import contextmanager +from boto3.s3.transfer import TransferConfig +from botocore.exceptions import ClientError + +from toil.lib.misc import modify_url +from toil.lib.pipes import WritablePipe, ReadablePipe, HashingPipe +from toil.lib.retry import ErrorCondition, old_retry, retry, DEFAULT_DELAYS, DEFAULT_TIMEOUT + +try: + from boto.exception import BotoServerError + from mypy_boto3_s3 import S3ServiceResource + from mypy_boto3_s3.literals import BucketLocationConstraintType + from mypy_boto3_s3.service_resource import Bucket +except ImportError: + BotoServerError = None # type: ignore + # AWS/boto extra is not installed + +logger = logging.getLogger(__name__) + + +class NoSuchFileException(Exception): + pass + + +class AWSKeyNotFoundError(Exception): + pass + + +class AWSKeyAlreadyExistsError(Exception): + pass + + +class AWSBadEncryptionKeyError(Exception): + pass + + +def retryable_s3_errors(e: Exception) -> bool: + """ + Return true if this is an error from S3 that looks like we ought to retry our request. + """ + return ((isinstance(e, socket.error) and e.errno in (errno.ECONNRESET, 104)) + or (isinstance(e, BotoServerError) and e.status in (429, 500)) + or (isinstance(e, BotoServerError) and e.code in THROTTLED_ERROR_CODES) + # boto3 errors + or (isinstance(e, (S3ResponseError, ClientError)) and get_error_code(e) in THROTTLED_ERROR_CODES) + or (isinstance(e, ClientError) and 'BucketNotEmpty' in str(e)) + or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 409 and 'try again' in str(e)) + or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') in (404, 429, 500, 502, 503, 504))) + + +def retry_s3(delays: Iterable[float] = DEFAULT_DELAYS, timeout: float = DEFAULT_TIMEOUT, predicate: Callable[[Exception], bool] = retryable_s3_errors) -> Iterator[ContextManager[None]]: + """ + Retry iterator of context managers specifically for S3 operations. + """ + return old_retry(delays=delays, timeout=timeout, predicate=predicate) + + +# TODO: Determine specific retries +@retry() +def list_multipart_uploads(s3_resource, bucket, prefix, max_uploads=1): + s3_client = s3_resource.meta.client + return s3_client.list_multipart_uploads(Bucket=bucket, MaxUploads=max_uploads, Prefix=prefix) + + +# TODO: Determine specific retries +@retry() +def create_bucket(s3_resource: S3ServiceResource, bucket: str) -> Bucket: + """ + Create an AWS S3 bucket, using the given Boto3 S3 resource, with the + given name, in the S3 resource's region. + + Supports the us-east-1 region, where bucket creation is special. + + *ALL* S3 bucket creation should use this function. + """ + s3_client = s3_resource.meta.client + logger.info(f"Creating AWS bucket {bucket} in region {s3_client.meta.region_name}") + if s3_client.meta.region_name == "us-east-1": # see https://github.com/boto/boto3/issues/125 + s3_client.create_bucket(Bucket=bucket) + else: + s3_client.create_bucket( + Bucket=bucket, + CreateBucketConfiguration={"LocationConstraint": s3_client.meta.region_name}, + ) + waiter = s3_client.get_waiter('bucket_exists') + waiter.wait(Bucket=bucket) + owner_tag = os.environ.get('TOIL_OWNER_TAG') + if owner_tag: + bucket_tagging = s3_resource.BucketTagging(bucket) + bucket_tagging.put(Tagging={'TagSet': [{'Key': 'Owner', 'Value': owner_tag}]}) + logger.debug(f"Successfully created new bucket '{bucket}'") + return s3_resource.Bucket(bucket) + + +# # TODO: Determine specific retries +@retry(errors=[BotoServerError]) +def delete_bucket(s3_resource: S3ServiceResource, bucket: str, display_type='log') -> None: + display = print if display_type == 'print' else logger.debug + s3_client = s3_resource.meta.client + bucket_obj = s3_resource.Bucket(bucket) + try: + uploads = s3_client.list_multipart_uploads(Bucket=bucket).get('Uploads') or list() + for u in uploads: + s3_client.abort_multipart_upload(Bucket=bucket, Key=u["Key"], UploadId=u["UploadId"]) + bucket_obj.objects.all().delete() + bucket_obj.object_versions.delete() + bucket_obj.delete() + except s3_client.exceptions.NoSuchBucket: + display(f"Bucket already deleted (NoSuchBucket): '{bucket}'") + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') != 404: + raise + display(f"Bucket already deleted (404): '{bucket}'") + else: + display(f"Successfully deleted bucket: '{bucket}'") + + +# TODO: Determine specific retries +@retry(errors=[BotoServerError]) +def bucket_exists(s3_resource, bucket: str) -> Union[bool, Bucket]: + s3_client = s3_resource.meta.client + try: + s3_client.head_bucket(Bucket=bucket) + return s3_resource.Bucket(bucket) + except (ClientError, s3_client.exceptions.NoSuchBucket) as e: + error_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') + if error_code == 404: + return False + else: + raise + + +# TODO: Determine specific retries +@retry(errors=[BotoServerError]) +def copy_s3_to_s3(s3_resource, src_bucket, src_key, dst_bucket, dst_key, extra_args: Optional[dict] = None): + if not extra_args: + source = {'Bucket': src_bucket, 'Key': src_key} + # Note: this may have errors if using sse-c because of + # a bug with encryption using copy_object and copy (which uses copy_object for files <5GB): + # https://github.com/aws/aws-cli/issues/6012 + # this will only happen if we attempt to copy a file previously encrypted with sse-c + # copying an unencrypted file and encrypting it as sse-c seems to work fine though + kwargs = dict(CopySource=source, Bucket=dst_bucket, Key=dst_key, ExtraArgs=extra_args) + s3_resource.meta.client.copy(**kwargs) + else: + pass + + +# TODO: Determine specific retries +@retry(errors=[BotoServerError]) +def copy_local_to_s3(s3_resource, local_file_path, dst_bucket, dst_key, extra_args: Optional[dict] = None): + s3_client = s3_resource.meta.client + s3_client.upload_file(local_file_path, dst_bucket, dst_key, ExtraArgs=extra_args) + + +class MultiPartPipe(WritablePipe): + def __init__(self, part_size, s3_client, bucket_name, file_id, encryption_args, encoding, errors): + super(MultiPartPipe, self).__init__() + self.encoding = encoding + self.errors = errors + self.part_size = part_size + self.s3_client = s3_client + self.bucket_name = bucket_name + self.file_id = file_id + self.encryption_args = encryption_args + + def readFrom(self, readable): + # Get the first block of data we want to put + buf = readable.read(self.part_size) + assert isinstance(buf, bytes) + + # We will compute a checksum + hasher = hashlib.sha1() + hasher.update(buf) + + # low-level clients are thread safe + response = self.s3_client.create_multipart_upload(Bucket=self.bucket_name, + Key=self.file_id, + **self.encryption_args) + upload_id = response['UploadId'] + parts = [] + try: + for part_num in itertools.count(): + logger.debug(f'[{upload_id}] Uploading part %d of %d bytes', part_num + 1, len(buf)) + # TODO: include the Content-MD5 header: + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.complete_multipart_upload + part = self.s3_client.upload_part(Bucket=self.bucket_name, + Key=self.file_id, + PartNumber=part_num + 1, + UploadId=upload_id, + Body=BytesIO(buf), + **self.encryption_args) + parts.append({"PartNumber": part_num + 1, "ETag": part["ETag"]}) + + # Get the next block of data we want to put + buf = readable.read(self.part_size) + if len(buf) == 0: + # Don't allow any part other than the very first to be empty. + break + hasher.update(buf) + except: + self.s3_client.abort_multipart_upload(Bucket=self.bucket_name, + Key=self.file_id, + UploadId=upload_id) + else: + # Save the checksum + checksum = f'sha1${hasher.hexdigest()}' + response = self.s3_client.complete_multipart_upload(Bucket=self.bucket_name, + Key=self.file_id, + UploadId=upload_id, + MultipartUpload={"Parts": parts}) + logger.debug(f'[{upload_id}] Upload complete...') + + +def parse_s3_uri(uri: str) -> Tuple[str, str]: + # does not support s3/gs: https://docs.python.org/3/library/urllib.parse.html + # use regex instead? + if isinstance(uri, str): + uri = urllib.parse.urlparse(uri) + if uri.scheme.lower() != 's3': + raise ValueError(f'Invalid schema. Expecting s3 prefix, not: {uri}') + # bucket_name, key_name = uri[len('s3://'):].split('/', 1) + bucket_name, key_name = uri.netloc.strip('/'), uri.path.strip('/') + return bucket_name, key_name + + +def list_s3_items(s3_resource, bucket, prefix, startafter=None): + s3_client = s3_resource.meta.client + paginator = s3_client.get_paginator('list_objects_v2') + kwargs = dict(Bucket=bucket, Prefix=prefix) + if startafter: + kwargs['StartAfter'] = startafter + for page in paginator.paginate(**kwargs): + for key in page.get('Contents', []): + yield key + + +@retry(errors=[ErrorCondition(error=ClientError, error_codes=[404, 500, 502, 503, 504])]) +def upload_to_s3(readable, + s3_resource, + bucket: str, + key: str, + extra_args: Optional[dict] = None): + """ + Upload a readable object to s3, using multipart uploading if applicable. + + :param readable: a readable stream or a local file path to upload to s3 + :param S3.Resource resource: boto3 resource + :param str bucket: name of the bucket to upload to + :param str key: the name of the file to upload to + :param dict extra_args: http headers to use when uploading - generally used for encryption purposes + :param int partSize: max size of each part in the multipart upload, in bytes + :return: version of the newly uploaded file + """ + if extra_args is None: + extra_args = {} + + s3_client = s3_resource.meta.client + config = TransferConfig( + multipart_threshold=DEFAULT_AWS_CHUNK_SIZE, + multipart_chunksize=DEFAULT_AWS_CHUNK_SIZE, + use_threads=True + ) + logger.debug("Uploading %s", key) + # these methods use multipart if necessary + if isinstance(readable, str): + s3_client.upload_file(Filename=readable, + Bucket=bucket, + Key=key, + ExtraArgs=extra_args, + Config=config) + else: + s3_client.upload_fileobj(Fileobj=readable, + Bucket=bucket, + Key=key, + ExtraArgs=extra_args, + Config=config) + + object_summary = s3_resource.ObjectSummary(bucket, key) + object_summary.wait_until_exists(**extra_args) + + +@contextmanager +def download_stream(s3_resource, bucket: str, key: str, checksum_to_verify: Optional[str] = None, + extra_args: Optional[dict] = None, encoding=None, errors=None): + """Context manager that gives out a download stream to download data.""" + bucket = s3_resource.Bucket(bucket) + + class DownloadPipe(ReadablePipe): + def writeTo(self, writable): + kwargs = dict(Key=key, Fileobj=writable, ExtraArgs=extra_args) + if not extra_args: + del kwargs['ExtraArgs'] + bucket.download_fileobj(**kwargs) + + try: + if checksum_to_verify: + with DownloadPipe(encoding=encoding, errors=errors) as readable: + # Interpose a pipe to check the hash + with HashingPipe(readable, encoding=encoding, errors=errors) as verified: + yield verified + else: + # Readable end of pipe produces text mode output if encoding specified + with DownloadPipe(encoding=encoding, errors=errors) as readable: + # No true checksum available, so don't hash + yield readable + except s3_resource.meta.client.exceptions.NoSuchKey: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + elif e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 400 and \ + e.response.get('Error', {}).get('Message') == 'Bad Request' and \ + e.operation_name == 'HeadObject': + # An error occurred (400) when calling the HeadObject operation: Bad Request + raise AWSBadEncryptionKeyError('Your AWS encryption key is most likely configured incorrectly ' + '(HeadObject operation: Bad Request).') + raise + + +def download_fileobject(s3_resource, bucket: Bucket, key: str, fileobj, extra_args: Optional[dict] = None): + try: + bucket.download_fileobj(Key=key, Fileobj=fileobj, ExtraArgs=extra_args) + except s3_resource.meta.client.exceptions.NoSuchKey: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + elif e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 400 and \ + e.response.get('Error', {}).get('Message') == 'Bad Request' and \ + e.operation_name == 'HeadObject': + # An error occurred (400) when calling the HeadObject operation: Bad Request + raise AWSBadEncryptionKeyError('Your AWS encryption key is most likely configured incorrectly ' + '(HeadObject operation: Bad Request).') + raise + + +def s3_key_exists(s3_resource, bucket: str, key: str, check: bool = False, extra_args: dict = None): + """Return True if the s3 obect exists, and False if not. Will error if encryption args are incorrect.""" + extra_args = extra_args or {} + s3_client = s3_resource.meta.client + try: + s3_client.head_object(Bucket=bucket, Key=key, **extra_args) + return True + except s3_client.exceptions.NoSuchKey: + if check: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + return False + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: + if check: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + return False + elif e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 400 and \ + e.response.get('Error', {}).get('Message') == 'Bad Request' and \ + e.operation_name == 'HeadObject': + # An error occurred (400) when calling the HeadObject operation: Bad Request + raise AWSBadEncryptionKeyError('Your AWS encryption key is most likely configured incorrectly ' + '(HeadObject operation: Bad Request).') + else: + raise + + +def head_s3_object(s3_resource, bucket: str, key: str, check=False, extra_args: dict = None): + s3_client = s3_resource.meta.client + extra_args = extra_args or {} + try: + return s3_client.head_object(Bucket=bucket, Key=key, **extra_args) + except s3_client.exceptions.NoSuchKey: + if check: + raise NoSuchFileException(f"File '{key}' not found in AWS jobstore bucket: '{bucket}'") + + +def get_s3_object(s3_resource, bucket: str, key: str, extra_args: dict = None): + if extra_args is None: + extra_args = dict() + s3_client = s3_resource.meta.client + return s3_client.get_object(Bucket=bucket, Key=key, **extra_args) + + +def put_s3_object(s3_resource, bucket: str, key: str, body: Optional[bytes], extra_args: dict = None): + if extra_args is None: + extra_args = dict() + s3_client = s3_resource.meta.client + return s3_client.put_object(Bucket=bucket, Key=key, Body=body, **extra_args) + + +def generate_presigned_url(s3_resource, bucket: str, key_name: str, expiration: int) -> Tuple[str, str]: + s3_client = s3_resource.meta.client + return s3_client.generate_presigned_url( + 'get_object', + Params={'Bucket': bucket, 'Key': key_name}, + ExpiresIn=expiration) + + +def create_public_url(s3_resource, bucket: str, key: str): + bucket_obj = Bucket(bucket) + bucket_obj.Object(key).Acl().put(ACL='public-read') # TODO: do we need to generate a signed url after doing this? + url = generate_presigned_url(s3_resource=s3_resource, + bucket=bucket, + key_name=key, + # One year should be sufficient to finish any pipeline ;-) + expiration=int(timedelta(days=365).total_seconds())) + # boto doesn't properly remove the x-amz-security-token parameter when + # query_auth is False when using an IAM role (see issue #2043). Including the + # x-amz-security-token parameter without the access key results in a 403, + # even if the resource is public, so we need to remove it. + # TODO: verify that this is still the case + return modify_url(url, remove=['x-amz-security-token', 'AWSAccessKeyId', 'Signature']) + + +def get_s3_bucket_region(s3_resource, bucket: str): + s3_client = s3_resource.meta.client + # AWS returns None for the default of 'us-east-1' + return s3_client.get_bucket_location(Bucket=bucket).get('LocationConstraint', None) or 'us-east-1' + + +def get_bucket_region(bucket_name: str, endpoint_url: Optional[str] = None) -> str: + """ + Get the AWS region name associated with the given S3 bucket. + + Takes an optional S3 API URL override. + """ + s3_client = cast(S3Client, session.client('s3', endpoint_url=endpoint_url)) + + for attempt in retry_s3(): + with attempt: + return s3_client.get_bucket_location(Bucket=bucket).get('LocationConstraint', None) or 'us-east-1' \ No newline at end of file diff --git a/src/toil/lib/aws/util.py b/src/toil/lib/aws/util.py new file mode 100644 index 0000000000..c512c07ba0 --- /dev/null +++ b/src/toil/lib/aws/util.py @@ -0,0 +1,153 @@ +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import os +import re + +from urllib.request import urlopen +from urllib.error import URLError + +from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union +from toil.lib.aws.ec2 import running_on_ec2 +from toil.lib.aws.ecs import running_on_ecs + +logger = logging.getLogger(__name__) + +CLOUD_KEY_REGEX = re.compile( + "^" + "(?P(?:s3|gs|wasb))" + "://" + "(?P[^/]+)" + "/" + "(?P.+)" + "$") +AWS_ZONE_REGEX = re.compile(r'^([a-z]{2}-[a-z]+-[1-9][0-9]*)([a-z])$') + + +def get_aws_zone_from_metadata() -> Optional[str]: + """ + Get the AWS zone from instance metadata, if on EC2 and the boto module is + available. Otherwise, gets the AWS zone from ECS task metadata, if on ECS. + """ + + # When running on ECS, we also appear to be running on EC2, but the EC2 + # metadata service doesn't seem to be contactable. So we check ECS first. + + if running_on_ecs(): + # Use the ECS metadata service + logger.debug("Fetch AZ from ECS metadata") + try: + resp = json.load(urlopen(os.environ['ECS_CONTAINER_METADATA_URI_V4'] + '/task', timeout=1)) + logger.debug("ECS metadata: %s", resp) + if isinstance(resp, dict): + # We found something. Go with that. + return resp.get('AvailabilityZone') + except (json.decoder.JSONDecodeError, KeyError, URLError) as e: + # We're on ECS but can't get the metadata. That's odd. + logger.warning("Skipping ECS metadata due to error: %s", e) + if running_on_ec2(): + # On EC2 alone, or on ECS but we couldn't get ahold of the ECS + # metadata. + try: + # Use the EC2 metadata service + import boto + from boto.utils import get_instance_metadata + logger.debug("Fetch AZ from EC2 metadata") + return get_instance_metadata()['placement']['availability-zone'] + except ImportError: + # This is expected to happen a lot + logger.debug("No boto to fetch ECS metadata") + except (KeyError, URLError) as e: + # We're on EC2 but can't get the metadata. That's odd. + logger.warning("Skipping EC2 metadata due to error: %s", e) + return None + + +def get_current_aws_region() -> Optional[str]: + """ + Return the AWS region that the currently configured AWS zone (see + get_current_aws_zone()) is in. + """ + # Try to derive it from the zone. + aws_zone = get_current_aws_zone() + return zone_to_region(aws_zone) if aws_zone else None + + +def get_aws_zone_from_environment() -> Optional[str]: + """ + Get the AWS zone from TOIL_AWS_ZONE if set. + """ + return os.environ.get('TOIL_AWS_ZONE', None) + + +def get_aws_zone_from_boto() -> Optional[str]: + """ + Get the AWS zone from the Boto config file, if it is configured and the + boto module is available. + """ + try: + import boto + zone = boto.config.get('Boto', 'ec2_region_name') + if zone is not None: + zone += 'a' # derive an availability zone in the region + return zone + except ImportError: + pass + return None + + +def get_aws_zone_from_environment_region() -> Optional[str]: + """ + Pick an AWS zone in the region defined by TOIL_AWS_REGION, if it is set. + """ + aws_region = os.environ.get('TOIL_AWS_REGION') + if aws_region is not None: + # If a region is specified, use the first zone in the region. + return aws_region + 'a' + # Otherwise, don't pick a region and let us fall back on the next method. + return None + + +def get_current_aws_zone() -> Optional[str]: + """ + Get the currently configured or occupied AWS zone to use. + + Reports the TOIL_AWS_ZONE environment variable if set. + + Otherwise, if we have boto and are running on EC2, or if we are on ECS, + reports the zone we are running in. + + Otherwise, if we have the TOIL_AWS_REGION variable set, chooses a zone in + that region. + + Finally, if we have boto2, and a default region is configured in Boto 2, + chooses a zone in that region. + + Returns None if no method can produce a zone to use. + """ + return get_aws_zone_from_environment() or \ + get_aws_zone_from_metadata() or \ + get_aws_zone_from_environment_region() or \ + get_aws_zone_from_boto() + + +def zone_to_region(zone: str) -> str: + """Get a region (e.g. us-west-2) from a zone (e.g. us-west-1c).""" + # re.compile() caches the regex internally so we don't have to + availability_zone_regex = re.compile(r'^([a-z]{2}-[a-z]+-[1-9][0-9]*)([a-z])$') + availability_zone = availability_zone_regex.match(zone) + if not availability_zone: + raise ValueError(f"Can't extract region from availability zone '{zone}'") + return availability_zone.group(1) diff --git a/src/toil/lib/aws/utils.py b/src/toil/lib/aws/utils.py index 459f0631a0..1cdf3f3e8e 100644 --- a/src/toil/lib/aws/utils.py +++ b/src/toil/lib/aws/utils.py @@ -20,14 +20,14 @@ from urllib.parse import ParseResult from toil.lib.aws import session -from toil.lib.misc import printq from toil.lib.retry import ( retry, old_retry, get_error_status, get_error_code, DEFAULT_DELAYS, - DEFAULT_TIMEOUT + DEFAULT_TIMEOUT, + ErrorCondition ) if sys.version_info >= (3, 8): @@ -69,9 +69,8 @@ ] @retry(errors=[BotoServerError]) -def delete_iam_role( - role_name: str, region: Optional[str] = None, quiet: bool = True -) -> None: +def delete_iam_role(role_name: str, region: Optional[str] = None, display_type='print') -> None: + display = print if display_type == 'print' else logger.debug from boto.iam.connection import IAMConnection # TODO: the Boto3 type hints are a bit oversealous here; they want hundreds # of overloads of the client-getting methods to exist based on the literal @@ -88,38 +87,36 @@ def delete_iam_role( role = iam_resource.Role(role_name) # normal policies for attached_policy in role.attached_policies.all(): - printq(f'Now dissociating policy: {attached_policy.policy_name} from role {role.name}', quiet) + display(f'Now dissociating policy: {attached_policy.policy_name} from role {role.name}') role.detach_policy(PolicyArn=attached_policy.arn) # inline policies for inline_policy in role.policies.all(): - printq(f'Deleting inline policy: {inline_policy.policy_name} from role {role.name}', quiet) + display(f'Deleting inline policy: {inline_policy.policy_name} from role {role.name}') # couldn't find an easy way to remove inline policies with boto3; use boto boto_iam_connection.delete_role_policy(role.name, inline_policy.policy_name) iam_client.delete_role(RoleName=role_name) - printq(f'Role {role_name} successfully deleted.', quiet) + display(f'Role {role_name} successfully deleted.') @retry(errors=[BotoServerError]) -def delete_iam_instance_profile( - instance_profile_name: str, region: Optional[str] = None, quiet: bool = True -) -> None: +def delete_iam_instance_profile(instance_profile_name: str, region: Optional[str] = None, display_type='print') -> None: + display = print if display_type == 'print' else logger.debug iam_resource = cast(IAMServiceResource, session.resource("iam", region_name=region)) instance_profile = iam_resource.InstanceProfile(instance_profile_name) if instance_profile.roles is not None: for role in instance_profile.roles: - printq(f'Now dissociating role: {role.name} from instance profile {instance_profile_name}', quiet) + display(f'Now dissociating role: {role.name} from instance profile {instance_profile_name}') instance_profile.remove_role(RoleName=role.name) instance_profile.delete() - printq(f'Instance profile "{instance_profile_name}" successfully deleted.', quiet) + display(f'Instance profile "{instance_profile_name}" successfully deleted.') @retry(errors=[BotoServerError]) -def delete_sdb_domain( - sdb_domain_name: str, region: Optional[str] = None, quiet: bool = True -) -> None: +def delete_sdb_domain(sdb_domain_name: str, region: Optional[str] = None, display_type='print') -> None: + display = print if display_type == 'print' else logger.debug sdb_client = cast(SimpleDBClient, session.client("sdb", region_name=region)) sdb_client.delete_domain(DomainName=sdb_domain_name) - printq(f'SBD Domain: "{sdb_domain_name}" successfully deleted.', quiet) + display(f'SBD Domain: "{sdb_domain_name}" successfully deleted.') def connection_reset(e: Exception) -> bool: @@ -146,22 +143,32 @@ def retryable_s3_errors(e: Exception) -> bool: or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') in (404, 429, 500, 502, 503, 504))) -def retry_s3(delays: Iterable[float] = DEFAULT_DELAYS, timeout: float = DEFAULT_TIMEOUT, predicate: Callable[[Exception], bool] = retryable_s3_errors) -> Iterator[ContextManager[None]]: +def retry_s3(delays: Iterable[float] = (0, 1, 1, 4, 16, 64), + timeout: float = 300, + predicate: Callable[[Exception], bool] = retryable_s3_errors) -> Iterator[ContextManager[None]]: """ Retry iterator of context managers specifically for S3 operations. """ return old_retry(delays=delays, timeout=timeout, predicate=predicate) -@retry(errors=[BotoServerError]) -def delete_s3_bucket( - s3_resource: "S3ServiceResource", - bucket: str, - quiet: bool = True -) -> None: - """ - Delete the given S3 bucket. - """ - printq(f'Deleting s3 bucket: {bucket}', quiet) + +@retry( + errors=[ + ErrorCondition( + error=BotoServerError, + error_codes=[404, 429, 500, 502, 503, 504]), + ErrorCondition( + error=ClientError, + error_message_must_include='BucketNotEmpty'), + ErrorCondition( + error=ClientError, + error_codes=[404, 429, 500, 502, 503, 504] + ), + ] +) +def delete_s3_bucket(s3_resource: "S3ServiceResource", bucket: str, display_type='log') -> None: + display = print if display_type == 'print' else logger.debug + display(f'Deleting s3 bucket: {bucket}') paginator = s3_resource.meta.client.get_paginator('list_object_versions') try: @@ -174,12 +181,12 @@ def delete_s3_bucket( to_delete: List[Dict[str, Any]] = cast(List[Dict[str, Any]], response.get('Versions', [])) + \ cast(List[Dict[str, Any]], response.get('DeleteMarkers', [])) for entry in to_delete: - printq(f" Deleting {entry['Key']} version {entry['VersionId']}", quiet) + display(f" Deleting {entry['Key']} version {entry['VersionId']}") s3_resource.meta.client.delete_object(Bucket=bucket, Key=entry['Key'], VersionId=entry['VersionId']) s3_resource.Bucket(bucket).delete() - printq(f'\n * Deleted s3 bucket successfully: {bucket}\n\n', quiet) + display(f'\n * Deleted s3 bucket successfully: {bucket}\n\n') except s3_resource.meta.client.exceptions.NoSuchBucket: - printq(f'\n * S3 bucket no longer exists: {bucket}\n\n', quiet) + display(f'\n * S3 bucket no longer exists: {bucket}\n\n') def create_s3_bucket( @@ -360,7 +367,7 @@ def list_objects_for_url(url: ParseResult) -> List[str]: key_name = key_name + '/' # Decide if we need to override Boto's built-in URL here. - # TODO: Decuplicate with get_object_for_url, or push down into session module + # TODO: Deduplicate with get_object_for_url, or push down into session module endpoint_url: Optional[str] = None host = os.environ.get('TOIL_S3_HOST', None) port = os.environ.get('TOIL_S3_PORT', None) diff --git a/src/toil/lib/checksum.py b/src/toil/lib/checksum.py new file mode 100644 index 0000000000..101a91d6da --- /dev/null +++ b/src/toil/lib/checksum.py @@ -0,0 +1,83 @@ +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import hashlib + +from io import BytesIO +from typing import BinaryIO, Union + +from toil.lib.aws.config import S3_PART_SIZE + +logger = logging.getLogger(__name__) + + +class ChecksumError(Exception): + """Raised when a download does not contain the correct data.""" + + +class Etag: + """A hasher for s3 etags.""" + def __init__(self, chunk_size): + self.etag_bytes = 0 + self.etag_parts = [] + self.etag_hasher = hashlib.md5() + self.chunk_size = chunk_size + + def update(self, chunk): + if self.etag_bytes + len(chunk) > self.chunk_size: + chunk_head = chunk[:self.chunk_size - self.etag_bytes] + chunk_tail = chunk[self.chunk_size - self.etag_bytes:] + self.etag_hasher.update(chunk_head) + self.etag_parts.append(self.etag_hasher.digest()) + self.etag_hasher = hashlib.md5() + self.etag_hasher.update(chunk_tail) + self.etag_bytes = len(chunk_tail) + else: + self.etag_hasher.update(chunk) + self.etag_bytes += len(chunk) + + def hexdigest(self): + if self.etag_bytes: + self.etag_parts.append(self.etag_hasher.digest()) + self.etag_bytes = 0 + if len(self.etag_parts) > 1: + etag = hashlib.md5(b"".join(self.etag_parts)).hexdigest() + return f'{etag}-{len(self.etag_parts)}' + else: + return self.etag_hasher.hexdigest() + + +hashers = {'sha1': hashlib.sha1(), + 'sha256': hashlib.sha256(), + 'etag': Etag(chunk_size=S3_PART_SIZE)} + + +def compute_checksum_for_file(local_file_path: str, algorithm: str = 'sha1') -> str: + with open(local_file_path, 'rb') as fh: + checksum_result = compute_checksum_for_content(fh, algorithm=algorithm) + return checksum_result + + +def compute_checksum_for_content(fh: Union[BinaryIO, BytesIO], algorithm: str = 'sha1') -> str: + """ + Note: Chunk size matters for s3 etags, and must be the same to get the same hash from the same object. + Therefore this buffer is not modifiable throughout Toil. + """ + hasher = hashers[algorithm] + contents = fh.read(S3_PART_SIZE) + while contents != b'': + hasher.update(contents) + contents = fh.read(S3_PART_SIZE) + + return f'{algorithm}${hasher.hexdigest()}' diff --git a/src/toil/lib/conversions.py b/src/toil/lib/conversions.py index 8ddf301de3..5f8d8f177c 100644 --- a/src/toil/lib/conversions.py +++ b/src/toil/lib/conversions.py @@ -2,9 +2,10 @@ Conversion utilities for mapping memory, disk, core declarations from strings to numbers and vice versa. Also contains general conversion functions """ - import math + from typing import SupportsInt, Tuple, Union +from toil.lib.units import KIB, MIB, GIB, TIB, PIB, EIB, KB, MB, GB, TB, PB, EB # See https://en.wikipedia.org/wiki/Binary_prefix BINARY_PREFIXES = ['ki', 'mi', 'gi', 'ti', 'pi', 'ei', 'kib', 'mib', 'gib', 'tib', 'pib', 'eib'] @@ -15,30 +16,30 @@ def bytes_in_unit(unit: str = 'B') -> int: num_bytes = 1 if unit.lower() in ['ki', 'kib']: - num_bytes = 1 << 10 + num_bytes = KIB if unit.lower() in ['mi', 'mib']: - num_bytes = 1 << 20 + num_bytes = MIB if unit.lower() in ['gi', 'gib']: - num_bytes = 1 << 30 + num_bytes = GIB if unit.lower() in ['ti', 'tib']: - num_bytes = 1 << 40 + num_bytes = TIB if unit.lower() in ['pi', 'pib']: - num_bytes = 1 << 50 + num_bytes = PIB if unit.lower() in ['ei', 'eib']: - num_bytes = 1 << 60 + num_bytes = EIB if unit.lower() in ['k', 'kb']: - num_bytes = 1000 + num_bytes = KB if unit.lower() in ['m', 'mb']: - num_bytes = 1000 ** 2 + num_bytes = MB if unit.lower() in ['g', 'gb']: - num_bytes = 1000 ** 3 + num_bytes = GB if unit.lower() in ['t', 'tb']: - num_bytes = 1000 ** 4 + num_bytes = TB if unit.lower() in ['p', 'pb']: - num_bytes = 1000 ** 5 + num_bytes = PB if unit.lower() in ['e', 'eb']: - num_bytes = 1000 ** 6 + num_bytes = EB return num_bytes @@ -82,7 +83,7 @@ def bytes2human(n: SupportsInt) -> str: elif n < 1: return '0 b' - power_level = math.floor(math.log(n, 1024)) + power_level = math.floor(math.log(n, KB)) units = ('b', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei') unit = units[power_level if power_level < len(units) else -1] diff --git a/src/toil/lib/misc.py b/src/toil/lib/misc.py index 879606f90d..d0fc66c511 100644 --- a/src/toil/lib/misc.py +++ b/src/toil/lib/misc.py @@ -8,11 +8,12 @@ import sys import time import typing +import urllib.parse +import pytz + from contextlib import closing from typing import Iterator, List, Optional, Union -import pytz - logger = logging.getLogger(__name__) @@ -79,10 +80,6 @@ def slow_down(seconds: float) -> float: return max(seconds, sys.float_info.epsilon) -def printq(msg: str, quiet: bool) -> None: - if not quiet: - print(msg) - def truncExpBackoff() -> Iterator[float]: # as recommended here https://forums.aws.amazon.com/thread.jspa?messageID=406788#406788 @@ -135,3 +132,17 @@ def call_command(cmd: List[str], *args: str, input: Optional[str] = None, timeou raise CalledProcessErrorStderr(proc.returncode, cmd, output=stdout, stderr=stderr) logger.debug("command succeeded: {}: {}".format(" ".join(cmd), stdout.rstrip())) return stdout + + +def modify_url(url: str, remove: List[str]) -> str: + """ + Given a valid URL string, split out the params, remove any offending + params in 'remove', and return the cleaned URL. + """ + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + params = urllib.parse.parse_qs(query) + for param_key in remove: + if param_key in params: + del params[param_key] + query = urllib.parse.urlencode(params, doseq=True) + return urllib.parse.urlunsplit((scheme, netloc, path, query, fragment)) diff --git a/src/toil/lib/pipes.py b/src/toil/lib/pipes.py new file mode 100644 index 0000000000..68f9144328 --- /dev/null +++ b/src/toil/lib/pipes.py @@ -0,0 +1,358 @@ +import errno +import logging +import os +import hashlib +from abc import ABC, abstractmethod + +from toil.lib.checksum import ChecksumError +from toil.lib.threading import ExceptionalThread + +log = logging.getLogger(__name__) + + +class WritablePipe(ABC): + """ + An object-oriented wrapper for os.pipe. Clients should subclass it, implement + :meth:`.readFrom` to consume the readable end of the pipe, then instantiate the class as a + context manager to get the writable end. See the example below. + + >>> import sys, shutil + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... shutil.copyfileobj(codecs.getreader('utf-8')(readable), sys.stdout) + >>> with MyPipe() as writable: + ... _ = writable.write('Hello, world!\\n'.encode('utf-8')) + Hello, world! + + Each instance of this class creates a thread and invokes the readFrom method in that thread. + The thread will be join()ed upon normal exit from the context manager, i.e. the body of the + `with` statement. If an exception occurs, the thread will not be joined but a well-behaved + :meth:`.readFrom` implementation will terminate shortly thereafter due to the pipe having + been closed. + + Now, exceptions in the reader thread will be reraised in the main thread: + + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as writable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + + More complicated, less illustrative tests: + + Same as above, but proving that handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as writable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + + Exceptions in the body of the with statement aren't masked, and handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... pass + >>> with MyPipe() as writable: + ... raise RuntimeError('Hello, world!') + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + """ + + @abstractmethod + def readFrom(self, readable): + """ + Implement this method to read data from the pipe. This method should support both + binary and text mode output. + + :param file readable: the file object representing the readable end of the pipe. Do not + explicitly invoke the close() method of the object, that will be done automatically. + """ + raise NotImplementedError() + + def _reader(self): + with os.fdopen(self.readable_fh, 'rb') as readable: + # TODO: If the reader somehow crashes here, both threads might try + # to close readable_fh. Fortunately we don't do anything that + # should be able to fail here. + self.readable_fh = None # signal to parent thread that we've taken over + self.readFrom(readable) + self.reader_done = True + + def __init__(self, encoding=None, errors=None): + """ + The specified encoding and errors apply to the writable end of the pipe. + + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(WritablePipe, self).__init__() + self.encoding = encoding + self.errors = errors + self.readable_fh = None + self.writable = None + self.thread = None + self.reader_done = False + + def __enter__(self): + self.readable_fh, writable_fh = os.pipe() + self.writable = os.fdopen(writable_fh, 'wb' if self.encoding == None else 'wt', encoding=self.encoding, errors=self.errors) + self.thread = ExceptionalThread(target=self._reader) + self.thread.start() + return self.writable + + def __exit__(self, exc_type, exc_val, exc_tb): + # Closeing the writable end will send EOF to the readable and cause the reader thread + # to finish. + # TODO: Can close() fail? If so, would we try and clean up after the reader? + self.writable.close() + try: + if self.thread is not None: + # reraises any exception that was raised in the thread + self.thread.join() + except Exception as e: + if exc_type is None: + # Only raise the child exception if there wasn't + # already an exception in the main thread + raise + else: + log.error('Swallowing additional exception in reader thread: %s', str(e)) + finally: + # The responsibility for closing the readable end is generally that of the reader + # thread. To cover the small window before the reader takes over we also close it here. + readable_fh = self.readable_fh + if readable_fh is not None: + # Close the file handle. The reader thread must be dead now. + os.close(readable_fh) + + +class ReadablePipe(ABC): + """ + An object-oriented wrapper for os.pipe. Clients should subclass it, implement + :meth:`.writeTo` to place data into the writable end of the pipe, then instantiate the class + as a context manager to get the writable end. See the example below. + + >>> import sys, shutil + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... writable.write('Hello, world!\\n'.encode('utf-8')) + >>> with MyPipe() as readable: + ... shutil.copyfileobj(codecs.getreader('utf-8')(readable), sys.stdout) + Hello, world! + + Each instance of this class creates a thread and invokes the :meth:`.writeTo` method in that + thread. The thread will be join()ed upon normal exit from the context manager, i.e. the body + of the `with` statement. If an exception occurs, the thread will not be joined but a + well-behaved :meth:`.writeTo` implementation will terminate shortly thereafter due to the + pipe having been closed. + + Now, exceptions in the reader thread will be reraised in the main thread: + + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as readable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + + More complicated, less illustrative tests: + + Same as above, but proving that handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as readable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + + Exceptions in the body of the with statement aren't masked, and handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... pass + >>> with MyPipe() as readable: + ... raise RuntimeError('Hello, world!') + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + """ + + @abstractmethod + def writeTo(self, writable): + """ + Implement this method to write data from the pipe. This method should support both + binary and text mode input. + + :param file writable: the file object representing the writable end of the pipe. Do not + explicitly invoke the close() method of the object, that will be done automatically. + """ + raise NotImplementedError() + + def _writer(self): + try: + with os.fdopen(self.writable_fh, 'wb') as writable: + self.writeTo(writable) + except IOError as e: + # The other side of the pipe may have been closed by the + # reading thread, which is OK. + if e.errno != errno.EPIPE: + raise + + def __init__(self, encoding=None, errors=None): + """ + The specified encoding and errors apply to the readable end of the pipe. + + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(ReadablePipe, self).__init__() + self.encoding = encoding + self.errors = errors + self.writable_fh = None + self.readable = None + self.thread = None + + def __enter__(self): + readable_fh, self.writable_fh = os.pipe() + self.readable = os.fdopen(readable_fh, 'rb' if self.encoding == None else 'rt', encoding=self.encoding, errors=self.errors) + self.thread = ExceptionalThread(target=self._writer) + self.thread.start() + return self.readable + + def __exit__(self, exc_type, exc_val, exc_tb): + # Close the read end of the pipe. The writing thread may + # still be writing to the other end, but this will wake it up + # if that's the case. + self.readable.close() + try: + if self.thread is not None: + # reraises any exception that was raised in the thread + self.thread.join() + except: + if exc_type is None: + # Only raise the child exception if there wasn't + # already an exception in the main thread + raise + + +class ReadableTransformingPipe(ReadablePipe): + """ + A pipe which is constructed around a readable stream, and which provides a + context manager that gives a readable stream. + + Useful as a base class for pipes which have to transform or otherwise visit + bytes that flow through them, instead of just consuming or producing data. + + Clients should subclass it and implement :meth:`.transform`, like so: + + >>> import sys, shutil + >>> class MyPipe(ReadableTransformingPipe): + ... def transform(self, readable, writable): + ... writable.write(readable.read().decode('utf-8').upper().encode('utf-8')) + >>> class SourcePipe(ReadablePipe): + ... def writeTo(self, writable): + ... writable.write('Hello, world!\\n'.encode('utf-8')) + >>> with SourcePipe() as source: + ... with MyPipe(source) as transformed: + ... shutil.copyfileobj(codecs.getreader('utf-8')(transformed), sys.stdout) + HELLO, WORLD! + + The :meth:`.transform` method runs in its own thread, and should move data + chunk by chunk instead of all at once. It should finish normally if it + encounters either an EOF on the readable, or a :class:`BrokenPipeError` on + the writable. This means tat it should make sure to actually catch a + :class:`BrokenPipeError` when writing. + + See also: :class:`toil.lib.misc.WriteWatchingStream`. + + """ + def __init__(self, source, encoding=None, errors=None): + """ + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(ReadableTransformingPipe, self).__init__(encoding=encoding, errors=errors) + self.source = source + + @abstractmethod + def transform(self, readable, writable): + """ + Implement this method to ship data through the pipe. + + :param file readable: the input stream file object to transform. + + :param file writable: the file object representing the writable end of the pipe. Do not + explicitly invoke the close() method of the object, that will be done automatically. + """ + raise NotImplementedError() + + def writeTo(self, writable): + self.transform(self.source, writable) + + +class HashingPipe(ReadableTransformingPipe): + """ + Class which checksums all the data read through it. If it + reaches EOF and the checksum isn't correct, raises ChecksumError. + + Assumes info actually has a checksum. + """ + def __init__(self, source, encoding=None, errors=None, checksum_to_verify=None): + """ + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(HashingPipe, self).__init__(source=source, encoding=encoding, errors=errors) + self.checksum_to_verify = checksum_to_verify + + def transform(self, readable, writable): + hash_object = hashlib.sha1() + contents = readable.read(1024 * 1024) + while contents != b'': + hash_object.update(contents) + try: + writable.write(contents) + except BrokenPipeError: + # Read was stopped early by user code. + # Can't check the checksum. + return + contents = readable.read(1024 * 1024) + final_computed_checksum = f'sha1${hash_object.hexdigest()}' + if not self.checksum_to_verify == final_computed_checksum: + raise ChecksumError(f'Checksum mismatch. Expected: {self.checksum_to_verify} Actual: {final_computed_checksum}') diff --git a/src/toil/lib/retry.py b/src/toil/lib/retry.py index f225714c48..299cf4c3d5 100644 --- a/src/toil/lib/retry.py +++ b/src/toil/lib/retry.py @@ -165,6 +165,12 @@ def boto_bucket(bucket_name): except ModuleNotFoundError: botocore = None +try: + import boto.exception + SUPPORTED_HTTP_ERRORS.append(boto.exception.BotoServerError) +except ModuleNotFoundError: + boto = None + logger = logging.getLogger(__name__) @@ -180,7 +186,7 @@ def __init__(self, error: Optional[Any] = None, error_codes: List[int] = None, boto_error_codes: List[str] = None, - error_message_must_include: str = None, + error_message_must_include: List[str] = None, retry_on_this_condition: bool = True): """ Initialize this ErrorCondition. @@ -193,7 +199,7 @@ def __init__(self, (e.g. "BucketNotFound", "ClientError", etc.) that are specific to Boto 3 and must match to be retried (optional; defaults to not checking). - :param error_message_must_include: A string that must be in the error message + :param error_message_must_include: A list of strings, one of which must be in the error message to be retried (optional; defaults to not checking) :param retry_on_this_condition: This can be set to False to always error on this condition. @@ -209,7 +215,7 @@ def __init__(self, self.error = error self.error_codes = error_codes self.boto_error_codes = boto_error_codes - self.error_message_must_include = error_message_must_include + self.error_message_must_include = error_message_must_include or [] self.retry_on_this_condition = retry_on_this_condition if self.error_codes: @@ -477,7 +483,11 @@ def error_meets_conditions(e, error_conditions): for error in error_conditions: if isinstance(e, error.error): if error.error_codes or error.boto_error_codes or error.error_message_must_include: - error_message_condition_met = meets_error_message_condition(e, error.error_message_must_include) + error_message_condition_met = False + for error_message_keyword in error.error_message_must_include: + error_message_condition_met = meets_error_message_condition(e, error_message_keyword) + if error_message_condition_met: + break error_code_condition_met = meets_error_code_condition(e, error.error_codes) boto_error_code_condition_met = meets_boto_error_code_condition(e, error.boto_error_codes) if error_message_condition_met and error_code_condition_met and boto_error_code_condition_met: diff --git a/src/toil/lib/units.py b/src/toil/lib/units.py new file mode 100644 index 0000000000..de1143d75f --- /dev/null +++ b/src/toil/lib/units.py @@ -0,0 +1,15 @@ +B = 1 + +KI = KIB = 1 << 10 +MI = MIB = 1 << 20 +GI = GIB = 1 << 30 +TI = TIB = 1 << 40 +PI = PIB = 1 << 50 +EI = EIB = 1 << 60 + +K = KB = 1000 +M = MB = 1000 ** 2 +G = GB = 1000 ** 3 +T = TB = 1000 ** 4 +P = PB = 1000 ** 5 +E = EB = 1000 ** 6 diff --git a/src/toil/provisioners/__init__.py b/src/toil/provisioners/__init__.py index 6162da4725..13a06015ca 100644 --- a/src/toil/provisioners/__init__.py +++ b/src/toil/provisioners/__init__.py @@ -148,7 +148,7 @@ def check_valid_node_types(provisioner, node_types: List[Tuple[Set[str], Optiona # check if a valid node type for aws from toil.lib.generatedEC2Lists import E2Instances, regionDict if provisioner == 'aws': - from toil.lib.aws import get_current_aws_region + from toil.lib.aws.util import get_current_aws_region current_region = get_current_aws_region() or 'us-west-2' # check if instance type exists in this region for node_type in node_types: diff --git a/src/toil/provisioners/aws/__init__.py b/src/toil/provisioners/aws/__init__.py index 3074c122b0..ed30bd6a68 100644 --- a/src/toil/provisioners/aws/__init__.py +++ b/src/toil/provisioners/aws/__init__.py @@ -13,24 +13,23 @@ # limitations under the License. import datetime import logging -import os + from collections import namedtuple from operator import attrgetter from statistics import mean, stdev -from typing import List, Optional, Any +from typing import List, Optional -from toil.lib.aws import (get_aws_zone_from_environment, - get_aws_zone_from_metadata, - get_aws_zone_from_environment_region, - get_aws_zone_from_boto, - running_on_ec2, - zone_to_region) +from toil.lib.aws.util import (get_aws_zone_from_environment, + get_aws_zone_from_metadata, + get_aws_zone_from_environment_region, + get_aws_zone_from_boto) logger = logging.getLogger(__name__) ZoneTuple = namedtuple('ZoneTuple', ['name', 'price_deviation']) + def get_aws_zone_from_spot_market(spotBid: Optional[float], nodeType: Optional[str], boto2_ec2: Optional["boto.connection.AWSAuthConnection"], zone_options: Optional[List[str]]) -> Optional[str]: """ diff --git a/src/toil/provisioners/aws/awsProvisioner.py b/src/toil/provisioners/aws/awsProvisioner.py index b3775d6a80..9cf7b988f1 100644 --- a/src/toil/provisioners/aws/awsProvisioner.py +++ b/src/toil/provisioners/aws/awsProvisioner.py @@ -35,7 +35,7 @@ from boto.utils import get_instance_metadata from botocore.exceptions import ClientError -from toil.lib.aws import zone_to_region +from toil.lib.aws.util import zone_to_region from toil.lib.aws.ami import get_flatcar_ami from toil.lib.aws.utils import create_s3_bucket from toil.lib.aws.session import AWSConnectionManager diff --git a/src/toil/server/app.py b/src/toil/server/app.py index e382b772b7..0f0bab28ab 100644 --- a/src/toil/server/app.py +++ b/src/toil/server/app.py @@ -23,7 +23,10 @@ from toil.server.wes.toil_backend import ToilBackend from toil.server.wsgi_app import run_app from toil.version import version -from toil.lib.aws import running_on_ec2, running_on_ecs, get_current_aws_region + +from toil.lib.aws.ec2 import running_on_ec2 +from toil.lib.aws.ecs import running_on_ecs +from toil.lib.aws.util import get_current_aws_region logger = logging.getLogger(__name__) diff --git a/src/toil/server/utils.py b/src/toil/server/utils.py index 7e2d4b5a46..4a64d6f23c 100644 --- a/src/toil/server/utils.py +++ b/src/toil/server/utils.py @@ -25,7 +25,7 @@ from toil.lib.io import AtomicFileCreate try: - from toil.lib.aws import get_current_aws_region + from toil.lib.aws.util import get_current_aws_region from toil.lib.aws.session import client from toil.lib.aws.utils import retry_s3 HAVE_S3 = True diff --git a/src/toil/test/__init__.py b/src/toil/test/__init__.py index 92684af7d7..9b861454bc 100644 --- a/src/toil/test/__init__.py +++ b/src/toil/test/__init__.py @@ -58,7 +58,7 @@ from toil.lib.iterables import concat from toil.lib.memoize import memoize from toil.lib.threading import ExceptionalThread, cpu_count -from toil.lib.aws import running_on_ec2 +from toil.lib.aws.ec2 import running_on_ec2 from toil.version import distVersion logger = logging.getLogger(__name__) @@ -396,7 +396,7 @@ def needs_aws_batch(test_item: MT) -> MT: test_item = needs_env_var('TOIL_AWS_BATCH_QUEUE', 'an AWS Batch queue name or ARN')(test_item) test_item = needs_env_var('TOIL_AWS_BATCH_JOB_ROLE_ARN', 'an IAM role ARN that grants S3 and SDB access')(test_item) try: - from toil.lib.aws import get_current_aws_region + from toil.lib.aws.util import get_current_aws_region if get_current_aws_region() is None: # We don't know a region so we need one set. # TODO: It always won't be set if we get here. diff --git a/src/toil/test/cwl/cwlTest.py b/src/toil/test/cwl/cwlTest.py index b78877d0a9..fbd928a552 100644 --- a/src/toil/test/cwl/cwlTest.py +++ b/src/toil/test/cwl/cwlTest.py @@ -41,7 +41,8 @@ from toil.fileStores import FileID from toil.fileStores.abstractFileStore import AbstractFileStore from toil.lib.threading import cpu_count -from toil.lib.aws import zone_to_region +from toil.lib.aws.util import zone_to_region +from toil.lib.retry import ErrorCondition from toil.provisioners import cluster_factory from toil.provisioners.aws import get_best_aws_zone from toil.test.provisioners.aws.awsProvisionerTest import AbstractAWSAutoscaleTest diff --git a/src/toil/test/jobStores/jobStoreTest.py b/src/toil/test/jobStores/jobStoreTest.py index 7497fc679d..61ae63ca25 100644 --- a/src/toil/test/jobStores/jobStoreTest.py +++ b/src/toil/test/jobStores/jobStoreTest.py @@ -41,6 +41,7 @@ NoSuchJobException) from toil.jobStores.fileJobStore import FileJobStore from toil.lib.aws.utils import create_s3_bucket, get_object_for_url +from toil.lib.retry import ErrorCondition from toil.lib.memoize import memoize from toil.statsAndLogging import StatsAndLogging from toil.test import (ToilTest, diff --git a/src/toil/test/lib/aws/test_iam.py b/src/toil/test/lib/aws/test_iam.py index 5883edc1ed..8fadf16772 100644 --- a/src/toil/test/lib/aws/test_iam.py +++ b/src/toil/test/lib/aws/test_iam.py @@ -12,15 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os -import uuid -from typing import Optional -import pytest - -from toil.jobStores.aws.jobStore import AWSJobStore -from toil.lib.aws.utils import create_s3_bucket -from toil.lib.ec2 import establish_boto3_session -from toil.test import ToilTest, needs_aws_s3 + +from toil.test import ToilTest from toil.lib.aws import iam logger = logging.getLogger(__name__) diff --git a/src/toil/test/provisioners/clusterTest.py b/src/toil/test/provisioners/clusterTest.py index d7327df21c..c610cdd77f 100644 --- a/src/toil/test/provisioners/clusterTest.py +++ b/src/toil/test/provisioners/clusterTest.py @@ -20,7 +20,7 @@ from uuid import uuid4 -from toil.lib.aws import zone_to_region +from toil.lib.aws.util import zone_to_region from toil.provisioners.aws import get_best_aws_zone from toil.test import ToilTest, needs_aws_ec2, needs_fetchable_appliance diff --git a/src/toil/test/server/serverTest.py b/src/toil/test/server/serverTest.py index 9f0fa8aa89..368940f7e5 100644 --- a/src/toil/test/server/serverTest.py +++ b/src/toil/test/server/serverTest.py @@ -203,11 +203,12 @@ def setUpClass(cls) -> None: """ super().setUpClass() - from toil.lib.aws import get_current_aws_region, session + from toil.lib.aws.util import get_current_aws_region + from toil.lib.aws.session import establish_boto3_session from toil.lib.aws.utils import create_s3_bucket cls.region = get_current_aws_region() - cls.s3_resource = session.resource("s3", region_name=cls.region) + cls.s3_resource = establish_boto3_session(region_name=cls.region).resource("s3", region_name=cls.region) cls.bucket_name = f"toil-test-{uuid.uuid4()}" cls.bucket = create_s3_bucket(cls.s3_resource, cls.bucket_name, cls.region) @@ -217,7 +218,7 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: from toil.lib.aws.utils import delete_s3_bucket if cls.bucket_name: - delete_s3_bucket(cls.s3_resource, cls.bucket_name, cls.region) + delete_s3_bucket(cls.s3_resource, cls.bucket_name) super().tearDownClass() class AWSStateStoreTest(hidden.AbstractStateStoreTest, BucketUsingTest):