diff --git a/eksrollup/cli.py b/eksrollup/cli.py index 0a599b5..77ff55d 100755 --- a/eksrollup/cli.py +++ b/eksrollup/cli.py @@ -74,8 +74,8 @@ def scale_up_asg(cluster_name, asg, count): if desired_capacity == asg_old_desired_capacity: logger.info(f'Desired and current capacity for {asg_name} are equal. Skipping ASG.') - if asg_tag_desired_capacity.get('Value'): - logger.info('Found capacity tags on ASG from previous run. Leaving alone.') + if asg_tag_desired_capacity.get('Value') and asg_tag_orig_capacity.get('Value') and asg_tag_orig_max_capacity.get('Value'): + logger.info(f'Found capacity tags on ASG {asg_name} from previous run. Leaving alone.') return int(asg_tag_desired_capacity.get('Value')), int(asg_tag_orig_capacity.get( 'Value')), int(asg_tag_orig_max_capacity.get('Value')) else: @@ -164,7 +164,7 @@ def update_asgs(asgs, cluster_name): f'Setting the scale of ASG {asg_name} based on {outdated_instance_count} outdated instances.') asg_state_dict[asg_name] = scale_up_asg(cluster_name, asg, outdated_instance_count) - k8s_nodes = get_k8s_nodes() + k8s_nodes, k8s_excluded_nodes = get_k8s_nodes() if (run_mode == 2) or (run_mode == 3): for asg_name, asg_tuple in asg_outdated_instance_dict.items(): outdated_instances, asg = asg_tuple @@ -203,9 +203,14 @@ def update_asgs(asgs, cluster_name): else: taint_node(node_name) except Exception as exception: - logger.error(f"Encountered an error when adding taint/cordoning node {node_name}") - logger.error(exception) - exit(1) + try: + node_name = get_node_by_instance_id(k8s_excluded_nodes, outdated['InstanceId']) + logger.info(f"Node {node_name} was excluded") + continue + except Exception as exception: + logger.error(f"Encountered an error when adding taint/cordoning node {node_name}") + logger.error(exception) + exit(1) if len(outdated_instances) != 0: # if ASG termination is ignored then suspend 'Launch' and 'ReplaceUnhealthy' @@ -235,8 +240,12 @@ def update_asgs(asgs, cluster_name): logger.info(f'Waiting for {between_nodes_wait} seconds before continuing...') time.sleep(between_nodes_wait) except Exception as drain_exception: - logger.info(drain_exception) - raise RollingUpdateException("Rolling update on ASG failed", asg_name) + try: + node_name = get_node_by_instance_id(k8s_excluded_nodes, outdated['InstanceId']) + logger.info(f"Node {node_name} was excluded") + continue + except: + raise RollingUpdateException("Rolling update on ASG failed", asg_name) # scaling cluster back down logger.info("Scaling asg back down to original state") diff --git a/eksrollup/lib/aws.py b/eksrollup/lib/aws.py index 3733a54..297ef2b 100644 --- a/eksrollup/lib/aws.py +++ b/eksrollup/lib/aws.py @@ -427,8 +427,7 @@ def count_all_cluster_instances(cluster_name, predictive=False, exclude_node_lab """ # Get the K8s nodes on the cluster, while excluding nodes with certain label keys - k8s_nodes = get_k8s_nodes(exclude_node_label_keys) - + k8s_nodes, excluded_k8s_nodes = get_k8s_nodes(exclude_node_label_keys) count = 0 asgs = get_all_asgs(cluster_name) for asg in asgs: diff --git a/eksrollup/lib/k8s.py b/eksrollup/lib/k8s.py index fc22a1f..c32f39d 100644 --- a/eksrollup/lib/k8s.py +++ b/eksrollup/lib/k8s.py @@ -33,21 +33,25 @@ def ensure_config_loaded(): def get_k8s_nodes(exclude_node_label_keys=app_config["EXCLUDE_NODE_LABEL_KEYS"]): """ - Returns a list of kubernetes nodes + Returns a tuple of kubernetes nodes (nodes, excluded) """ ensure_config_loaded() k8s_api = client.CoreV1Api() logger.info("Getting k8s nodes...") response = k8s_api.list_node() + nodes = [] + excluded_nodes = [] if exclude_node_label_keys is not None: - nodes = [] for node in response.items: if all(key not in node.metadata.labels for key in exclude_node_label_keys): nodes.append(node) - response.items = nodes - logger.info("Current k8s node count is {}".format(len(response.items))) - return response.items + else: + excluded_nodes.append(node) + else: + nodes=response.items + logger.info("Current total k8s node count is %d (included: %d, excluded: %d)", len(nodes)+len(excluded_nodes), len(nodes), len(excluded_nodes)) + return nodes, excluded_nodes def get_node_by_instance_id(k8s_nodes, instance_id): @@ -62,8 +66,8 @@ def get_node_by_instance_id(k8s_nodes, instance_id): logger.info('InstanceId {} is node {} in kubernetes land'.format(instance_id, k8s_node.metadata.name)) node_name = k8s_node.metadata.name if not node_name: - logger.info("Could not find a k8s node name for that instance id. Exiting") - raise Exception("Could not find a k8s node name for that instance id. Exiting") + logger.info(f"Could not find a k8s node name for instance id {instance_id}") + raise Exception(f"Could not find a k8s node name for instance id {instance_id}") return node_name @@ -205,7 +209,7 @@ def k8s_nodes_ready(max_retry=app_config['GLOBAL_MAX_RETRY'], wait=app_config['G # reset healthy nodes after every loop healthy_nodes = True retry_count += 1 - nodes = get_k8s_nodes() + nodes, excluded_nodes = get_k8s_nodes() for node in nodes: conditions = node.status.conditions for condition in conditions: @@ -236,9 +240,9 @@ def k8s_nodes_count(desired_node_count, max_retry=app_config['GLOBAL_MAX_RETRY'] while retry_count < max_retry: nodes_online = True retry_count += 1 - nodes = get_k8s_nodes() + nodes, excluded_nodes = get_k8s_nodes() logger.info('Current k8s node count is {}'.format(len(nodes))) - if len(nodes) != desired_node_count: + if len(nodes) < desired_node_count: nodes_online = False logger.info('Waiting for k8s nodes to reach count {}...'.format(desired_node_count)) time.sleep(wait)