Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 54 additions & 25 deletions dask_cloudprovider/aws/ecs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import logging
import uuid
import warnings
import weakref
from typing import List, Optional
Expand Down Expand Up @@ -224,9 +223,9 @@ async def start(self):
"awsvpcConfiguration": {
"subnets": self._vpc_subnets,
"securityGroups": self._security_groups,
"assignPublicIp": "ENABLED"
if self._use_public_ip
else "DISABLED",
"assignPublicIp": (
"ENABLED" if self._use_public_ip else "DISABLED"
),
}
},
}
Expand Down Expand Up @@ -461,7 +460,9 @@ class ECSCluster(SpecCluster, ConfigMixin):
This creates a dask scheduler and workers on an existing ECS cluster.

All the other required resources such as roles, task definitions, tasks, etc
will be created automatically like in :class:`FargateCluster`.
will be created automatically like in :class:`FargateCluster`. Resource names will
include the value of `self.name` to uniquely associate them with this cluster, and
they will also be tagged with `dask_cluster_name` using the same value.

Parameters
----------
Expand Down Expand Up @@ -579,9 +580,11 @@ class ECSCluster(SpecCluster, ConfigMixin):
Defaults to ``None`` which results in a new cluster being created for you.
cluster_name_template: str (optional)
A template to use for the cluster name if ``cluster_arn`` is set to
``None``.
``None``. Valid substitution variables are:

Defaults to ``'dask-{uuid}'``
``name`` <= self.name (usually a UUID)

Defaults to ``'dask-{name}'``
execution_role_arn: str (optional)
The ARN of an existing IAM role to use for ECS execution.

Expand Down Expand Up @@ -626,9 +629,12 @@ class ECSCluster(SpecCluster, ConfigMixin):

Default ``None`` (one will be created called ``dask-ecs``)
cloudwatch_logs_stream_prefix: str (optional)
Prefix for log streams.
Prefix for log streams. Valid substitution variables are:

``name`` <= self.name (usually a UUID)
``cluster_name`` <= self.cluster_name (ECS cluster name)

Defaults to the cluster name.
Defaults to ``{cluster_name}/{name}``.
cloudwatch_logs_default_retention: int (optional)
Retention for logs in days. For use when log group is auto created.

Expand Down Expand Up @@ -921,7 +927,10 @@ async def _start(
if self._cloudwatch_logs_stream_prefix is None:
self._cloudwatch_logs_stream_prefix = self.config.get(
"cloudwatch_logs_stream_prefix"
).format(cluster_name=self.cluster_name)
).format(
cluster_name=self.cluster_name,
name=self.name,
)

