diff --git a/runner_manager/models/backend.py b/runner_manager/models/backend.py index cb0f8592..c08e7535 100644 --- a/runner_manager/models/backend.py +++ b/runner_manager/models/backend.py @@ -7,6 +7,7 @@ from mypy_boto3_ec2.type_defs import ( BlockDeviceMappingTypeDef, EbsBlockDeviceTypeDef, + IamInstanceProfileTypeDef, TagSpecificationTypeDef, TagTypeDef, ) @@ -133,6 +134,7 @@ class AWSConfig(BackendConfig): "BlockDeviceMappings": Sequence[BlockDeviceMappingTypeDef], "MaxCount": int, "MinCount": int, + "IamInstanceProfile": IamInstanceProfileTypeDef, }, ) @@ -150,6 +152,7 @@ class AWSInstanceConfig(InstanceConfig): tags: Dict[str, str] = {} volume_type: VolumeTypeType = "gp3" disk_size_gb: int = 20 + iam_instance_profile_arn: str = "" def configure_instance(self, runner: Runner) -> AwsInstance: """Configure instance.""" @@ -184,6 +187,9 @@ def configure_instance(self, runner: Runner) -> AwsInstance: Tags=tags, ), ] + iam_instance_profile: IamInstanceProfileTypeDef = IamInstanceProfileTypeDef( + Arn=self.iam_instance_profile_arn + ) return AwsInstance( ImageId=self.image, InstanceType=self.instance_type, @@ -194,4 +200,5 @@ def configure_instance(self, runner: Runner) -> AwsInstance: MaxCount=self.max_count, MinCount=self.min_count, BlockDeviceMappings=block_device_mappings, + IamInstanceProfile=iam_instance_profile, ) diff --git a/tests/unit/backend/test_aws.py b/tests/unit/backend/test_aws.py index 0ffcbfd0..9bad369c 100644 --- a/tests/unit/backend/test_aws.py +++ b/tests/unit/backend/test_aws.py @@ -46,11 +46,17 @@ def aws_runner(runner: Runner, aws_group: RunnerGroup) -> Runner: def test_aws_instance_config(runner: Runner): AWSConfig() instance_config = AWSInstanceConfig( - tags={"test": "test"}, subnet_id="i-0f9b0a3b7b3b3b3b3" + tags={"test": "test"}, + subnet_id="i-0f9b0a3b7b3b3b3b3", + iam_instance_profile_arn="test", ) instance: AwsInstance = instance_config.configure_instance(runner) assert instance["ImageId"] == instance_config.image assert instance["SubnetId"] == instance_config.subnet_id + assert ( + instance["IamInstanceProfile"]["Arn"] + == instance_config.iam_instance_profile_arn + ) assert runner.name in instance["UserData"] tags = instance["TagSpecifications"][0]["Tags"] assert TagTypeDef(Key="test", Value="test") in tags