Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the use of user-provided provisioner (resolves #2806). #2807

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/toil/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,11 @@ def _addOptions(addGroupFn, config):
"in an autoscaled cluster, as well as parameters to control the "
"level of provisioning.")

addOptionFn("--provisioner", dest="provisioner", choices=['aws', 'azure', 'gce'],
help="The provisioner for cluster auto-scaling. The currently supported choices are"
"'azure', 'gce', or 'aws'. The default is %s." % config.provisioner)
addOptionFn("--provisioner", dest="provisioner",
help="The provisioner for cluster auto-scaling. The currently logical name "
"supported choices are 'azure', 'gce', or 'aws'. You can use your own "
"implementation ex: --provisioner my.own.implementation."
"provisioner.MyProvisioner. The default is %s." % config.provisioner)

addOptionFn('--nodeTypes', default=None,
help="List of node types separated by commas. The syntax for each node type "
Expand Down
56 changes: 32 additions & 24 deletions src/toil/provisioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,48 @@
# limitations under the License.
from __future__ import absolute_import

import importlib
import logging
logger = logging.getLogger(__name__)


def clusterFactory(provisioner, clusterName=None, zone=None, nodeStorage=50, sseKey=None):
def getClassFromFQN(cls):
modules, className = cls.rsplit('.', 1)
module = importlib.import_module(modules)
return getattr(module, className)


def clusterFactory(provisioner, clusterName=None, zone=None, nodeStorage=50, **kwargs):
"""
:param clusterName: The name of the cluster.
:param provisioner: The cloud type of the cluster.
:param zone: The cloud zone
:return: A cluster object for the the cloud type.
"""
if provisioner == 'aws':
try:
from toil.provisioners.aws.awsProvisioner import AWSProvisioner
except ImportError:
logger.error('The aws extra must be installed to use this provisioner')
raise
return AWSProvisioner(clusterName, zone, nodeStorage, sseKey)
elif provisioner == 'gce':
try:
from toil.provisioners.gceProvisioner import GCEProvisioner
except ImportError:
logger.error('The google extra must be installed to use this provisioner')
raise
return GCEProvisioner(clusterName, zone, nodeStorage, sseKey)
elif provisioner == 'azure':
try:
from toil.provisioners.azure.azureProvisioner import AzureProvisioner
except ImportError:
logger.error('The azure extra must be installed to use this provisioner')
raise
return AzureProvisioner(clusterName, zone, nodeStorage)
else:
raise RuntimeError("Invalid provisioner '%s'" % provisioner)
known_provisionners = {'aws': ('toil.provisioners.aws.awsProvisioner.AWSProvisioner',
'The aws extra must be installed to use this provisioner'),
'gce': ('toil.provisioners.gceProvisioner.GCEProvisioner',
'The google extra must be installed to use this provisioner'),
'azure': ('toil.provisioners.azure.azureProvisioner.AzureProvisioner',
'The azure extra must be installed to use this provisioner')}

_provisioner = provisioner
_err_msg = "Invalid provisioner '%s'" % provisioner

known_provisionner = known_provisionners.get(provisioner)
if known_provisionner:
_provisioner = known_provisionner[0]
_err_msg = known_provisionner[1]

try:
clazz = getClassFromFQN(_provisioner)
return clazz(clusterName=clusterName, zone=zone, nodeStorage=nodeStorage, **kwargs)
except ImportError:
logger.error(_err_msg)
raise
except (ValueError, AttributeError):
raise RuntimeError(_err_msg)


class NoSuchClusterException(Exception):
"""Indicates that the specified cluster does not exist."""
Expand Down
3 changes: 2 additions & 1 deletion src/toil/provisioners/abstractProvisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ class AbstractProvisioner(with_metaclass(ABCMeta, object)):
"""
LEADER_HOME_DIR = '/root/' # home directory in the Toil appliance on an instance

def __init__(self, clusterName=None, zone=None, nodeStorage=50):
def __init__(self, clusterName=None, zone=None, nodeStorage=50, **kwargs):
"""
Initialize provisioner.

:param clusterName: The cluster identifier.
:param zone: The zone the cluster runs in.
:param nodeStorage: The amount of storage on the worker instances, in gigabytes.
:param kwargs: Optional additional parameters to fit specific provisioner needs.
"""
self.clusterName = clusterName
self._zone = zone
Expand Down
2 changes: 1 addition & 1 deletion src/toil/provisioners/ansibleDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AnsibleDriver(AbstractProvisioner):
"""
Wrapper class for Ansible calls.
"""
def __init__(self, playbooks, clusterName, zone, nodeStorage):
def __init__(self, playbooks, clusterName, zone, nodeStorage, **kwargs):
self.playbooks = playbooks
super(AnsibleDriver, self).__init__(clusterName, zone, nodeStorage)

Expand Down
4 changes: 2 additions & 2 deletions src/toil/provisioners/aws/awsProvisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ class AWSProvisioner(AbstractProvisioner):
Implements an AWS provisioner using the boto libraries.
"""

def __init__(self, clusterName, zone, nodeStorage, sseKey):
def __init__(self, clusterName, zone, nodeStorage, **kwargs):
super(AWSProvisioner, self).__init__(clusterName, zone, nodeStorage)
self.cloud = 'aws'
self._sseKey = sseKey
self._sseKey = kwargs.get('sseKey')
if not zone:
self._zone = getCurrentAWSZone()
if clusterName:
Expand Down
2 changes: 1 addition & 1 deletion src/toil/provisioners/azure/azureProvisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AzureProvisioner(AnsibleDriver):

"""

def __init__(self, clusterName, zone, nodeStorage):
def __init__(self, clusterName, zone, nodeStorage, **kwargs):
self.cloud = 'azure'

self.playbook = {
Expand Down
4 changes: 2 additions & 2 deletions src/toil/provisioners/gceProvisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class GCEProvisioner(AbstractProvisioner):
SOURCE_IMAGE = (b'https://www.googleapis.com/compute/v1/projects/coreos-cloud/global/'
b'images/coreos-stable-1576-4-0-v20171206')

def __init__(self, clusterName, zone, nodeStorage, sseKey):
def __init__(self, clusterName, zone, nodeStorage, **kwargs):
super(GCEProvisioner, self).__init__(clusterName, zone, nodeStorage)
self.cloud = 'gce'
self._sseKey = sseKey
self._sseKey = kwargs.get('sseKey')

# If the clusterName is not given, then Toil must be running on the leader
# and should read the settings from the instance meta-data.
Expand Down
82 changes: 82 additions & 0 deletions src/toil/test/provisioners/provisionerTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (C) 2015 UCSC Computational Genomics Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import unittest

from unittest.mock import patch, Mock
from toil.provisioners import clusterFactory
import toil.provisioners
from toil.provisioners.abstractProvisioner import AbstractProvisioner

from toil.test import needs_azure, integrative, ToilTest, needs_appliance, timeLimit, slow

log = logging.getLogger(__name__)


class DummyProvisioner(AbstractProvisioner):

def readClusterSettings(self):
pass

def launchCluster(self, leaderNodeType, leaderStorage, owner, **kwargs):
pass

def addNodes(self, nodeType, numNodes, preemptable, spotBid=None):
pass

def terminateNodes(self, nodes):
pass

def getLeader(self):
pass

def getProvisionedWorkers(self, nodeType, preemptable):
pass

def getNodeShape(self, nodeType=None, preemptable=False):
pass

def destroyCluster(self):
pass


class ProvisionerTest(ToilTest):

@patch('toil.provisioners.aws.awsProvisioner.AWSProvisioner', Mock())
def test_clusteFactory_should_return_AWSProvisioner_instance_when_aws(self):
actual = clusterFactory('aws')
toil.provisioners.aws.awsProvisioner.AWSProvisioner.assert_called_once()

@patch('toil.provisioners.gceProvisioner.GCEProvisioner', Mock())
def test_clusteFactory_should_return_GCEProvisioner_instance_when_gce(self):
actual = clusterFactory('gce')
toil.provisioners.gceProvisioner.GCEProvisioner.assert_called_once()

@patch('toil.provisioners.azure.azureProvisioner.AzureProvisioner', Mock())
def test_clusteFactory_should_return_AzureProvisioner_instance_when_azure(self):
actual = clusterFactory('azure')
toil.provisioners.azure.azureProvisioner.AzureProvisioner.assert_called_once()

@patch('toil.provisioners.aws.awsProvisioner.AWSProvisioner', Mock())
def test_clusteFactory_should_return_AWSProvisioner_instance_when_awsFQN(self):
actual = clusterFactory('toil.provisioners.aws.awsProvisioner.AWSProvisioner')
toil.provisioners.aws.awsProvisioner.AWSProvisioner.assert_called_once()

def test_clusteFactory_should_return_DummyProvisioner_when_FQN(self):
actual = clusterFactory('toil.test.provisioners.provisionerTest.DummyProvisioner')
self.assertTrue(isinstance(actual, toil.test.provisioners.provisionerTest.DummyProvisioner))


if __name__ == '__main__':
unittest.main()
8 changes: 5 additions & 3 deletions src/toil/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

def addBasicProvisionerOptions(parser):
parser.add_argument("--version", action='version', version=version)
parser.add_argument('-p', "--provisioner", dest='provisioner', choices=['aws', 'azure', 'gce'], required=False, default="aws",
help="The provisioner for cluster auto-scaling. Only aws is currently "
"supported")
parser.add_argument('-p', "--provisioner", dest='provisioner', required=False, default="aws",
help="The provisioner for cluster auto-scaling. Possible choices "
"(but not limited to) are 'aws', 'azure', 'gce'. Only aws is currently"
" supported or your own implementation "
"ex: -p my.own.implementation.provisioner.MyProvisioner")
parser.add_argument('-z', '--zone', dest='zone', required=False, default=None,
help="The availability zone of the master. This parameter can also be set via the 'TOIL_X_ZONE' "
"environment variable, where X is AWS, GCE, or AZURE, or by the ec2_region_name parameter "
Expand Down