if self.cloudwatch_logs_group is None:
self.cloudwatch_logs_group = (
Expand Down Expand Up @@ -1025,7 +1034,12 @@ def _new_worker_name(self, worker_number):

@property
def tags(self):
return {**self._tags, **DEFAULT_TAGS, "cluster": self.cluster_name}
return {
**self._tags,
**DEFAULT_TAGS,
"cluster": self.cluster_name,
"dask_cluster_name": self.name,
}

async def _create_cluster(self):
if not self._fargate_scheduler or not self._fargate_workers:
Expand All @@ -1038,7 +1052,10 @@ async def _create_cluster(self):
self.cluster_name = dask.config.expand_environment_variables(
self._cluster_name_template
)
self.cluster_name = self.cluster_name.format(uuid=str(uuid.uuid4())[:10])
self.cluster_name = self.cluster_name.format(
name=self.name,
uuid=self.name, # backwards-compatible
)
async with self._client("ecs") as ecs:
response = await ecs.create_cluster(
clusterName=self.cluster_name,
Expand All @@ -1059,7 +1076,7 @@ async def _delete_cluster(self):

@property
def _execution_role_name(self):
return "{}-{}".format(self.cluster_name, "execution-role")
return "dask-{}-execution-role".format(self.name)

async def _create_execution_role(self):
async with self._client("iam") as iam:
Expand Down Expand Up @@ -1099,7 +1116,7 @@ async def _create_execution_role(self):

@property
def _task_role_name(self):
return "{}-{}".format(self.cluster_name, "task-role")
return "dask-{}-task-role".format(self.name)

async def _create_task_role(self):
async with self._client("iam") as iam:
Expand Down Expand Up @@ -1141,6 +1158,8 @@ async def _delete_role(self, role):
await iam.delete_role(RoleName=role)

async def _create_cloudwatch_logs_group(self):
# The log group does not include `name` because it is shared by all Dask ECS clusters. But,
# log streams do because they are specific to each Dask cluster.
log_group_name = "dask-ecs"
async with self._client("logs") as logs:
groups = await logs.describe_log_groups()
Expand All @@ -1160,23 +1179,29 @@ async def _create_cloudwatch_logs_group(self):
# Note: Not cleaning up the logs here as they may be useful after the cluster is destroyed
return log_group_name

@property
def _security_group_name(self):
return "dask-{}-security-group".format(self.name)

async def _create_security_groups(self):
async with self._client("ec2") as client:
group = await create_default_security_group(
client, self.cluster_name, self._vpc, self.tags
client, self._security_group_name, self._vpc, self.tags
)
weakref.finalize(self, self.sync, self._delete_security_groups)
return [group]

async def _delete_security_groups(self):
timeout = Timeout(
30, "Unable to delete AWS security group " + self.cluster_name, warn=True
30,
"Unable to delete AWS security group {}".format(self._security_group_name),
warn=True,
)
async with self._client("ec2") as ec2:
while timeout.run():
try:
await ec2.delete_security_group(
GroupName=self.cluster_name, DryRun=False
GroupName=self._security_group_name, DryRun=False
)
except Exception:
await asyncio.sleep(2)
Expand All @@ -1185,7 +1210,7 @@ async def _delete_security_groups(self):
async def _create_scheduler_task_definition_arn(self):
async with self._client("ecs") as ecs:
response = await ecs.register_task_definition(
family="{}-{}".format(self.cluster_name, "scheduler"),
family="dask-{}-scheduler".format(self.name),
taskRoleArn=self._task_role_arn,
executionRoleArn=self._execution_role_arn,
networkMode="awsvpc",
Expand Down Expand Up @@ -1223,14 +1248,18 @@ async def _create_scheduler_task_definition_arn(self):
"awslogs-create-group": "true",
},
},
"mountPoints": self._mount_points
if self._mount_points and self._mount_volumes_on_scheduler
else [],
"mountPoints": (
self._mount_points
if self._mount_points and self._mount_volumes_on_scheduler
else []
),
}
],
volumes=self._volumes
if self._volumes and self._mount_volumes_on_scheduler
else [],
volumes=(
self._volumes
if self._volumes and self._mount_volumes_on_scheduler
else []
),
requiresCompatibilities=["FARGATE"] if self._fargate_scheduler else [],
runtimePlatform={"cpuArchitecture": self._cpu_architecture},
cpu=str(self._scheduler_cpu),
Expand All @@ -1255,7 +1284,7 @@ async def _create_worker_task_definition_arn(self):
)
async with self._client("ecs") as ecs:
response = await ecs.register_task_definition(
family="{}-{}".format(self.cluster_name, "worker"),
family="dask-{}-worker".format(self.name),
taskRoleArn=self._task_role_arn,
executionRoleArn=self._execution_role_arn,
networkMode="awsvpc",
Expand Down
127 changes: 127 additions & 0 deletions dask_cloudprovider/aws/tests/test_ecs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest import mock
from unittest.mock import AsyncMock

import pytest

aiobotocore = pytest.importorskip("aiobotocore")
Expand All @@ -6,3 +9,127 @@
def test_import():
from dask_cloudprovider.aws import ECSCluster # noqa
from dask_cloudprovider.aws import FargateCluster # noqa


def test_reuse_ecs_cluster():
from dask_cloudprovider.aws import ECSCluster # noqa

fc1_name = "Spooky"
fc2_name = "Weevil"
vpc_name = "MyNetwork"
vpc_subnets = ["MySubnet1", "MySubnet2"]
cluster_arn = "CompletelyMadeUp"
cluster_name = "Crunchy"
log_group_name = "dask-ecs"

expected_execution_role_name1 = f"dask-{fc1_name}-execution-role"
expected_task_role_name1 = f"dask-{fc1_name}-task-role"
expected_log_stream_prefix1 = f"{cluster_name}/{fc1_name}"
expected_security_group_name1 = f"dask-{fc1_name}-security-group"
expected_scheduler_task_name1 = f"dask-{fc1_name}-scheduler"
expected_worker_task_name1 = f"dask-{fc1_name}-worker"

expected_execution_role_name2 = f"dask-{fc2_name}-execution-role"
expected_task_role_name2 = f"dask-{fc2_name}-task-role"
expected_log_stream_prefix2 = f"{cluster_name}/{fc2_name}"
expected_security_group_name2 = f"dask-{fc2_name}-security-group"
expected_scheduler_task_name2 = f"dask-{fc2_name}-scheduler"
expected_worker_task_name2 = f"dask-{fc2_name}-worker"

mock_client = AsyncMock()
mock_client.describe_clusters.return_value = {
"clusters": [{"clusterName": cluster_name}]
}
mock_client.list_account_settings.return_value = {"settings": {"value": "enabled"}}
mock_client.create_role.return_value = {"Role": {"Arn": "Random"}}
mock_client.describe_log_groups.return_value = {"logGroups": []}

class MockSession:
class MockClient:
async def __aenter__(self, *args, **kwargs):
return mock_client

async def __aexit__(self, *args, **kwargs):
return

def create_client(self, *args, **kwargs):
return MockSession.MockClient()

with (
mock.patch(
"dask_cloudprovider.aws.ecs.get_session", return_value=MockSession()
),
mock.patch("distributed.deploy.spec.SpecCluster._start"),
mock.patch("weakref.finalize"),
):
# Make ourselves a test cluster
fc1 = ECSCluster(
name=fc1_name,
cluster_arn=cluster_arn,
vpc=vpc_name,
subnets=vpc_subnets,
skip_cleanup=True,
)
# Are we re-using the existing ECS cluster?
assert fc1.cluster_name == cluster_name
# Have we made completely unique AWS resources to run on that cluster?
assert fc1._execution_role_name == expected_execution_role_name1
assert fc1._task_role_name == expected_task_role_name1
assert fc1._cloudwatch_logs_stream_prefix == expected_log_stream_prefix1
assert (
fc1.scheduler_spec["options"]["log_stream_prefix"]
== expected_log_stream_prefix1
)
assert (
fc1.new_spec["options"]["log_stream_prefix"] == expected_log_stream_prefix1
)
assert fc1.cloudwatch_logs_group == log_group_name
assert fc1.scheduler_spec["options"]["log_group"] == log_group_name
assert fc1.new_spec["options"]["log_group"] == log_group_name
sg_calls = mock_client.create_security_group.call_args_list
assert len(sg_calls) == 1
assert sg_calls[0].kwargs["GroupName"] == expected_security_group_name1
td_calls = mock_client.register_task_definition.call_args_list
assert len(td_calls) == 2
assert td_calls[0].kwargs["family"] == expected_scheduler_task_name1
assert td_calls[1].kwargs["family"] == expected_worker_task_name1

# Reset mocks ready for second cluster
mock_client.create_security_group.reset_mock()
mock_client.register_task_definition.reset_mock()

# Make ourselves a second test cluster on the same ECS cluster
fc2 = ECSCluster(
name=fc2_name,
cluster_arn=cluster_arn,
vpc=vpc_name,
subnets=vpc_subnets,
skip_cleanup=True,
)
# Are we re-using the existing ECS cluster?
assert fc2.cluster_name == cluster_name
# Have we made completely unique AWS resources to run on that cluster?
assert fc2._execution_role_name == expected_execution_role_name2
assert fc2._task_role_name == expected_task_role_name2
assert fc2._cloudwatch_logs_stream_prefix == expected_log_stream_prefix2
assert (
fc2.scheduler_spec["options"]["log_stream_prefix"]
== expected_log_stream_prefix2
)
assert (
fc2.new_spec["options"]["log_stream_prefix"] == expected_log_stream_prefix2
)
assert fc2.cloudwatch_logs_group == log_group_name
assert fc2.scheduler_spec["options"]["log_group"] == log_group_name
assert fc2.new_spec["options"]["log_group"] == log_group_name
sg_calls = mock_client.create_security_group.call_args_list
assert len(sg_calls) == 1
assert sg_calls[0].kwargs["GroupName"] == expected_security_group_name2
td_calls = mock_client.register_task_definition.call_args_list
assert len(td_calls) == 2
assert td_calls[0].kwargs["family"] == expected_scheduler_task_name2
assert td_calls[1].kwargs["family"] == expected_worker_task_name2

# Finish up
fc1.close()
fc2.close()
4 changes: 2 additions & 2 deletions dask_cloudprovider/cloudprovider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ cloudprovider:
image: "daskdev/dask:latest" # Docker image to use for non GPU tasks
cpu_architecture: "X86_64" # Runtime platform CPU architecture
gpu_image: "rapidsai/rapidsai:latest" # Docker image to use for GPU tasks
cluster_name_template: "dask-{uuid}" # Template to use when creating a cluster
cluster_name_template: "dask-{name}" # Template to use when creating a cluster
cluster_arn: "" # ARN of existing ECS cluster to use (if not set one will be created)
execution_role_arn: "" # Arn of existing execution role to use (if not set one will be created)
task_role_arn: "" # Arn of existing task role to use (if not set one will be created)
task_role_policies: [] # List of policy arns to attach to tasks (e.g S3 read only access)
# platform_version: "LATEST" # Fargate platformVersion string like "1.4.0" or "LATEST"

cloudwatch_logs_group: "" # Name of existing cloudwatch logs group to use (if not set one will be created)
cloudwatch_logs_stream_prefix: "{cluster_name}" # Stream prefix template
cloudwatch_logs_stream_prefix: "{cluster_name}/{name}" # Stream prefix template
cloudwatch_logs_default_retention: 30 # Number of days to retain logs (only applied if not using existing group)

vpc: "default" # VPC to use for tasks
Expand Down
Loading