diff --git a/VERSION b/VERSION index 21222ce..2c3fc41 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.5.0 +v2.5.1 diff --git a/customizations-for-aws-control-tower.template b/customizations-for-aws-control-tower.template index f7bcfcd..ad7bc3c 100644 --- a/customizations-for-aws-control-tower.template +++ b/customizations-for-aws-control-tower.template @@ -12,7 +12,7 @@ # permissions and limitations under the License. AWSTemplateFormatVersion: '2010-09-09' -Description: '(SO0089) - customizations-for-aws-control-tower Solution. Version: v2.5.0' +Description: '(SO0089) - customizations-for-aws-control-tower Solution. Version: v2.5.1' Parameters: PipelineApprovalStage: @@ -127,7 +127,7 @@ Mappings: SourceBucketName: Name: control-tower-cfct-assets-prod SourceKeyName: - Name: customizations-for-aws-control-tower/v2.5.0/custom-control-tower-configuration.zip + Name: customizations-for-aws-control-tower/v2.5.1/custom-control-tower-configuration.zip CustomControlTowerPipelineS3TriggerKey: Name: custom-control-tower-configuration.zip CustomControlTowerPipelineS3NonTriggerKey: @@ -145,7 +145,7 @@ Mappings: SolutionID: 'SO0089' MetricsURL: 'https://metrics.awssolutionsbuilder.com/generic' Data: - AddonTemplate: 'https://s3.amazonaws.com/control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.0/custom-control-tower-initiation.template' + AddonTemplate: 'https://s3.amazonaws.com/control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.1/custom-control-tower-initiation.template' AWSControlTower: ExecutionRole: Name: "AWSControlTowerExecution" @@ -292,7 +292,7 @@ Resources: Code: S3: Bucket: control-tower-cfct-assets-prod - Key: !Sub customizations-for-aws-control-tower/v2.5.0/custom-control-tower-configuration-${AWS::Region}.zip + Key: !Sub customizations-for-aws-control-tower/v2.5.1/custom-control-tower-configuration-${AWS::Region}.zip # SSM Parameter to store the git repository name CustomControlTowerRepoNameParameter: @@ -551,7 +551,7 @@ Resources: - {KMSKeyName: !FindInMap [KMS, Alias, Name]} Source: Type: CODEPIPELINE - BuildSpec: "version: 0.2\nphases:\n install:\n runtime-versions:\n python: 3.8\n ruby: 2.6\n commands:\n - export current=$(pwd)\n - if [ -f manifest.yaml ];then export current=$(pwd);else if [ -f custom-control-tower-configuration/manifest.yaml ]; then export current=$(pwd)/custom-control-tower-configuration; else echo 'manifest.yaml does not exist at the root level of custom-control-tower-configuration.zip or inside custom-control-tower-configuration folder, please check the ZIP file'; exit 1; fi; fi;\n - apt-get -q update 1> /dev/null\n - apt-get -q install zip wget python3-pip libyaml-dev -y 1>/dev/null\n - export LC_ALL='en_US.UTF-8'\n - locale-gen en_US en_US.UTF-8\n - dpkg-reconfigure locales --frontend noninteractive\n pre_build:\n commands:\n - cd $current\n - echo 'Download CustomControlTower Scripts'\n - aws s3 cp --quiet s3://control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.0/custom-control-tower-scripts.zip $current\n - unzip -q -o $current/custom-control-tower-scripts.zip -d $current\n - cp codebuild_scripts/* .\n - bash install_stage_dependencies.sh $STAGE_NAME\n build:\n commands:\n - echo 'Starting build $(date) in $(pwd)'\n - echo 'bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES'\n - bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES \n - echo 'Running build scripts completed $(date)'\n post_build:\n commands:\n - echo 'Starting post build $(date) in $(pwd)'\n - echo 'build completed on $(date)'\n\nartifacts:\n files:\n - '**/*'\n\n" + BuildSpec: "version: 0.2\nphases:\n install:\n runtime-versions:\n python: 3.8\n ruby: 2.6\n commands:\n - export current=$(pwd)\n - if [ -f manifest.yaml ];then export current=$(pwd);else if [ -f custom-control-tower-configuration/manifest.yaml ]; then export current=$(pwd)/custom-control-tower-configuration; else echo 'manifest.yaml does not exist at the root level of custom-control-tower-configuration.zip or inside custom-control-tower-configuration folder, please check the ZIP file'; exit 1; fi; fi;\n - apt-get -q update 1> /dev/null\n - apt-get -q install zip wget python3-pip libyaml-dev -y 1>/dev/null\n - export LC_ALL='en_US.UTF-8'\n - locale-gen en_US en_US.UTF-8\n - dpkg-reconfigure locales --frontend noninteractive\n pre_build:\n commands:\n - cd $current\n - echo 'Download CustomControlTower Scripts'\n - aws s3 cp --quiet s3://control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.1/custom-control-tower-scripts.zip $current\n - unzip -q -o $current/custom-control-tower-scripts.zip -d $current\n - cp codebuild_scripts/* .\n - bash install_stage_dependencies.sh $STAGE_NAME\n build:\n commands:\n - echo 'Starting build $(date) in $(pwd)'\n - echo 'bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES'\n - bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES \n - echo 'Running build scripts completed $(date)'\n post_build:\n commands:\n - echo 'Starting post build $(date) in $(pwd)'\n - echo 'build completed on $(date)'\n\nartifacts:\n files:\n - '**/*'\n\n" Environment: ComputeType: BUILD_GENERAL1_SMALL Image: "aws/codebuild/standard:5.0" @@ -576,7 +576,7 @@ Resources: - Name: SOLUTION_ID Value: !FindInMap [ Solution, Metrics, SolutionID ] - Name: SOLUTION_VERSION - Value: v2.5.0 + Value: v2.5.1 Artifacts: Name: !Sub ${CustomControlTowerPipelineArtifactS3Bucket}-Built Type: CODEPIPELINE @@ -679,7 +679,7 @@ Resources: - {KMSKeyName: !FindInMap [KMS, Alias, Name]} Source: Type: CODEPIPELINE - BuildSpec: "version: 0.2\nphases:\n install:\n runtime-versions:\n python: 3.8\n ruby: 2.6\n commands:\n - export current=$(pwd)\n - if [ -f manifest.yaml ];then export current=$(pwd);else if [ -f custom-control-tower-configuration/manifest.yaml ]; then export current=$(pwd)/custom-control-tower-configuration; else echo 'manifest.yaml does not exist at the root level of custom-control-tower-configuration.zip or inside custom-control-tower-configuration folder, please check the ZIP file'; exit 1; fi; fi;\n - apt-get -q update 1> /dev/null\n - apt-get -q install zip wget python3-pip libyaml-dev -y 1> /dev/null \n pre_build:\n commands:\n - cd $current\n - echo 'Download CustomControlTower Scripts'\n - aws s3 cp --quiet s3://control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.0/custom-control-tower-scripts.zip $current\n - unzip -q -o $current/custom-control-tower-scripts.zip -d $current\n - cp codebuild_scripts/* .\n - bash install_stage_dependencies.sh $STAGE_NAME\n build:\n commands:\n - echo 'Starting build $(date) in $(pwd)'\n - echo 'bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES'\n - bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES\n - echo 'Running build scripts completed $(date)'\n post_build:\n commands:\n - echo 'Starting post build $(date) in $(pwd)'\n - echo 'build completed on $(date)'\n\nartifacts:\n files:\n - '**/*'\n" + BuildSpec: "version: 0.2\nphases:\n install:\n runtime-versions:\n python: 3.8\n ruby: 2.6\n commands:\n - export current=$(pwd)\n - if [ -f manifest.yaml ];then export current=$(pwd);else if [ -f custom-control-tower-configuration/manifest.yaml ]; then export current=$(pwd)/custom-control-tower-configuration; else echo 'manifest.yaml does not exist at the root level of custom-control-tower-configuration.zip or inside custom-control-tower-configuration folder, please check the ZIP file'; exit 1; fi; fi;\n - apt-get -q update 1> /dev/null\n - apt-get -q install zip wget python3-pip libyaml-dev -y 1> /dev/null \n pre_build:\n commands:\n - cd $current\n - echo 'Download CustomControlTower Scripts'\n - aws s3 cp --quiet s3://control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.1/custom-control-tower-scripts.zip $current\n - unzip -q -o $current/custom-control-tower-scripts.zip -d $current\n - cp codebuild_scripts/* .\n - bash install_stage_dependencies.sh $STAGE_NAME\n build:\n commands:\n - echo 'Starting build $(date) in $(pwd)'\n - echo 'bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES'\n - bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES\n - echo 'Running build scripts completed $(date)'\n post_build:\n commands:\n - echo 'Starting post build $(date) in $(pwd)'\n - echo 'build completed on $(date)'\n\nartifacts:\n files:\n - '**/*'\n" Environment: ComputeType: BUILD_GENERAL1_SMALL Image: "aws/codebuild/standard:5.0" @@ -700,7 +700,7 @@ Resources: - Name: SOLUTION_ID Value: !FindInMap [ Solution, Metrics, SolutionID ] - Name: SOLUTION_VERSION - Value: v2.5.0 + Value: v2.5.1 Artifacts: Name: !Sub ${CustomControlTowerPipelineArtifactS3Bucket}-Built Type: CODEPIPELINE @@ -855,7 +855,7 @@ Resources: - {KMSKeyName: !FindInMap [KMS, Alias, Name]} Source: Type: CODEPIPELINE - BuildSpec: "version: 0.2\nphases:\n install:\n runtime-versions:\n python: 3.8\n ruby: 2.6\n commands:\n - export current=$(pwd)\n - if [ -f manifest.yaml ];then export current=$(pwd);else if [ -f custom-control-tower-configuration/manifest.yaml ]; then export current=$(pwd)/custom-control-tower-configuration; else echo 'manifest.yaml does not exist at the root level of custom-control-tower-configuration.zip or inside custom-control-tower-configuration folder, please check the ZIP file'; exit 1; fi; fi;\n - apt-get -q update 1> /dev/null\n - apt-get -q install zip wget python3-pip libyaml-dev -y 1> /dev/null\n pre_build:\n commands:\n - cd $current\n - echo 'Download CustomControlTower Scripts'\n - aws s3 cp --quiet s3://control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.0/custom-control-tower-scripts.zip $current\n - unzip -q -o $current/custom-control-tower-scripts.zip -d $current\n - cp codebuild_scripts/* .\n - bash install_stage_dependencies.sh $STAGE_NAME\n build:\n commands:\n - echo 'Starting build $(date) in $(pwd)'\n - echo 'bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES'\n - bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES\n - echo 'Running build scripts completed $(date)'\n post_build:\n commands:\n - echo 'Starting post build $(date) in $(pwd)'\n - echo 'build completed on $(date)'\n\nartifacts:\n files:\n - '**/*'\n" + BuildSpec: "version: 0.2\nphases:\n install:\n runtime-versions:\n python: 3.8\n ruby: 2.6\n commands:\n - export current=$(pwd)\n - if [ -f manifest.yaml ];then export current=$(pwd);else if [ -f custom-control-tower-configuration/manifest.yaml ]; then export current=$(pwd)/custom-control-tower-configuration; else echo 'manifest.yaml does not exist at the root level of custom-control-tower-configuration.zip or inside custom-control-tower-configuration folder, please check the ZIP file'; exit 1; fi; fi;\n - apt-get -q update 1> /dev/null\n - apt-get -q install zip wget python3-pip libyaml-dev -y 1> /dev/null\n pre_build:\n commands:\n - cd $current\n - echo 'Download CustomControlTower Scripts'\n - aws s3 cp --quiet s3://control-tower-cfct-assets-prod/customizations-for-aws-control-tower/v2.5.1/custom-control-tower-scripts.zip $current\n - unzip -q -o $current/custom-control-tower-scripts.zip -d $current\n - cp codebuild_scripts/* .\n - bash install_stage_dependencies.sh $STAGE_NAME\n build:\n commands:\n - echo 'Starting build $(date) in $(pwd)'\n - echo 'bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES'\n - bash execute_stage_scripts.sh $STAGE_NAME $LOG_LEVEL $WAIT_TIME $SM_ARN $ARTIFACT_BUCKET $KMS_KEY_ALIAS_NAME $BOOL_VALUES $NONE_TYPE_VALUES\n - echo 'Running build scripts completed $(date)'\n post_build:\n commands:\n - echo 'Starting post build $(date) in $(pwd)'\n - echo 'build completed on $(date)'\n\nartifacts:\n files:\n - '**/*'\n" Environment: ComputeType: BUILD_GENERAL1_SMALL Image: "aws/codebuild/standard:5.0" @@ -880,7 +880,7 @@ Resources: - Name: SOLUTION_ID Value: !FindInMap [Solution, Metrics, SolutionID] - Name: SOLUTION_VERSION - Value: v2.5.0 + Value: v2.5.1 - Name: METRICS_URL Value: !FindInMap [Solution, Metrics, MetricsURL] - Name: CONTROL_TOWER_BASELINE_CONFIG_STACKSET @@ -1003,10 +1003,10 @@ Resources: Variables: LOG_LEVEL: !FindInMap [LambdaFunction, Logging, Level] SOLUTION_ID: !FindInMap [Solution, Metrics, SolutionID] - SOLUTION_VERSION: v2.5.0 + SOLUTION_VERSION: v2.5.1 Code: S3Bucket: !Sub "control-tower-cfct-assets-prod-${AWS::Region}" - S3Key: customizations-for-aws-control-tower/v2.5.0/custom-control-tower-config-deployer.zip + S3Key: customizations-for-aws-control-tower/v2.5.1/custom-control-tower-config-deployer.zip FunctionName: CustomControlTowerDeploymentLambda Description: Custom Control Tower Deployment Lambda Handler: config_deployer.lambda_handler @@ -1273,14 +1273,14 @@ Resources: ADMINISTRATION_ROLE_ARN: !Sub arn:${AWS::Partition}:iam::${AWS::AccountId}:role/service-role/AWSControlTowerStackSetRole EXECUTION_ROLE_NAME: !FindInMap [AWSControlTower, ExecutionRole, Name] SOLUTION_ID: !FindInMap [Solution, Metrics, SolutionID] - SOLUTION_VERSION: v2.5.0 + SOLUTION_VERSION: v2.5.1 METRICS_URL: !FindInMap [Solution, Metrics, MetricsURL] MAX_CONCURRENT_PERCENT: !Ref MaxConcurrentPercentage FAILED_TOLERANCE_PERCENT: !Ref FailureTolerancePercentage REGION_CONCURRENCY_TYPE: !Ref RegionConcurrencyType Code: S3Bucket: !Sub "control-tower-cfct-assets-prod-${AWS::Region}" - S3Key: customizations-for-aws-control-tower/v2.5.0/custom-control-tower-state-machine.zip + S3Key: customizations-for-aws-control-tower/v2.5.1/custom-control-tower-state-machine.zip FunctionName: CustomControlTowerStateMachineLambda Description: Custom Control Tower State Machine Handler Handler: state_machine_router.lambda_handler @@ -2888,10 +2888,10 @@ Resources: LOG_LEVEL: !FindInMap [LambdaFunction, Logging, Level] CODE_PIPELINE_NAME: !Ref CustomControlTowerCodePipeline SOLUTION_ID: !FindInMap [ Solution, Metrics, SolutionID ] - SOLUTION_VERSION: v2.5.0 + SOLUTION_VERSION: v2.5.1 Code: S3Bucket: !Sub "control-tower-cfct-assets-prod-${AWS::Region}" - S3Key: customizations-for-aws-control-tower/v2.5.0/custom-control-tower-lifecycle-event-handler.zip + S3Key: customizations-for-aws-control-tower/v2.5.1/custom-control-tower-lifecycle-event-handler.zip Description: Custom Control Tower Lifecyle event Lambda to handle lifecycle events Handler: lifecycle_event_handler.lambda_handler MemorySize: 512 @@ -3062,6 +3062,6 @@ Outputs: Value: !Ref CustomControlTowerPipelineS3Bucket CustomControlTowerSolutionVersion: Description: Version Number - Value: "v2.5.0" + Value: "v2.5.1" Export: Name: Custom-Control-Tower-Version diff --git a/deployment/build-s3-dist.sh b/deployment/build-s3-dist.sh index 14edc06..44c212f 100755 --- a/deployment/build-s3-dist.sh +++ b/deployment/build-s3-dist.sh @@ -91,7 +91,26 @@ zip -Xr "$build_dist_dir"/custom-control-tower-configuration.zip ./* # build regional config zip file echo -e "\n*** Build regional config zip file" -declare -a region_list=( "ap-northeast-2" "ap-southeast-2" "ca-central-1" "eu-west-1" "eu-west-2" "me-south-1" "us-east-1" "us-west-1" "ap-east-1" "ap-south-1" "eu-central-1" "eu-north-1" "eu-west-3" "sa-east-1" "us-east-2" "us-west-2" "ap-northeast-1" "ap-southeast-1" ) +# Support all regions in https://docs.aws.amazon.com/controltower/latest/userguide/region-how.html + GovCloud regions +declare -a region_list=( + "us-east-1" + "us-east-2" + "us-west-2" + "ca-central-1" + "ap-southeast-2" + "ap-southeast-1" + "eu-central-1" + "eu-west-1" + "eu-west-2" + "eu-north-1" + "ap-south-1" + "ap-northeast-2" + "ap-northeast-1" + "eu-west-3" + "sa-east-1" + "us-gov-west-1" + "us-gov-east-1" +) for region in "${region_list[@]}" do echo -e "\n Building config zip for $region region" diff --git a/deployment/lambda_build.py b/deployment/lambda_build.py index b992a74..bae0dca 100644 --- a/deployment/lambda_build.py +++ b/deployment/lambda_build.py @@ -16,9 +16,9 @@ # !/usr/bin/env python3 import glob import os -import sys import shutil import subprocess +import sys from pathlib import Path LIB_PATH = "source/src" @@ -32,7 +32,7 @@ "deployment_lambda": "custom-control-tower-config-deployer", "build_scripts": "custom-control-tower-scripts", "lifecycle_event_handler": "custom-control-tower-lifecycle-event-handler", - "state_machine_trigger": "custom-control-tower-state-machine-trigger" + "state_machine_trigger": "custom-control-tower-state-machine-trigger", } @@ -92,7 +92,12 @@ def main(argv): else: os.makedirs(S3_OUTPUT_PATH, exist_ok=True) print(" Installing dependencies...") - install_dependencies(dist_folder=DIST_PATH, lib_path=LIB_PATH, handlers_path=HANDLERS_PATH, codebuild_script_path=CODEBUILD_SCRIPTS_PATH) + install_dependencies( + dist_folder=DIST_PATH, + lib_path=LIB_PATH, + handlers_path=HANDLERS_PATH, + codebuild_script_path=CODEBUILD_SCRIPTS_PATH, + ) for arg in argv: if arg in LAMBDA_BUILD_MAPPING: @@ -112,7 +117,9 @@ def main(argv): print(f" Creating archive for {zip_file_name}..") create_lambda_archive( - zip_file_name=zip_file_name, source=DIST_PATH, output_path=S3_OUTPUT_PATH + zip_file_name=zip_file_name, + source=DIST_PATH, + output_path=S3_OUTPUT_PATH, ) diff --git a/source/codebuild_scripts/find_replace.py b/source/codebuild_scripts/find_replace.py index 3669ace..9b09796 100644 --- a/source/codebuild_scripts/find_replace.py +++ b/source/codebuild_scripts/find_replace.py @@ -15,17 +15,17 @@ # !/bin/python -import os import inspect -import yaml -import sys import json -from jinja2 import Environment, FileSystemLoader -from cfct.utils.logger import Logger +import os +import sys +import yaml +from cfct.utils.logger import Logger +from jinja2 import Environment, FileSystemLoader # initialise logger -log_level = 'info' +log_level = "info" logger = Logger(loglevel=log_level) @@ -37,9 +37,9 @@ def find_replace(function_path, file_name, destination_file, parameters): j2template = j2env.get_template(file_name) dictionary = {} for key, value in parameters.items(): - if 'json' in file_name and not isinstance(value, list): - value = "\"%s\"" % value - elif 'json' in file_name and isinstance(value, list): + if "json" in file_name and not isinstance(value, list): + value = '"%s"' % value + elif "json" in file_name and isinstance(value, list): value = json.dumps(value) dictionary.update({key: value}) logger.debug(dictionary) @@ -47,8 +47,7 @@ def find_replace(function_path, file_name, destination_file, parameters): with open(destination_file, "w") as fh: fh.write(output) except Exception as e: - logger.log_general_exception( - __file__.split('/')[-1], inspect.stack()[0][3], e) + logger.log_general_exception(__file__.split("/")[-1], inspect.stack()[0][3], e) raise @@ -57,40 +56,46 @@ def update_add_on_manifest(event, path): exclude_j2_files = [] # Find and replace the variable in Manifest file - for item in event.get('input_parameters'): - f = item.get('file_name') + for item in event.get("input_parameters"): + f = item.get("file_name") exclude_j2_files.append(f) filename, file_extension = os.path.splitext(f) - destination_file_path = extract_path + "/" + filename \ - if file_extension == '.j2' else extract_path + "/" + f - find_replace(extract_path, f, destination_file_path, - item.get('parameters')) + destination_file_path = ( + extract_path + "/" + filename + if file_extension == ".j2" + else extract_path + "/" + f + ) + find_replace(extract_path, f, destination_file_path, item.get("parameters")) def sanitize_boolean_type(s, bools): - s = ' ' + s + s = " " + s logger.info("Adding quotes around the boolean values: {}".format(bools)) logger.info("Print original string: {}".format(s)) - for w in [x.strip() for x in bools.split(',')]: - s = s.replace(':' + ' ' + w, ': "' + w + '"') - logger.info("If found, wrapped '{}' with double quotes, printing" - " the modified string: {}".format(w, s)) + for w in [x.strip() for x in bools.split(",")]: + s = s.replace(":" + " " + w, ': "' + w + '"') + logger.info( + "If found, wrapped '{}' with double quotes, printing" + " the modified string: {}".format(w, s) + ) return yaml.safe_load(s[1:]) def sanitize_null_type(d, none_type_values): s = json.dumps(d) - s = ' ' + s + s = " " + s logger.info("Replacing none_type/null with empty quotes.") - for w in [x.strip() for x in none_type_values.split(',')]: - s = s.replace(':' + ' ' + w, ': ""') - logger.info("If found, replacing '{}' with double quotes, printing" - " the modified string: {}".format(w, s)) + for w in [x.strip() for x in none_type_values.split(",")]: + s = s.replace(":" + " " + w, ': ""') + logger.info( + "If found, replacing '{}' with double quotes, printing" + " the modified string: {}".format(w, s) + ) return yaml.safe_load(s[1:]) def generate_event(user_input_file, path, bools, none_types): - logger.info('Generating Event') + logger.info("Generating Event") with open(user_input_file) as f: user_input = sanitize_boolean_type(f.read(), bools) logger.info("Boolean values wrapped with quotes (if applicable)") @@ -101,7 +106,7 @@ def generate_event(user_input_file, path, bools, none_types): update_add_on_manifest(user_input, path) -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) > 4: path = sys.argv[2] file_name = sys.argv[1] @@ -109,8 +114,9 @@ def generate_event(user_input_file, path, bools, none_types): boolean_type_values = sys.argv[3] generate_event(file_name, path, boolean_type_values, none_type_values) else: - print('Not enough arguments provided. Please provide the path and' - ' user input file names.') - print('Example: merge_manifest.py ' - ' ') + print( + "Not enough arguments provided. Please provide the path and" + " user input file names." + ) + print("Example: merge_manifest.py " " ") sys.exit(2) diff --git a/source/codebuild_scripts/merge_baseline_template_parameter.py b/source/codebuild_scripts/merge_baseline_template_parameter.py index 434beea..13ba1cf 100644 --- a/source/codebuild_scripts/merge_baseline_template_parameter.py +++ b/source/codebuild_scripts/merge_baseline_template_parameter.py @@ -14,15 +14,16 @@ ############################################################################### import json -import sys import os import subprocess +import sys + from cfct.utils.logger import Logger def _read_file(file): if os.path.isfile(file): - logger.info('File - {} exists'.format(file)) + logger.info("File - {} exists".format(file)) logger.info("Reading from {}".format(file)) with open(file) as f: return json.load(f) @@ -31,7 +32,7 @@ def _read_file(file): sys.exit(1) -def _write_file(data, file, mode='w'): +def _write_file(data, file, mode="w"): logger.info("Writing to {}".format(file)) with open(file, mode) as outfile: json.dump(data, outfile, indent=2) @@ -46,44 +47,42 @@ def _flip_to_json(yaml_file): def _flip_to_yaml(json_file): - yaml_file = json_file[:-len(updated_flag)] + yaml_file = json_file[: -len(updated_flag)] logger.info("Flipping JSON > {} to YAML > {}".format(json_file, yaml_file)) # final stage - convert json avm template to yaml format subprocess.run(["cfn-flip", "-y", json_file, yaml_file]) -def file_matcher(master_data, add_on_data, master_key='master', - add_on_key='add_on'): +def file_matcher(master_data, add_on_data, master_key="master", add_on_key="add_on"): for item in master_data.get(master_key): logger.info("Iterating Master AVM File List") for key, value in item.items(): - logger.info('{}: {}'.format(key, value)) - master_file_name = value.split('/')[-1] - logger.info("master_value: {}".format(value.split('/')[-1])) + logger.info("{}: {}".format(key, value)) + master_file_name = value.split("/")[-1] + logger.info("master_value: {}".format(value.split("/")[-1])) for i in add_on_data.get(add_on_key): logger.info("Iterating Add-On AVM File List for comparision.") for k, v in i.items(): - logger.info('{}: {}'.format(k, v)) - add_on_file_name = v.split('/')[-1] - logger.info("add_on_value: {}".format(v.split('/')[-1])) + logger.info("{}: {}".format(k, v)) + add_on_file_name = v.split("/")[-1] + logger.info("add_on_value: {}".format(v.split("/")[-1])) if master_file_name == add_on_file_name: - logger.info("Matching file names found - " - "full path below") + logger.info("Matching file names found - " "full path below") logger.info("File in master list: {}".format(value)) logger.info("File in add-on list: {}".format(v)) # Pass value and v to merge functions - if master_file_name.lower().endswith('.template'): + if master_file_name.lower().endswith(".template"): logger.info("Processing template file") # merge master avm template with add_on template # send json data - final_json = update_template(_flip_to_json(value), - _flip_to_json(v)) + final_json = update_template( + _flip_to_json(value), _flip_to_json(v) + ) # write the json data to json file - updated_json_file_name = \ - os.path.join(value+updated_flag) + updated_json_file_name = os.path.join(value + updated_flag) _write_file(final_json, updated_json_file_name) _flip_to_yaml(updated_json_file_name) - if master_file_name.lower().endswith('.json'): + if master_file_name.lower().endswith(".json"): logger.info("Processing parameter file") update_parameters(value, v) @@ -92,16 +91,17 @@ def update_level_1_dict(master, add_on, level_1_key): for key1, value1 in add_on.items(): if isinstance(value1, dict) and key1 == level_1_key: # Check if primary key matches - logger.info("Level 1 keys matched ADDON {} == {}".format( - key1, level_1_key)) + logger.info("Level 1 keys matched ADDON {} == {}".format(key1, level_1_key)) # Iterate through the 2nd level dicts in the value for key2, value2 in value1.items(): logger.info("----------------------------------") # Match k with master dict keys - add if not present for k1, v1 in master.items(): if isinstance(v1, dict) and k1 == level_1_key: - logger.info("Level 1 keys matched MASTER " - "{} == {}".format(k1, level_1_key)) + logger.info( + "Level 1 keys matched MASTER " + "{} == {}".format(k1, level_1_key) + ) flag = False # Iterate through the 2nd level dicts in # the value @@ -110,17 +110,17 @@ def update_level_1_dict(master, add_on, level_1_key): if key2 == k2: logger.info("Found matching keys") flag = False - logger.info("Setting flag value to {}" - .format(flag)) + logger.info("Setting flag value to {}".format(flag)) break else: flag = True logger.info( "Add-on key not found in existing" " dict, setting flag value to {}" - " to update dict.".format(flag)) + " to update dict.".format(flag) + ) if flag: - logger.info('Adding key {}'.format(key2)) + logger.info("Adding key {}".format(key2)) d2 = {key2: value2} v1.update(d2) logger.debug(master) @@ -152,7 +152,7 @@ def update_template(master, add_on): return master -def update_parameters(master, add_on, decision_key='ParameterKey'): +def update_parameters(master, add_on, decision_key="ParameterKey"): logger.info("Merging parameter files.") m_list = _read_file(master) add_list = _read_file(add_on) @@ -164,11 +164,16 @@ def update_parameters(master, add_on, decision_key='ParameterKey'): for i in m_list: logger.info(i.get(decision_key)) if item.get(decision_key) == i.get(decision_key): - logger.info("Keys: '{}' matched, skipping" - .format(item.get(decision_key))) + logger.info( + "Keys: '{}' matched, skipping".format( + item.get(decision_key) + ) + ) flag = False - logger.info("Setting flag value to {} and stopping" - " the loop.".format(flag)) + logger.info( + "Setting flag value to {} and stopping" + " the loop.".format(flag) + ) break else: flag = True @@ -181,7 +186,7 @@ def update_parameters(master, add_on, decision_key='ParameterKey'): return m_list -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) > 3: log_level = sys.argv[1] master_baseline_file = sys.argv[2] @@ -196,8 +201,12 @@ def update_parameters(master, add_on, decision_key='ParameterKey'): file_matcher(master_list, add_on_list) else: - print('No arguments provided. Please provide the existing and ' - 'new manifest files names.') - print('Example: merge_baseline_template_parameter.py ' - ' ') + print( + "No arguments provided. Please provide the existing and " + "new manifest files names." + ) + print( + "Example: merge_baseline_template_parameter.py " + " " + ) sys.exit(2) diff --git a/source/codebuild_scripts/merge_manifest.py b/source/codebuild_scripts/merge_manifest.py index bc83377..ff290db 100644 --- a/source/codebuild_scripts/merge_manifest.py +++ b/source/codebuild_scripts/merge_manifest.py @@ -13,11 +13,12 @@ # governing permissions and limitations under the License. # ############################################################################## -import yaml import sys + +import yaml from cfct.utils.logger import Logger -log_level = 'info' +log_level = "info" logger = Logger(loglevel=log_level) @@ -28,10 +29,10 @@ def update_level_one_list(existing, add_on, level_one_dct_key, decision_key): for add_on_key_level_one_list in add_on.get(level_one_dct_key): flag = False if existing.get(level_one_dct_key): - for existing_key_level_one_list in \ - existing.get(level_one_dct_key): - if add_on_key_level_one_list.get(decision_key) == \ - existing_key_level_one_list.get(decision_key): + for existing_key_level_one_list in existing.get(level_one_dct_key): + if add_on_key_level_one_list.get( + decision_key + ) == existing_key_level_one_list.get(decision_key): flag = False # break the loop if same name is found in the list break @@ -41,16 +42,19 @@ def update_level_one_list(existing, add_on, level_one_dct_key, decision_key): flag = True else: flag = True - if flag and add_on_key_level_one_list not in existing.get(level_one_dct_key): + if flag and add_on_key_level_one_list not in existing.get( + level_one_dct_key + ): # to avoid duplication append check to see if value in # the list already exist - logger.info("(Level 1) Adding new {} > {}: {}" - .format(type(add_on_key_level_one_list) - .__name__, decision_key, - add_on_key_level_one_list - .get(decision_key))) - existing.get(level_one_dct_key) \ - .append(add_on_key_level_one_list) + logger.info( + "(Level 1) Adding new {} > {}: {}".format( + type(add_on_key_level_one_list).__name__, + decision_key, + add_on_key_level_one_list.get(decision_key), + ) + ) + existing.get(level_one_dct_key).append(add_on_key_level_one_list) logger.debug(existing.get(level_one_dct_key)) return existing @@ -68,30 +72,32 @@ def _json_to_yaml(json, filename): # print(yml) # create new manifest file - file = open(filename, 'w') + file = open(filename, "w") file.write(yml) file.close() def update_scp_policies(add_on, original): - level_1_key = 'organization_policies' - decision_key = 'name' + level_1_key = "organization_policies" + decision_key = "name" # process new scp policy addition updated_manifest = update_level_one_list( - original, add_on, level_1_key, decision_key) + original, add_on, level_1_key, decision_key + ) original = _reload(updated_manifest, original) return original def update_cloudformation_resources(add_on, original): - level_1_key = 'cloudformation_resources' - decision_key = 'name' + level_1_key = "cloudformation_resources" + decision_key = "name" # process new baseline addition updated_manifest = update_level_one_list( - original, add_on, level_1_key, decision_key) + original, add_on, level_1_key, decision_key + ) original = _reload(updated_manifest, original) return original @@ -118,8 +124,12 @@ def main(): output_manifest_file_path = sys.argv[3] main() else: - print('No arguments provided. Please provide the existing and' - ' new manifest files names.') - print('Example: merge_manifest.py' - ' ') + print( + "No arguments provided. Please provide the existing and" + " new manifest files names." + ) + print( + "Example: merge_manifest.py" + " " + ) sys.exit(2) diff --git a/source/codebuild_scripts/state_machine_trigger.py b/source/codebuild_scripts/state_machine_trigger.py index 5a603bb..5a900d5 100644 --- a/source/codebuild_scripts/state_machine_trigger.py +++ b/source/codebuild_scripts/state_machine_trigger.py @@ -15,10 +15,11 @@ import os import sys -from cfct.exceptions import StackSetHasFailedInstances -from cfct.utils.logger import Logger + import cfct.manifest.manifest_parser as parse +from cfct.exceptions import StackSetHasFailedInstances from cfct.manifest.sm_execution_manager import SMExecutionManager +from cfct.utils.logger import Logger def main(): @@ -39,48 +40,55 @@ def main(): if len(sys.argv) > 7: # set environment variables - manifest_name = 'manifest.yaml' + manifest_name = "manifest.yaml" file_path = sys.argv[3] - os.environ['WAIT_TIME'] = sys.argv[2] - os.environ['MANIFEST_FILE_PATH'] = file_path - os.environ['SM_ARN'] = sys.argv[4] - os.environ['STAGING_BUCKET'] = sys.argv[5] - os.environ['TEMPLATE_KEY_PREFIX'] = '_custom_ct_templates_staging' - os.environ['MANIFEST_FILE_NAME'] = manifest_name - os.environ['MANIFEST_FOLDER'] = file_path[:-len(manifest_name)] + os.environ["WAIT_TIME"] = sys.argv[2] + os.environ["MANIFEST_FILE_PATH"] = file_path + os.environ["SM_ARN"] = sys.argv[4] + os.environ["STAGING_BUCKET"] = sys.argv[5] + os.environ["TEMPLATE_KEY_PREFIX"] = "_custom_ct_templates_staging" + os.environ["MANIFEST_FILE_NAME"] = manifest_name + os.environ["MANIFEST_FOLDER"] = file_path[: -len(manifest_name)] stage_name = sys.argv[6] - os.environ['STAGE_NAME'] = stage_name - os.environ['KMS_KEY_ALIAS_NAME'] = sys.argv[7] - os.environ['CAPABILITIES'] = '["CAPABILITY_NAMED_IAM","CAPABILITY_AUTO_EXPAND"]' + os.environ["STAGE_NAME"] = stage_name + os.environ["KMS_KEY_ALIAS_NAME"] = sys.argv[7] + os.environ[ + "CAPABILITIES" + ] = '["CAPABILITY_NAMED_IAM","CAPABILITY_AUTO_EXPAND"]' enforce_successful_stack_instances = None if len(sys.argv) > 8: - enforce_successful_stack_instances = True if sys.argv[8] == "true" else False + enforce_successful_stack_instances = ( + True if sys.argv[8] == "true" else False + ) sm_input_list = [] - if stage_name.upper() == 'SCP': + if stage_name.upper() == "SCP": # get SCP state machine input list - os.environ['EXECUTION_MODE'] = 'parallel' + os.environ["EXECUTION_MODE"] = "parallel" sm_input_list = get_scp_inputs() logger.info("SCP sm_input_list:") logger.info(sm_input_list) - elif stage_name.upper() == 'STACKSET': - os.environ['EXECUTION_MODE'] = 'sequential' + elif stage_name.upper() == "STACKSET": + os.environ["EXECUTION_MODE"] = "sequential" sm_input_list = get_stack_set_inputs() logger.info("STACKSET sm_input_list:") logger.info(sm_input_list) if sm_input_list: logger.info("=== Launching State Machine Execution ===") - launch_state_machine_execution(sm_input_list, enforce_successful_stack_instances) + launch_state_machine_execution( + sm_input_list, enforce_successful_stack_instances + ) else: - logger.info("State Machine input list is empty. No action " - "required.") + logger.info("State Machine input list is empty. No action " "required.") else: - print('No arguments provided. ') - print('Example: state_machine_trigger.py ' - ' ' - ' ') + print("No arguments provided. ") + print( + "Example: state_machine_trigger.py " + " " + " " + ) sys.exit(2) except Exception as e: logger.log_unhandled_exception(e) @@ -95,19 +103,23 @@ def get_stack_set_inputs() -> list: return parse.stack_set_manifest() -def launch_state_machine_execution(sm_input_list, enforce_successful_stack_instances=False): +def launch_state_machine_execution( + sm_input_list, enforce_successful_stack_instances=False +): if isinstance(sm_input_list, list): - manager = SMExecutionManager(logger, sm_input_list, enforce_successful_stack_instances) + manager = SMExecutionManager( + logger, sm_input_list, enforce_successful_stack_instances + ) try: status, failed_list = manager.launch_executions() except StackSetHasFailedInstances as error: logger.error(f"{error.stack_set_name} has following failed instances:") for instance in error.failed_stack_set_instances: message = { - "StackID": instance['StackId'], - "Account": instance['Account'], - "Region": instance['Region'], - "StatusReason": instance['StatusReason'] + "StackID": instance["StackId"], + "Account": instance["Account"], + "Region": instance["Region"], + "StatusReason": instance["StatusReason"], } logger.error(message) sys.exit(1) @@ -115,18 +127,20 @@ def launch_state_machine_execution(sm_input_list, enforce_successful_stack_insta else: raise TypeError("State Machine Input List must be of list type") - if status == 'FAILED': + if status == "FAILED": logger.error( "\n********************************************************" "\nState Machine Execution(s) Failed. \nNavigate to the " "AWS Step Functions console \nand review the following " "State Machine Executions.\nARN List:\n" - "{}\n********************************************************" - .format(failed_list)) + "{}\n********************************************************".format( + failed_list + ) + ) sys.exit(1) -if __name__ == '__main__': - os.environ['LOG_LEVEL'] = sys.argv[1] - logger = Logger(loglevel=os.environ['LOG_LEVEL']) +if __name__ == "__main__": + os.environ["LOG_LEVEL"] = sys.argv[1] + logger = Logger(loglevel=os.environ["LOG_LEVEL"]) main() diff --git a/source/src/cfct/aws/services/cloudformation.py b/source/src/cfct/aws/services/cloudformation.py index b54a241..e161bd6 100644 --- a/source/src/cfct/aws/services/cloudformation.py +++ b/source/src/cfct/aws/services/cloudformation.py @@ -15,44 +15,50 @@ # !/bin/python +import json import os +from typing import Any, Dict, List + from botocore.exceptions import ClientError -from cfct.utils.retry_decorator import try_except_retry from cfct.aws.utils.boto3_session import Boto3Session -from cfct.types import StackSetInstanceTypeDef, StackSetRequestTypeDef, ResourcePropertiesTypeDef -import json - -from typing import Dict, List, Any +from cfct.types import ( + ResourcePropertiesTypeDef, + StackSetInstanceTypeDef, + StackSetRequestTypeDef, +) +from cfct.utils.retry_decorator import try_except_retry class StackSet(Boto3Session): - DEPLOYED_BY_CFCT_TAG = {"Key": "AWS_Solutions", "Value": "CustomControlTowerStackSet"} + DEPLOYED_BY_CFCT_TAG = { + "Key": "AWS_Solutions", + "Value": "CustomControlTowerStackSet", + } CFCT_STACK_SET_PREFIX = "CustomControlTower-" DEPLOY_METHOD = "stack_set" def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'cloudformation' - self.max_concurrent_percent = int( - os.environ.get('MAX_CONCURRENT_PERCENT', 100)) + __service_name = "cloudformation" + self.max_concurrent_percent = int(os.environ.get("MAX_CONCURRENT_PERCENT", 100)) self.failed_tolerance_percent = int( - os.environ.get('FAILED_TOLERANCE_PERCENT', 10)) + os.environ.get("FAILED_TOLERANCE_PERCENT", 10) + ) self.region_concurrency_type = os.environ.get( - 'REGION_CONCURRENCY_TYPE', 'PARALLEL').upper() + "REGION_CONCURRENCY_TYPE", "PARALLEL" + ).upper() self.max_results_per_page = 20 super().__init__(logger, __service_name, **kwargs) self.cfn_client = super().get_client() - self.operation_in_progress_except_msg = \ - 'Caught exception OperationInProgressException' \ - ' handling the exception...' + self.operation_in_progress_except_msg = ( + "Caught exception OperationInProgressException" " handling the exception..." + ) @try_except_retry() def describe_stack_set(self, stack_set_name): try: - response = self.cfn_client.describe_stack_set( - StackSetName=stack_set_name - ) + response = self.cfn_client.describe_stack_set(StackSetName=stack_set_name) return response except self.cfn_client.exceptions.StackSetNotFoundException: pass @@ -64,13 +70,15 @@ def describe_stack_set(self, stack_set_name): def describe_stack_set_operation(self, stack_set_name, operation_id): try: response = self.cfn_client.describe_stack_set_operation( - StackSetName=stack_set_name, - OperationId=operation_id + StackSetName=stack_set_name, OperationId=operation_id ) return response except ClientError as e: - self.logger.error("'{}' StackSet Operation ID: {} not found." - .format(stack_set_name, operation_id)) + self.logger.error( + "'{}' StackSet Operation ID: {} not found.".format( + stack_set_name, operation_id + ) + ) self.logger.log_unhandled_exception(e) raise @@ -94,37 +102,53 @@ def get_accounts_and_regions_per_stack_set(self, stack_name): """ try: response = self.cfn_client.list_stack_instances( - StackSetName=stack_name, - MaxResults=self.max_results_per_page + StackSetName=stack_name, MaxResults=self.max_results_per_page ) - stack_instance_list = response.get('Summaries', []) + stack_instance_list = response.get("Summaries", []) # build the account and region list for the stack set # using list(set(LIST)) to remove the duplicate values from the list - account_list = list(set([stack_instance['Account'] - for stack_instance - in stack_instance_list])) - region_list = list(set([stack_instance['Region'] - for stack_instance - in stack_instance_list])) - next_token = response.get('NextToken', None) + account_list = list( + set( + [ + stack_instance["Account"] + for stack_instance in stack_instance_list + ] + ) + ) + region_list = list( + set( + [stack_instance["Region"] for stack_instance in stack_instance_list] + ) + ) + next_token = response.get("NextToken", None) while next_token is not None: self.logger.info("Next Token Returned: {}".format(next_token)) response = self.cfn_client.list_stack_instances( StackSetName=stack_name, MaxResults=self.max_results_per_page, - NextToken=next_token + NextToken=next_token, ) - stack_instance_list = response.get('Summaries', []) - next_token = response.get('NextToken', None) + stack_instance_list = response.get("Summaries", []) + next_token = response.get("NextToken", None) # update account and region lists - additional_account_list = list(set([stack_instance['Account'] - for stack_instance in - stack_instance_list])) - additional_region_list = list(set([stack_instance['Region'] - for stack_instance - in stack_instance_list])) + additional_account_list = list( + set( + [ + stack_instance["Account"] + for stack_instance in stack_instance_list + ] + ) + ) + additional_region_list = list( + set( + [ + stack_instance["Region"] + for stack_instance in stack_instance_list + ] + ) + ) account_list = account_list + additional_account_list region_list = region_list + additional_region_list return list(set(account_list)), list(set(region_list)) @@ -132,8 +156,9 @@ def get_accounts_and_regions_per_stack_set(self, stack_name): self.logger.log_unhandled_exception(e) raise - def create_stack_set(self, stack_set_name, template_url, - cf_params, capabilities, tag_key, tag_value): + def create_stack_set( + self, stack_set_name, template_url, cf_params, capabilities, tag_key, tag_value + ): try: parameters = [] param_dict = {} @@ -146,8 +171,8 @@ def create_stack_set(self, stack_set_name, template_url, if type(value) == list: value = ",".join(map(str, value)) - param_dict['ParameterKey'] = key - param_dict['ParameterValue'] = value + param_dict["ParameterKey"] = key + param_dict["ParameterValue"] = value parameters.append(param_dict.copy()) response = self.cfn_client.create_stack_set( @@ -156,14 +181,10 @@ def create_stack_set(self, stack_set_name, template_url, Parameters=parameters, Capabilities=json.loads(capabilities), Tags=[ - { - 'Key': tag_key, - 'Value': tag_value - }, + {"Key": tag_key, "Value": tag_value}, ], - AdministrationRoleARN=os.environ.get( - 'ADMINISTRATION_ROLE_ARN'), - ExecutionRoleName=os.environ.get('EXECUTION_ROLE_NAME') + AdministrationRoleARN=os.environ.get("ADMINISTRATION_ROLE_ARN"), + ExecutionRoleName=os.environ.get("EXECUTION_ROLE_NAME"), ) return response except ClientError as e: @@ -177,23 +198,23 @@ def create_stack_instances(self, stack_set_name, account_list, region_list): Accounts=account_list, Regions=region_list, OperationPreferences={ - 'FailureTolerancePercentage': self.failed_tolerance_percent, - 'MaxConcurrentPercentage': self.max_concurrent_percent, - 'RegionConcurrencyType': self.region_concurrency_type - } + "FailureTolerancePercentage": self.failed_tolerance_percent, + "MaxConcurrentPercentage": self.max_concurrent_percent, + "RegionConcurrencyType": self.region_concurrency_type, + }, ) return response except ClientError as e: - if e.response['Error']['Code'] == 'OperationInProgressException': + if e.response["Error"]["Code"] == "OperationInProgressException": self.logger.info(self.operation_in_progress_except_msg) return {"OperationId": "OperationInProgressException"} else: self.logger.log_unhandled_exception(e) raise - def create_stack_instances_with_override_params(self, stack_set_name, - account_list, region_list, - override_params): + def create_stack_instances_with_override_params( + self, stack_set_name, account_list, region_list, override_params + ): try: parameters = [] param_dict = {} @@ -206,8 +227,8 @@ def create_stack_instances_with_override_params(self, stack_set_name, if type(value) == list: value = ",".join(map(str, value)) - param_dict['ParameterKey'] = key - param_dict['ParameterValue'] = value + param_dict["ParameterKey"] = key + param_dict["ParameterValue"] = value parameters.append(param_dict.copy()) response = self.cfn_client.create_stack_instances( @@ -216,24 +237,27 @@ def create_stack_instances_with_override_params(self, stack_set_name, Regions=region_list, ParameterOverrides=parameters, OperationPreferences={ - 'FailureTolerancePercentage': self.failed_tolerance_percent, - 'MaxConcurrentPercentage': self.max_concurrent_percent, - 'RegionConcurrencyType': self.region_concurrency_type - } + "FailureTolerancePercentage": self.failed_tolerance_percent, + "MaxConcurrentPercentage": self.max_concurrent_percent, + "RegionConcurrencyType": self.region_concurrency_type, + }, ) return response except ClientError as e: - if e.response['Error']['Code'] == 'OperationInProgressException': - self.logger.info("Caught exception " - "'OperationInProgressException', " - "handling the exception...") + if e.response["Error"]["Code"] == "OperationInProgressException": + self.logger.info( + "Caught exception " + "'OperationInProgressException', " + "handling the exception..." + ) return {"OperationId": "OperationInProgressException"} else: self.logger.log_unhandled_exception(e) raise - def update_stack_instances(self, stack_set_name, account_list, region_list, - override_params): + def update_stack_instances( + self, stack_set_name, account_list, region_list, override_params + ): try: parameters = [] param_dict = {} @@ -246,8 +270,8 @@ def update_stack_instances(self, stack_set_name, account_list, region_list, if type(value) == list: value = ",".join(map(str, value)) - param_dict['ParameterKey'] = key - param_dict['ParameterValue'] = value + param_dict["ParameterKey"] = key + param_dict["ParameterValue"] = value parameters.append(param_dict.copy()) response = self.cfn_client.update_stack_instances( @@ -256,22 +280,21 @@ def update_stack_instances(self, stack_set_name, account_list, region_list, Regions=region_list, ParameterOverrides=parameters, OperationPreferences={ - 'FailureTolerancePercentage': self.failed_tolerance_percent, - 'MaxConcurrentPercentage': self.max_concurrent_percent, - 'RegionConcurrencyType': self.region_concurrency_type - } + "FailureTolerancePercentage": self.failed_tolerance_percent, + "MaxConcurrentPercentage": self.max_concurrent_percent, + "RegionConcurrencyType": self.region_concurrency_type, + }, ) return response except ClientError as e: - if e.response['Error']['Code'] == 'OperationInProgressException': + if e.response["Error"]["Code"] == "OperationInProgressException": self.logger.info(self.operation_in_progress_except_msg) return {"OperationId": "OperationInProgressException"} else: self.logger.log_unhandled_exception(e) raise - def update_stack_set(self, stack_set_name, parameter, template_url, - capabilities): + def update_stack_set(self, stack_set_name, parameter, template_url, capabilities): try: parameters = [] param_dict = {} @@ -284,8 +307,8 @@ def update_stack_set(self, stack_set_name, parameter, template_url, if type(value) == list: value = ",".join(map(str, value)) - param_dict['ParameterKey'] = key - param_dict['ParameterValue'] = value + param_dict["ParameterKey"] = key + param_dict["ParameterValue"] = value parameters.append(param_dict.copy()) response = self.cfn_client.update_stack_set( @@ -293,18 +316,17 @@ def update_stack_set(self, stack_set_name, parameter, template_url, TemplateURL=template_url, Parameters=parameters, Capabilities=json.loads(capabilities), - AdministrationRoleARN=os.environ.get( - 'ADMINISTRATION_ROLE_ARN'), - ExecutionRoleName=os.environ.get('EXECUTION_ROLE_NAME'), + AdministrationRoleARN=os.environ.get("ADMINISTRATION_ROLE_ARN"), + ExecutionRoleName=os.environ.get("EXECUTION_ROLE_NAME"), OperationPreferences={ - 'FailureTolerancePercentage': self.failed_tolerance_percent, - 'MaxConcurrentPercentage': self.max_concurrent_percent, - 'RegionConcurrencyType': self.region_concurrency_type - } + "FailureTolerancePercentage": self.failed_tolerance_percent, + "MaxConcurrentPercentage": self.max_concurrent_percent, + "RegionConcurrencyType": self.region_concurrency_type, + }, ) return response except ClientError as e: - if e.response['Error']['Code'] == 'OperationInProgressException': + if e.response["Error"]["Code"] == "OperationInProgressException": self.logger.info(self.operation_in_progress_except_msg) return {"OperationId": "OperationInProgressException"} else: @@ -321,8 +343,9 @@ def delete_stack_set(self, stack_set_name): self.logger.log_unhandled_exception(e) raise - def delete_stack_instances(self, stack_set_name, account_list, region_list, - retain_condition=False): + def delete_stack_instances( + self, stack_set_name, account_list, region_list, retain_condition=False + ): try: response = self.cfn_client.delete_stack_instances( StackSetName=stack_set_name, @@ -330,14 +353,14 @@ def delete_stack_instances(self, stack_set_name, account_list, region_list, Regions=region_list, RetainStacks=retain_condition, OperationPreferences={ - 'FailureTolerancePercentage': self.failed_tolerance_percent, - 'MaxConcurrentPercentage': self.max_concurrent_percent, - 'RegionConcurrencyType': self.region_concurrency_type - } + "FailureTolerancePercentage": self.failed_tolerance_percent, + "MaxConcurrentPercentage": self.max_concurrent_percent, + "RegionConcurrencyType": self.region_concurrency_type, + }, ) return response except ClientError as e: - if e.response['Error']['Code'] == 'OperationInProgressException': + if e.response["Error"]["Code"] == "OperationInProgressException": self.logger.info(self.operation_in_progress_except_msg) return {"OperationId": "OperationInProgressException"} else: @@ -350,7 +373,7 @@ def describe_stack_instance(self, stack_set_name, account_id, region): response = self.cfn_client.describe_stack_instance( StackSetName=stack_set_name, StackInstanceAccount=account_id, - StackInstanceRegion=region + StackInstanceRegion=region, ) return response except ClientError as e: @@ -366,23 +389,27 @@ def list_stack_set_operations(self, **kwargs): self.logger.log_unhandled_exception(e) raise - def _filter_managed_stack_set_names(self, list_stackset_response: Dict[str, Any]) -> List[str]: + def _filter_managed_stack_set_names( + self, list_stackset_response: Dict[str, Any] + ) -> List[str]: """ Reduces a list of given stackset summaries to only those considered managed by CfCT """ managed_stack_set_names: List[str] = [] - for summary in list_stackset_response['Summaries']: - stack_set_name = summary['StackSetName'] + for summary in list_stackset_response["Summaries"]: + stack_set_name = summary["StackSetName"] try: - response: Dict[str, Any] = self.cfn_client.describe_stack_set(StackSetName=stack_set_name) + response: Dict[str, Any] = self.cfn_client.describe_stack_set( + StackSetName=stack_set_name + ) except ClientError as error: - if error.response['Error']['Code'] == "StackSetNotFoundException": + if error.response["Error"]["Code"] == "StackSetNotFoundException": continue raise if self.is_managed_by_cfct(describe_stackset_response=response): managed_stack_set_names.append(stack_set_name) - + return managed_stack_set_names def get_managed_stack_set_names(self) -> List[str]: @@ -394,20 +421,29 @@ def get_managed_stack_set_names(self) -> List[str]: managed_stackset_names: List[str] = [] paginator = self.cfn_client.get_paginator("list_stack_sets") for page in paginator.paginate(Status="ACTIVE"): - managed_stackset_names.extend(self._filter_managed_stack_set_names(list_stackset_response=page)) + managed_stackset_names.extend( + self._filter_managed_stack_set_names(list_stackset_response=page) + ) return managed_stackset_names def is_managed_by_cfct(self, describe_stackset_response: Dict[str, Any]) -> bool: """ A StackSet is considered managed if it has both the prefix we expect, and the proper tag """ - - has_tag = StackSet.DEPLOYED_BY_CFCT_TAG in describe_stackset_response['StackSet']['Tags'] - has_prefix = describe_stackset_response['StackSet']['StackSetName'].startswith(StackSet.CFCT_STACK_SET_PREFIX) - is_active = describe_stackset_response['StackSet']['Status'] == "ACTIVE" + + has_tag = ( + StackSet.DEPLOYED_BY_CFCT_TAG + in describe_stackset_response["StackSet"]["Tags"] + ) + has_prefix = describe_stackset_response["StackSet"]["StackSetName"].startswith( + StackSet.CFCT_STACK_SET_PREFIX + ) + is_active = describe_stackset_response["StackSet"]["Status"] == "ACTIVE" return all((has_prefix, has_tag, is_active)) - def get_stack_sets_not_present_in_manifest(self, manifest_stack_sets: List[str]) -> List[str]: + def get_stack_sets_not_present_in_manifest( + self, manifest_stack_sets: List[str] + ) -> List[str]: """ Compares list of stacksets defined in the manifest versus the stacksets in the account and returns a list of all stackset names to be deleted @@ -415,39 +451,60 @@ def get_stack_sets_not_present_in_manifest(self, manifest_stack_sets: List[str]) # Stack sets defined in the manifest will not have the CFCT_STACK_SET_PREFIX # To make comparisons simpler - manifest_stack_sets_with_prefix = [f"{StackSet.CFCT_STACK_SET_PREFIX}{name}" for name in manifest_stack_sets] + manifest_stack_sets_with_prefix = [ + f"{StackSet.CFCT_STACK_SET_PREFIX}{name}" for name in manifest_stack_sets + ] cfct_deployed_stack_sets = self.get_managed_stack_set_names() - return list(set(cfct_deployed_stack_sets).difference(set(manifest_stack_sets_with_prefix))) + return list( + set(cfct_deployed_stack_sets).difference( + set(manifest_stack_sets_with_prefix) + ) + ) - def generate_delete_request(self, stacksets_to_delete: List[str]) -> List[StackSetRequestTypeDef]: + def generate_delete_request( + self, stacksets_to_delete: List[str] + ) -> List[StackSetRequestTypeDef]: requests: List[StackSetRequestTypeDef] = [] for stackset_name in stacksets_to_delete: - deployed_instances = self._get_stackset_instances(stackset_name=stackset_name) - requests.append(StackSetRequestTypeDef( - RequestType="Delete", - ResourceProperties=ResourcePropertiesTypeDef( - StackSetName=stackset_name, - TemplateURL="DeleteStackSetNoopURL", - Capabilities=json.dumps(["CAPABILITY_NAMED_IAM", "CAPABILITY_AUTO_EXPAND"]), - Parameters={}, - AccountList=list({instance['account'] for instance in deployed_instances}), - RegionList=list({instance['region'] for instance in deployed_instances}), - SSMParameters={} - ), - SkipUpdateStackSet="yes", - )) + deployed_instances = self._get_stackset_instances( + stackset_name=stackset_name + ) + requests.append( + StackSetRequestTypeDef( + RequestType="Delete", + ResourceProperties=ResourcePropertiesTypeDef( + StackSetName=stackset_name, + TemplateURL="DeleteStackSetNoopURL", + Capabilities=json.dumps( + ["CAPABILITY_NAMED_IAM", "CAPABILITY_AUTO_EXPAND"] + ), + Parameters={}, + AccountList=list( + {instance["account"] for instance in deployed_instances} + ), + RegionList=list( + {instance["region"] for instance in deployed_instances} + ), + SSMParameters={}, + ), + SkipUpdateStackSet="yes", + ) + ) return requests - - def _get_stackset_instances(self, stackset_name: str) -> List[StackSetInstanceTypeDef]: + def _get_stackset_instances( + self, stackset_name: str + ) -> List[StackSetInstanceTypeDef]: instance_regions_and_accounts: List[StackSetInstanceTypeDef] = [] paginator = self.cfn_client.get_paginator("list_stack_instances") for page in paginator.paginate(StackSetName=stackset_name): - for summary in page['Summaries']: - instance_regions_and_accounts.append(StackSetInstanceTypeDef( - account=summary['Account'], - region=summary['Region'], - )) + for summary in page["Summaries"]: + instance_regions_and_accounts.append( + StackSetInstanceTypeDef( + account=summary["Account"], + region=summary["Region"], + ) + ) return instance_regions_and_accounts @@ -455,17 +512,15 @@ def _get_stackset_instances(self, stackset_name: str) -> List[StackSetInstanceTy class Stacks(Boto3Session): def __init__(self, logger, region, **kwargs): self.logger = logger - __service_name = 'cloudformation' - kwargs.update({'region': region}) + __service_name = "cloudformation" + kwargs.update({"region": region}) super().__init__(logger, __service_name, **kwargs) self.cfn_client = super().get_client() @try_except_retry() def describe_stacks(self, stack_name): try: - response = self.cfn_client.describe_stacks( - StackName=stack_name - ) + response = self.cfn_client.describe_stacks(StackName=stack_name) return response except ClientError as e: self.logger.log_unhandled_exception(e) diff --git a/source/src/cfct/aws/services/code_pipeline.py b/source/src/cfct/aws/services/code_pipeline.py index 5155a65..5386aaf 100644 --- a/source/src/cfct/aws/services/code_pipeline.py +++ b/source/src/cfct/aws/services/code_pipeline.py @@ -15,17 +15,19 @@ # !/bin/python import inspect + from botocore.exceptions import ClientError from cfct.aws.utils.boto3_session import Boto3Session class CodePipeline(Boto3Session): """This class make code pipeline API calls such as starts code pipeline - execution, etc. + execution, etc. """ + def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'codepipeline' + __service_name = "codepipeline" super().__init__(logger, __service_name, **kwargs) self.code_pipeline = super().get_client() diff --git a/source/src/cfct/aws/services/ec2.py b/source/src/cfct/aws/services/ec2.py index acddecd..7d7aca2 100644 --- a/source/src/cfct/aws/services/ec2.py +++ b/source/src/cfct/aws/services/ec2.py @@ -22,31 +22,24 @@ class EC2(Boto3Session): def __init__(self, logger, region, **kwargs): self.logger = logger - __service_name = 'ec2' - kwargs.update({'region': region}) + __service_name = "ec2" + kwargs.update({"region": region}) super().__init__(logger, __service_name, **kwargs) self.ec2_client = super().get_client() - def describe_availability_zones(self, name='state', value='available'): + def describe_availability_zones(self, name="state", value="available"): try: response = self.ec2_client.describe_availability_zones( - Filters=[ - { - 'Name': name, - 'Values': [value] - } - ] + Filters=[{"Name": name, "Values": [value]}] ) - return [resp['ZoneName'] for resp in response['AvailabilityZones']] + return [resp["ZoneName"] for resp in response["AvailabilityZones"]] except ClientError as e: self.logger.log_unhandled_exception(e) raise def create_key_pair(self, key_name): try: - response = self.ec2_client.create_key_pair( - KeyName=key_name - ) + response = self.ec2_client.create_key_pair(KeyName=key_name) return response except ClientError as e: self.logger.log_unhandled_exception(e) diff --git a/source/src/cfct/aws/services/kms.py b/source/src/cfct/aws/services/kms.py index 5b59f3e..1d27c3a 100644 --- a/source/src/cfct/aws/services/kms.py +++ b/source/src/cfct/aws/services/kms.py @@ -20,20 +20,18 @@ class KMS(Boto3Session): - """This class makes KMS API calls as needed. - """ + """This class makes KMS API calls as needed.""" + def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'kms' + __service_name = "kms" super().__init__(logger, __service_name, **kwargs) self.kms_client = super().get_client() def describe_key(self, alias_name): try: - key_id = 'alias/' + alias_name - response = self.kms_client.describe_key( - KeyId=key_id - ) + key_id = "alias/" + alias_name + response = self.kms_client.describe_key(KeyId=key_id) return response except ClientError as e: self.logger.log_unhandled_exception(e) @@ -44,15 +42,12 @@ def create_key(self, policy, description, tag_key, tag_value): response = self.kms_client.create_key( Policy=policy, Description=description, - KeyUsage='ENCRYPT_DECRYPT', - Origin='AWS_KMS', + KeyUsage="ENCRYPT_DECRYPT", + Origin="AWS_KMS", BypassPolicyLockoutSafetyCheck=True, Tags=[ - { - 'TagKey': tag_key, - 'TagValue': tag_value - }, - ] + {"TagKey": tag_key, "TagValue": tag_value}, + ], ) return response except ClientError as e: @@ -62,8 +57,7 @@ def create_key(self, policy, description, tag_key, tag_value): def create_alias(self, alias_name, key_name): try: response = self.kms_client.create_alias( - AliasName=alias_name, - TargetKeyId=key_name + AliasName=alias_name, TargetKeyId=key_name ) return response except ClientError as e: @@ -87,8 +81,8 @@ def put_key_policy(self, key_id, policy): KeyId=key_id, Policy=policy, # Per API docs, the only valid value is default. - PolicyName='default', - BypassPolicyLockoutSafetyCheck=True + PolicyName="default", + BypassPolicyLockoutSafetyCheck=True, ) return response except ClientError as e: @@ -100,7 +94,7 @@ def enable_key_rotation(self, key_id): response = self.get_key_rotation_status(key_id) # Enable auto key rotation only if it hasn't been enabled - if not response.get('KeyRotationEnabled'): + if not response.get("KeyRotationEnabled"): self.kms_client.enable_key_rotation(KeyId=key_id) return response except ClientError as e: @@ -109,9 +103,7 @@ def enable_key_rotation(self, key_id): def get_key_rotation_status(self, key_id): try: - response = self.kms_client.get_key_rotation_status( - KeyId=key_id - ) + response = self.kms_client.get_key_rotation_status(KeyId=key_id) return response except ClientError as e: self.logger.log_unhandled_exception(e) diff --git a/source/src/cfct/aws/services/organizations.py b/source/src/cfct/aws/services/organizations.py index c40fb53..5bf4762 100644 --- a/source/src/cfct/aws/services/organizations.py +++ b/source/src/cfct/aws/services/organizations.py @@ -16,14 +16,14 @@ # !/bin/python from botocore.exceptions import ClientError -from cfct.utils.retry_decorator import try_except_retry from cfct.aws.utils.boto3_session import Boto3Session +from cfct.utils.retry_decorator import try_except_retry class Organizations(Boto3Session): def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'organizations' + __service_name = "organizations" super().__init__(logger, __service_name, **kwargs) self.org_client = super().get_client() self.next_token_returned_msg = "Next Token Returned: {}" @@ -41,19 +41,17 @@ def list_organizational_units_for_parent(self, parent_id): ParentId=parent_id ) - ou_list = response.get('OrganizationalUnits', []) - next_token = response.get('NextToken', None) + ou_list = response.get("OrganizationalUnits", []) + next_token = response.get("NextToken", None) while next_token is not None: - self.logger.info(self.next_token_returned_msg - .format(next_token)) - response = self.org_client \ - .list_organizational_units_for_parent( - ParentId=parent_id, - NextToken=next_token) + self.logger.info(self.next_token_returned_msg.format(next_token)) + response = self.org_client.list_organizational_units_for_parent( + ParentId=parent_id, NextToken=next_token + ) self.logger.info("Extending OU List") - ou_list.extend(response.get('OrganizationalUnits', [])) - next_token = response.get('NextToken', None) + ou_list.extend(response.get("OrganizationalUnits", [])) + next_token = response.get("NextToken", None) return ou_list except ClientError as e: @@ -62,23 +60,19 @@ def list_organizational_units_for_parent(self, parent_id): def list_accounts_for_parent(self, parent_id): try: - response = self.org_client.list_accounts_for_parent( - ParentId=parent_id - ) + response = self.org_client.list_accounts_for_parent(ParentId=parent_id) - account_list = response.get('Accounts', []) - next_token = response.get('NextToken', None) + account_list = response.get("Accounts", []) + next_token = response.get("NextToken", None) while next_token is not None: - self.logger.info(self.next_token_returned_msg - .format(next_token)) + self.logger.info(self.next_token_returned_msg.format(next_token)) response = self.org_client.list_accounts_for_parent( - ParentId=parent_id, - NextToken=next_token + ParentId=parent_id, NextToken=next_token ) self.logger.info("Extending Account List") - account_list.extend(response.get('Accounts', [])) - next_token = response.get('NextToken', None) + account_list.extend(response.get("Accounts", [])) + next_token = response.get("NextToken", None) return account_list except ClientError as e: @@ -97,18 +91,15 @@ def get_accounts_in_org(self): try: response = self.org_client.list_accounts() - account_list = response.get('Accounts', []) - next_token = response.get('NextToken', None) + account_list = response.get("Accounts", []) + next_token = response.get("NextToken", None) while next_token is not None: - self.logger.info(self.next_token_returned_msg - .format(next_token)) - response = self.org_client.list_accounts( - NextToken=next_token - ) + self.logger.info(self.next_token_returned_msg.format(next_token)) + response = self.org_client.list_accounts(NextToken=next_token) self.logger.info("Extending Account List") - account_list.extend(response.get('Accounts', [])) - next_token = response.get('NextToken', None) + account_list.extend(response.get("Accounts", [])) + next_token = response.get("NextToken", None) return account_list except ClientError as e: @@ -121,4 +112,4 @@ def describe_organization(self): return response except ClientError as e: self.logger.log_unhandled_exception(e) - raise \ No newline at end of file + raise diff --git a/source/src/cfct/aws/services/s3.py b/source/src/cfct/aws/services/s3.py index ee1abe9..9d6959d 100644 --- a/source/src/cfct/aws/services/s3.py +++ b/source/src/cfct/aws/services/s3.py @@ -16,6 +16,7 @@ # !/bin/python import tempfile + from botocore.exceptions import ClientError from cfct.aws.utils.boto3_session import Boto3Session @@ -23,16 +24,14 @@ class S3(Boto3Session): def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 's3' + __service_name = "s3" super().__init__(logger, __service_name, **kwargs) self.s3_client = super().get_client() self.s3_resource = super().get_resource() def get_bucket_policy(self, bucket_name): try: - response = self.s3_client.get_bucket_policy( - Bucket=bucket_name - ) + response = self.s3_client.get_bucket_policy(Bucket=bucket_name) return response except ClientError as e: self.logger.log_unhandled_exception(e) @@ -41,25 +40,24 @@ def get_bucket_policy(self, bucket_name): def put_bucket_policy(self, bucket_name, bucket_policy): try: response = self.s3_client.put_bucket_policy( - Bucket=bucket_name, - Policy=bucket_policy + Bucket=bucket_name, Policy=bucket_policy ) return response except ClientError as e: self.logger.log_unhandled_exception(e) raise - def upload_file(self, bucket_name, local_file_location, - remote_file_location): + def upload_file(self, bucket_name, local_file_location, remote_file_location): try: self.s3_resource.Bucket(bucket_name).upload_file( - local_file_location, remote_file_location) + local_file_location, remote_file_location + ) except ClientError as e: self.logger.log_unhandled_exception(e) raise def download_file(self, bucket_name, key_name, local_file_location): - """ This function downloads the file from the S3 bucket for a given + """This function downloads the file from the S3 bucket for a given S3 path in the method attribute. Use Cases: @@ -72,12 +70,13 @@ def download_file(self, bucket_name, key_name, local_file_location): """ try: self.logger.info( - "Downloading {}/{} from S3 to {}".format(bucket_name, - key_name, - local_file_location)) - self.s3_resource\ - .Bucket(bucket_name).download_file(key_name, - local_file_location) + "Downloading {}/{} from S3 to {}".format( + bucket_name, key_name, local_file_location + ) + ) + self.s3_resource.Bucket(bucket_name).download_file( + key_name, local_file_location + ) except ClientError as e: self.logger.log_unhandled_exception(e) raise @@ -87,15 +86,15 @@ def put_bucket_encryption(self, bucket_name, key_id): self.s3_client.put_bucket_encryption( Bucket=bucket_name, ServerSideEncryptionConfiguration={ - 'Rules': [ + "Rules": [ { - 'ApplyServerSideEncryptionByDefault': { - 'SSEAlgorithm': 'aws:kms', - 'KMSMasterKeyID': key_id + "ApplyServerSideEncryptionByDefault": { + "SSEAlgorithm": "aws:kms", + "KMSMasterKeyID": key_id, } }, ] - } + }, ) except ClientError as e: @@ -106,7 +105,7 @@ def list_buckets(self): return self.s3_client.list_buckets() def get_s3_object(self, remote_s3_url): - """ This function downloads the file from the S3 bucket for a given + """This function downloads the file from the S3 bucket for a given S3 path in the method attribute. :param remote_s3_url: s3://bucket-name/key-name diff --git a/source/src/cfct/aws/services/scp.py b/source/src/cfct/aws/services/scp.py index c8e838b..50dac8d 100644 --- a/source/src/cfct/aws/services/scp.py +++ b/source/src/cfct/aws/services/scp.py @@ -22,18 +22,16 @@ class ServiceControlPolicy(Boto3Session): def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'organizations' + __service_name = "organizations" super().__init__(logger, __service_name, **kwargs) self.org_client = super().get_client() def list_policies(self, page_size=20): try: - paginator = self.org_client.get_paginator('list_policies') + paginator = self.org_client.get_paginator("list_policies") response_iterator = paginator.paginate( - Filter='SERVICE_CONTROL_POLICY', - PaginationConfig={ - 'PageSize': page_size - } + Filter="SERVICE_CONTROL_POLICY", + PaginationConfig={"PageSize": page_size}, ) return response_iterator except ClientError as e: @@ -42,14 +40,11 @@ def list_policies(self, page_size=20): def list_policies_for_target(self, target_id, page_size=20): try: - paginator = self.org_client\ - .get_paginator('list_policies_for_target') + paginator = self.org_client.get_paginator("list_policies_for_target") response_iterator = paginator.paginate( TargetId=target_id, - Filter='SERVICE_CONTROL_POLICY', - PaginationConfig={ - 'PageSize': page_size - } + Filter="SERVICE_CONTROL_POLICY", + PaginationConfig={"PageSize": page_size}, ) return response_iterator except ClientError as e: @@ -58,13 +53,9 @@ def list_policies_for_target(self, target_id, page_size=20): def list_targets_for_policy(self, policy_id, page_size=20): try: - paginator = self.org_client.get_paginator( - 'list_targets_for_policy') + paginator = self.org_client.get_paginator("list_targets_for_policy") response_iterator = paginator.paginate( - PolicyId=policy_id, - PaginationConfig={ - 'PageSize': page_size - } + PolicyId=policy_id, PaginationConfig={"PageSize": page_size} ) return response_iterator except ClientError as e: @@ -77,7 +68,7 @@ def create_policy(self, name, description, content): Content=content, Description=description, Name=name, - Type='SERVICE_CONTROL_POLICY' + Type="SERVICE_CONTROL_POLICY", ) return response except ClientError as e: @@ -87,10 +78,7 @@ def create_policy(self, name, description, content): def update_policy(self, policy_id, name, description, content): try: response = self.org_client.update_policy( - PolicyId=policy_id, - Name=name, - Description=description, - Content=content + PolicyId=policy_id, Name=name, Description=description, Content=content ) return response except ClientError as e: @@ -99,25 +87,21 @@ def update_policy(self, policy_id, name, description, content): def delete_policy(self, policy_id): try: - self.org_client.delete_policy( - PolicyId=policy_id - ) + self.org_client.delete_policy(PolicyId=policy_id) except ClientError as e: self.logger.log_unhandled_exception(e) raise def attach_policy(self, policy_id, target_id): try: - self.org_client.attach_policy( - PolicyId=policy_id, - TargetId=target_id - ) + self.org_client.attach_policy(PolicyId=policy_id, TargetId=target_id) except ClientError as e: - if e.response['Error']['Code'] == \ - 'DuplicatePolicyAttachmentException': - self.logger.exception("Caught exception " - "'DuplicatePolicyAttachmentException', " - "taking no action...") + if e.response["Error"]["Code"] == "DuplicatePolicyAttachmentException": + self.logger.exception( + "Caught exception " + "'DuplicatePolicyAttachmentException', " + "taking no action..." + ) return else: self.logger.log_unhandled_exception(e) @@ -125,15 +109,14 @@ def attach_policy(self, policy_id, target_id): def detach_policy(self, policy_id, target_id): try: - self.org_client.detach_policy( - PolicyId=policy_id, - TargetId=target_id - ) + self.org_client.detach_policy(PolicyId=policy_id, TargetId=target_id) except ClientError as e: - if e.response['Error']['Code'] == 'PolicyNotAttachedException': - self.logger.exception("Caught exception " - "'PolicyNotAttachedException'," - " taking no action...") + if e.response["Error"]["Code"] == "PolicyNotAttachedException": + self.logger.exception( + "Caught exception " + "'PolicyNotAttachedException'," + " taking no action..." + ) return else: self.logger.log_unhandled_exception(e) @@ -142,15 +125,15 @@ def detach_policy(self, policy_id, target_id): def enable_policy_type(self, root_id): try: self.org_client.enable_policy_type( - RootId=root_id, - PolicyType='SERVICE_CONTROL_POLICY' + RootId=root_id, PolicyType="SERVICE_CONTROL_POLICY" ) except ClientError as e: - if e.response['Error']['Code'] == \ - 'PolicyTypeAlreadyEnabledException': - self.logger.exception("Caught " - "PolicyTypeAlreadyEnabledException'," - " taking no action...") + if e.response["Error"]["Code"] == "PolicyTypeAlreadyEnabledException": + self.logger.exception( + "Caught " + "PolicyTypeAlreadyEnabledException'," + " taking no action..." + ) else: self.logger.log_unhandled_exception(e) raise diff --git a/source/src/cfct/aws/services/ssm.py b/source/src/cfct/aws/services/ssm.py index bb4a5cd..d5981da 100644 --- a/source/src/cfct/aws/services/ssm.py +++ b/source/src/cfct/aws/services/ssm.py @@ -15,51 +15,54 @@ # !/bin/python import os + from botocore.exceptions import ClientError -from cfct.utils.retry_decorator import try_except_retry from cfct.aws.utils.boto3_session import Boto3Session +from cfct.utils.retry_decorator import try_except_retry -ssm_region = os.environ.get('AWS_REGION') +ssm_region = os.environ.get("AWS_REGION") class SSM(Boto3Session): def __init__(self, logger, region=ssm_region, **kwargs): self.logger = logger - __service_name = 'ssm' - kwargs.update({'region': region}) + __service_name = "ssm" + kwargs.update({"region": region}) super().__init__(logger, __service_name, **kwargs) self.ssm_client = super().get_client() self.description = "This value was stored by Custom Control " "Tower Solution." - def put_parameter(self, - name, - value, - description="This value was stored by Custom Control " - "Tower Solution.", - type='String', - overwrite=True): + def put_parameter( + self, + name, + value, + description="This value was stored by Custom Control " "Tower Solution.", + type="String", + overwrite=True, + ): try: response = self.ssm_client.put_parameter( Name=name, Value=value, Description=description, Type=type, - Overwrite=overwrite + Overwrite=overwrite, ) return response except ClientError as e: self.logger.log_unhandled_exception(e) raise - def put_parameter_use_cmk(self, - name, - value, - key_id, - description="This value was stored by Custom " - "Control Tower Solution.", - type='SecureString', - overwrite=True): + def put_parameter_use_cmk( + self, + name, + value, + key_id, + description="This value was stored by Custom " "Control Tower Solution.", + type="SecureString", + overwrite=True, + ): try: response = self.ssm_client.put_parameter( Name=name, @@ -67,7 +70,7 @@ def put_parameter_use_cmk(self, Description=description, KeyId=key_id, Type=type, - Overwrite=overwrite + Overwrite=overwrite, ) return response except ClientError as e: @@ -76,14 +79,13 @@ def put_parameter_use_cmk(self, def get_parameter(self, name): try: - response = self.ssm_client.get_parameter( - Name=name, - WithDecryption=True - ) - return response.get('Parameter', {}).get('Value') + response = self.ssm_client.get_parameter(Name=name, WithDecryption=True) + return response.get("Parameter", {}).get("Value") except ClientError as e: - if e.response['Error']['Code'] == 'ParameterNotFound': - self.logger.log_unhandled_exception('The SSM Parameter {} was not found'.format(name)) + if e.response["Error"]["Code"] == "ParameterNotFound": + self.logger.log_unhandled_exception( + "The SSM Parameter {} was not found".format(name) + ) self.logger.log_unhandled_exception(e) raise @@ -101,22 +103,22 @@ def delete_parameter(self, name): def get_parameters_by_path(self, path): try: response = self.ssm_client.get_parameters_by_path( - Path=path if path.startswith('/') else '/'+path, + Path=path if path.startswith("/") else "/" + path, Recursive=False, - WithDecryption=True + WithDecryption=True, ) - params_list = response.get('Parameters', []) - next_token = response.get('NextToken', None) + params_list = response.get("Parameters", []) + next_token = response.get("NextToken", None) while next_token is not None: response = self.ssm_client.get_parameters_by_path( - Path=path if path.startswith('/') else '/' + path, + Path=path if path.startswith("/") else "/" + path, Recursive=False, WithDecryption=True, - NextToken=next_token + NextToken=next_token, ) - params_list.extend(response.get('Parameters', [])) - next_token = response.get('NextToken', None) + params_list.extend(response.get("Parameters", [])) + next_token = response.get("NextToken", None) return params_list except ClientError as e: @@ -128,7 +130,7 @@ def delete_parameters_by_path(self, name): params_list = self.get_parameters_by_path(name) if params_list: for param in params_list: - self.delete_parameter(param.get('Name')) + self.delete_parameter(param.get("Name")) except ClientError as e: self.logger.log_unhandled_exception(e) raise @@ -139,13 +141,13 @@ def describe_parameters(self, parameter_name, begins_with=False): response = self.ssm_client.describe_parameters( ParameterFilters=[ { - 'Key': 'Name', - 'Option': 'BeginsWith' if begins_with else 'Equals', - 'Values': [parameter_name] + "Key": "Name", + "Option": "BeginsWith" if begins_with else "Equals", + "Values": [parameter_name], } ] ) - parameters = response.get('Parameters', []) + parameters = response.get("Parameters", []) if parameters: return parameters[0] else: diff --git a/source/src/cfct/aws/services/state_machine.py b/source/src/cfct/aws/services/state_machine.py index bee88b3..fc7ceb3 100644 --- a/source/src/cfct/aws/services/state_machine.py +++ b/source/src/cfct/aws/services/state_machine.py @@ -15,51 +15,57 @@ # !/bin/python -import boto3 import json + +import boto3 from botocore.exceptions import ClientError -from cfct.utils.string_manipulation import sanitize from cfct.aws.utils.boto3_session import Boto3Session +from cfct.utils.string_manipulation import sanitize class StateMachine(Boto3Session): def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'stepfunctions' + __service_name = "stepfunctions" super().__init__(logger, __service_name, **kwargs) self.state_machine_client = super().get_client() def start_execution(self, state_machine_arn, input, name): try: - self.logger.info("Starting execution of state machine: {} with " - "input: {}".format(state_machine_arn, input)) + self.logger.info( + "Starting execution of state machine: {} with " + "input: {}".format(state_machine_arn, input) + ) response = self.state_machine_client.start_execution( stateMachineArn=state_machine_arn, input=json.dumps(input), - name=sanitize(name) + name=sanitize(name), + ) + self.logger.info( + "State machine Execution ARN: {}".format(response["executionArn"]) ) - self.logger.info("State machine Execution ARN: {}" - .format(response['executionArn'])) - return response.get('executionArn') + return response.get("executionArn") except ClientError as e: self.logger.log_unhandled_exception(e) raise def check_state_machine_status(self, execution_arn) -> str: try: - self.logger.info("Checking execution of state machine: {}" - .format(execution_arn)) + self.logger.info( + "Checking execution of state machine: {}".format(execution_arn) + ) response = self.state_machine_client.describe_execution( executionArn=execution_arn ) - self.logger.info("State machine Execution Status: {}" - .format(response['status'])) - if response['status'] == 'RUNNING': - return 'RUNNING' - elif response['status'] == 'SUCCEEDED': - return 'SUCCEEDED' + self.logger.info( + "State machine Execution Status: {}".format(response["status"]) + ) + if response["status"] == "RUNNING": + return "RUNNING" + elif response["status"] == "SUCCEEDED": + return "SUCCEEDED" else: - return 'FAILED' + return "FAILED" except ClientError as e: self.logger.log_unhandled_exception(e) raise diff --git a/source/src/cfct/aws/services/sts.py b/source/src/cfct/aws/services/sts.py index 97d02b0..38f35b0 100644 --- a/source/src/cfct/aws/services/sts.py +++ b/source/src/cfct/aws/services/sts.py @@ -16,6 +16,7 @@ # !/bin/python from os import environ + from botocore.exceptions import ClientError from cfct.aws.utils.boto3_session import Boto3Session from cfct.aws.utils.get_partition import get_partition @@ -28,12 +29,14 @@ def __call__(self, logger, account): # assume role session_name = "custom-control-tower-session" partition = get_partition() - role_arn = "%s%s%s%s%s%s" % ("arn:", - partition, - ":iam::", - str(account), - ":role/", - environ.get('EXECUTION_ROLE_NAME')) + role_arn = "%s%s%s%s%s%s" % ( + "arn:", + partition, + ":iam::", + str(account), + ":role/", + environ.get("EXECUTION_ROLE_NAME"), + ) credentials = sts.assume_role(role_arn, session_name) return credentials except ClientError as e: @@ -44,28 +47,26 @@ def __call__(self, logger, account): class STS(Boto3Session): def __init__(self, logger, **kwargs): self.logger = logger - __service_name = 'sts' - kwargs.update({'region': self.get_sts_region}) - kwargs.update({'endpoint_url': self.get_sts_endpoint()}) + __service_name = "sts" + kwargs.update({"region": self.get_sts_region}) + kwargs.update({"endpoint_url": self.get_sts_endpoint()}) super().__init__(logger, __service_name, **kwargs) self.sts_client = super().get_client() @property def get_sts_region(self): - return environ.get('AWS_REGION') + return environ.get("AWS_REGION") @staticmethod def get_sts_endpoint(): - return "https://sts.%s.amazonaws.com" % environ.get('AWS_REGION') + return "https://sts.%s.amazonaws.com" % environ.get("AWS_REGION") def assume_role(self, role_arn, session_name, duration=900): try: response = self.sts_client.assume_role( - RoleArn=role_arn, - RoleSessionName=session_name, - DurationSeconds=duration + RoleArn=role_arn, RoleSessionName=session_name, DurationSeconds=duration ) - return response['Credentials'] + return response["Credentials"] except ClientError as e: self.logger.log_unhandled_exception(e) raise diff --git a/source/src/cfct/aws/utils/boto3_session.py b/source/src/cfct/aws/utils/boto3_session.py index f473af5..4862be4 100644 --- a/source/src/cfct/aws/utils/boto3_session.py +++ b/source/src/cfct/aws/utils/boto3_session.py @@ -13,10 +13,11 @@ # governing permissions and limitations under the License. # ############################################################################## +from os import getenv + # !/bin/python import boto3 from botocore.config import Config -from os import getenv class Boto3Session: @@ -34,33 +35,30 @@ def __init__(self, logger, region, **kwargs): def __init__(self, logger, service_name, **kwargs): """ - Parameters - ---------- - logger : object - The logger object - region : str - AWS region name. Example: 'us-east-1' - service_name : str - AWS service name. Example: 'ec2' - credentials = dict, optional - set of temporary AWS security credentials - endpoint_url : str - The complete URL to use for the constructed client. + Parameters + ---------- + logger : object + The logger object + region : str + AWS region name. Example: 'us-east-1' + service_name : str + AWS service name. Example: 'ec2' + credentials = dict, optional + set of temporary AWS security credentials + endpoint_url : str + The complete URL to use for the constructed client. """ self.logger = logger self.service_name = service_name - self.credentials = kwargs.get('credentials', None) - self.region = kwargs.get('region', None) - self.endpoint_url = kwargs.get('endpoint_url', None) - self.solution_id = getenv('SOLUTION_ID', 'SO0089') - self.solution_version = getenv('SOLUTION_VERSION', 'undefined') - user_agent = f'AwsSolution/{self.solution_id}/{self.solution_version}' + self.credentials = kwargs.get("credentials", None) + self.region = kwargs.get("region", None) + self.endpoint_url = kwargs.get("endpoint_url", None) + self.solution_id = getenv("SOLUTION_ID", "SO0089") + self.solution_version = getenv("SOLUTION_VERSION", "undefined") + user_agent = f"AwsSolution/{self.solution_id}/{self.solution_version}" self.boto_config = Config( user_agent_extra=user_agent, - retries={ - 'mode': 'standard', - 'max_attempts': 20 - } + retries={"mode": "standard", "max_attempts": 20}, ) def get_client(self): @@ -70,35 +68,34 @@ def get_client(self): """ if self.credentials is None: if self.endpoint_url is None: - return boto3.client(self.service_name, - region_name=self.region, - config=self.boto_config) + return boto3.client( + self.service_name, region_name=self.region, config=self.boto_config + ) else: - return boto3.client(self.service_name, region_name=self.region, - config=self.boto_config, - endpoint_url=self.endpoint_url) + return boto3.client( + self.service_name, + region_name=self.region, + config=self.boto_config, + endpoint_url=self.endpoint_url, + ) else: if self.region is None: - return boto3.client(self.service_name, - aws_access_key_id=self.credentials - .get('AccessKeyId'), - aws_secret_access_key=self.credentials - .get('SecretAccessKey'), - aws_session_token=self.credentials - .get('SessionToken'), - config=self.boto_config - ) + return boto3.client( + self.service_name, + aws_access_key_id=self.credentials.get("AccessKeyId"), + aws_secret_access_key=self.credentials.get("SecretAccessKey"), + aws_session_token=self.credentials.get("SessionToken"), + config=self.boto_config, + ) else: - return boto3.client(self.service_name, - region_name=self.region, - aws_access_key_id=self.credentials - .get('AccessKeyId'), - aws_secret_access_key=self.credentials - .get('SecretAccessKey'), - aws_session_token=self.credentials - .get('SessionToken'), - config=self.boto_config - ) + return boto3.client( + self.service_name, + region_name=self.region, + aws_access_key_id=self.credentials.get("AccessKeyId"), + aws_secret_access_key=self.credentials.get("SecretAccessKey"), + aws_session_token=self.credentials.get("SessionToken"), + config=self.boto_config, + ) def get_resource(self): """Creates a boto3 resource service client object by name @@ -107,33 +104,31 @@ def get_resource(self): """ if self.credentials is None: if self.endpoint_url is None: - return boto3.resource(self.service_name, - region_name=self.region, - config=self.boto_config) + return boto3.resource( + self.service_name, region_name=self.region, config=self.boto_config + ) else: - return boto3.resource(self.service_name, - region_name=self.region, - config=self.boto_config, - endpoint_url=self.endpoint_url) + return boto3.resource( + self.service_name, + region_name=self.region, + config=self.boto_config, + endpoint_url=self.endpoint_url, + ) else: if self.region is None: - return boto3.resource(self.service_name, - aws_access_key_id=self.credentials - .get('AccessKeyId'), - aws_secret_access_key=self.credentials - .get('SecretAccessKey'), - aws_session_token=self.credentials - .get('SessionToken'), - config=self.boto_config - ) + return boto3.resource( + self.service_name, + aws_access_key_id=self.credentials.get("AccessKeyId"), + aws_secret_access_key=self.credentials.get("SecretAccessKey"), + aws_session_token=self.credentials.get("SessionToken"), + config=self.boto_config, + ) else: - return boto3.resource(self.service_name, - region_name=self.region, - aws_access_key_id=self.credentials - .get('AccessKeyId'), - aws_secret_access_key=self.credentials - .get('SecretAccessKey'), - aws_session_token=self.credentials - .get('SessionToken'), - config=self.boto_config - ) + return boto3.resource( + self.service_name, + region_name=self.region, + aws_access_key_id=self.credentials.get("AccessKeyId"), + aws_secret_access_key=self.credentials.get("SecretAccessKey"), + aws_session_token=self.credentials.get("SessionToken"), + config=self.boto_config, + ) diff --git a/source/src/cfct/aws/utils/get_partition.py b/source/src/cfct/aws/utils/get_partition.py index ddd9028..755db96 100644 --- a/source/src/cfct/aws/utils/get_partition.py +++ b/source/src/cfct/aws/utils/get_partition.py @@ -20,12 +20,12 @@ def get_partition(): """ :return: partition name for the current AWS region """ - region_name = environ.get('AWS_REGION') - china_region_name_prefix = 'cn' - us_gov_cloud_region_name_prefix = 'us-gov' - aws_regions_partition = 'aws' - aws_china_regions_partition = 'aws-cn' - aws_us_gov_cloud_regions_partition = 'aws-us-gov' + region_name = environ.get("AWS_REGION") + china_region_name_prefix = "cn" + us_gov_cloud_region_name_prefix = "us-gov" + aws_regions_partition = "aws" + aws_china_regions_partition = "aws-cn" + aws_us_gov_cloud_regions_partition = "aws-us-gov" # China regions if region_name.startswith(china_region_name_prefix): diff --git a/source/src/cfct/aws/utils/get_regions.py b/source/src/cfct/aws/utils/get_regions.py index 4d02395..d13664e 100644 --- a/source/src/cfct/aws/utils/get_regions.py +++ b/source/src/cfct/aws/utils/get_regions.py @@ -17,7 +17,7 @@ def get_available_regions(service_name): - """ Returns list of available regions given an AWS service. + """Returns list of available regions given an AWS service. Args: service_name diff --git a/source/src/cfct/aws/utils/url_conversion.py b/source/src/cfct/aws/utils/url_conversion.py index 9c1b149..49d77ba 100644 --- a/source/src/cfct/aws/utils/url_conversion.py +++ b/source/src/cfct/aws/utils/url_conversion.py @@ -13,8 +13,8 @@ # and limitations under the License. # ############################################################################### -from urllib.parse import urlparse from os import environ +from urllib.parse import urlparse def convert_s3_url_to_http_url(s3_url): @@ -39,19 +39,21 @@ def convert_s3_url_to_http_url(s3_url): def build_http_url(bucket_name, key_name): - """ Builds http url for the given bucket and key name + """Builds http url for the given bucket and key name :param bucket_name: :param key_name: :return HTTP URL: example: https://bucket-name.s3.Region.amazonaws.com/key-name """ - return "{}{}{}{}{}{}".format('https://', - bucket_name, - '.s3.', - environ.get('AWS_REGION'), - '.amazonaws.com/', - key_name) + return "{}{}{}{}{}{}".format( + "https://", + bucket_name, + ".s3.", + environ.get("AWS_REGION"), + ".amazonaws.com/", + key_name, + ) def parse_bucket_key_names(http_url): @@ -69,16 +71,16 @@ def parse_bucket_key_names(http_url): # Handle Amazon S3 path-style URL # Needed to handle response from describe_provisioning_artifact API - response['Info']['TemplateUrl'] # example: https://s3.Region.amazonaws.com/bucket-name/key-name - if http_url.startswith('https://s3.'): + if http_url.startswith("https://s3."): parsed_url = urlparse(http_url) - bucket_name = parsed_url.path.split('/', 2)[1] - key_name = parsed_url.path.split('/', 2)[2] - region = parsed_url.netloc.split('.')[1] + bucket_name = parsed_url.path.split("/", 2)[1] + key_name = parsed_url.path.split("/", 2)[2] + region = parsed_url.netloc.split(".")[1] # Handle Amazon S3 virtual-hosted–style URL # example: https://bucket-name.s3.Region.amazonaws.com/key-name else: parsed_url = urlparse(http_url) - bucket_name = parsed_url.netloc.split('.')[0] + bucket_name = parsed_url.netloc.split(".")[0] key_name = parsed_url.path[1:] - region = parsed_url.netloc.split('.')[2] + region = parsed_url.netloc.split(".")[2] return bucket_name, key_name, region diff --git a/source/src/cfct/lambda_handlers/config_deployer.py b/source/src/cfct/lambda_handlers/config_deployer.py index 5a15060..1361a09 100644 --- a/source/src/cfct/lambda_handlers/config_deployer.py +++ b/source/src/cfct/lambda_handlers/config_deployer.py @@ -15,22 +15,23 @@ # !/bin/python -import os -import json import inspect +import json +import os import zipfile from hashlib import md5 from uuid import uuid4 -from jinja2 import Environment, FileSystemLoader -from cfct.aws.services.s3 import S3 + from cfct.aws.services.kms import KMS +from cfct.aws.services.s3 import S3 from cfct.aws.services.ssm import SSM from cfct.utils.crhelper import cfn_handler -from cfct.utils.os_util import make_dir from cfct.utils.logger import Logger +from cfct.utils.os_util import make_dir +from jinja2 import Environment, FileSystemLoader # initialise logger -log_level = os.environ.get('LOG_LEVEL') +log_level = os.environ.get("LOG_LEVEL") logger = Logger(loglevel=log_level) init_failed = False @@ -42,7 +43,7 @@ def unzip_function(zip_file_name, function_path, output_path): orig_path = os.getcwd() os.chdir(function_path) - zip_file = zipfile.ZipFile(zip_file_name, 'r') + zip_file = zipfile.ZipFile(zip_file_name, "r") zip_file.extractall(output_path) zip_file.close() os.chdir(orig_path) @@ -54,7 +55,7 @@ def find_replace(function_path, file_name, destination_file, parameters): j2template = j2env.get_template(file_name) dictionary = {} for key, value in parameters.items(): - value = "\"%s\"" % value if "json" in file_name else value + value = '"%s"' % value if "json" in file_name else value dictionary.update({key: value}) logger.debug(dictionary) output = j2template.render(dictionary) @@ -71,9 +72,9 @@ def zip_function(zip_file_name, function_path, output_path, exclude_list): os.remove(zip_file_name) except OSError: pass - zip_file = zipfile.ZipFile(zip_file_name, mode='a') + zip_file = zipfile.ZipFile(zip_file_name, mode="a") os.chdir(function_path) - for folder, subs, files in os.walk('.'): + for folder, subs, files in os.walk("."): for filename in files: file_path = os.path.join(folder, filename) if not any(x in file_path for x in exclude_list): @@ -89,43 +90,47 @@ def find_alias(alias_name): alias_not_found = True while alias_not_found: response_list_alias = kms.list_aliases(marker) - truncated_flag = response_list_alias.get('Truncated') - for alias in response_list_alias.get('Aliases'): - if alias.get('AliasName') == alias_name: - logger.info('Found key attached with existing key id.') - key_id = alias.get('TargetKeyId') + truncated_flag = response_list_alias.get("Truncated") + for alias in response_list_alias.get("Aliases"): + if alias.get("AliasName") == alias_name: + logger.info("Found key attached with existing key id.") + key_id = alias.get("TargetKeyId") return key_id if not truncated_flag: - logger.info('Alias not found in the list') + logger.info("Alias not found in the list") alias_not_found = False else: - logger.info('Could not find alias in truncated response,' - ' trying again...') - marker = response_list_alias.get('NextMarker') - logger.info('Trying again with NextMarker: {}'.format(marker)) + logger.info( + "Could not find alias in truncated response," " trying again..." + ) + marker = response_list_alias.get("NextMarker") + logger.info("Trying again with NextMarker: {}".format(marker)) def create_cmk_with_alias(alias_name, event_policy): - logger.info('Creating new KMS key id and alias.') + logger.info("Creating new KMS key id and alias.") policy = str(json.dumps(event_policy)) - logger.info('Policy') + logger.info("Policy") logger.info(policy) - response_create_key = kms.create_key(str(policy), 'CMK created for Custom' - ' Control Tower Resources', - 'AWSSolutions', 'CustomControlTower') - logger.info('KMS Key created.') - key_id = response_create_key.get('KeyMetadata', {}).get('KeyId') + response_create_key = kms.create_key( + str(policy), + "CMK created for Custom" " Control Tower Resources", + "AWSSolutions", + "CustomControlTower", + ) + logger.info("KMS Key created.") + key_id = response_create_key.get("KeyMetadata", {}).get("KeyId") kms.create_alias(alias_name, key_id) - logger.info('Alias created.') + logger.info("Alias created.") return key_id def update_key_policy(key_id, event_policy): policy = str(json.dumps(event_policy)) - logger.info('Policy') + logger.info("Policy") logger.info(policy) response_update_policy = kms.put_key_policy(key_id, policy) - logger.info('Response: Update Key Policy') + logger.info("Response: Update Key Policy") logger.info(response_update_policy) @@ -141,22 +146,20 @@ def config_deployer(event): s3 = S3(logger) # set variables - source_bucket_name = event.get('BucketConfig', {}) \ - .get('SourceBucketName') - key_name = event.get('BucketConfig', {}).get('SourceS3Key') - destination_bucket_name = event.get('BucketConfig', {}) \ - .get('DestinationBucketName') - input_zip_file_name = key_name.split("/")[-1] if "/" in key_name \ - else key_name - output_zip_file_name = event.get('BucketConfig', {}) \ - .get('DestinationS3Key') - alias_name = event.get('KMSConfig', {}).get('KMSKeyAlias') - policy = event.get('KMSConfig', {}).get('KMSKeyPolicy') - flag_value = event.get('MetricsFlag') - base_path = '/tmp/custom_control_tower' + source_bucket_name = event.get("BucketConfig", {}).get("SourceBucketName") + key_name = event.get("BucketConfig", {}).get("SourceS3Key") + destination_bucket_name = event.get("BucketConfig", {}).get( + "DestinationBucketName" + ) + input_zip_file_name = key_name.split("/")[-1] if "/" in key_name else key_name + output_zip_file_name = event.get("BucketConfig", {}).get("DestinationS3Key") + alias_name = event.get("KMSConfig", {}).get("KMSKeyAlias") + policy = event.get("KMSConfig", {}).get("KMSKeyPolicy") + flag_value = event.get("MetricsFlag") + base_path = "/tmp/custom_control_tower" input_file_path = base_path + "/" + input_zip_file_name - extract_path = base_path + "/" + 'extract' - output_path = base_path + "/" + 'out' + extract_path = base_path + "/" + "extract" + output_path = base_path + "/" + "out" exclude_j2_files = [] # Search for existing KMS key alias @@ -166,13 +169,14 @@ def config_deployer(event): # new target key if not key_id: key_id = create_cmk_with_alias(alias_name, policy) - logger.info('Key ID created: {}'.format(key_id)) + logger.info("Key ID created: {}".format(key_id)) kms.enable_key_rotation(key_id) - logger.info('Automatic key rotation enabled.') + logger.info("Automatic key rotation enabled.") else: - logger.info('Key ID: {} found attached with alias: {}' - .format(key_id, alias_name)) - logger.info('Updating KMS key policy') + logger.info( + "Key ID: {} found attached with alias: {}".format(key_id, alias_name) + ) + logger.info("Updating KMS key policy") update_key_policy(key_id, policy) kms.enable_key_rotation(key_id) @@ -187,17 +191,20 @@ def config_deployer(event): unzip_function(input_zip_file_name, base_path, extract_path) # Find and replace the variable in Manifest file - for item in event.get('FindReplace'): - f = item.get('FileName') - parameters = item.get('Parameters') + for item in event.get("FindReplace"): + f = item.get("FileName") + parameters = item.get("Parameters") exclude_j2_files.append(f) filename, file_extension = os.path.splitext(f) - destination_file_path = extract_path + "/" + filename \ - if file_extension == '.j2' else extract_path + "/" + f + destination_file_path = ( + extract_path + "/" + filename + if file_extension == ".j2" + else extract_path + "/" + f + ) find_replace(extract_path, f, destination_file_path, parameters) # Zip the contents - exclude = ['zip'] + exclude_j2_files + exclude = ["zip"] + exclude_j2_files make_dir(output_path, logger) zip_function(output_zip_file_name, extract_path, output_path, exclude) @@ -207,19 +214,18 @@ def config_deployer(event): s3.upload_file(destination_bucket_name, local_file, remote_file) # create SSM parameters to send anonymous data if opted in - put_ssm_parameter('/org/primary/metrics_flag', flag_value) - put_ssm_parameter('/org/primary/customer_uuid', str(uuid4())) + put_ssm_parameter("/org/primary/metrics_flag", flag_value) + put_ssm_parameter("/org/primary/customer_uuid", str(uuid4())) return None except Exception as e: - logger.log_general_exception( - __file__.split('/')[-1], inspect.stack()[0][3], e) + logger.log_general_exception(__file__.split("/")[-1], inspect.stack()[0][3], e) raise def update_config_deployer(event): - alias_name = event.get('KMSConfig', {}).get('KMSKeyAlias') - policy = event.get('KMSConfig', {}).get('KMSKeyPolicy') - flag_value = event.get('MetricsFlag') + alias_name = event.get("KMSConfig", {}).get("KMSKeyAlias") + policy = event.get("KMSConfig", {}).get("KMSKeyPolicy") + flag_value = event.get("MetricsFlag") # Search for existing KMS key alias key_id = find_alias(alias_name) @@ -228,19 +234,20 @@ def update_config_deployer(event): # new target key if not key_id: key_id = create_cmk_with_alias(alias_name, policy) - logger.info('Key ID created: {}'.format(key_id)) + logger.info("Key ID created: {}".format(key_id)) kms.enable_key_rotation(key_id) - logger.info('Automatic key rotation enabled.') + logger.info("Automatic key rotation enabled.") else: - logger.info('Key ID: {} found attached with alias: {}' - .format(key_id, alias_name)) - logger.info('Updating KMS key policy') + logger.info( + "Key ID: {} found attached with alias: {}".format(key_id, alias_name) + ) + logger.info("Updating KMS key policy") update_key_policy(key_id, policy) kms.enable_key_rotation(key_id) # create SSM parameters to send anonymous data if opted in - put_ssm_parameter('/org/primary/metrics_flag', flag_value) - put_ssm_parameter('/org/primary/customer_uuid', str(uuid4())) + put_ssm_parameter("/org/primary/metrics_flag", flag_value) + put_ssm_parameter("/org/primary/customer_uuid", str(uuid4())) return None @@ -251,28 +258,28 @@ def create(event, context): As there is no real 'resource', and it will never be replaced, PhysicalResourceId is set to a hash of StackId and LogicalId. """ - s = '%s-%s' % (event.get('StackId'), event.get('LogicalResourceId')) - physical_resource_id = md5(s.encode('UTF-8')).hexdigest() + s = "%s-%s" % (event.get("StackId"), event.get("LogicalResourceId")) + physical_resource_id = md5(s.encode("UTF-8")).hexdigest() logger.info("physical_resource_id: {}".format(physical_resource_id)) - if event.get('ResourceType') == 'Custom::ConfigDeployer': - response = config_deployer(event.get('ResourceProperties')) + if event.get("ResourceType") == "Custom::ConfigDeployer": + response = config_deployer(event.get("ResourceProperties")) return physical_resource_id, response else: - logger.error('No valid ResourceType found!') + logger.error("No valid ResourceType found!") def update(event, context): """ Update the KMS key policy. """ - physical_resource_id = event.get('PhysicalResourceId') + physical_resource_id = event.get("PhysicalResourceId") - if event.get('ResourceType') == 'Custom::ConfigDeployer': - response = update_config_deployer(event.get('ResourceProperties')) + if event.get("ResourceType") == "Custom::ConfigDeployer": + response = update_config_deployer(event.get("ResourceProperties")) return physical_resource_id, response else: - logger.error('No valid ResourceType found!') + logger.error("No valid ResourceType found!") def delete(event, context): @@ -285,5 +292,4 @@ def lambda_handler(event, context): logger.info("<<<<<<<<<< ConfigDeployer Event >>>>>>>>>>") logger.info(event) logger.debug(context) - return cfn_handler(event, context, create, update, delete, - logger, init_failed) + return cfn_handler(event, context, create, update, delete, logger, init_failed) diff --git a/source/src/cfct/lambda_handlers/lifecycle_event_handler.py b/source/src/cfct/lambda_handlers/lifecycle_event_handler.py index 3ebec4d..47e3f6f 100644 --- a/source/src/cfct/lambda_handlers/lifecycle_event_handler.py +++ b/source/src/cfct/lambda_handlers/lifecycle_event_handler.py @@ -1,4 +1,3 @@ - ############################################################################## # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # # @@ -14,14 +13,16 @@ # governing permissions and limitations under the License. # ############################################################################## -import os import inspect -from cfct.utils.logger import Logger +import os + from cfct.aws.services.code_pipeline import CodePipeline +from cfct.utils.logger import Logger # initialise logger -log_level = 'info' if os.environ.get('LOG_LEVEL') is None \ - else os.environ.get('LOG_LEVEL') +log_level = ( + "info" if os.environ.get("LOG_LEVEL") is None else os.environ.get("LOG_LEVEL") +) logger = Logger(loglevel=log_level) init_failed = False @@ -42,22 +43,25 @@ def invoke_code_pipeline(event): response from starting pipeline execution """ msg_count = 0 - for record in event['Records']: + for record in event["Records"]: # Here validates that the event source is aws control tower. # The filtering of specific control tower lifecycle events is done # by a CWE rule, which is configured to deliver only # the matching events to the SQS queue. - if record['body'] is not None and record['body'].find('"source":"aws.controltower"') >= 0: + if ( + record["body"] is not None + and record["body"].find('"source":"aws.controltower"') >= 0 + ): msg_count += 1 if msg_count > 0: - logger.info(str(msg_count) + - " Control Tower lifecycle event(s) found in the queue." - " Start invoking code pipeline...") + logger.info( + str(msg_count) + " Control Tower lifecycle event(s) found in the queue." + " Start invoking code pipeline..." + ) cp = CodePipeline(logger) - response = cp.start_pipeline_execution( - os.environ.get('CODE_PIPELINE_NAME')) + response = cp.start_pipeline_execution(os.environ.get("CODE_PIPELINE_NAME")) else: logger.info("No lifecycle events in the queue!") @@ -75,8 +79,9 @@ def lambda_handler(event, context): context """ try: - logger.info("<<<<<<<<<< Poll Control Tower lifecyle events from" - " SQS queue >>>>>>>>>>") + logger.info( + "<<<<<<<<<< Poll Control Tower lifecyle events from" " SQS queue >>>>>>>>>>" + ) logger.info(event) logger.debug(context) @@ -85,6 +90,5 @@ def lambda_handler(event, context): logger.info("Response from Code Pipeline: ") logger.info(response) except Exception as e: - logger.log_general_exception( - __file__.split('/')[-1], inspect.stack()[0][3], e) + logger.log_general_exception(__file__.split("/")[-1], inspect.stack()[0][3], e) raise diff --git a/source/src/cfct/lambda_handlers/state_machine_router.py b/source/src/cfct/lambda_handlers/state_machine_router.py index 8d9d3f8..2cf9e05 100644 --- a/source/src/cfct/lambda_handlers/state_machine_router.py +++ b/source/src/cfct/lambda_handlers/state_machine_router.py @@ -15,39 +15,43 @@ # !/bin/python -import os import inspect -from cfct.state_machine_handler import CloudFormation, StackSetSMRequests, \ - ServiceControlPolicy +import os + +from cfct.state_machine_handler import ( + CloudFormation, + ServiceControlPolicy, + StackSetSMRequests, +) from cfct.utils.logger import Logger # initialise logger -log_level = os.environ['LOG_LEVEL'] +log_level = os.environ["LOG_LEVEL"] logger = Logger(loglevel=log_level) def cloudformation(event, function_name): logger.info("Router FunctionName: {}".format(function_name)) stack_set = CloudFormation(event, logger) - if function_name == 'describe_stack_set': + if function_name == "describe_stack_set": response = stack_set.describe_stack_set() - elif function_name == 'describe_stack_set_operation': + elif function_name == "describe_stack_set_operation": response = stack_set.describe_stack_set_operation() - elif function_name == 'list_stack_instances': + elif function_name == "list_stack_instances": response = stack_set.list_stack_instances() - elif function_name == 'list_stack_instances_account_ids': + elif function_name == "list_stack_instances_account_ids": response = stack_set.list_stack_instances_account_ids() - elif function_name == 'create_stack_set': + elif function_name == "create_stack_set": response = stack_set.create_stack_set() - elif function_name == 'create_stack_instances': + elif function_name == "create_stack_instances": response = stack_set.create_stack_instances() - elif function_name == 'update_stack_set': + elif function_name == "update_stack_set": response = stack_set.update_stack_set() - elif function_name == 'update_stack_instances': + elif function_name == "update_stack_instances": response = stack_set.update_stack_instances() - elif function_name == 'delete_stack_set': + elif function_name == "delete_stack_set": response = stack_set.delete_stack_set() - elif function_name == 'delete_stack_instances': + elif function_name == "delete_stack_instances": response = stack_set.delete_stack_instances() else: message = build_messages(1) @@ -61,32 +65,31 @@ def cloudformation(event, function_name): def service_control_policy(event, function_name): scp = ServiceControlPolicy(event, logger) logger.info("Router FunctionName: {}".format(function_name)) - if function_name == 'list_policies': + if function_name == "list_policies": response = scp.list_policies() - elif function_name == 'list_policies_for_account': + elif function_name == "list_policies_for_account": response = scp.list_policies_for_account() - elif function_name == 'list_policies_for_ou': + elif function_name == "list_policies_for_ou": response = scp.list_policies_for_ou() - elif function_name == 'create_policy': + elif function_name == "create_policy": response = scp.create_policy() - elif function_name == 'update_policy': + elif function_name == "update_policy": response = scp.update_policy() - elif function_name == 'delete_policy': + elif function_name == "delete_policy": response = scp.delete_policy() - elif function_name == 'configure_count': - policy_list = event.get('ResourceProperties').get('PolicyList', []) + elif function_name == "configure_count": + policy_list = event.get("ResourceProperties").get("PolicyList", []) logger.info("List of policies: {}".format(policy_list)) - event.update({'Index': 0}) - event.update({'Step': 1}) - event.update({'Count': len(policy_list)}) + event.update({"Index": 0}) + event.update({"Step": 1}) + event.update({"Count": len(policy_list)}) return event - elif function_name == 'iterator': - index = event.get('Index') - step = event.get('Step') - count = event.get('Count') - policy_list = event.get('ResourceProperties').get('PolicyList', []) - policy_to_apply = policy_list[index] \ - if len(policy_list) > index else None + elif function_name == "iterator": + index = event.get("Index") + step = event.get("Step") + count = event.get("Count") + policy_list = event.get("ResourceProperties").get("PolicyList", []) + policy_to_apply = policy_list[index] if len(policy_list) > index else None if index < count: _continue = True @@ -95,31 +98,31 @@ def service_control_policy(event, function_name): index = index + step - event.update({'Index': index}) - event.update({'Step': step}) - event.update({'Continue': _continue}) - event.update({'PolicyName': policy_to_apply}) + event.update({"Index": index}) + event.update({"Step": step}) + event.update({"Continue": _continue}) + event.update({"PolicyName": policy_to_apply}) return event - elif function_name == 'attach_policy': + elif function_name == "attach_policy": response = scp.attach_policy() - elif function_name == 'detach_policy': + elif function_name == "detach_policy": response = scp.detach_policy() - elif function_name == 'detach_policy_from_all_accounts': + elif function_name == "detach_policy_from_all_accounts": response = scp.detach_policy_from_all_accounts() - elif function_name == 'enable_policy_type': + elif function_name == "enable_policy_type": response = scp.enable_policy_type() - elif function_name == 'configure_count_2': - ou_list = event.get('ResourceProperties').get('OUList', []) + elif function_name == "configure_count_2": + ou_list = event.get("ResourceProperties").get("OUList", []) logger.info("List of OUs: {}".format(ou_list)) - event.update({'Index': 0}) - event.update({'Step': 1}) - event.update({'Count': len(ou_list)}) + event.update({"Index": 0}) + event.update({"Step": 1}) + event.update({"Count": len(ou_list)}) return event - elif function_name == 'iterator2': - index = event.get('Index') - step = event.get('Step') - count = event.get('Count') - ou_list = event.get('ResourceProperties').get('OUList', []) + elif function_name == "iterator2": + index = event.get("Index") + step = event.get("Step") + count = event.get("Count") + ou_list = event.get("ResourceProperties").get("OUList", []) ou_map = ou_list[index] if len(ou_list) > index else None if index < count: @@ -129,17 +132,24 @@ def service_control_policy(event, function_name): index = index + step - event.update({'Index': index}) - event.update({'Step': step}) - event.update({'Continue': _continue}) - if ou_map: # ou list example: [['ouname1','ouid1],'Attach'] - logger.info("[state_machine_router.service_control_policy] ou_map: {}".format(ou_map)) - logger.debug("[state_machine_router.service_control_policy] OUName: {}; OUId: {}; Operation: {}"\ - .format(ou_map[0][0], ou_map[0][1], ou_map[1])) - - event.update({'OUName': ou_map[0][0]}) - event.update({'OUId': ou_map[0][1]}) - event.update({'Operation': ou_map[1]}) + event.update({"Index": index}) + event.update({"Step": step}) + event.update({"Continue": _continue}) + if ou_map: # ou list example: [['ouname1','ouid1],'Attach'] + logger.info( + "[state_machine_router.service_control_policy] ou_map: {}".format( + ou_map + ) + ) + logger.debug( + "[state_machine_router.service_control_policy] OUName: {}; OUId: {}; Operation: {}".format( + ou_map[0][0], ou_map[0][1], ou_map[1] + ) + ) + + event.update({"OUName": ou_map[0][0]}) + event.update({"OUId": ou_map[0][1]}) + event.update({"Operation": ou_map[1]}) return event @@ -156,13 +166,13 @@ def stackset_sm_requests(event, function_name): sr = StackSetSMRequests(event, logger) logger.info("Router FunctionName: {}".format(function_name)) - if function_name == 'ssm_put_parameters': + if function_name == "ssm_put_parameters": response = sr.ssm_put_parameters() - elif function_name == 'export_cfn_output': + elif function_name == "export_cfn_output": response = sr.export_cfn_output() - elif function_name == 'send_execution_data': + elif function_name == "send_execution_data": response = sr.send_execution_data() - elif function_name == 'random_wait': + elif function_name == "random_wait": response = sr.random_wait() else: message = build_messages(1) @@ -182,11 +192,9 @@ def build_messages(type): message """ if type == 1: - message = "Function name does not match any function" \ - " in the handler file." + message = "Function name does not match any function" " in the handler file." elif type == 2: - message = "Class name does not match any class" \ - " in the handler file." + message = "Class name does not match any class" " in the handler file." else: message = "Class name not found in input." return message @@ -198,15 +206,15 @@ def lambda_handler(event, context): logger.debug("Lambda_handler Event") logger.debug(event) # Execute custom resource handlers - class_name = event.get('params', {}).get('ClassName') - function_name = event.get('params', {}).get('FunctionName') + class_name = event.get("params", {}).get("ClassName") + function_name = event.get("params", {}).get("FunctionName") if class_name is not None: if class_name == "CloudFormation": return cloudformation(event, function_name) - elif class_name == 'StackSetSMRequests': + elif class_name == "StackSetSMRequests": return stackset_sm_requests(event, function_name) - elif class_name == 'SCP': + elif class_name == "SCP": return service_control_policy(event, function_name) else: message = build_messages(2) @@ -217,6 +225,5 @@ def lambda_handler(event, context): logger.info(message) return {"Message": message} except Exception as e: - logger.log_general_exception( - __file__.split('/')[-1], inspect.stack()[0][3], e) + logger.log_general_exception(__file__.split("/")[-1], inspect.stack()[0][3], e) raise diff --git a/source/src/cfct/manifest/cfn_params_handler.py b/source/src/cfct/manifest/cfn_params_handler.py index 1d94f2c..4d363d0 100644 --- a/source/src/cfct/manifest/cfn_params_handler.py +++ b/source/src/cfct/manifest/cfn_params_handler.py @@ -13,25 +13,30 @@ # and limitations under the License. # ############################################################################### -import time import random +import time from os import environ -from cfct.aws.services.ssm import SSM + from cfct.aws.services.ec2 import EC2 from cfct.aws.services.kms import KMS +from cfct.aws.services.ssm import SSM from cfct.aws.services.sts import AssumeRole -from cfct.utils.string_manipulation import sanitize, trim_string_from_front, \ - convert_string_to_list from cfct.utils.password_generator import random_pwd_generator +from cfct.utils.string_manipulation import ( + convert_string_to_list, + sanitize, + trim_string_from_front, +) class CFNParamsHandler(object): """This class goes through the cfn parameters passed by users to - state machines and SSM parameters to get the correct parameter - , create parameter value and update SSM parameters as applicable. - For example, if a cfn parameter is passed, save it in - SSM parameter store. + state machines and SSM parameters to get the correct parameter + , create parameter value and update SSM parameters as applicable. + For example, if a cfn parameter is passed, save it in + SSM parameter store. """ + def __init__(self, logger): self.logger = logger self.ssm = SSM(self.logger) @@ -40,24 +45,24 @@ def __init__(self, logger): def _session(self, region, account_id=None): # instantiate EC2 session - account_id = account_id[0] if \ - isinstance(account_id, list) else account_id + account_id = account_id[0] if isinstance(account_id, list) else account_id if account_id is None: return EC2(self.logger, region) else: - return EC2(self.logger, - region, - credentials=self.assume_role(self.logger, - account_id)) + return EC2( + self.logger, + region, + credentials=self.assume_role(self.logger, account_id), + ) def _get_ssm_params(self, ssm_parm_name): return self.ssm.get_parameter(ssm_parm_name) def _get_kms_key_id(self): - alias_name = environ.get('KMS_KEY_ALIAS_NAME') + alias_name = environ.get("KMS_KEY_ALIAS_NAME") response = self.kms.describe_key(alias_name) self.logger.debug(response) - key_id = response.get('KeyMetadata', {}).get('KeyId') + key_id = response.get("KeyMetadata", {}).get("KeyId") return key_id def get_azs_from_member_account(self, region, qty, account, key_az=None): @@ -73,13 +78,13 @@ def get_azs_from_member_account(self, region, qty, account, key_az=None): list: availability zone names """ if key_az: - self.logger.info("Looking up values in SSM parameter:{}" - .format(key_az)) + self.logger.info("Looking up values in SSM parameter:{}".format(key_az)) existing_param = self.ssm.describe_parameters(key_az) if existing_param: - self.logger.info('Found existing SSM parameter, returning' - ' existing AZ list.') + self.logger.info( + "Found existing SSM parameter, returning" " existing AZ list." + ) return self.ssm.get_parameter(key_az) if account is not None: # fetch account from list for cross account assume role workflow @@ -87,12 +92,13 @@ def get_azs_from_member_account(self, region, qty, account, key_az=None): # AZ list for a given region in any account. acct = account[0] if isinstance(account, list) else account ec2 = self._session(region, acct) - self.logger.info("Getting list of AZs in region: {} from" - " account: {}".format(region, acct)) + self.logger.info( + "Getting list of AZs in region: {} from" + " account: {}".format(region, acct) + ) return self._get_az(ec2, key_az, qty) else: - self.logger.info("Creating EC2 Session in {} region" - .format(region)) + self.logger.info("Creating EC2 Session in {} region".format(region)) ec2 = EC2(self.logger, region) return self._get_az(ec2, key_az, qty) @@ -100,15 +106,20 @@ def _get_az(self, ec2, key_az, qty): # Get AZs az_list = ec2.describe_availability_zones() self.logger.info("_get_azs output: %s" % az_list) - random_az_list = ','.join(random.sample(az_list, qty)) - description = "Contains random AZs selected by Custom Control Tower" \ - "Solution" + random_az_list = ",".join(random.sample(az_list, qty)) + description = "Contains random AZs selected by Custom Control Tower" "Solution" if key_az: self.ssm.put_parameter(key_az, random_az_list, description) return random_az_list - def _create_key_pair(self, account, region, param_key_material=None, - param_key_fingerprint=None, param_key_name=None): + def _create_key_pair( + self, + account, + region, + param_key_material=None, + param_key_fingerprint=None, + param_key_name=None, + ): """Creates an ec2 key pair if it does not exist already. Args: @@ -123,39 +134,52 @@ def _create_key_pair(self, account, region, param_key_material=None, key name """ if param_key_name: - self.logger.info("Looking up values in SSM parameter:{}" - .format(param_key_name)) + self.logger.info( + "Looking up values in SSM parameter:{}".format(param_key_name) + ) existing_param = self.ssm.describe_parameters(param_key_name) if existing_param: return self.ssm.get_parameter(param_key_name) - key_name = sanitize("%s_%s_%s_%s" % ('custom_control_tower', account, - region, - time.strftime("%Y-%m-%dT%H-%M-%S") - )) + key_name = sanitize( + "%s_%s_%s_%s" + % ( + "custom_control_tower", + account, + region, + time.strftime("%Y-%m-%dT%H-%M-%S"), + ) + ) ec2 = self._session(region, account) # create EC2 key pair in member account - self.logger.info("Create key pair in the member account {} in" - " region: {}".format(account, region)) + self.logger.info( + "Create key pair in the member account {} in" + " region: {}".format(account, region) + ) response = ec2.create_key_pair(key_name) # add key material and fingerprint in the SSM Parameter Store self.logger.info("Adding Key Material and Fingerprint to SSM PS") - description = "Contains EC2 key pair asset created by Custom " \ - "Control Tower Solution: " \ - "EC2 Key Pair Custom Resource." + description = ( + "Contains EC2 key pair asset created by Custom " + "Control Tower Solution: " + "EC2 Key Pair Custom Resource." + ) # Get Custom Control Tower KMS Key ID key_id = self._get_kms_key_id() if param_key_fingerprint: - self.ssm.put_parameter_use_cmk(param_key_fingerprint, response - .get('KeyFingerprint'), - key_id, description) + self.ssm.put_parameter_use_cmk( + param_key_fingerprint, + response.get("KeyFingerprint"), + key_id, + description, + ) if param_key_material: - self.ssm.put_parameter_use_cmk(param_key_material, response - .get('KeyMaterial'), - key_id, description) + self.ssm.put_parameter_use_cmk( + param_key_material, response.get("KeyMaterial"), key_id, description + ) if param_key_name: self.ssm.put_parameter(param_key_name, key_name, description) @@ -172,34 +196,38 @@ def random_password(self, length, key_password=None, alphanum=True): alphanum (bool): [optional] if False it will also include ';:=+!@#%^&*()[]{}' in the character set """ - response = '_get_ssm_secure_string_' + key_password + response = "_get_ssm_secure_string_" + key_password param_exists = False if key_password: - self.logger.info("Looking up values in SSM parameter:{}" - .format(key_password)) + self.logger.info( + "Looking up values in SSM parameter:{}".format(key_password) + ) existing_param = self.ssm.describe_parameters(key_password) if existing_param: param_exists = True if not param_exists: - additional = '' + additional = "" if not alphanum: - additional = ';:=+!@#%^&*()[]{}' + additional = ";:=+!@#%^&*()[]{}" password = random_pwd_generator(length, additional) self.logger.info("Adding Random password to SSM Parameter Store") - description = "Contains random password created by Custom Control"\ - " Tower Solution" + description = ( + "Contains random password created by Custom Control" " Tower Solution" + ) if key_password: key_id = self._get_kms_key_id() - self.ssm.put_parameter_use_cmk(key_password, password, key_id, - description) + self.ssm.put_parameter_use_cmk( + key_password, password, key_id, description + ) return response - def update_params(self, params_in: list, account=None, region=None, - substitute_ssm_values=True): + def update_params( + self, params_in: list, account=None, region=None, substitute_ssm_values=True + ): """Updates SSM parameters Args: params_in (list): Python List of dict of input params e.g. @@ -225,21 +253,30 @@ def update_params(self, params_in: list, account=None, region=None, for param in params_in: key = param.get("ParameterKey") value = param.get("ParameterValue") - separator = ',' - value = value if separator not in value else \ - convert_string_to_list(value, separator) + separator = "," + value = ( + value + if separator not in value + else convert_string_to_list(value, separator) + ) if not isinstance(value, list): - value = self._process_alfred_helper(param, key, value, account, - region, - substitute_ssm_values) + value = self._process_alfred_helper( + param, key, value, account, region, substitute_ssm_values + ) else: new_value_list = [] for nested_value in value: new_value_list.append( - self._process_alfred_helper(param, key, nested_value, - account, region, - substitute_ssm_values)) + self._process_alfred_helper( + param, + key, + nested_value, + account, + region, + substitute_ssm_values, + ) + ) value = new_value_list params_out.update({key: value}) @@ -247,8 +284,9 @@ def update_params(self, params_in: list, account=None, region=None, self.logger.info("params out : {}".format(params_out)) return params_out - def _process_alfred_helper(self, param, key, value, account=None, - region=None, substitute_ssm_values=True): + def _process_alfred_helper( + self, param, key, value, account=None, region=None, substitute_ssm_values=True + ): """Parses and processes alfred helpers 'alfred_ ' Args: @@ -298,7 +336,7 @@ def _update_alfred_ssm(self, keyword, value, substitute_ssm_values): Return: value of the SSM parameter """ - ssm_param_name = trim_string_from_front(keyword, 'alfred_ssm_') + ssm_param_name = trim_string_from_front(keyword, "alfred_ssm_") param_flag = True if ssm_param_name: @@ -327,20 +365,23 @@ def _update_alfred_genkeypair(self, param, account, region): keymaterial_param_name = None keyfingerprint_param_name = None keyname_param_name = None - ssm_parameters = param.get('ssm_parameters', []) + ssm_parameters = param.get("ssm_parameters", []) if type(ssm_parameters) is list: for ssm_parameter in ssm_parameters: - val = ssm_parameter.get('value')[2:-1] - if val.lower() == 'keymaterial': - keymaterial_param_name = ssm_parameter.get('name') - elif val.lower() == 'keyfingerprint': - keyfingerprint_param_name = ssm_parameter.get('name') - elif val.lower() == 'keyname': - keyname_param_name = ssm_parameter.get('name') - value = self._create_key_pair(account, region, - keymaterial_param_name, - keyfingerprint_param_name, - keyname_param_name) + val = ssm_parameter.get("value")[2:-1] + if val.lower() == "keymaterial": + keymaterial_param_name = ssm_parameter.get("name") + elif val.lower() == "keyfingerprint": + keyfingerprint_param_name = ssm_parameter.get("name") + elif val.lower() == "keyname": + keyname_param_name = ssm_parameter.get("name") + value = self._create_key_pair( + account, + region, + keymaterial_param_name, + keyfingerprint_param_name, + keyname_param_name, + ) return value def _update_alfred_genpass(self, keyword, param): @@ -355,19 +396,19 @@ def _update_alfred_genpass(self, keyword, param): Return: generated random password """ - sub_string = trim_string_from_front(keyword, 'alfred_genpass_') + sub_string = trim_string_from_front(keyword, "alfred_genpass_") if sub_string: pw_length = int(sub_string) else: pw_length = 8 password_param_name = None - ssm_parameters = param.get('ssm_parameters', []) + ssm_parameters = param.get("ssm_parameters", []) if type(ssm_parameters) is list: for ssm_parameter in ssm_parameters: - val = ssm_parameter.get('value')[2:-1] - if val.lower() == 'password': - password_param_name = ssm_parameter.get('name') + val = ssm_parameter.get("value")[2:-1] + if val.lower() == "password": + password_param_name = ssm_parameter.get("name") value = self.random_password(pw_length, password_param_name, False) return value @@ -385,19 +426,20 @@ def _update_alfred_genaz(self, keyword, param, account, region): Return: list of random az's """ - sub_string = trim_string_from_front(keyword, 'alfred_genaz_') + sub_string = trim_string_from_front(keyword, "alfred_genaz_") if sub_string: no_of_az = int(sub_string) else: no_of_az = 2 az_param_name = None - ssm_parameters = param.get('ssm_parameters', []) + ssm_parameters = param.get("ssm_parameters", []) if type(ssm_parameters) is list: for ssm_parameter in ssm_parameters: - val = ssm_parameter.get('value')[2:-1] - if val.lower() == 'az': - az_param_name = ssm_parameter.get('name') + val = ssm_parameter.get("value")[2:-1] + if val.lower() == "az": + az_param_name = ssm_parameter.get("name") value = self.get_azs_from_member_account( - region, no_of_az, account, az_param_name) + region, no_of_az, account, az_param_name + ) return value diff --git a/source/src/cfct/manifest/manifest.py b/source/src/cfct/manifest/manifest.py index ee05dc4..8379f57 100755 --- a/source/src/cfct/manifest/manifest.py +++ b/source/src/cfct/manifest/manifest.py @@ -14,8 +14,7 @@ ############################################################################## import yorm -from yorm.types import String, Boolean -from yorm.types import List, AttributeDictionary +from yorm.types import AttributeDictionary, Boolean, List, String @yorm.attr(name=String) @@ -113,8 +112,7 @@ def __init__(self): @yorm.attr(description=String) @yorm.attr(apply_to_accounts_in_ou=ApplyToOUList) class Policy(AttributeDictionary): - def __init__(self, name, policy_file, description, - apply_to_accounts_in_ou): + def __init__(self, name, policy_file, description, apply_to_accounts_in_ou): super().__init__() self.name = name self.description = description @@ -137,8 +135,17 @@ def __init__(self): @yorm.attr(deployment_targets=DeployTargets) @yorm.attr(parameters=Parameters) class ResourceProps(AttributeDictionary): - def __init__(self, name, resource_file, parameters, parameter_file, - deploy_method, deployment_targets, export_outputs, regions): + def __init__( + self, + name, + resource_file, + parameters, + parameter_file, + deploy_method, + deployment_targets, + export_outputs, + regions, + ): super().__init__() self.name = name self.resource_file = resource_file diff --git a/source/src/cfct/manifest/manifest_parser.py b/source/src/cfct/manifest/manifest_parser.py index 0af74a0..8c10ceb 100644 --- a/source/src/cfct/manifest/manifest_parser.py +++ b/source/src/cfct/manifest/manifest_parser.py @@ -13,32 +13,40 @@ # and limitations under the License. # ############################################################################### +import json import os import sys -import json -from typing import List, Dict, Any -from cfct.utils.logger import Logger -from cfct.manifest.manifest import Manifest -from cfct.manifest.stage_to_s3 import StageFile -from cfct.manifest.sm_input_builder import InputBuilder, SCPResourceProperties, \ - StackSetResourceProperties -from cfct.utils.parameter_manipulation import transform_params -from cfct.utils.string_manipulation import convert_list_values_to_string, \ - empty_separator_handler, list_sanitizer -from cfct.aws.services.s3 import S3 +from typing import Any, Dict, List + +from cfct.aws.services.cloudformation import StackSet from cfct.aws.services.organizations import Organizations +from cfct.aws.services.s3 import S3 from cfct.manifest.cfn_params_handler import CFNParamsHandler +from cfct.manifest.manifest import Manifest +from cfct.manifest.sm_input_builder import ( + InputBuilder, + SCPResourceProperties, + StackSetResourceProperties, +) +from cfct.manifest.stage_to_s3 import StageFile from cfct.metrics.solution_metrics import SolutionMetrics -from cfct.aws.services.cloudformation import StackSet +from cfct.utils.logger import Logger +from cfct.utils.parameter_manipulation import transform_params +from cfct.utils.string_manipulation import ( + convert_list_values_to_string, + empty_separator_handler, + list_sanitizer, +) + +VERSION_1 = "2020-01-01" +VERSION_2 = "2021-03-15" -VERSION_1 = '2020-01-01' -VERSION_2 = '2021-03-15' +logger = Logger(loglevel=os.environ["LOG_LEVEL"]) -logger = Logger(loglevel=os.environ['LOG_LEVEL']) def scp_manifest(): # determine manifest version - manifest = Manifest(os.environ.get('MANIFEST_FILE_PATH')) + manifest = Manifest(os.environ.get("MANIFEST_FILE_PATH")) if manifest.version == VERSION_1: get_scp_input = SCPParser() return get_scp_input.parse_scp_manifest_v1() @@ -49,7 +57,7 @@ def scp_manifest(): def stack_set_manifest(): # determine manifest version - manifest = Manifest(os.environ.get('MANIFEST_FILE_PATH')) + manifest = Manifest(os.environ.get("MANIFEST_FILE_PATH")) send = SolutionMetrics(logger) if manifest.version == VERSION_1: data = {"ManifestVersion": VERSION_1} @@ -77,13 +85,13 @@ class SCPParser: def __init__(self): self.logger = logger - self.manifest = Manifest(os.environ.get('MANIFEST_FILE_PATH')) + self.manifest = Manifest(os.environ.get("MANIFEST_FILE_PATH")) def parse_scp_manifest_v1(self) -> list: state_machine_inputs = [] self.logger.info( - "Processing SCPs from {} file".format( - os.environ.get('MANIFEST_FILE_PATH'))) + "Processing SCPs from {} file".format(os.environ.get("MANIFEST_FILE_PATH")) + ) build = BuildStateMachineInput(self.manifest.region) org_data = OrganizationsData() for policy in self.manifest.organization_policies: @@ -94,20 +102,20 @@ def parse_scp_manifest_v1(self) -> list: self.logger.debug( "[manifest_parser.parse_scp_manifest_v1] attach_ou_list: {} ".format( - attach_ou_list)) + attach_ou_list + ) + ) # Add ou id to final ou list final_ou_list = org_data.get_final_ou_list(attach_ou_list) - state_machine_inputs.append(build.scp_sm_input( - final_ou_list, - policy, - policy_url)) + state_machine_inputs.append( + build.scp_sm_input(final_ou_list, policy, policy_url) + ) # Exit if there are no organization policies if len(state_machine_inputs) == 0: - self.logger.info("Organization policies not found" - " in the manifest.") + self.logger.info("Organization policies not found" " in the manifest.") sys.exit(0) else: return state_machine_inputs @@ -116,32 +124,33 @@ def parse_scp_manifest_v2(self) -> list: state_machine_inputs = [] self.logger.info( "[manifest_parser.parse_scp_manifest_v2] Processing SCPs from {} file".format( - os.environ.get('MANIFEST_FILE_PATH'))) + os.environ.get("MANIFEST_FILE_PATH") + ) + ) build = BuildStateMachineInput(self.manifest.region) org_data = OrganizationsData() for resource in self.manifest.resources: - if resource.deploy_method == 'scp': + if resource.deploy_method == "scp": local_file = StageFile(self.logger, resource.resource_file) policy_url = local_file.get_staged_file() - attach_ou_list = set( - resource.deployment_targets.organizational_units) + attach_ou_list = set(resource.deployment_targets.organizational_units) self.logger.debug( "[manifest_parser.parse_scp_manifest_v2] attach_ou_list: {} ".format( - attach_ou_list)) + attach_ou_list + ) + ) # Add ou id to final ou list final_ou_list = org_data.get_final_ou_list(attach_ou_list) - state_machine_inputs.append(build.scp_sm_input( - final_ou_list, - resource, - policy_url)) + state_machine_inputs.append( + build.scp_sm_input(final_ou_list, resource, policy_url) + ) # Exit if there are no organization policies if len(state_machine_inputs) == 0: - self.logger.info("Organization policies not found" - " in the manifest.") + self.logger.info("Organization policies not found" " in the manifest.") sys.exit(0) else: return state_machine_inputs @@ -162,13 +171,16 @@ class StackSetParser: def __init__(self): self.logger = logger self.stack_set = StackSet(logger) - self.manifest = Manifest(os.environ.get('MANIFEST_FILE_PATH')) - self.manifest_folder = os.environ.get('MANIFEST_FOLDER') + self.manifest = Manifest(os.environ.get("MANIFEST_FILE_PATH")) + self.manifest_folder = os.environ.get("MANIFEST_FOLDER") def parse_stack_set_manifest_v1(self) -> list: - self.logger.info("Parsing Core Resources from {} file" - .format(os.environ.get('MANIFEST_FILE_PATH'))) + self.logger.info( + "Parsing Core Resources from {} file".format( + os.environ.get("MANIFEST_FILE_PATH") + ) + ) build = BuildStateMachineInput(self.manifest.region) org = OrganizationsData() organizations_data = org.get_organization_details() @@ -183,48 +195,53 @@ def parse_stack_set_manifest_v1(self) -> list: accounts_in_ou = org.get_accounts_in_ou( organizations_data.get("OuIdToAccountMap"), organizations_data.get("OuNameToIdMap"), - resource.deploy_to_ou + resource.deploy_to_ou, ) # convert account numbers to string type - account_list = convert_list_values_to_string( - resource.deploy_to_account) + account_list = convert_list_values_to_string(resource.deploy_to_account) self.logger.info(">>>>>> ACCOUNT LIST") self.logger.info(account_list) sanitized_account_list = org.get_final_account_list( - account_list, organizations_data.get("AccountsInAllOUs"), - accounts_in_ou, organizations_data.get("NameToAccountMap")) + account_list, + organizations_data.get("AccountsInAllOUs"), + accounts_in_ou, + organizations_data.get("NameToAccountMap"), + ) - self.logger.info("Print merged account list - accounts in manifest" - " + account under OU in manifest") + self.logger.info( + "Print merged account list - accounts in manifest" + " + account under OU in manifest" + ) self.logger.info(sanitized_account_list) - if resource.deploy_method.lower() == 'stack_set': + if resource.deploy_method.lower() == "stack_set": sm_input = build.stack_set_state_machine_input_v1( - resource, sanitized_account_list) + resource, sanitized_account_list + ) state_machine_inputs.append(sm_input) else: raise ValueError( f"Unsupported deploy_method: {resource.deploy_method} " - f"found for resource {resource.name}") + f"found for resource {resource.name}" + ) self.logger.info(f"<<<<<<<<< FINISH : {resource.name} <<<<<<<<<") # Exit if there are no CloudFormation resources if len(state_machine_inputs) == 0: - self.logger.info("CloudFormation resources not found in the " - "manifest") + self.logger.info("CloudFormation resources not found in the " "manifest") sys.exit(0) else: return state_machine_inputs - - - def parse_stack_set_manifest_v2(self) -> list: - self.logger.info("Parsing Core Resources from {} file" - .format(os.environ.get('MANIFEST_FILE_PATH'))) + self.logger.info( + "Parsing Core Resources from {} file".format( + os.environ.get("MANIFEST_FILE_PATH") + ) + ) build = BuildStateMachineInput(self.manifest.region) org = OrganizationsData() organizations_data = org.get_organization_details() @@ -235,10 +252,18 @@ def parse_stack_set_manifest_v2(self) -> list: manifest_stacksets: List[str] = [] for resource in self.manifest.resources: if resource["deploy_method"] == StackSet.DEPLOY_METHOD: - manifest_stacksets.append(resource['name']) - - stacksets_to_be_deleted = self.stack_set.get_stack_sets_not_present_in_manifest(manifest_stack_sets=manifest_stacksets) - state_machine_inputs.extend(self.stack_set.generate_delete_request(stacksets_to_delete=stacksets_to_be_deleted)) + manifest_stacksets.append(resource["name"]) + + stacksets_to_be_deleted = ( + self.stack_set.get_stack_sets_not_present_in_manifest( + manifest_stack_sets=manifest_stacksets + ) + ) + state_machine_inputs.extend( + self.stack_set.generate_delete_request( + stacksets_to_delete=stacksets_to_be_deleted + ) + ) for resource in self.manifest.resources: if resource.deploy_method == StackSet.DEPLOY_METHOD: @@ -250,37 +275,44 @@ def parse_stack_set_manifest_v2(self) -> list: accounts_in_ou = org.get_accounts_in_ou( organizations_data.get("OuIdToAccountMap"), organizations_data.get("OuNameToIdMap"), - resource.deployment_targets.organizational_units + resource.deployment_targets.organizational_units, ) # convert account numbers to string type account_list = convert_list_values_to_string( - resource.deployment_targets.accounts) + resource.deployment_targets.accounts + ) self.logger.info(">>>>>> ACCOUNT LIST") self.logger.info(account_list) sanitized_account_list = org.get_final_account_list( - account_list, organizations_data.get("AccountsInAllNestedOUs"), - accounts_in_ou, organizations_data.get("NameToAccountMap")) + account_list, + organizations_data.get("AccountsInAllNestedOUs"), + accounts_in_ou, + organizations_data.get("NameToAccountMap"), + ) - self.logger.info("Print merged account list - accounts in " - "manifest + account under OU in manifest") + self.logger.info( + "Print merged account list - accounts in " + "manifest + account under OU in manifest" + ) self.logger.info(sanitized_account_list) - if resource.deploy_method.lower() == 'stack_set': + if resource.deploy_method.lower() == "stack_set": sm_input = build.stack_set_state_machine_input_v2( - resource, sanitized_account_list) + resource, sanitized_account_list + ) state_machine_inputs.append(sm_input) else: raise ValueError( f"Unsupported deploy_method: {resource.deploy_method} " - f"found for resource {resource.name}") + f"found for resource {resource.name}" + ) self.logger.info(f"<<<<<<<<< FINISH : {resource.name} <<<<<<<<") # Exit if there are no CloudFormation resources if len(state_machine_inputs) == 0: - self.logger.info("CloudFormation resources not found in the " - "manifest") + self.logger.info("CloudFormation resources not found in the " "manifest") sys.exit(0) else: return state_machine_inputs @@ -295,7 +327,7 @@ class BuildStateMachineInput: def __init__(self, region): self.logger = logger self.param_handler = CFNParamsHandler(logger) - self.manifest_folder = os.environ.get('MANIFEST_FOLDER') + self.manifest_folder = os.environ.get("MANIFEST_FOLDER") self.region = region self.s3 = S3(logger) @@ -303,12 +335,11 @@ def scp_sm_input(self, attach_ou_list, policy, policy_url) -> dict: ou_list = [] for ou in attach_ou_list: - ou_list.append((ou, 'Attach')) + ou_list.append((ou, "Attach")) - resource_properties = SCPResourceProperties(policy.name, - policy.description, - policy_url, - ou_list) + resource_properties = SCPResourceProperties( + policy.name, policy.description, policy_url, ou_list + ) scp_input = InputBuilder(resource_properties.get_scp_input_map()) sm_input = scp_input.input_map() @@ -338,21 +369,23 @@ def stack_set_state_machine_input_v1(self, resource, account_list) -> dict: else: parameters = [] - sm_params = self.param_handler.update_params(parameters, account_list, - region, False) + sm_params = self.param_handler.update_params( + parameters, account_list, region, False + ) ssm_parameters = self._create_ssm_input_map(resource.ssm_parameters) # generate state machine input list stack_set_name = "CustomControlTower-{}".format(resource.name) - resource_properties = StackSetResourceProperties(stack_set_name, - template_url, - sm_params, - os.environ - .get('CAPABILITIES'), - account_list, - region_list, - ssm_parameters) + resource_properties = StackSetResourceProperties( + stack_set_name, + template_url, + sm_params, + os.environ.get("CAPABILITIES"), + account_list, + region_list, + ssm_parameters, + ) ss_input = InputBuilder(resource_properties.get_stack_set_input_map()) return ss_input.input_map() @@ -372,43 +405,41 @@ def stack_set_state_machine_input_v2(self, resource, account_list) -> dict: # if parameter file link is provided for the CFN resource if resource.parameter_file == "": - self.logger.info("parameter_file property not found in the " - "manifest") + self.logger.info("parameter_file property not found in the " "manifest") self.logger.info(resource.parameter_file) self.logger.info(resource.parameters) parameters = self._load_params_from_manifest(resource.parameters) elif not resource.parameters: - self.logger.info("parameters property not found in the " - "manifest") + self.logger.info("parameters property not found in the " "manifest") self.logger.info(resource.parameter_file) self.logger.info(resource.parameters) parameters = self._load_params_from_file(resource.parameter_file) - sm_params = self.param_handler.update_params(parameters, account_list, - region, False) + sm_params = self.param_handler.update_params( + parameters, account_list, region, False + ) - self.logger.info("Input Parameters for State Machine: {}".format( - sm_params)) + self.logger.info("Input Parameters for State Machine: {}".format(sm_params)) ssm_parameters = self._create_ssm_input_map(resource.export_outputs) # generate state machine input list stack_set_name = "CustomControlTower-{}".format(resource.name) - resource_properties = StackSetResourceProperties(stack_set_name, - template_url, - sm_params, - os.environ - .get('CAPABILITIES'), - account_list, - region_list, - ssm_parameters) + resource_properties = StackSetResourceProperties( + stack_set_name, + template_url, + sm_params, + os.environ.get("CAPABILITIES"), + account_list, + region_list, + ssm_parameters, + ) ss_input = InputBuilder(resource_properties.get_stack_set_input_map()) return ss_input.input_map() def _load_params_from_manifest(self, parameter_list: list): - self.logger.info("Replace the keys with CloudFormation " - "Parameter data type") + self.logger.info("Replace the keys with CloudFormation " "Parameter data type") params_list = [] for item in parameter_list: # must initialize params inside loop to avoid overwriting values @@ -420,16 +451,14 @@ def _load_params_from_manifest(self, parameter_list: list): return params_list def _load_params_from_file(self, relative_parameter_path): - if relative_parameter_path.lower().startswith('s3://'): + if relative_parameter_path.lower().startswith("s3://"): parameter_file = self.s3.get_s3_object(relative_parameter_path) else: - parameter_file = os.path.join(self.manifest_folder, - relative_parameter_path) + parameter_file = os.path.join(self.manifest_folder, relative_parameter_path) - self.logger.info("Parsing the parameter file: {}".format( - parameter_file)) + self.logger.info("Parsing the parameter file: {}".format(parameter_file)) - with open(parameter_file, 'r') as content_file: + with open(parameter_file, "r") as content_file: parameter_file_content = content_file.read() params = json.loads(parameter_file_content) @@ -441,9 +470,7 @@ def _create_ssm_input_map(self, ssm_parameters): for ssm_parameter in ssm_parameters: key = ssm_parameter.name value = ssm_parameter.value - ssm_value = self.param_handler.update_params( - transform_params({key: value}) - ) + ssm_value = self.param_handler.update_params(transform_params({key: value})) ssm_input_map.update(ssm_value) return ssm_input_map @@ -458,52 +485,75 @@ class OrganizationsData: def __init__(self): self.logger = logger self.stack_set = StackSet(logger) - self.control_tower_baseline_config_stackset = os.environ['CONTROL_TOWER_BASELINE_CONFIG_STACKSET'] \ - if os.getenv('CONTROL_TOWER_BASELINE_CONFIG_STACKSET') is not None else 'AWSControlTowerBP-BASELINE-CONFIG' + self.control_tower_baseline_config_stackset = ( + os.environ["CONTROL_TOWER_BASELINE_CONFIG_STACKSET"] + if os.getenv("CONTROL_TOWER_BASELINE_CONFIG_STACKSET") is not None + else "AWSControlTowerBP-BASELINE-CONFIG" + ) def get_accounts_in_ou(self, ou_id_to_account_map, ou_name_to_id_map, ou_list): accounts_in_ou = [] ou_ids_manifest = [] accounts_in_nested_ou = [] - if 'Root' in ou_list: - accounts_list, region_list = self.get_accounts_in_ct_baseline_config_stack_set() + if "Root" in ou_list: + ( + accounts_list, + region_list, + ) = self.get_accounts_in_ct_baseline_config_stack_set() accounts_in_ou = accounts_list - else: + else: # convert OU Name to OU IDs for ou_name in ou_list: - if ':' in ou_name: # Process nested OU. For example: TestOU1:TestOU2:TestOU3 + if ( + ":" in ou_name + ): # Process nested OU. For example: TestOU1:TestOU2:TestOU3 ou_id = self.get_ou_id(ou_name, ":") accounts_in_nested_ou.extend(self.get_active_accounts_in_ou(ou_id)) - self.logger.debug("[manifest_parser.get_accounts_in_ou] ou_name: {}; ou_id: {}; accounts_in_nested_ou: {}" \ - .format(ou_name, ou_id, accounts_in_nested_ou)) + self.logger.debug( + "[manifest_parser.get_accounts_in_ou] ou_name: {}; ou_id: {}; accounts_in_nested_ou: {}".format( + ou_name, ou_id, accounts_in_nested_ou + ) + ) else: - ou_id = [value for key, value in ou_name_to_id_map.items() - if ou_name == key] + ou_id = [ + value + for key, value in ou_name_to_id_map.items() + if ou_name == key + ] ou_ids_manifest.extend(ou_id) - self.logger.debug("[manifest_parser.get_accounts_in_ou] ou_name: {}; ou_id: {}; ou_ids_manifest for non-nested ous: {}" \ - .format(ou_name, ou_id, ou_ids_manifest)) + self.logger.debug( + "[manifest_parser.get_accounts_in_ou] ou_name: {}; ou_id: {}; ou_ids_manifest for non-nested ous: {}".format( + ou_name, ou_id, ou_ids_manifest + ) + ) for ou_id, accounts in ou_id_to_account_map.items(): if ou_id in ou_ids_manifest: accounts_in_ou.extend(accounts) - self.logger.debug("[manifest_parser.get_accounts_in_ou] Accounts in non_nested OUs: {}" \ - .format(accounts_in_ou)) + self.logger.debug( + "[manifest_parser.get_accounts_in_ou] Accounts in non_nested OUs: {}".format( + accounts_in_ou + ) + ) - self.logger.debug("[manifest_parser.get_accounts_in_ou] Accounts in nested OUs: {}" \ - .format(accounts_in_nested_ou)) + self.logger.debug( + "[manifest_parser.get_accounts_in_ou] Accounts in nested OUs: {}".format( + accounts_in_nested_ou + ) + ) - # add accounts for nested ous + # add accounts for nested ous accounts_in_ou.extend(accounts_in_nested_ou) - self.logger.info(">>> Accounts: {} in OUs: {}" - .format(accounts_in_ou, ou_list)) + self.logger.info(">>> Accounts: {} in OUs: {}".format(accounts_in_ou, ou_list)) return accounts_in_ou - def get_final_account_list(self, account_list, accounts_in_all_ous, - accounts_in_ou, name_to_account_map): + def get_final_account_list( + self, account_list, accounts_in_all_ous, accounts_in_ou, name_to_account_map + ): # separate account id and emails name_list = [] new_account_list = [] @@ -520,12 +570,15 @@ def get_final_account_list(self, account_list, accounts_in_all_ous, if name_list: # convert OU Name to OU IDs for name in name_list: - name_account = [value for key, value in - name_to_account_map.items() - if name.lower() == key.lower()] + name_account = [ + value + for key, value in name_to_account_map.items() + if name.lower() == key.lower() + ] self.logger.info(f"==== name_account: {name_account}") - self.logger.info("%%%%%%% Name {} - Account {}" - .format(name, name_account)) + self.logger.info( + "%%%%%%% Name {} - Account {}".format(name, name_account) + ) new_account_list.extend(name_account) # Remove account ids from the manifest that is not # in the organization or not active @@ -563,8 +616,9 @@ def get_organization_details(self) -> dict: # 2) Accounts for each OU at the root level. # use case: map OU Name to account IDs # key: OU ID (str); value: Active accounts (list) - accounts_in_all_ous, ou_id_to_account_map = \ - self._get_accounts_in_ou(org, all_ou_ids) + accounts_in_all_ous, ou_id_to_account_map = self._get_accounts_in_ou( + org, all_ou_ids + ) # Returns account name in manifest to account id mapping. # key: account name; value: account id @@ -579,7 +633,7 @@ def get_organization_details(self) -> dict: "OuNameToIdMap": ou_name_to_id_map, "NameToAccountMap": name_to_account_map, "ActiveAccountsForRoot": active_account_list, - "AccountsInAllNestedOUs": accounts_in_all_nested_ous + "AccountsInAllNestedOUs": accounts_in_all_nested_ous, } def _get_ou_ids(self, org): @@ -602,10 +656,10 @@ def _get_ou_ids(self, org): for ou_at_root_level in ou_list_at_root_level: # build list of all the OU IDs under Org root - _all_ou_ids.append(ou_at_root_level.get('Id')) + _all_ou_ids.append(ou_at_root_level.get("Id")) # build a list of ou id _ou_name_to_id_map.update( - {ou_at_root_level.get('Name'): ou_at_root_level.get('Id')} + {ou_at_root_level.get("Name"): ou_at_root_level.get("Id")} ) self.logger.info("Print OU Name to OU ID Map") @@ -617,12 +671,11 @@ def _get_root_id(self, org): response = org.list_roots() self.logger.info("Response: List Roots") self.logger.info(response) - return response['Roots'][0].get('Id') + return response["Roots"][0].get("Id") def _list_ou_for_parent(self, org, parent_id): _ou_list = org.list_organizational_units_for_parent(parent_id) - self.logger.info("Print Organizational Units List under {}" - .format(parent_id)) + self.logger.info("Print Organizational Units List under {}".format(parent_id)) self.logger.info(_ou_list) return _ou_list @@ -635,23 +688,23 @@ def _get_accounts_in_ou(self, org, ou_id_list): _account_list = org.list_accounts_for_parent(_ou_id) for _account in _account_list: # filter ACTIVE and CREATED accounts - if _account.get('Status') == "ACTIVE": + if _account.get("Status") == "ACTIVE": # create a list of accounts in OU - accounts_in_all_ous.append(_account.get('Id')) - _accounts_in_ou.append(_account.get('Id')) + accounts_in_all_ous.append(_account.get("Id")) + _accounts_in_ou.append(_account.get("Id")) # create a map of accounts for each ou - self.logger.info("Creating Key:Value Mapping - " - "OU ID: {} ; Account List: {}" - .format(_ou_id, _accounts_in_ou)) + self.logger.info( + "Creating Key:Value Mapping - " + "OU ID: {} ; Account List: {}".format(_ou_id, _accounts_in_ou) + ) ou_id_to_account_map.update({_ou_id: _accounts_in_ou}) self.logger.info(ou_id_to_account_map) # reset list of accounts in the OU _accounts_in_ou = [] - self.logger.info("All accounts in OU List: {}" - .format(accounts_in_all_ous)) + self.logger.info("All accounts in OU List: {}".format(accounts_in_all_ous)) self.logger.info("OU to Account ID mapping") self.logger.info(ou_id_to_account_map) return accounts_in_all_ous, ou_id_to_account_map @@ -664,10 +717,8 @@ def get_account_for_name(self, org): _name_to_account_map = {} for account in account_list: if account.get("Status") == "ACTIVE": - active_account_list.append(account.get('Id')) - _name_to_account_map.update( - {account.get("Name"): account.get("Id")} - ) + active_account_list.append(account.get("Id")) + _name_to_account_map.update({account.get("Name"): account.get("Id")}) self.logger.info("Print Account Name > Account Mapping") self.logger.info(_name_to_account_map) @@ -677,68 +728,92 @@ def get_account_for_name(self, org): def get_final_ou_list(self, ou_list): # Get ou id given an ou name final_ou_list = [] - for ou_name in ou_list: - ou_id= self.get_ou_id(ou_name, ":") - this_ou_list= [ou_name, ou_id] + for ou_name in ou_list: + ou_id = self.get_ou_id(ou_name, ":") + this_ou_list = [ou_name, ou_id] final_ou_list.append(this_ou_list) - + self.logger.info( "[manifest_parser.get_final_ou_list] final_ou_list: {} ".format( - final_ou_list)) + final_ou_list + ) + ) return final_ou_list def get_ou_id(self, nested_ou_name, delimiter): org = Organizations(self.logger) response = org.list_roots() - root_id = response['Roots'][0].get('Id') - self.logger.info("[manifest_parser.get_ou_id] Organizations Root Id: {}".format(root_id)) + root_id = response["Roots"][0].get("Id") + self.logger.info( + "[manifest_parser.get_ou_id] Organizations Root Id: {}".format(root_id) + ) - if nested_ou_name == 'Root': + if nested_ou_name == "Root": return root_id else: - self.logger.info("[manifest_parser.get_ou_id] Looking up the OU Id for OUName: {} with nested" - " ou delimiter: '{}'".format(nested_ou_name, - delimiter)) + self.logger.info( + "[manifest_parser.get_ou_id] Looking up the OU Id for OUName: {} with nested" + " ou delimiter: '{}'".format(nested_ou_name, delimiter) + ) ou_id = self._get_ou_id(org, root_id, nested_ou_name, delimiter) if ou_id is None or len(ou_id) == 0: raise ValueError("OU id is not found for {}".format(nested_ou_name)) - + return ou_id def _get_ou_id(self, org, parent_id, nested_ou_name, delimiter): - nested_ou_name_list = empty_separator_handler( - delimiter, nested_ou_name) + nested_ou_name_list = empty_separator_handler(delimiter, nested_ou_name) response = self.list_ou_for_parent( - org, parent_id, list_sanitizer(nested_ou_name_list)) - self.logger.info("[manifest_parser._get_ou_id] _list_ou_for_parent response: {}".format(response)) + org, parent_id, list_sanitizer(nested_ou_name_list) + ) + self.logger.info( + "[manifest_parser._get_ou_id] _list_ou_for_parent response: {}".format( + response + ) + ) return response - + def list_ou_for_parent(self, org, parent_id, nested_ou_name_list): ou_list = org.list_organizational_units_for_parent(parent_id) index = 0 # always process the first item - self.logger.debug("[manifest_parser.list_ou_id_for_parent] nested_ou_name_list: {}" - .format(nested_ou_name_list)) - self.logger.debug("[manifest_parser.list_ou_id_for_parent] ou_list: {} for parent id {}" - .format(ou_list, parent_id)) + self.logger.debug( + "[manifest_parser.list_ou_id_for_parent] nested_ou_name_list: {}".format( + nested_ou_name_list + ) + ) + self.logger.debug( + "[manifest_parser.list_ou_id_for_parent] ou_list: {} for parent id {}".format( + ou_list, parent_id + ) + ) for dictionary in ou_list: - self.logger.debug("[manifest_parser.list_ou_id_for_parent] dictionary:{}".format(dictionary)) - if dictionary.get('Name') == nested_ou_name_list[index]: - self.logger.info("[manifest_parser.list_ou_id_for_parent] OU Name: {} exists under parent id: {}" - .format(dictionary.get('Name'), - parent_id)) + self.logger.debug( + "[manifest_parser.list_ou_id_for_parent] dictionary:{}".format( + dictionary + ) + ) + if dictionary.get("Name") == nested_ou_name_list[index]: + self.logger.info( + "[manifest_parser.list_ou_id_for_parent] OU Name: {} exists under parent id: {}".format( + dictionary.get("Name"), parent_id + ) + ) # pop the first item in the list nested_ou_name_list.pop(index) if len(nested_ou_name_list) == 0: - self.logger.info("[manifest_parser.list_ou_id_for_parent] Returning last level OU ID: {}" - .format(dictionary.get('Id'))) - return dictionary.get('Id') + self.logger.info( + "[manifest_parser.list_ou_id_for_parent] Returning last level OU ID: {}".format( + dictionary.get("Id") + ) + ) + return dictionary.get("Id") else: - return self.list_ou_for_parent(org, - dictionary.get('Id'), - nested_ou_name_list) + return self.list_ou_for_parent( + org, dictionary.get("Id"), nested_ou_name_list + ) def get_active_accounts_in_ou(self, ou_id): """ @@ -749,22 +824,35 @@ def get_active_accounts_in_ou(self, ou_id): account_list = org.list_accounts_for_parent(ou_id) for account in account_list: # filter ACTIVE and CREATED accounts - if account.get('Status') == "ACTIVE": - active_accounts_in_ou.append(account.get('Id')) + if account.get("Status") == "ACTIVE": + active_accounts_in_ou.append(account.get("Id")) - self.logger.info("All active accounts in nested OU %s:" %(ou_id)) + self.logger.info("All active accounts in nested OU %s:" % (ou_id)) self.logger.info(active_accounts_in_ou) - + return active_accounts_in_ou def get_accounts_in_ct_baseline_config_stack_set(self): """ This function gets active accounts which the control tower baseline config stackset deploys to """ - accounts_list, region_list = self.stack_set.get_accounts_and_regions_per_stack_set(self.control_tower_baseline_config_stackset) + ( + accounts_list, + region_list, + ) = self.stack_set.get_accounts_and_regions_per_stack_set( + self.control_tower_baseline_config_stackset + ) - self.logger.info("[manifest_parser.get_accounts_in_ct_baseline_config_stack_set] All active accounts in control tower baseline config stackset: {}".format(accounts_list)) - self.logger.info("[manifest_parser.get_accounts_in_ct_baseline_config_stack_set] All regions in control tower baseline stackset: {}".format(region_list)) + self.logger.info( + "[manifest_parser.get_accounts_in_ct_baseline_config_stack_set] All active accounts in control tower baseline config stackset: {}".format( + accounts_list + ) + ) + self.logger.info( + "[manifest_parser.get_accounts_in_ct_baseline_config_stack_set] All regions in control tower baseline stackset: {}".format( + region_list + ) + ) return accounts_list, region_list @@ -774,9 +862,12 @@ def get_master_account_id_in_org(self): """ org = Organizations(self.logger) response = org.describe_organization() - master_account_id = response['Organization'].get('MasterAccountId') + master_account_id = response["Organization"].get("MasterAccountId") - self.logger.info("[manifest_parser.get_master_account_id_in_org] Master account id: %s" %(master_account_id)) + self.logger.info( + "[manifest_parser.get_master_account_id_in_org] Master account id: %s" + % (master_account_id) + ) return master_account_id @@ -789,6 +880,10 @@ def get_all_accounts_in_all_nested_ous(self): accounts_list.append(master_account_id) - self.logger.info("[manifest_parser.get_all_accounts_in_all_ous] All active accounts in control tower baseline config stackset plus master account: {}".format(accounts_list)) + self.logger.info( + "[manifest_parser.get_all_accounts_in_all_ous] All active accounts in control tower baseline config stackset plus master account: {}".format( + accounts_list + ) + ) - return accounts_list \ No newline at end of file + return accounts_list diff --git a/source/src/cfct/manifest/sm_execution_manager.py b/source/src/cfct/manifest/sm_execution_manager.py index 99a1639..2e503bc 100644 --- a/source/src/cfct/manifest/sm_execution_manager.py +++ b/source/src/cfct/manifest/sm_execution_manager.py @@ -13,23 +13,23 @@ # and limitations under the License. # ############################################################################### +import filecmp import os -import time import tempfile -import filecmp +import time from uuid import uuid4 + from botocore.exceptions import ClientError +from cfct.aws.services.cloudformation import StackSet from cfct.aws.services.s3 import S3 from cfct.aws.services.state_machine import StateMachine -from cfct.aws.services.cloudformation import StackSet -from cfct.exceptions import StackSetHasFailedInstances -from cfct.utils.string_manipulation import trim_length_from_end from cfct.aws.utils.url_conversion import parse_bucket_key_names -from cfct.utils.parameter_manipulation import transform_params, \ - reverse_transform_params -from cfct.utils.list_comparision import compare_lists -from cfct.metrics.solution_metrics import SolutionMetrics +from cfct.exceptions import StackSetHasFailedInstances from cfct.manifest.cfn_params_handler import CFNParamsHandler +from cfct.metrics.solution_metrics import SolutionMetrics +from cfct.utils.list_comparision import compare_lists +from cfct.utils.parameter_manipulation import reverse_transform_params, transform_params +from cfct.utils.string_manipulation import trim_length_from_end class SMExecutionManager: @@ -42,51 +42,50 @@ def __init__(self, logger, sm_input_list, enforce_successful_stack_instances=Fal self.param_handler = CFNParamsHandler(logger) self.state_machine = StateMachine(logger) self.stack_set = StackSet(logger) - self.wait_time = os.environ.get('WAIT_TIME') - self.execution_mode = os.environ.get('EXECUTION_MODE') + self.wait_time = os.environ.get("WAIT_TIME") + self.execution_mode = os.environ.get("EXECUTION_MODE") self.enforce_successful_stack_instances = enforce_successful_stack_instances def launch_executions(self): self.logger.info("%%% Launching State Machine Execution %%%") - if self.execution_mode.upper() == 'PARALLEL': + if self.execution_mode.upper() == "PARALLEL": self.logger.info(" | | | | | Running Parallel Mode. | | | | |") return self.run_execution_parallel_mode() - elif self.execution_mode.upper() == 'SEQUENTIAL': + elif self.execution_mode.upper() == "SEQUENTIAL": self.logger.info(" > > > > > Running Sequential Mode. > > > > >") return self.run_execution_sequential_mode() else: - raise ValueError("Invalid execution mode: {}" - .format(self.execution_mode)) + raise ValueError("Invalid execution mode: {}".format(self.execution_mode)) def run_execution_sequential_mode(self): status, failed_execution_list = None, [] # start executions at given intervals for sm_input in self.sm_input_list: updated_sm_input = self.populate_ssm_params(sm_input) - stack_set_name = sm_input.get('ResourceProperties')\ - .get('StackSetName', '') + stack_set_name = sm_input.get("ResourceProperties").get("StackSetName", "") is_deletion = sm_input.get("RequestType").lower() == "Delete".lower() if is_deletion: start_execution_flag = True - else: - template_matched, parameters_matched = \ - self.compare_template_and_params(sm_input, stack_set_name) - - self.logger.info("Stack Set Name: {} | " - "Same Template?: {} | " - "Same Parameters?: {}" - .format(stack_set_name, - template_matched, - parameters_matched)) + else: + template_matched, parameters_matched = self.compare_template_and_params( + sm_input, stack_set_name + ) + + self.logger.info( + "Stack Set Name: {} | " + "Same Template?: {} | " + "Same Parameters?: {}".format( + stack_set_name, template_matched, parameters_matched + ) + ) if template_matched and parameters_matched and self.stack_set_exist: start_execution_flag = self.compare_stack_instances( - sm_input, - stack_set_name + sm_input, stack_set_name ) # template and parameter does not require update - updated_sm_input.update({'SkipUpdateStackSet': 'yes'}) + updated_sm_input.update({"SkipUpdateStackSet": "yes"}) else: # the template or parameters needs to be updated # start SM execution @@ -95,28 +94,34 @@ def run_execution_sequential_mode(self): if start_execution_flag: sm_exec_name = self.get_sm_exec_name(updated_sm_input) - sm_exec_arn = self.setup_execution(updated_sm_input, - sm_exec_name) + sm_exec_arn = self.setup_execution(updated_sm_input, sm_exec_name) self.list_sm_exec_arns.append(sm_exec_arn) - status, failed_execution_list = \ - self.monitor_state_machines_execution_status() - if status == 'FAILED': + ( + status, + failed_execution_list, + ) = self.monitor_state_machines_execution_status() + if status == "FAILED": return status, failed_execution_list if self.enforce_successful_stack_instances: try: self.enforce_stack_set_deployment_successful(stack_set_name) except ClientError as error: - if (is_deletion and error.response['Error']['Code'] == "StackSetNotFoundException"): + if ( + is_deletion + and error.response["Error"]["Code"] + == "StackSetNotFoundException" + ): pass else: raise error - else: - self.logger.info("State Machine execution completed. " - "Starting next execution...") + self.logger.info( + "State Machine execution completed. " + "Starting next execution..." + ) self.logger.info("All State Machine executions completed.") return status, failed_execution_list @@ -128,17 +133,15 @@ def run_execution_parallel_mode(self): self.list_sm_exec_arns.append(sm_exec_arn) time.sleep(int(self.wait_time)) # monitor execution status - status, failed_execution_list = \ - self.monitor_state_machines_execution_status() + status, failed_execution_list = self.monitor_state_machines_execution_status() return status, failed_execution_list @staticmethod def get_sm_exec_name(sm_input): - if os.environ.get('STAGE_NAME').upper() == 'SCP': - return sm_input.get('ResourceProperties')\ - .get('PolicyDocument').get('Name') - elif os.environ.get('STAGE_NAME').upper() == 'STACKSET': - return sm_input.get('ResourceProperties').get('StackSetName') + if os.environ.get("STAGE_NAME").upper() == "SCP": + return sm_input.get("ResourceProperties").get("PolicyDocument").get("Name") + elif os.environ.get("STAGE_NAME").upper() == "STACKSET": + return sm_input.get("ResourceProperties").get("StackSetName") else: return str(uuid4()) # return random string @@ -146,36 +149,38 @@ def setup_execution(self, sm_input, name): self.logger.info("State machine Input: {}".format(sm_input)) # set execution name - exec_name = "%s-%s-%s" % (sm_input.get('RequestType'), - trim_length_from_end(name.replace(" ", ""), - 50), - time.strftime("%Y-%m-%dT%H-%M-%S")) + exec_name = "%s-%s-%s" % ( + sm_input.get("RequestType"), + trim_length_from_end(name.replace(" ", ""), 50), + time.strftime("%Y-%m-%dT%H-%M-%S"), + ) # execute all SM at regular interval of wait_time - return self.state_machine.start_execution(os.environ.get('SM_ARN'), - sm_input, - exec_name) + return self.state_machine.start_execution( + os.environ.get("SM_ARN"), sm_input, exec_name + ) def populate_ssm_params(self, sm_input): """The scenario is if you have one CFN resource that exports output - from CFN stack to SSM parameter and then the next CFN resource - reads the SSM parameter as input, then it has to wait for the first - CFN resource to finish; read the SSM parameters and use its value - as input for second CFN resource's input for SM. Get the parameters - for CFN template from sm_input + from CFN stack to SSM parameter and then the next CFN resource + reads the SSM parameter as input, then it has to wait for the first + CFN resource to finish; read the SSM parameters and use its value + as input for second CFN resource's input for SM. Get the parameters + for CFN template from sm_input """ - self.logger.info("Populating SSM parameter values for SM input: {}" - .format(sm_input)) - params = sm_input.get('ResourceProperties')\ - .get('Parameters', {}) + self.logger.info( + "Populating SSM parameter values for SM input: {}".format(sm_input) + ) + params = sm_input.get("ResourceProperties").get("Parameters", {}) # First transform it from {name: value} to [{'ParameterKey': name}, # {'ParameterValue': value}] # then replace the SSM parameter names with its values sm_params = self.param_handler.update_params(transform_params(params)) # Put it back into the self.state_machine_event - sm_input.get('ResourceProperties').update({'Parameters': sm_params}) - self.logger.info("Done populating SSM parameter values for SM input:" - " {}".format(sm_input)) + sm_input.get("ResourceProperties").update({"Parameters": sm_params}) + self.logger.info( + "Done populating SSM parameter values for SM input:" " {}".format(sm_input) + ) return sm_input def compare_template_and_params(self, sm_input, stack_name): @@ -183,15 +188,14 @@ def compare_template_and_params(self, sm_input, stack_name): self.logger.info("Comparing the templates and parameters.") template_compare, params_compare = False, False if stack_name: - describe_response = self.stack_set\ - .describe_stack_set(stack_name) - self.logger.info("Print Describe Stack Set Response: {}" - .format(describe_response)) + describe_response = self.stack_set.describe_stack_set(stack_name) + self.logger.info( + "Print Describe Stack Set Response: {}".format(describe_response) + ) if describe_response is not None: self.logger.info("Found existing stack set.") - operation_status_flag = self.get_stack_set_operation_status( - stack_name) + operation_status_flag = self.get_stack_set_operation_status(stack_name) if operation_status_flag: self.logger.info("Continuing...") @@ -199,12 +203,14 @@ def compare_template_and_params(self, sm_input, stack_name): return operation_status_flag, operation_status_flag # Compare template copy - START - self.logger.info("Comparing the template of the StackSet:" - " {} with local copy of template" - .format(stack_name)) - - template_http_url = sm_input.get('ResourceProperties')\ - .get('TemplateURL', '') + self.logger.info( + "Comparing the template of the StackSet:" + " {} with local copy of template".format(stack_name) + ) + + template_http_url = sm_input.get("ResourceProperties").get( + "TemplateURL", "" + ) if template_http_url: bucket_name, key_name, region = parse_bucket_key_names( template_http_url @@ -212,49 +218,51 @@ def compare_template_and_params(self, sm_input, stack_name): local_template_file = tempfile.mkstemp()[1] s3_endpoint_url = "https://s3.%s.amazonaws.com" % region - s3 = S3(self.logger, - region=region, - endpoint_url=s3_endpoint_url) - s3.download_file(bucket_name, key_name, - local_template_file) + s3 = S3(self.logger, region=region, endpoint_url=s3_endpoint_url) + s3.download_file(bucket_name, key_name, local_template_file) else: - self.logger.error("TemplateURL in state machine input " - "is empty. Check state_machine_event" - ":{}".format(sm_input)) + self.logger.error( + "TemplateURL in state machine input " + "is empty. Check state_machine_event" + ":{}".format(sm_input) + ) return False, False cfn_template_file = tempfile.mkstemp()[1] with open(cfn_template_file, "w") as f: - f.write(describe_response.get('StackSet') - .get('TemplateBody')) + f.write(describe_response.get("StackSet").get("TemplateBody")) # cmp function return true of the contents are same - template_compare = filecmp.cmp(local_template_file, - cfn_template_file, - False) - self.logger.info("Comparing the parameters of the StackSet" - ": {} with local copy of JSON parameters" - " file".format(stack_name)) + template_compare = filecmp.cmp( + local_template_file, cfn_template_file, False + ) + self.logger.info( + "Comparing the parameters of the StackSet" + ": {} with local copy of JSON parameters" + " file".format(stack_name) + ) params_compare = True - params = sm_input.get('ResourceProperties')\ - .get('Parameters', {}) + params = sm_input.get("ResourceProperties").get("Parameters", {}) # template are same - compare parameters (skip if template # are not same) if template_compare: - cfn_params = reverse_transform_params(describe_response - .get('StackSet') - .get('Parameters') - ) + cfn_params = reverse_transform_params( + describe_response.get("StackSet").get("Parameters") + ) for key, value in params.items(): - if cfn_params.get(key, '') != value: + if cfn_params.get(key, "") != value: params_compare = False break - self.logger.info("template_compare={}; params_compare={}" - .format(template_compare, params_compare)) + self.logger.info( + "template_compare={}; params_compare={}".format( + template_compare, params_compare + ) + ) else: - self.logger.info('Stack Set does not exist. ' - 'Creating a new stack set ....') + self.logger.info( + "Stack Set does not exist. " "Creating a new stack set ...." + ) template_compare, params_compare = True, True # set this flag to create the stack set self.stack_set_exist = False @@ -262,23 +270,26 @@ def compare_template_and_params(self, sm_input, stack_name): return template_compare, params_compare def get_stack_set_operation_status(self, stack_name): - self.logger.info("Checking the status of last stack set " - "operation on {}".format(stack_name)) - response = self.stack_set. \ - list_stack_set_operations(StackSetName=stack_name, - MaxResults=1) - if response and response.get('Summaries'): - for instance in response.get('Summaries'): - self.logger.info("Status of last stack set " - "operation : {}" - .format(instance - .get('Status'))) - if instance.get('Status') != 'SUCCEEDED': - self.logger.info("The last stack operation" - " did not succeed. " - "Triggering " - " Update StackSet for {}" - .format(stack_name)) + self.logger.info( + "Checking the status of last stack set " + "operation on {}".format(stack_name) + ) + response = self.stack_set.list_stack_set_operations( + StackSetName=stack_name, MaxResults=1 + ) + if response and response.get("Summaries"): + for instance in response.get("Summaries"): + self.logger.info( + "Status of last stack set " + "operation : {}".format(instance.get("Status")) + ) + if instance.get("Status") != "SUCCEEDED": + self.logger.info( + "The last stack operation" + " did not succeed. " + "Triggering " + " Update StackSet for {}".format(stack_name) + ) return False return True @@ -293,36 +304,45 @@ def compare_stack_instances(self, sm_input: dict, stack_name: str) -> bool: on the StackSet # False: if no changes to stack instances are required """ - self.logger.info("Comparing deployed stack instances with " - "expected accounts & regions for " - "StackSet: {}".format(stack_name)) - expected_account_list = sm_input.get('ResourceProperties')\ - .get("AccountList", []) - expected_region_list = sm_input.get('ResourceProperties')\ - .get("RegionList", []) - - actual_account_list, actual_region_list = \ - self.stack_set.get_accounts_and_regions_per_stack_set(stack_name) - - self.logger.info("*** Stack instances expected to be deployed " - "in following accounts. ***") + self.logger.info( + "Comparing deployed stack instances with " + "expected accounts & regions for " + "StackSet: {}".format(stack_name) + ) + expected_account_list = sm_input.get("ResourceProperties").get( + "AccountList", [] + ) + expected_region_list = sm_input.get("ResourceProperties").get("RegionList", []) + + ( + actual_account_list, + actual_region_list, + ) = self.stack_set.get_accounts_and_regions_per_stack_set(stack_name) + + self.logger.info( + "*** Stack instances expected to be deployed " "in following accounts. ***" + ) self.logger.info(expected_account_list) - self.logger.info("*** Stack instances actually deployed " - "in following accounts. ***") + self.logger.info( + "*** Stack instances actually deployed " "in following accounts. ***" + ) self.logger.info(actual_account_list) - self.logger.info("*** Stack instances expected to be deployed " - "in following regions. ***") + self.logger.info( + "*** Stack instances expected to be deployed " "in following regions. ***" + ) self.logger.info(expected_region_list) - self.logger.info("*** Stack instances actually deployed " - "in following regions. ***") + self.logger.info( + "*** Stack instances actually deployed " "in following regions. ***" + ) self.logger.info(actual_region_list) self.logger.info("*** Comparing account lists ***") - accounts_matched = compare_lists(actual_account_list, - expected_account_list) + accounts_matched = compare_lists(actual_account_list, expected_account_list) self.logger.info("*** Comparing region lists ***") - regions_matched = compare_lists(actual_region_list, - expected_region_list,) + regions_matched = compare_lists( + actual_region_list, + expected_region_list, + ) if accounts_matched and regions_matched: self.logger.info("No need to add or remove stack instances.") return False @@ -332,49 +352,53 @@ def compare_stack_instances(self, sm_input: dict, stack_name: str) -> bool: def monitor_state_machines_execution_status(self): if self.list_sm_exec_arns: - final_status = 'RUNNING' + final_status = "RUNNING" - while final_status == 'RUNNING': + while final_status == "RUNNING": for sm_exec_arn in self.list_sm_exec_arns: - status = self.state_machine.check_state_machine_status( - sm_exec_arn) - if status == 'RUNNING': - final_status = 'RUNNING' + status = self.state_machine.check_state_machine_status(sm_exec_arn) + if status == "RUNNING": + final_status = "RUNNING" time.sleep(int(self.wait_time)) break else: - final_status = 'COMPLETED' + final_status = "COMPLETED" err_flag = False failed_sm_execution_list = [] for sm_exec_arn in self.list_sm_exec_arns: - status = self.state_machine.check_state_machine_status( - sm_exec_arn) - if status == 'SUCCEEDED': + status = self.state_machine.check_state_machine_status(sm_exec_arn) + if status == "SUCCEEDED": continue else: failed_sm_execution_list.append(sm_exec_arn) err_flag = True if err_flag: - return 'FAILED', failed_sm_execution_list + return "FAILED", failed_sm_execution_list else: - return 'SUCCEEDED', [] + return "SUCCEEDED", [] else: - self.logger.info("SM Execution List {} is empty, nothing to " - "monitor.".format(self.list_sm_exec_arns)) + self.logger.info( + "SM Execution List {} is empty, nothing to " + "monitor.".format(self.list_sm_exec_arns) + ) return None, [] def enforce_stack_set_deployment_successful(self, stack_set_name: str) -> None: - failed_detailed_statuses = [ - "CANCELLED", - "FAILED", - "INOPERABLE" + failed_detailed_statuses = ["CANCELLED", "FAILED", "INOPERABLE"] + list_filters = [ + {"Name": "DETAILED_STATUS", "Values": status} + for status in failed_detailed_statuses ] - list_filters = [{"Name": "DETAILED_STATUS", "Values": status} for status in failed_detailed_statuses] # Note that we don't paginate because if this API returns any elements, failed instances exist. for list_filter in list_filters: - response = self.stack_set.cfn_client.list_stack_instances(StackSetName=stack_set_name, Filters=[list_filter]) + response = self.stack_set.cfn_client.list_stack_instances( + StackSetName=stack_set_name, Filters=[list_filter] + ) if response.get("Summaries", []): - raise StackSetHasFailedInstances(stack_set_name=stack_set_name, failed_stack_set_instances=response["Summaries"]) + raise StackSetHasFailedInstances( + stack_set_name=stack_set_name, + failed_stack_set_instances=response["Summaries"], + ) return None diff --git a/source/src/cfct/manifest/sm_input_builder.py b/source/src/cfct/manifest/sm_input_builder.py index fc77600..8d7fa5a 100644 --- a/source/src/cfct/manifest/sm_input_builder.py +++ b/source/src/cfct/manifest/sm_input_builder.py @@ -35,8 +35,7 @@ class InputBuilder(StateMachineInput): """ - def __init__(self, resource_properties, request_type='Create', - skip_stack_set='no'): + def __init__(self, resource_properties, request_type="Create", skip_stack_set="no"): self._request_type = request_type self._resource_properties = resource_properties self._skip_stack_set = skip_stack_set @@ -44,9 +43,9 @@ def __init__(self, resource_properties, request_type='Create', def input_map(self) -> dict: input_map = { "RequestType": self._request_type, - "ResourceProperties": self._resource_properties + "ResourceProperties": self._resource_properties, } - if getenv('STAGE_NAME').upper() == 'STACKSET': + if getenv("STAGE_NAME").upper() == "STACKSET": input_map.update({"SkipUpdateStackSet": self._skip_stack_set}) return input_map @@ -67,9 +66,17 @@ class SCPResourceProperties: """ - def __init__(self, policy_name, policy_description, policy_url, ou_list, - policy_list=None, account_id='', operation='', - ou_name_delimiter=':'): + def __init__( + self, + policy_name, + policy_description, + policy_url, + ou_list, + policy_list=None, + account_id="", + operation="", + ou_name_delimiter=":", + ): self._policy_name = policy_name self._policy_description = policy_description self._policy_url = policy_url @@ -86,38 +93,46 @@ def get_scp_input_map(self): "PolicyList": self._policy_list, "Operation": self._operation, "OUList": self._ou_list, - "OUNameDelimiter": self._ou_name_delimiter + "OUNameDelimiter": self._ou_name_delimiter, } def _get_policy_document(self): return { "Name": self._policy_name, "Description": self._policy_description, - "PolicyURL": self._policy_url + "PolicyURL": self._policy_url, } class StackSetResourceProperties: """ - This class helps create and return input needed to execute Stack Set - state machine. This also defines the required keys to execute the state - machine. - - Example: - - resource_properties = StackSetResourceProperties(stack_set_name, - template_url, - parameters, - capabilities, - account_list, - region_list, - ssm_parameters) - ss_input = InputBuilder(resource_properties.get_stack_set_input_map()) - sm_input = ss_input.input_map() - """ - - def __init__(self, stack_set_name, template_url, parameters, - capabilities, account_list, region_list, ssm_parameters): + This class helps create and return input needed to execute Stack Set + state machine. This also defines the required keys to execute the state + machine. + + Example: + + resource_properties = StackSetResourceProperties(stack_set_name, + template_url, + parameters, + capabilities, + account_list, + region_list, + ssm_parameters) + ss_input = InputBuilder(resource_properties.get_stack_set_input_map()) + sm_input = ss_input.input_map() + """ + + def __init__( + self, + stack_set_name, + template_url, + parameters, + capabilities, + account_list, + region_list, + ssm_parameters, + ): self._stack_set_name = stack_set_name self._template_url = template_url self._parameters = parameters @@ -134,7 +149,7 @@ def get_stack_set_input_map(self): "Parameters": self._get_cfn_parameters(), "AccountList": self._get_account_list(), "RegionList": self._get_region_list(), - "SSMParameters": self._get_ssm_parameters() + "SSMParameters": self._get_ssm_parameters(), } def _get_cfn_parameters(self): diff --git a/source/src/cfct/manifest/stage_to_s3.py b/source/src/cfct/manifest/stage_to_s3.py index 181ef78..4c4d850 100644 --- a/source/src/cfct/manifest/stage_to_s3.py +++ b/source/src/cfct/manifest/stage_to_s3.py @@ -15,8 +15,9 @@ # !/bin/python import os + from cfct.aws.services.s3 import S3 -from cfct.aws.utils.url_conversion import convert_s3_url_to_http_url, build_http_url +from cfct.aws.utils.url_conversion import build_http_url, convert_s3_url_to_http_url class StageFile(S3): @@ -26,14 +27,15 @@ class StageFile(S3): boto_3 = Boto3Session(logger, region, service_name, **kwargs) client = boto_3.get_client() """ + def __init__(self, logger, relative_file_path): """ - Parameters - ---------- - logger : object - The logger object - relative_file_path : str - Relative Path of the file. + Parameters + ---------- + logger : object + The logger object + relative_file_path : str + Relative Path of the file. """ self.logger = logger self.relative_file_path = relative_file_path @@ -45,9 +47,9 @@ def get_staged_file(self): :return: S3 URL, type: String """ - if self.relative_file_path.lower().startswith('s3'): + if self.relative_file_path.lower().startswith("s3"): return self.convert_url() - elif self.relative_file_path.lower().startswith('http'): + elif self.relative_file_path.lower().startswith("http"): return self.relative_file_path else: return self.stage_file() @@ -64,17 +66,16 @@ def stage_file(self): :return: S3 URL, type: String """ - local_file = os.path.join(os.environ.get('MANIFEST_FOLDER'), - self.relative_file_path) - key_name = "{}/{}".format(os.environ.get('TEMPLATE_KEY_PREFIX'), - self.relative_file_path) - self.logger.info("Uploading the template file: {} to S3 bucket: {} " - "and key: {}".format(local_file, - os.environ.get('STAGING_BUCKET'), - key_name)) - super().upload_file(os.environ.get('STAGING_BUCKET'), - local_file, - key_name) - http_url = build_http_url(os.environ.get('STAGING_BUCKET'), - key_name) + local_file = os.path.join( + os.environ.get("MANIFEST_FOLDER"), self.relative_file_path + ) + key_name = "{}/{}".format( + os.environ.get("TEMPLATE_KEY_PREFIX"), self.relative_file_path + ) + self.logger.info( + "Uploading the template file: {} to S3 bucket: {} " + "and key: {}".format(local_file, os.environ.get("STAGING_BUCKET"), key_name) + ) + super().upload_file(os.environ.get("STAGING_BUCKET"), local_file, key_name) + http_url = build_http_url(os.environ.get("STAGING_BUCKET"), key_name) return http_url diff --git a/source/src/cfct/metrics/solution_metrics.py b/source/src/cfct/metrics/solution_metrics.py index 6adde77..8663d22 100644 --- a/source/src/cfct/metrics/solution_metrics.py +++ b/source/src/cfct/metrics/solution_metrics.py @@ -13,19 +13,21 @@ # and limitations under the License. # ############################################################################### -from json import dumps +import os from datetime import datetime -from cfct.aws.services.ssm import SSM +from json import dumps + import requests -import os +from cfct.aws.services.ssm import SSM from cfct.utils.decimal_encoder import DecimalEncoder class SolutionMetrics(object): """This class is used to send anonymous metrics from customer using - the solution to the Solutions Builder team when customer choose to - have their data sent during the deployment of the solution. + the solution to the Solutions Builder team when customer choose to + have their data sent during the deployment of the solution. """ + def __init__(self, logger): self.logger = logger self.ssm = SSM(logger) @@ -36,12 +38,15 @@ def _get_parameter_value(self, key): if response: value = self.ssm.get_parameter(key) else: - value = 'ssm-param-key-not-found' + value = "ssm-param-key-not-found" return value - def solution_metrics(self, data, - solution_id=os.environ.get('SOLUTION_ID'), - url=os.environ.get('METRICS_URL')): + def solution_metrics( + self, + data, + solution_id=os.environ.get("SOLUTION_ID"), + url=os.environ.get("METRICS_URL"), + ): """Sends anonymous customer metrics to s3 via API gateway owned and managed by the Solutions Builder team. @@ -54,17 +59,14 @@ def solution_metrics(self, data, Return: status code returned by https post request """ try: - send_metrics = self._get_parameter_value('/org/primary/' - 'metrics_flag') - if send_metrics.lower() == 'yes': - uuid = self._get_parameter_value('/org/primary/customer_uuid') - time_stamp = {'TimeStamp': str(datetime.utcnow().isoformat())} - params = {'Solution': solution_id, - 'UUID': uuid, - 'Data': data} + send_metrics = self._get_parameter_value("/org/primary/" "metrics_flag") + if send_metrics.lower() == "yes": + uuid = self._get_parameter_value("/org/primary/customer_uuid") + time_stamp = {"TimeStamp": str(datetime.utcnow().isoformat())} + params = {"Solution": solution_id, "UUID": uuid, "Data": data} metrics = dict(time_stamp, **params) json_data = dumps(metrics, cls=DecimalEncoder) - headers = {'content-type': 'application/json'} + headers = {"content-type": "application/json"} r = requests.post(url, data=json_data, headers=headers) code = r.status_code return code diff --git a/source/src/cfct/state_machine_handler.py b/source/src/cfct/state_machine_handler.py index 2e5b3cd..436a6c1 100644 --- a/source/src/cfct/state_machine_handler.py +++ b/source/src/cfct/state_machine_handler.py @@ -16,18 +16,19 @@ # !/bin/python import inspect import json -import time import tempfile +import time from random import randint + from botocore.exceptions import ClientError +from cfct.aws.services.cloudformation import Stacks, StackSet from cfct.aws.services.organizations import Organizations as Org +from cfct.aws.services.s3 import S3 from cfct.aws.services.scp import ServiceControlPolicy as SCP -from cfct.aws.services.cloudformation import StackSet, Stacks -from cfct.aws.services.sts import AssumeRole from cfct.aws.services.ssm import SSM -from cfct.aws.services.s3 import S3 -from cfct.metrics.solution_metrics import SolutionMetrics +from cfct.aws.services.sts import AssumeRole from cfct.aws.utils.url_conversion import parse_bucket_key_names +from cfct.metrics.solution_metrics import SolutionMetrics class CloudFormation(object): @@ -37,32 +38,32 @@ class CloudFormation(object): def __init__(self, event, logger): self.event = event - self.params = event.get('ResourceProperties') + self.params = event.get("ResourceProperties") self.logger = logger self.logger.info(self.__class__.__name__ + " Class Event") self.logger.info(event) def describe_stack_set(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) # add loop flag to handle Skip StackSet Update choice - if self.event.get('LoopFlag') is None: - self.event.update({'LoopFlag': 'not-applicable'}) + if self.event.get("LoopFlag") is None: + self.event.update({"LoopFlag": "not-applicable"}) # To prevent CFN from throwing 'Response object is too long.' # when the event payload gets overloaded Deleting the # 'OldResourceProperties' from event, since it not being used in # the SM - if self.event.get('OldResourceProperties'): - self.event.pop('OldResourceProperties', '') + if self.event.get("OldResourceProperties"): + self.event.pop("OldResourceProperties", "") # Check if stack set already exist stack_set = StackSet(self.logger) - response = stack_set.describe_stack_set( - self.params.get('StackSetName')) + response = stack_set.describe_stack_set(self.params.get("StackSetName")) self.logger.info("Describe Response") self.logger.info(response) # If stack_set already exist, skip to create the stack_set_instance @@ -72,38 +73,42 @@ def describe_stack_set(self): else: value = "no" self.logger.info("Existing stack set not found.") - self.event.update({'StackSetExist': value}) + self.event.update({"StackSetExist": value}) return self.event def describe_stack_set_operation(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - self.event.update({'RetryDeleteFlag': False}) + self.event.update({"RetryDeleteFlag": False}) stack_set = StackSet(self.logger) response = stack_set.describe_stack_set_operation( - self.params.get('StackSetName'), - self.event.get('OperationId')) + self.params.get("StackSetName"), self.event.get("OperationId") + ) self.logger.info(response) - operation_status = response.get('StackSetOperation', {}).get('Status') + operation_status = response.get("StackSetOperation", {}).get("Status") self.logger.info("Operation Status: {}".format(operation_status)) - if operation_status == 'FAILED': - account_id = self.params.get('AccountList')[0] \ - if type(self.params.get('AccountList')) is list \ - else None + if operation_status == "FAILED": + account_id = ( + self.params.get("AccountList")[0] + if type(self.params.get("AccountList")) is list + else None + ) if account_id: - for region in self.params.get('RegionList'): - self.logger.info("Account: {} - describing stack " - "instance in {} region" - .format(account_id, region)) + for region in self.params.get("RegionList"): + self.logger.info( + "Account: {} - describing stack " + "instance in {} region".format(account_id, region) + ) try: resp = stack_set.describe_stack_instance( - self.params.get('StackSetName'), - account_id, - region) - self.event.update({region: resp.get( - 'StackInstance', {}).get('StatusReason')}) + self.params.get("StackSetName"), account_id, region + ) + self.event.update( + {region: resp.get("StackInstance", {}).get("StatusReason")} + ) except ClientError as e: # When CFN has triggered StackInstance delete and # the SCP is still attached (due to race condition) @@ -112,68 +117,78 @@ def describe_stack_set_operation(self): # exception back, the CFN stack in target account # ends up with 'DELETE_FAILED' state # so it should try again - if e.response['Error']['Code'] == \ - 'StackInstanceNotFoundException' and \ - self.event.get('RequestType') == 'Delete': + if ( + e.response["Error"]["Code"] + == "StackInstanceNotFoundException" + and self.event.get("RequestType") == "Delete" + ): self.logger.exception( "Caught exception" "'StackInstanceNotFoundException'," "sending the flag to go back to " - " Delete Stack Instances stage...") - self.event.update({'RetryDeleteFlag': True}) + " Delete Stack Instances stage..." + ) + self.event.update({"RetryDeleteFlag": True}) - operation_status = response.get('StackSetOperation', {}).get('Status') - self.event.update({'OperationStatus': operation_status}) + operation_status = response.get("StackSetOperation", {}).get("Status") + self.event.update({"OperationStatus": operation_status}) return self.event def list_stack_instances_account_ids(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - if self.event.get('NextToken') is None \ - or self.event.get('NextToken') == 'Complete': + if ( + self.event.get("NextToken") is None + or self.event.get("NextToken") == "Complete" + ): accounts = [] else: - accounts = self.event.get('StackInstanceAccountList', []) + accounts = self.event.get("StackInstanceAccountList", []) # Check if stack instances exist stack_set = StackSet(self.logger) - if self.event.get('NextToken') is not None and \ - self.event.get('NextToken') != 'Complete': + if ( + self.event.get("NextToken") is not None + and self.event.get("NextToken") != "Complete" + ): response = stack_set.list_stack_instances( - StackSetName=self.params.get('StackSetName'), + StackSetName=self.params.get("StackSetName"), MaxResults=20, - NextToken=self.event.get('NextToken')) + NextToken=self.event.get("NextToken"), + ) else: response = stack_set.list_stack_instances( - StackSetName=self.params.get('StackSetName'), - MaxResults=20) + StackSetName=self.params.get("StackSetName"), MaxResults=20 + ) self.logger.info("List SI Accounts Response") self.logger.info(response) if response: - if not response.get('Summaries'): # 'True' if list is empty - self.event.update({'NextToken': 'Complete'}) - self.logger.info("No existing stack instances found." - " (Summaries List: Empty)") + if not response.get("Summaries"): # 'True' if list is empty + self.event.update({"NextToken": "Complete"}) + self.logger.info( + "No existing stack instances found." " (Summaries List: Empty)" + ) else: - for instance in response.get('Summaries'): - account_id = instance.get('Account') + for instance in response.get("Summaries"): + account_id = instance.get("Account") accounts.append(account_id) - self.event.update({'StackInstanceAccountList': - list(set(accounts))}) - self.logger.info("Next Token Returned: {}" - .format(response.get('NextToken'))) - - if response.get('NextToken') is None: - self.event.update({'NextToken': 'Complete'}) - self.logger.info("No existing stack instances found." - " (Summaries List: Empty)") + self.event.update({"StackInstanceAccountList": list(set(accounts))}) + self.logger.info( + "Next Token Returned: {}".format(response.get("NextToken")) + ) + + if response.get("NextToken") is None: + self.event.update({"NextToken": "Complete"}) + self.logger.info( + "No existing stack instances found." " (Summaries List: Empty)" + ) else: - self.event.update({'NextToken': - response.get('NextToken')}) + self.event.update({"NextToken": response.get("NextToken")}) return self.event def list_stack_instances(self): @@ -187,60 +202,69 @@ def list_stack_instances(self): Raises: """ - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - if 'ParameterOverrides' in self.params.keys(): + if "ParameterOverrides" in self.params.keys(): self.logger.info("Override parameters found in the event") - self.event.update({'OverrideParametersExist': 'yes'}) + self.event.update({"OverrideParametersExist": "yes"}) else: self.logger.info("Override parameters NOT found in the event") - self.event.update({'OverrideParametersExist': 'no'}) + self.event.update({"OverrideParametersExist": "no"}) # Check if stack instances exist stack_set = StackSet(self.logger) # if account list is not present then only create StackSet # and skip stack instance creation - if type(self.params.get('AccountList')) is not list or \ - not self.params.get('AccountList'): + if type(self.params.get("AccountList")) is not list or not self.params.get( + "AccountList" + ): self._set_skip_stack_instance_operation() return self.event else: # proceed if account list exists - account_id = self.params.get('AccountList')[0] + account_id = self.params.get("AccountList")[0] # if this is 2nd round, fetch one of the existing accounts # that hasn't been processed in the first round - if self.event.get('ActiveAccountList') is not None \ - and self.event.get('ActiveRegionList') is not None \ - and self.params.get('AccountList') != \ - self.event.get('ActiveAccountList'): - account_id = self._add_list(self.params.get('AccountList'), - self.event.get('ActiveAccountList') - )[0] + if ( + self.event.get("ActiveAccountList") is not None + and self.event.get("ActiveRegionList") is not None + and self.params.get("AccountList") + != self.event.get("ActiveAccountList") + ): + account_id = self._add_list( + self.params.get("AccountList"), self.event.get("ActiveAccountList") + )[0] - self.logger.info("Account Id for list stack instance: {}" - .format(account_id)) + self.logger.info( + "Account Id for list stack instance: {}".format(account_id) + ) - if self.event.get('NextToken') is not None and \ - self.event.get('NextToken') != 'Complete': + if ( + self.event.get("NextToken") is not None + and self.event.get("NextToken") != "Complete" + ): - self.logger.info('Found next token') + self.logger.info("Found next token") response = stack_set.list_stack_instances( - StackSetName=self.params.get('StackSetName'), + StackSetName=self.params.get("StackSetName"), StackInstanceAccount=account_id, MaxResults=20, - NextToken=self.event.get('NextToken') - ) + NextToken=self.event.get("NextToken"), + ) else: - self.logger.info('Next token not found.') + self.logger.info("Next token not found.") response = stack_set.list_stack_instances( - StackSetName=self.params.get('StackSetName'), + StackSetName=self.params.get("StackSetName"), StackInstanceAccount=account_id, - MaxResults=20) - self.logger.info("List Stack Instance Response" - " for account: {}".format(account_id)) + MaxResults=20, + ) + self.logger.info( + "List Stack Instance Response" " for account: {}".format(account_id) + ) self.logger.info(response) if response is not None: @@ -250,8 +274,10 @@ def list_stack_instances(self): # instance operation is needed. # Therefore here set values as input for step functions # to trigger create operation accordingly. - if not response.get('Summaries') and \ - self.event.get('StackInstanceAccountList') is None: + if ( + not response.get("Summaries") + and self.event.get("StackInstanceAccountList") is None + ): self._set_only_create_stack_instance_operation() return self.event @@ -260,85 +286,98 @@ def list_stack_instances(self): # to determine what operations (create, update, delete) # that step functions should perform. else: - existing_region_list = [] \ - if self.event.get('ExistingRegionList') is None \ - else self.event.get('ExistingRegionList') - existing_account_list = [] \ - if self.event.get('StackInstanceAccountList') \ - is None \ - else self.event.get('StackInstanceAccountList') - - if response.get('Summaries'): - self.logger.info("Found existing stack instance for " - "AccountList.") - self.event.update({'InstanceExist': 'yes'}) - existing_region_list = \ - self._get_existing_stack_instance_info( - response.get('Summaries'), - existing_region_list) + existing_region_list = ( + [] + if self.event.get("ExistingRegionList") is None + else self.event.get("ExistingRegionList") + ) + existing_account_list = ( + [] + if self.event.get("StackInstanceAccountList") is None + else self.event.get("StackInstanceAccountList") + ) + + if response.get("Summaries"): + self.logger.info( + "Found existing stack instance for " "AccountList." + ) + self.event.update({"InstanceExist": "yes"}) + existing_region_list = self._get_existing_stack_instance_info( + response.get("Summaries"), existing_region_list + ) # If there are no stack instances for new account list # but there are some for existing accounts that are # not in the new account list, get the info about # those stack instances. - elif self.event.get('StackInstanceAccountList') \ - is not None and len(existing_region_list) == 0: - account_id = self.event.get( - 'StackInstanceAccountList')[0] + elif ( + self.event.get("StackInstanceAccountList") is not None + and len(existing_region_list) == 0 + ): + account_id = self.event.get("StackInstanceAccountList")[0] response = stack_set.list_stack_instances( - StackSetName=self.params.get('StackSetName'), + StackSetName=self.params.get("StackSetName"), StackInstanceAccount=account_id, - MaxResults=20) - self.logger.info("List Stack Instance Response for" - " StackInstanceAccountList") + MaxResults=20, + ) + self.logger.info( + "List Stack Instance Response for" + " StackInstanceAccountList" + ) self.logger.info(response) - if response.get('Summaries'): - self.logger.info("Found existing stack instances " - "for StackInstanceAccountList.") - self.event.update({'InstanceExist': 'yes'}) - existing_region_list = \ + if response.get("Summaries"): + self.logger.info( + "Found existing stack instances " + "for StackInstanceAccountList." + ) + self.event.update({"InstanceExist": "yes"}) + existing_region_list = ( self._get_existing_stack_instance_info( - response.get('Summaries'), - existing_region_list) + response.get("Summaries"), existing_region_list + ) + ) else: - existing_region_list = \ - self.params.get('RegionList') - - self.logger.info("Updated existing region List: {}" - .format(existing_region_list)) - - self.logger.info("Next Token Returned: {}" - .format(response.get('NextToken'))) - - if response.get('NextToken') is None: - - add_region_list, delete_region_list, add_account_list,\ - delete_account_list = \ - self._get_add_delete_region_account_list( - existing_region_list, - existing_account_list) - self._set_loop_flag(add_region_list, - delete_region_list, - add_account_list, - delete_account_list) - self._update_event_for_add(add_account_list, - add_region_list) - self._update_event_for_delete(delete_account_list, - delete_region_list) - self.event.update({'ExistingRegionList': - existing_region_list}) + existing_region_list = self.params.get("RegionList") + + self.logger.info( + "Updated existing region List: {}".format(existing_region_list) + ) + + self.logger.info( + "Next Token Returned: {}".format(response.get("NextToken")) + ) + + if response.get("NextToken") is None: + + ( + add_region_list, + delete_region_list, + add_account_list, + delete_account_list, + ) = self._get_add_delete_region_account_list( + existing_region_list, existing_account_list + ) + self._set_loop_flag( + add_region_list, + delete_region_list, + add_account_list, + delete_account_list, + ) + self._update_event_for_add(add_account_list, add_region_list) + self._update_event_for_delete( + delete_account_list, delete_region_list + ) + self.event.update({"ExistingRegionList": existing_region_list}) else: - self.event.update({'NextToken': - response.get('NextToken')}) + self.event.update({"NextToken": response.get("NextToken")}) # Update the self.event with existing_region_list - self.event.update({'ExistingRegionList': - existing_region_list}) + self.event.update({"ExistingRegionList": existing_region_list}) return self.event return self.event def _set_loop_flag( - self, add_region_list, delete_region_list, - add_account_list, delete_account_list): + self, add_region_list, delete_region_list, add_account_list, delete_account_list + ): """set LoopFlag used to determine if state machine will run more than once. LoopFlag - Yes. State machine executes twice @@ -349,49 +388,55 @@ def _set_loop_flag( """ # both are not empty - region and account was added if add_account_list and add_region_list: - self.event.update({'LoopFlag': 'yes'}) + self.event.update({"LoopFlag": "yes"}) # both are not empty - region and account was deleted elif delete_account_list and delete_region_list: - self.event.update({'LoopFlag': 'yes'}) + self.event.update({"LoopFlag": "yes"}) else: - self.event.update({'LoopFlag': 'no'}) + self.event.update({"LoopFlag": "no"}) def _get_add_delete_region_account_list( - self, existing_region_list, existing_account_list): + self, existing_region_list, existing_account_list + ): """build region and account list for adding and deleting stack instances operations. Returns: None """ - self.logger.info("Existing region list: {}" - .format(existing_region_list)) - self.logger.info("Existing account list: {}" - .format(existing_account_list)) - + self.logger.info("Existing region list: {}".format(existing_region_list)) + self.logger.info("Existing account list: {}".format(existing_account_list)) + # replace the region list in the self.event - add_region_list = self._add_list(self.params.get('RegionList'), - existing_region_list) + add_region_list = self._add_list( + self.params.get("RegionList"), existing_region_list + ) self.logger.info("Add region list: {}".format(add_region_list)) # Build a region list if the event is from AVM - delete_region_list = self._delete_list(self.params.get('RegionList'), - existing_region_list) + delete_region_list = self._delete_list( + self.params.get("RegionList"), existing_region_list + ) self.logger.info("Delete region list: {}".format(delete_region_list)) - add_account_list = self._add_list(self.params.get('AccountList'), - existing_account_list) + add_account_list = self._add_list( + self.params.get("AccountList"), existing_account_list + ) self.logger.info("Add account list: {}".format(add_account_list)) - delete_account_list = self._delete_list(self.params.get('AccountList'), - existing_account_list) + delete_account_list = self._delete_list( + self.params.get("AccountList"), existing_account_list + ) self.logger.info("Delete account list: {}".format(delete_account_list)) - return add_region_list, delete_region_list, \ - add_account_list, delete_account_list + return ( + add_region_list, + delete_region_list, + add_account_list, + delete_account_list, + ) - def _get_existing_stack_instance_info( - self, response_summary, existing_region_list): + def _get_existing_stack_instance_info(self, response_summary, existing_region_list): """Iterate through response to check if stack instance exists in account and region in the given self.event. Fetch region and account list for existing stack instances. @@ -400,15 +445,18 @@ def _get_existing_stack_instance_info( None """ for instance in response_summary: - if instance.get('Region') not in existing_region_list: - self.logger.info("Region {} not in the region list." - "Adding it..." - .format(instance.get('Region'))) + if instance.get("Region") not in existing_region_list: + self.logger.info( + "Region {} not in the region list." + "Adding it...".format(instance.get("Region")) + ) # appending to the list - existing_region_list.append(instance.get('Region')) + existing_region_list.append(instance.get("Region")) else: - self.logger.info("Region {} already in the region list." - "Skipping...".format(instance.get('Region'))) + self.logger.info( + "Region {} already in the region list." + "Skipping...".format(instance.get("Region")) + ) return existing_region_list def _set_only_create_stack_instance_operation(self): @@ -418,15 +466,16 @@ def _set_only_create_stack_instance_operation(self): Returns: event """ - self.event.update({'InstanceExist': 'no'}) + self.event.update({"InstanceExist": "no"}) # exit loop - self.event.update({'NextToken': 'Complete'}) + self.event.update({"NextToken": "Complete"}) # create stack instance set to yes - self.event.update({'CreateInstance': 'yes'}) + self.event.update({"CreateInstance": "yes"}) # delete stack instance set to no - self.event.update({'DeleteInstance': 'no'}) - self.logger.info("No existing stack instances found." - " (Summaries List: Empty)") + self.event.update({"DeleteInstance": "no"}) + self.logger.info( + "No existing stack instances found." " (Summaries List: Empty)" + ) def _set_skip_stack_instance_operation(self): """Set values as input for step function to @@ -435,31 +484,30 @@ def _set_skip_stack_instance_operation(self): Returns: event """ - self.event.update({'InstanceExist': 'no'}) - self.event.update({'NextToken': 'Complete'}) - self.event.update({'CreateInstance': 'no'}) - self.event.update({'DeleteInstance': 'no'}) + self.event.update({"InstanceExist": "no"}) + self.event.update({"NextToken": "Complete"}) + self.event.update({"CreateInstance": "no"}) + self.event.update({"DeleteInstance": "no"}) - def _update_event_for_delete(self, delete_account_list, - delete_region_list): + def _update_event_for_delete(self, delete_account_list, delete_region_list): if delete_account_list or delete_region_list: - self.event.update({'DeleteAccountList': delete_account_list}) - self.event.update({'DeleteRegionList': delete_region_list}) - self.event.update({'DeleteInstance': 'yes'}) - self.event.update({'NextToken': 'Complete'}) + self.event.update({"DeleteAccountList": delete_account_list}) + self.event.update({"DeleteRegionList": delete_region_list}) + self.event.update({"DeleteInstance": "yes"}) + self.event.update({"NextToken": "Complete"}) else: - self.event.update({'DeleteInstance': 'no'}) - self.event.update({'NextToken': 'Complete'}) + self.event.update({"DeleteInstance": "no"}) + self.event.update({"NextToken": "Complete"}) def _update_event_for_add(self, add_account_list, add_region_list): if add_account_list or add_region_list: - self.event.update({'AddAccountList': add_account_list}) - self.event.update({'AddRegionList': add_region_list}) - self.event.update({'CreateInstance': 'yes'}) - self.event.update({'NextToken': 'Complete'}) + self.event.update({"AddAccountList": add_account_list}) + self.event.update({"AddRegionList": add_region_list}) + self.event.update({"CreateInstance": "yes"}) + self.event.update({"NextToken": "Complete"}) else: - self.event.update({'CreateInstance': 'no'}) - self.event.update({'NextToken': 'Complete'}) + self.event.update({"CreateInstance": "no"}) + self.event.update({"NextToken": "Complete"}) def _add_list(self, new_list, existing_list): if isinstance(new_list, list) and isinstance(existing_list, list): @@ -468,9 +516,11 @@ def _add_list(self, new_list, existing_list): add_list = list(event_set - event_set.intersection(existing_set)) return add_list else: - raise ValueError("Both variables must be list.\n" - "Variable 1: {} \n " - "Variable 2: {}".format(new_list, existing_list)) + raise ValueError( + "Both variables must be list.\n" + "Variable 1: {} \n " + "Variable 2: {}".format(new_list, existing_list) + ) def _delete_list(self, new_list, existing_list): if isinstance(new_list, list) and isinstance(existing_list, list): @@ -479,13 +529,15 @@ def _delete_list(self, new_list, existing_list): delete_list = list(event_set.union(existing_set) - event_set) return delete_list else: - raise ValueError("Both variables must be list.\n" - "Variable 1: {} \n " - "Variable 2: {}".format(new_list, existing_list)) + raise ValueError( + "Both variables must be list.\n" + "Variable 1: {} \n " + "Variable 2: {}".format(new_list, existing_list) + ) def _get_ssm_secure_string(self, parameters): - if parameters.get('ALZRegion'): - ssm = SSM(self.logger, parameters.get('ALZRegion')) + if parameters.get("ALZRegion"): + ssm = SSM(self.logger, parameters.get("ALZRegion")) else: ssm = SSM(self.logger) @@ -493,227 +545,239 @@ def _get_ssm_secure_string(self, parameters): self.logger.info(parameters) copy = parameters.copy() for key, value in copy.items(): - if type(value) is str and value.startswith( - '_get_ssm_secure_string_'): - ssm_param_key = value[len('_get_ssm_secure_string_'):] + if type(value) is str and value.startswith("_get_ssm_secure_string_"): + ssm_param_key = value[len("_get_ssm_secure_string_") :] decrypted_value = ssm.get_parameter(ssm_param_key) copy.update({key: decrypted_value}) - elif type(value) is str and value.startswith( - '_alfred_decapsulation_'): - decapsulated_value = value[(len('_alfred_decapsulation_')+1):] - self.logger.info("Removing decapsulation header." - " Printing decapsulated value below:") + elif type(value) is str and value.startswith("_alfred_decapsulation_"): + decapsulated_value = value[(len("_alfred_decapsulation_") + 1) :] + self.logger.info( + "Removing decapsulation header." + " Printing decapsulated value below:" + ) copy.update({key: decapsulated_value}) return copy def create_stack_set(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) # Create a new stack set stack_set = StackSet(self.logger) self.logger.info("Creating StackSet") - parameters = self._get_ssm_secure_string( - self.params.get('Parameters')) + parameters = self._get_ssm_secure_string(self.params.get("Parameters")) response = stack_set.create_stack_set( - self.params.get('StackSetName'), - self.params.get('TemplateURL'), + self.params.get("StackSetName"), + self.params.get("TemplateURL"), parameters, - self.params.get('Capabilities'), - 'AWS_Solutions', - 'CustomControlTowerStackSet') - if response.get('StackSetId') is not None: + self.params.get("Capabilities"), + "AWS_Solutions", + "CustomControlTowerStackSet", + ) + if response.get("StackSetId") is not None: value = "success" else: value = "failure" - self.event.update({'StackSetStatus': value}) + self.event.update({"StackSetStatus": value}) # set create stack instance flag to yes (Handle SM Condition: # Create or Delete Stack Instance?) # check if the account list is empty - create_flag = 'no' if not self.params.get('AccountList') else 'yes' - self.event.update({'CreateInstance': create_flag}) + create_flag = "no" if not self.params.get("AccountList") else "yes" + self.event.update({"CreateInstance": create_flag}) # set delete stack instance flag to no (Handle SM Condition: # Delete Stack Instance or Finish?) - self.event.update({'DeleteInstance': 'no'}) + self.event.update({"DeleteInstance": "no"}) return self.event def create_stack_instances(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) # Create stack instances stack_set = StackSet(self.logger) # set to default values (new instance creation) - account_list = self.params.get('AccountList') - region_list = self.params.get('RegionList') + account_list = self.params.get("AccountList") + region_list = self.params.get("RegionList") # if AddAccountList is not empty - if self.event.get('AddAccountList') is not None and len(self.event.get('AddAccountList')) != 0: - account_list = self.event.get('AddAccountList') + if ( + self.event.get("AddAccountList") is not None + and len(self.event.get("AddAccountList")) != 0 + ): + account_list = self.event.get("AddAccountList") # if AddRegionList is not empty - if self.event.get('AddRegionList') is not None and len(self.event.get('AddRegionList')) != 0: - region_list = self.event.get('AddRegionList') + if ( + self.event.get("AddRegionList") is not None + and len(self.event.get("AddRegionList")) != 0 + ): + region_list = self.event.get("AddRegionList") # both AddAccountList and AddRegionList are not empty - if self.event.get('LoopFlag') == 'yes': + if self.event.get("LoopFlag") == "yes": # create new stack instance in new account only with # all regions. new stack instances in new region # for existing accounts will be deployed in the second round - if self.event.get('ActiveAccountList') is not None: - if self.event.get('ActiveAccountList') \ - == self.event.get('AddAccountList'): - account_list = \ - self._add_list(self.params.get('AccountList'), - self.event.get('ActiveAccountList')) + if self.event.get("ActiveAccountList") is not None: + if self.event.get("ActiveAccountList") == self.event.get( + "AddAccountList" + ): + account_list = self._add_list( + self.params.get("AccountList"), + self.event.get("ActiveAccountList"), + ) else: - account_list = self.event.get('AddAccountList') - region_list = self.params.get('RegionList') + account_list = self.event.get("AddAccountList") + region_list = self.params.get("RegionList") - self.event.update({'ActiveAccountList': account_list}) - self.event.update({'ActiveRegionList': region_list}) + self.event.update({"ActiveAccountList": account_list}) + self.event.update({"ActiveRegionList": region_list}) - self.logger.info("LoopFlag: {}".format(self.event.get('LoopFlag'))) - self.logger.info("Create stack instances for accounts: {}" - .format(account_list)) - self.logger.info("Create stack instances in regions: {}" - .format(region_list)) + self.logger.info("LoopFlag: {}".format(self.event.get("LoopFlag"))) + self.logger.info("Create stack instances for accounts: {}".format(account_list)) + self.logger.info("Create stack instances in regions: {}".format(region_list)) - self.logger.info("Creating StackSet Instance: {}".format( - self.params.get('StackSetName'))) - if 'ParameterOverrides' in self.params: - self.logger.info( - "Found 'ParameterOverrides' key in the event.") + self.logger.info( + "Creating StackSet Instance: {}".format(self.params.get("StackSetName")) + ) + if "ParameterOverrides" in self.params: + self.logger.info("Found 'ParameterOverrides' key in the event.") parameters = self._get_ssm_secure_string( - self.params.get('ParameterOverrides')) - response = stack_set. \ - create_stack_instances_with_override_params( - self.params.get('StackSetName'), - account_list, - region_list, - parameters) + self.params.get("ParameterOverrides") + ) + response = stack_set.create_stack_instances_with_override_params( + self.params.get("StackSetName"), account_list, region_list, parameters + ) else: response = stack_set.create_stack_instances( - self.params.get('StackSetName'), - account_list, - region_list) + self.params.get("StackSetName"), account_list, region_list + ) self.logger.info(response) - self.logger.info("Operation ID: {}" - .format(response.get('OperationId'))) - self.event.update({'OperationId': response.get('OperationId')}) + self.logger.info("Operation ID: {}".format(response.get("OperationId"))) + self.event.update({"OperationId": response.get("OperationId")}) return self.event def update_stack_set(self): # Updates the stack set and all associated stack instances. - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) stack_set = StackSet(self.logger) # Update existing StackSet - self.logger.info("Updating Stack Set: {}".format( - self.params.get('StackSetName'))) + self.logger.info( + "Updating Stack Set: {}".format(self.params.get("StackSetName")) + ) - parameters = self._get_ssm_secure_string( - self.params.get('Parameters')) + parameters = self._get_ssm_secure_string(self.params.get("Parameters")) response = stack_set.update_stack_set( - self.params.get('StackSetName'), + self.params.get("StackSetName"), parameters, - self.params.get('TemplateURL'), - self.params.get('Capabilities')) + self.params.get("TemplateURL"), + self.params.get("Capabilities"), + ) self.logger.info("Response Update Stack Set") self.logger.info(response) - self.logger.info("Operation ID: {}" - .format(response.get('OperationId'))) - self.event.update({'OperationId': response.get('OperationId')}) + self.logger.info("Operation ID: {}".format(response.get("OperationId"))) + self.event.update({"OperationId": response.get("OperationId")}) return self.event def update_stack_instances(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) stack_set = StackSet(self.logger) # this should come from the event - override_parameters = self.params.get('ParameterOverrides') - self.logger.info("override_params_list={}" - .format(override_parameters)) + override_parameters = self.params.get("ParameterOverrides") + self.logger.info("override_params_list={}".format(override_parameters)) response = stack_set.update_stack_instances( - self.params.get('StackSetName'), - self.params.get('AccountList'), - self.params.get('RegionList'), - override_parameters) + self.params.get("StackSetName"), + self.params.get("AccountList"), + self.params.get("RegionList"), + override_parameters, + ) self.logger.info("Update Stack Instance Response") self.logger.info(response) - self.logger.info("Operation ID: {}" - .format(response.get('OperationId'))) - self.event.update({'OperationId': response.get('OperationId')}) + self.logger.info("Operation ID: {}".format(response.get("OperationId"))) + self.event.update({"OperationId": response.get("OperationId")}) # need for Delete Stack Instance or Finish? choice in the # state machine. No will route to Finish path. - self.event.update({'DeleteInstance': 'no'}) + self.event.update({"DeleteInstance": "no"}) return self.event def delete_stack_set(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) # Delete StackSet stack_set = StackSet(self.logger) - self.logger.info("Deleting StackSet: {}" - .format(self.params.get('StackSetName'))) - self.logger.info(stack_set.delete_stack_set( - self.params.get('StackSetName'))) + self.logger.info( + "Deleting StackSet: {}".format(self.params.get("StackSetName")) + ) + self.logger.info(stack_set.delete_stack_set(self.params.get("StackSetName"))) return self.event def delete_stack_instances(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) # set to default values (new instance creation) - account_list = self.params.get('AccountList') + account_list = self.params.get("AccountList") # full region list - region_list = self.event.get('ExistingRegionList') + region_list = self.event.get("ExistingRegionList") # if DeleteAccountList is not empty - if self.event.get('DeleteAccountList') is not None and len(self.event.get('DeleteAccountList')) != 0: - account_list = self.event.get('DeleteAccountList') + if ( + self.event.get("DeleteAccountList") is not None + and len(self.event.get("DeleteAccountList")) != 0 + ): + account_list = self.event.get("DeleteAccountList") # full region list - region_list = self.event.get('ExistingRegionList') + region_list = self.event.get("ExistingRegionList") # if DeleteRegionList is not empty - if self.event.get('DeleteRegionList') is not None and len(self.event.get('DeleteRegionList')) != 0: - region_list = self.event.get('DeleteRegionList') + if ( + self.event.get("DeleteRegionList") is not None + and len(self.event.get("DeleteRegionList")) != 0 + ): + region_list = self.event.get("DeleteRegionList") # both DeleteAccountList and DeleteRegionList is not empty - if self.event.get('LoopFlag') == 'yes': + if self.event.get("LoopFlag") == "yes": # delete stack instance in deleted account with all regions # stack instances in all regions for existing accounts # will be deletion in the second round - account_list = self.event.get('DeleteAccountList') + account_list = self.event.get("DeleteAccountList") # full region list - region_list = self.event.get('ExistingRegionList') + region_list = self.event.get("ExistingRegionList") - self.event.update({'ActiveAccountList': account_list}) - self.event.update({'ActiveRegionList': region_list}) + self.event.update({"ActiveAccountList": account_list}) + self.event.update({"ActiveRegionList": region_list}) # Delete stack_set_instance(s) stack_set = StackSet(self.logger) - self.logger.info("Deleting Stack Instance: {}" - .format(self.params.get('StackSetName'))) + self.logger.info( + "Deleting Stack Instance: {}".format(self.params.get("StackSetName")) + ) response = stack_set.delete_stack_instances( - self.params.get('StackSetName'), - account_list, - region_list) + self.params.get("StackSetName"), account_list, region_list + ) self.logger.info(response) - self.logger.info("Operation ID: {}" - .format(response.get('OperationId'))) - self.event.update({'OperationId': response.get('OperationId')}) + self.logger.info("Operation ID: {}".format(response.get("OperationId"))) + self.event.update({"OperationId": response.get("OperationId")}) return self.event @@ -724,7 +788,7 @@ class ServiceControlPolicy(object): def __init__(self, event, logger): self.event = event - self.params = event.get('ResourceProperties') + self.params = event.get("ResourceProperties") self.logger = logger self.logger.info(self.__class__.__name__ + " Class Event") self.logger.info(event) @@ -733,163 +797,176 @@ def _load_policy(self, http_policy_path): bucket_name, key_name, region = parse_bucket_key_names(http_policy_path) policy_file = tempfile.mkstemp()[1] s3_endpoint_url = "https://s3.%s.amazonaws.com" % region - s3 = S3(self.logger, - region=region, - endpoint_url=s3_endpoint_url) + s3 = S3(self.logger, region=region, endpoint_url=s3_endpoint_url) s3.download_file(bucket_name, key_name, policy_file) self.logger.info("Parsing the policy file: {}".format(policy_file)) - with open(policy_file, 'r') as content_file: + with open(policy_file, "r") as content_file: policy_file_content = content_file.read() # Check if valid json json.loads(policy_file_content) # Return the Escaped JSON text - return policy_file_content\ - .replace('"', '\"')\ - .replace('\n', '\r\n')\ + return ( + policy_file_content.replace('"', '"') + .replace("\n", "\r\n") .replace(" ", "") + ) def list_policies(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) # Check if PolicyName attribute exists in event, # if so, it is called for attach or detach policy - if 'PolicyName' in self.event: - policy_name = self.event.get('PolicyName') + if "PolicyName" in self.event: + policy_name = self.event.get("PolicyName") else: - policy_name = self.params.get('PolicyDocument').get('Name') + policy_name = self.params.get("PolicyDocument").get("Name") # Check if SCP already exist scp = SCP(self.logger) pages = scp.list_policies() for page in pages: - policies_list = page.get('Policies') + policies_list = page.get("Policies") # iterate through the policies list for policy in policies_list: - if policy.get('Name') == policy_name: + if policy.get("Name") == policy_name: self.logger.info("Policy Found") - self.event.update({'PolicyId': policy.get('Id')}) - self.event.update({'PolicyArn': policy.get('Arn')}) - self.event.update({'PolicyExist': "yes"}) + self.event.update({"PolicyId": policy.get("Id")}) + self.event.update({"PolicyArn": policy.get("Arn")}) + self.event.update({"PolicyExist": "yes"}) return self.event else: continue - self.event.update({'PolicyExist': "no"}) + self.event.update({"PolicyExist": "no"}) return self.event def create_policy(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - policy_doc = self.params.get('PolicyDocument') + policy_doc = self.params.get("PolicyDocument") scp = SCP(self.logger) self.logger.info("Creating Service Control Policy") - policy_content = self._load_policy(policy_doc.get('PolicyURL')) + policy_content = self._load_policy(policy_doc.get("PolicyURL")) - response = scp.create_policy(policy_doc.get('Name'), - policy_doc.get('Description'), - policy_content) + response = scp.create_policy( + policy_doc.get("Name"), policy_doc.get("Description"), policy_content + ) self.logger.info("Create SCP Response") self.logger.info(response) - policy_id = response.get('Policy').get('PolicySummary').get('Id') - self.event.update({'PolicyId': policy_id}) + policy_id = response.get("Policy").get("PolicySummary").get("Id") + self.event.update({"PolicyId": policy_id}) return self.event def update_policy(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - policy_doc = self.params.get('PolicyDocument') - policy_id = self.event.get('PolicyId') - policy_content = self._load_policy(policy_doc.get('PolicyURL')) + policy_doc = self.params.get("PolicyDocument") + policy_id = self.event.get("PolicyId") + policy_content = self._load_policy(policy_doc.get("PolicyURL")) scp = SCP(self.logger) self.logger.info("Updating Service Control Policy") - response = scp.update_policy(policy_id, policy_doc.get('Name'), - policy_doc.get('Description'), - policy_content) + response = scp.update_policy( + policy_id, + policy_doc.get("Name"), + policy_doc.get("Description"), + policy_content, + ) self.logger.info("Update SCP Response") self.logger.info(response) - policy_id = response.get('Policy').get('PolicySummary').get('Id') - self.event.update({'PolicyId': policy_id}) + policy_id = response.get("Policy").get("PolicySummary").get("Id") + self.event.update({"PolicyId": policy_id}) return self.event def delete_policy(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - policy_id = self.event.get('PolicyId') + policy_id = self.event.get("PolicyId") scp = SCP(self.logger) self.logger.info("Deleting Service Control Policy") scp.delete_policy(policy_id) self.logger.info("Delete SCP") - status = 'Policy: {} deleted successfully'.format(policy_id) - self.event.update({'Status': status}) + status = "Policy: {} deleted successfully".format(policy_id) + self.event.update({"Status": status}) return self.event def attach_policy(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - if self.params.get('AccountId') == "": - target_id = self.event.get('OUId') + if self.params.get("AccountId") == "": + target_id = self.event.get("OUId") else: - target_id = self.params.get('AccountId') - policy_id = self.event.get('PolicyId') + target_id = self.params.get("AccountId") + policy_id = self.event.get("PolicyId") scp = SCP(self.logger) scp.attach_policy(policy_id, target_id) self.logger.info("Attach Policy") - status = "Policy: {} attached successfully to Target: {}"\ - .format(policy_id, target_id) - self.event.update({'Status': status}) + status = "Policy: {} attached successfully to Target: {}".format( + policy_id, target_id + ) + self.event.update({"Status": status}) return self.event def detach_policy(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - if self.params.get('AccountId') == "": - target_id = self.event.get('OUId') + if self.params.get("AccountId") == "": + target_id = self.event.get("OUId") else: - target_id = self.params.get('AccountId') - policy_id = self.event.get('PolicyId') + target_id = self.params.get("AccountId") + policy_id = self.event.get("PolicyId") scp = SCP(self.logger) scp.detach_policy(policy_id, target_id) self.logger.info("Detach Policy Response") - status = 'Policy: {} detached successfully from Target: {}' \ - .format(policy_id, target_id) - self.event.update({'Status': status}) + status = "Policy: {} detached successfully from Target: {}".format( + policy_id, target_id + ) + self.event.update({"Status": status}) return self.event def list_policies_for_ou(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - ou_name = self.event.get('OUName') - policy_name = self.params.get('PolicyDocument').get('Name') - ou_id = self.event.get('OUId') + ou_name = self.event.get("OUName") + policy_name = self.params.get("PolicyDocument").get("Name") + ou_id = self.event.get("OUId") if ou_id is None or len(ou_id) == 0: raise ValueError("OU id is not found for {}".format(ou_name)) - self.event.update({'OUId': ou_id}) + self.event.update({"OUId": ou_id}) self.list_policies_for_target(ou_id, policy_name) return self.event def list_policies_for_account(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - self.list_policies_for_target(self.params.get('AccountId'), - self.event.get('PolicyName')) + self.list_policies_for_target( + self.params.get("AccountId"), self.event.get("PolicyName") + ) return self.event def list_policies_for_target(self, target_id, policy_name): @@ -898,43 +975,45 @@ def list_policies_for_target(self, target_id, policy_name): pages = scp.list_policies_for_target(target_id) for page in pages: - policies_list = page.get('Policies') + policies_list = page.get("Policies") # iterate through the policies list for policy in policies_list: - if policy.get('Name') == policy_name: + if policy.get("Name") == policy_name: self.logger.info("Policy Found") - self.event.update({'PolicyId': policy.get('Id')}) - self.event.update({'PolicyArn': policy.get('Arn')}) - self.event.update({'PolicyAttached': "yes"}) + self.event.update({"PolicyId": policy.get("Id")}) + self.event.update({"PolicyArn": policy.get("Arn")}) + self.event.update({"PolicyAttached": "yes"}) return self.event else: continue - self.event.update({'PolicyAttached': "no"}) + self.event.update({"PolicyAttached": "no"}) def detach_policy_from_all_accounts(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - policy_id = self.event.get('PolicyId') + policy_id = self.event.get("PolicyId") scp = SCP(self.logger) pages = scp.list_targets_for_policy(policy_id) accounts = [] for page in pages: - target_list = page.get('Targets') + target_list = page.get("Targets") # iterate through the policies list for target in target_list: - account_id = target.get('TargetId') + account_id = target.get("TargetId") scp.detach_policy(policy_id, account_id) accounts.append(account_id) - status = 'Policy: {} detached successfully from Accounts: {}'\ - .format(policy_id, accounts) - self.event.update({'Status': status}) + status = "Policy: {} detached successfully from Accounts: {}".format( + policy_id, accounts + ) + self.event.update({"Status": status}) return self.event def enable_policy_type(self): @@ -942,7 +1021,7 @@ def enable_policy_type(self): response = org.list_roots() self.logger.info("List roots Response") self.logger.info(response) - root_id = response['Roots'][0].get('Id') + root_id = response["Roots"][0].get("Id") scp = SCP(self.logger) scp.enable_policy_type(root_id) @@ -956,25 +1035,28 @@ class StackSetSMRequests(object): def __init__(self, event, logger): self.event = event - self.params = event.get('ResourceProperties') + self.params = event.get("ResourceProperties") self.logger = logger self.logger.info(self.__class__.__name__ + " Class Event") self.logger.info(event) self.ssm = SSM(self.logger) def export_cfn_output(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - regions = self.params.get('RegionList') - accounts = self.params.get('AccountList') - stack_set_name = self.params.get('StackSetName') + regions = self.params.get("RegionList") + accounts = self.params.get("AccountList") + stack_set_name = self.params.get("StackSetName") stack_set = StackSet(self.logger) if len(accounts) == 0 or len(regions) == 0: - self.logger.info("Either AccountList or RegionList empty; so " - "skipping the export_cfn_output ") + self.logger.info( + "Either AccountList or RegionList empty; so " + "skipping the export_cfn_output " + ) return self.event self.logger.info("Picking the first account from AccountList") @@ -985,24 +1067,23 @@ def export_cfn_output(self): # First retrieve the Stack ID from the target account, # region deployed via the StackSet - response = stack_set.describe_stack_instance( - stack_set_name, account, region) + response = stack_set.describe_stack_instance(stack_set_name, account, region) stack_id, stack_name = self._retrieve_stack_info( - response, stack_set_name, account, region) + response, stack_set_name, account, region + ) # instantiate STS class _assume_role = AssumeRole() - cfn = Stacks(self.logger, - region, - credentials=_assume_role(self.logger, account)) + cfn = Stacks( + self.logger, region, credentials=_assume_role(self.logger, account) + ) response = cfn.describe_stacks(stack_id) - stacks = response.get('Stacks') + stacks = response.get("Stacks") if stacks is not None and type(stacks) is list: for stack in stacks: - self._update_event_with_stack_output( - stack, stack_id, account, region) + self._update_event_with_stack_output(stack, stack_id, account, region) return self.event def _retrieve_stack_info(self, response, stack_set_name, account, region): @@ -1017,20 +1098,20 @@ def _retrieve_stack_info(self, response, stack_set_name, account, region): Return: stack id and stack name """ - stack_id = response.get('StackInstance').get('StackId') + stack_id = response.get("StackInstance").get("StackId") self.logger.info("stack_id={}".format(stack_id)) if stack_id: - stack_name = stack_id.split('/')[1] + stack_name = stack_id.split("/")[1] else: - raise ValueError("Describe Stack Instance failed to retrieve" - " the StackId for StackSet: {} in account: " - "{} and region: {}" - .format(stack_set_name, account, region)) + raise ValueError( + "Describe Stack Instance failed to retrieve" + " the StackId for StackSet: {} in account: " + "{} and region: {}".format(stack_set_name, account, region) + ) self.logger.info("stack_name={}".format(stack_name)) return stack_id, stack_name - def _update_event_with_stack_output( - self, stack, stack_id, account, region): + def _update_event_with_stack_output(self, stack, stack_id, account, region): """update key and value in event with stack ouput Args: stack: json output of stack @@ -1041,19 +1122,19 @@ def _update_event_with_stack_output( Return: None """ - if stack.get('StackId') == stack_id: - self.logger.info("Found Stack: {}" - .format(stack.get('StackName'))) - self.logger.info("Exporting Output of Stack: {} from " - "Account: {} and region: {}" - .format(stack.get('StackName'), - str(account), region)) - outputs = stack.get('Outputs') + if stack.get("StackId") == stack_id: + self.logger.info("Found Stack: {}".format(stack.get("StackName"))) + self.logger.info( + "Exporting Output of Stack: {} from " + "Account: {} and region: {}".format( + stack.get("StackName"), str(account), region + ) + ) + outputs = stack.get("Outputs") if outputs is not None and type(outputs) is list: for output in outputs: - key = 'output_' + \ - output.get('OutputKey').lower() - value = output.get('OutputValue') + key = "output_" + output.get("OutputKey").lower() + value = output.get("OutputValue") self.event.update({key: value}) def nested_dictionary_iteration(self, dictionary): @@ -1065,11 +1146,12 @@ def nested_dictionary_iteration(self, dictionary): yield key, value def ssm_put_parameters(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) self.logger.info(self.params) - ssm_params = self.params.get('SSMParameters') - ssm_value = 'NotFound' + ssm_params = self.params.get("SSMParameters") + ssm_value = "NotFound" if ssm_params is not None and type(ssm_params) is dict: # iterate through the keys to save them in SSM Parameter Store for key, value in ssm_params.items(): @@ -1089,7 +1171,7 @@ def _save_ssm_parameters(self, key, value, ssm_value): Return: None """ - if value.startswith('$[') and value.endswith(']'): + if value.startswith("$[") and value.endswith("]"): value = value[2:-1] # Iterate through all the keys in the event # (includes the nested keys) @@ -1098,28 +1180,33 @@ def _save_ssm_parameters(self, key, value, ssm_value): ssm_value = v break else: - ssm_value = 'NotFound' - if ssm_value == 'NotFound': + ssm_value = "NotFound" + if ssm_value == "NotFound": # Print error if the key is not found in the State Machine output. # Handle scenario if only StackSet is created not stack instances. - self.logger.error("Unable to find the key: {} in the" - " State Machine Output".format(value)) + self.logger.error( + "Unable to find the key: {} in the" + " State Machine Output".format(value) + ) else: - self.logger.info("Adding value for SSM Parameter Store" - " Key: {}".format(key)) + self.logger.info( + "Adding value for SSM Parameter Store" " Key: {}".format(key) + ) self.ssm.put_parameter(key, ssm_value) def send_execution_data(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) send = SolutionMetrics(self.logger) data = {"StateMachineExecutionCount": "1"} send.solution_metrics(data) return self.event def random_wait(self): - self.logger.info("Executing: " + self.__class__.__name__ + "/" - + inspect.stack()[0][3]) + self.logger.info( + "Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3] + ) # Random wait between 1 to 14 minutes _seconds = randint(60, 840) time.sleep(_seconds) diff --git a/source/src/cfct/types.py b/source/src/cfct/types.py index 8fdabb8..0baa386 100644 --- a/source/src/cfct/types.py +++ b/source/src/cfct/types.py @@ -1,5 +1,4 @@ - -from typing import List, Dict, Any, TypedDict, Literal +from typing import Any, Dict, List, Literal, TypedDict class ResourcePropertiesTypeDef(TypedDict): @@ -24,4 +23,4 @@ class StackSetRequestTypeDef(TypedDict): class StackSetInstanceTypeDef(TypedDict): account: str - region: str \ No newline at end of file + region: str diff --git a/source/src/cfct/utils/crhelper.py b/source/src/cfct/utils/crhelper.py index c6adc8d..e372674 100644 --- a/source/src/cfct/utils/crhelper.py +++ b/source/src/cfct/utils/crhelper.py @@ -13,53 +13,63 @@ # and limitations under the License. # ############################################################################### +import json import threading + import requests -import json -def send(event, context, response_status, response_data, physical_resource_id, - logger, reason=None): - """This function sends status and response data to cloudformation. - """ - response_url = event['ResponseURL'] +def send( + event, + context, + response_status, + response_data, + physical_resource_id, + logger, + reason=None, +): + """This function sends status and response data to cloudformation.""" + response_url = event["ResponseURL"] logger.debug("CFN response URL: " + response_url) response_body = {} - response_body['Status'] = response_status - msg = 'See details in CloudWatch Log Stream: ' + context.log_stream_name + response_body["Status"] = response_status + msg = "See details in CloudWatch Log Stream: " + context.log_stream_name if not reason: - response_body['Reason'] = msg + response_body["Reason"] = msg else: - response_body['Reason'] = str(reason)[0:255] + '... ' + msg - response_body['PhysicalResourceId'] = \ + response_body["Reason"] = str(reason)[0:255] + "... " + msg + response_body["PhysicalResourceId"] = ( physical_resource_id or context.log_stream_name - response_body['StackId'] = event['StackId'] - response_body['RequestId'] = event['RequestId'] - response_body['LogicalResourceId'] = event['LogicalResourceId'] - if response_data and response_data != {} and response_data != [] \ - and isinstance(response_data, dict): - response_body['Data'] = response_data + ) + response_body["StackId"] = event["StackId"] + response_body["RequestId"] = event["RequestId"] + response_body["LogicalResourceId"] = event["LogicalResourceId"] + if ( + response_data + and response_data != {} + and response_data != [] + and isinstance(response_data, dict) + ): + response_body["Data"] = response_data logger.debug("<<<<<<< Response body >>>>>>>>>>") logger.debug(response_body) json_response_body = json.dumps(response_body) - headers = { - 'content-type': '', - 'content-length': str(len(json_response_body)) - } + headers = {"content-type": "", "content-length": str(len(json_response_body))} try: - if response_url == 'http://pre-signed-S3-url-for-response': - logger.info("CloudFormation returned status code:" - " THIS IS A TEST OUTSIDE OF CLOUDFORMATION") + if response_url == "http://pre-signed-S3-url-for-response": + logger.info( + "CloudFormation returned status code:" + " THIS IS A TEST OUTSIDE OF CLOUDFORMATION" + ) else: - response = requests.put(response_url, - data=json_response_body, - headers=headers) - logger.info("CloudFormation returned status code: " - + response.reason) + response = requests.put( + response_url, data=json_response_body, headers=headers + ) + logger.info("CloudFormation returned status code: " + response.reason) except Exception as e: logger.error("send(..) failed executing requests.put(..): " + str(e)) raise @@ -67,21 +77,30 @@ def send(event, context, response_status, response_data, physical_resource_id, def timeout(event, context, logger): """This function is executed just before lambda excecution time out - to send out time out failure message. + to send out time out failure message. """ logger.error("Execution is about to time out, sending failure message") - send(event, context, "FAILED", None, None, reason="Execution timed out", - logger=logger) + send( + event, + context, + "FAILED", + None, + None, + reason="Execution timed out", + logger=logger, + ) def cfn_handler(event, context, create, update, delete, logger, init_failed): """This handler function calls stack creation, update or deletion - based on request type and also sends status and response data - from any of the stack operations back to cloudformation, - as applicable. + based on request type and also sends status and response data + from any of the stack operations back to cloudformation, + as applicable. """ - logger.info("Lambda RequestId: %s CloudFormation RequestId: %s" % - (context.aws_request_id, event['RequestId'])) + logger.info( + "Lambda RequestId: %s CloudFormation RequestId: %s" + % (context.aws_request_id, event["RequestId"]) + ) # Define an object to place any response information you would like to send # back to CloudFormation (these keys can then be used by Fn::GetAttr) @@ -95,35 +114,58 @@ def cfn_handler(event, context, create, update, delete, logger, init_failed): logger.debug("EVENT: " + str(event)) # handle init failures if init_failed: - send(event, context, "FAILED", response_data, physical_resource_id, - logger, reason="Initialization Failed") + send( + event, + context, + "FAILED", + response_data, + physical_resource_id, + logger, + reason="Initialization Failed", + ) raise Exception("Initialization Failed") # Setup timer to catch timeouts - t = threading.Timer((context.get_remaining_time_in_millis()/1000.00)-0.5, - timeout, args=[event, context, logger]) + t = threading.Timer( + (context.get_remaining_time_in_millis() / 1000.00) - 0.5, + timeout, + args=[event, context, logger], + ) t.start() try: # Execute custom resource handlers - logger.info("Received a %s Request" % event['RequestType']) - if event['RequestType'] == 'Create': + logger.info("Received a %s Request" % event["RequestType"]) + if event["RequestType"] == "Create": physical_resource_id, response_data = create(event, context) - elif event['RequestType'] == 'Update': + elif event["RequestType"] == "Update": physical_resource_id, response_data = update(event, context) - elif event['RequestType'] == 'Delete': + elif event["RequestType"] == "Delete": delete(event, context) # Send response back to CloudFormation logger.info("Completed successfully, sending response to cfn") - send(event, context, "SUCCESS", response_data, physical_resource_id, - logger=logger) + send( + event, + context, + "SUCCESS", + response_data, + physical_resource_id, + logger=logger, + ) # Catch any exceptions, log the stacktrace, send a failure back to # CloudFormation and then raise an exception except Exception as e: logger.error(e, exc_info=True) - send(event, context, "FAILED", response_data, physical_resource_id, - reason=e, logger=logger) + send( + event, + context, + "FAILED", + response_data, + physical_resource_id, + reason=e, + logger=logger, + ) finally: t.cancel() diff --git a/source/src/cfct/utils/datetime_encoder.py b/source/src/cfct/utils/datetime_encoder.py index 362d4a4..c775628 100644 --- a/source/src/cfct/utils/datetime_encoder.py +++ b/source/src/cfct/utils/datetime_encoder.py @@ -15,7 +15,7 @@ # !/bin/python import json -from datetime import datetime, date +from datetime import date, datetime class DateTimeEncoder(json.JSONEncoder): diff --git a/source/src/cfct/utils/list_comparision.py b/source/src/cfct/utils/list_comparision.py index 290bf59..e8b30c7 100644 --- a/source/src/cfct/utils/list_comparision.py +++ b/source/src/cfct/utils/list_comparision.py @@ -14,7 +14,9 @@ ############################################################################### from cfct.utils.logger import Logger -logger = Logger('info') + +logger = Logger("info") + def compare_lists(existing_list: list, new_list: list) -> bool: """Compares two list and return boolean flag if they match @@ -38,5 +40,3 @@ def compare_lists(existing_list: list, new_list: list) -> bool: else: logger.info("Lists didn't match.") return False - - diff --git a/source/src/cfct/utils/logger.py b/source/src/cfct/utils/logger.py index 4093f3c..3f38fbe 100644 --- a/source/src/cfct/utils/logger.py +++ b/source/src/cfct/utils/logger.py @@ -15,23 +15,25 @@ import json import logging + from cfct.utils.datetime_encoder import DateTimeEncoder class Logger(object): - - def __init__(self, loglevel='warning'): + def __init__(self, loglevel="warning"): """Initializes logging""" self.config(loglevel=loglevel) - def config(self, loglevel='warning'): + def config(self, loglevel="warning"): loglevel = logging.getLevelName(loglevel.upper()) main_logger = logging.getLogger() main_logger.setLevel(loglevel) - logfmt = '{"time_stamp": "%(asctime)s",' \ - '"log_level": "%(levelname)s",' \ - '"log_message": %(message)s}\n' + logfmt = ( + '{"time_stamp": "%(asctime)s",' + '"log_level": "%(levelname)s",' + '"log_message": %(message)s}\n' + ) if len(main_logger.handlers) == 0: main_logger.addHandler(logging.StreamHandler()) main_logger.handlers[0].setFormatter(logging.Formatter(logfmt)) @@ -85,7 +87,5 @@ def log_unhandled_exception(self, message): def log_general_exception(self, file, method, exception): """log general exception""" - message = {'FILE': file, - 'METHOD': method, - 'EXCEPTION': str(exception)} + message = {"FILE": file, "METHOD": method, "EXCEPTION": str(exception)} self.log.exception(self._format(message)) diff --git a/source/src/cfct/utils/os_util.py b/source/src/cfct/utils/os_util.py index 1486d89..afa4762 100644 --- a/source/src/cfct/utils/os_util.py +++ b/source/src/cfct/utils/os_util.py @@ -30,18 +30,14 @@ def make_dir(directory, logger=None): try: os.stat(directory) if logger is None: - print("\n Directory {} already exist... skipping" - .format(directory)) + print("\n Directory {} already exist... skipping".format(directory)) else: - logger.info("Directory {} already exist... skipping" - .format(directory)) + logger.info("Directory {} already exist... skipping".format(directory)) except OSError: if logger is None: - print("\n Directory {} not found, creating now..." - .format(directory)) + print("\n Directory {} not found, creating now...".format(directory)) else: - logger.info("Directory {} not found, creating now..." - .format(directory)) + logger.info("Directory {} not found, creating now...".format(directory)) os.makedirs(directory) @@ -58,16 +54,18 @@ def remove_dir(directory, logger=None): try: os.stat(directory) if logger is None: - print("\n Directory {} already exist, deleting open-source" - " directory".format(directory)) + print( + "\n Directory {} already exist, deleting open-source" + " directory".format(directory) + ) else: - logger.info("\n Directory {} already exist, deleting open-source" - " directory".format(directory)) + logger.info( + "\n Directory {} already exist, deleting open-source" + " directory".format(directory) + ) shutil.rmtree(directory) except OSError: if logger is None: - print("\n Directory {} not found... nothing to delete" - .format(directory)) + print("\n Directory {} not found... nothing to delete".format(directory)) else: - logger.info("Directory {} not found... nothing to delete" - .format(directory)) + logger.info("Directory {} not found... nothing to delete".format(directory)) diff --git a/source/src/cfct/utils/password_generator.py b/source/src/cfct/utils/password_generator.py index 66222c4..3f6b3d7 100644 --- a/source/src/cfct/utils/password_generator.py +++ b/source/src/cfct/utils/password_generator.py @@ -13,11 +13,11 @@ # and limitations under the License. # ############################################################################### -import string import random +import string -def random_pwd_generator(length, additional_str=''): +def random_pwd_generator(length, additional_str=""): """Generate random password. Args: @@ -27,13 +27,13 @@ def random_pwd_generator(length, additional_str=''): Returns: password """ - chars = string.ascii_uppercase + string.ascii_lowercase + string.digits \ - + additional_str + chars = ( + string.ascii_uppercase + string.ascii_lowercase + string.digits + additional_str + ) # Making sure the password has two numbers and symbols at the very least - password = ''.join(random.SystemRandom().choice(chars) - for _ in range(length-4)) + \ - ''.join(random.SystemRandom().choice(string.digits) - for _ in range(2)) + \ - ''.join(random.SystemRandom().choice(additional_str) - for _ in range(2)) + password = ( + "".join(random.SystemRandom().choice(chars) for _ in range(length - 4)) + + "".join(random.SystemRandom().choice(string.digits) for _ in range(2)) + + "".join(random.SystemRandom().choice(additional_str) for _ in range(2)) + ) return password diff --git a/source/src/cfct/utils/retry_decorator.py b/source/src/cfct/utils/retry_decorator.py index 854033d..3bfd084 100644 --- a/source/src/cfct/utils/retry_decorator.py +++ b/source/src/cfct/utils/retry_decorator.py @@ -14,12 +14,13 @@ ############################################################################### import time -from random import randint from functools import wraps +from random import randint + from cfct.utils.logger import Logger # initialise logger -logger = Logger(loglevel='info') +logger = Logger(loglevel="info") def try_except_retry(count=3, multiplier=2): @@ -27,7 +28,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): _count = count - _seconds = randint(5,10) + _seconds = randint(5, 10) while _count >= 1: try: return func(*args, **kwargs) @@ -40,5 +41,7 @@ def wrapper(*args, **kwargs): logger.error("Retry attempts failed, raising the exception.") raise return func(*args, **kwargs) + return wrapper + return decorator diff --git a/source/src/cfct/utils/string_manipulation.py b/source/src/cfct/utils/string_manipulation.py index b91dc8a..4f2a903 100644 --- a/source/src/cfct/utils/string_manipulation.py +++ b/source/src/cfct/utils/string_manipulation.py @@ -16,7 +16,7 @@ import re -def sanitize(name, space_allowed=False, replace_with_character='_'): +def sanitize(name, space_allowed=False, replace_with_character="_"): """Sanitizes input string. Replaces any character other than [a-zA-Z0-9._-] in a string @@ -35,18 +35,14 @@ def sanitize(name, space_allowed=False, replace_with_character='_'): Raises: """ if space_allowed: - sanitized_name = re.sub(r'([^\sa-zA-Z0-9._-])', - replace_with_character, - name) + sanitized_name = re.sub(r"([^\sa-zA-Z0-9._-])", replace_with_character, name) else: - sanitized_name = re.sub(r'([^a-zA-Z0-9._-])', - replace_with_character, - name) + sanitized_name = re.sub(r"([^a-zA-Z0-9._-])", replace_with_character, name) return sanitized_name def trim_length_from_end(string, length): - """ Trims the length of the given string to the given length + """Trims the length of the given string to the given length :param string: :param length: @@ -59,28 +55,29 @@ def trim_length_from_end(string, length): def trim_string_from_front(string, remove_starts_with_string): - """ Remove string provided in the search_string + """Remove string provided in the search_string and returns remainder of the string. :param string: :param remove_starts_with_string: :return: trimmed string """ if string.startswith(remove_starts_with_string): - return string[len(remove_starts_with_string):] + return string[len(remove_starts_with_string) :] else: - raise ValueError('The beginning of the string does ' - 'not match the string to be trimmed.') + raise ValueError( + "The beginning of the string does " "not match the string to be trimmed." + ) def extract_string(search_str): - return str[len(search_str):] + return str[len(search_str) :] def convert_list_values_to_string(_list): return list(map(str, _list)) -def convert_string_to_list(_string, separator=','): +def convert_string_to_list(_string, separator=","): """ splits the string with give separator and remove whitespaces """ @@ -92,7 +89,7 @@ def strip_list_items(array): def remove_empty_strings(array): - return [x for x in array if x != ''] + return [x for x in array if x != ""] def list_sanitizer(array): @@ -105,4 +102,4 @@ def empty_separator_handler(delimiter, nested_ou_name): nested_ou_name_list = [nested_ou_name] else: nested_ou_name_list = nested_ou_name.split(delimiter) - return nested_ou_name_list \ No newline at end of file + return nested_ou_name_list diff --git a/source/src/cfct/validation/custom_validation.py b/source/src/cfct/validation/custom_validation.py index a901c37..9342092 100644 --- a/source/src/cfct/validation/custom_validation.py +++ b/source/src/cfct/validation/custom_validation.py @@ -14,6 +14,7 @@ ############################################################################### import logging + log = logging.getLogger(__name__) # This is a custom valiator specifically for pyKwlify Schema extensions diff --git a/source/src/setup.py b/source/src/setup.py index 1a8b639..0c1671a 100644 --- a/source/src/setup.py +++ b/source/src/setup.py @@ -7,9 +7,13 @@ long_description = fh.read() +with open("../../VERSION", "r", encoding="utf-8") as version_file: + version = version_file.read() + + setuptools.setup( name="cfct", - version="2.3.0", + version=version, author="AWS", description="Customizations for Control Tower", long_description=long_description, @@ -42,7 +46,7 @@ "pytest == 6.2.4", "mypy == 0.930", "expecter==0.3.0", - "pykwalify == 1.8.0" + "pykwalify == 1.8.0", ], "dev": [ "ipython", diff --git a/source/tests/conftest.py b/source/tests/conftest.py index a10a910..e41a78e 100644 --- a/source/tests/conftest.py +++ b/source/tests/conftest.py @@ -1,19 +1,20 @@ -import pytest -import boto3 from os import environ -from moto import mock_s3, mock_organizations, mock_ssm, mock_cloudformation + +import boto3 +import pytest +from moto import mock_cloudformation, mock_organizations, mock_s3, mock_ssm -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def aws_credentials(): """Mocked AWS Credentials for moto""" - environ['AWS_ACCESS_KEY_ID'] = 'testing' - environ['AWS_SECRET_ACCESS_KEY'] = 'testing' - environ['AWS_SECURITY_TOKEN'] = 'testing' - environ['AWS_SESSION_TOKEN'] = 'testing' + environ["AWS_ACCESS_KEY_ID"] = "testing" + environ["AWS_SECRET_ACCESS_KEY"] = "testing" + environ["AWS_SECURITY_TOKEN"] = "testing" + environ["AWS_SESSION_TOKEN"] = "testing" -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def s3_client(aws_credentials): """S3 Mock Client""" with mock_s3(): @@ -21,7 +22,7 @@ def s3_client(aws_credentials): yield connection -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def s3_client_resource(aws_credentials): """S3 Mock Client""" with mock_s3(): @@ -29,7 +30,7 @@ def s3_client_resource(aws_credentials): yield connection -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def org_client(aws_credentials): """Organizations Mock Client""" with mock_organizations(): @@ -37,7 +38,7 @@ def org_client(aws_credentials): yield connection -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ssm_client(aws_credentials): """SSM Mock Client""" with mock_ssm(): diff --git a/source/tests/plugins/env_vars.py b/source/tests/plugins/env_vars.py index 5bdb0fa..fa4bb00 100644 --- a/source/tests/plugins/env_vars.py +++ b/source/tests/plugins/env_vars.py @@ -1,15 +1,15 @@ import os + import pytest BUCKET_NAME = "test-bucket" + @pytest.hookimpl(tryfirst=True) def pytest_load_initial_conftests(): - os.environ['STAGING_BUCKET'] = BUCKET_NAME - os.environ['TEMPLATE_KEY_PREFIX'] = '_custom_ct_templates_staging' - os.environ['LOG_LEVEL'] = 'info' - os.environ['STAGE_NAME'] = 'stackset' - os.environ['AWS_REGION'] = 'us-east-1' - os.environ['CAPABILITIES'] = 'CAPABILITY_NAMED_IAM, CAPABILITY_AUTO_EXPAND' - - + os.environ["STAGING_BUCKET"] = BUCKET_NAME + os.environ["TEMPLATE_KEY_PREFIX"] = "_custom_ct_templates_staging" + os.environ["LOG_LEVEL"] = "info" + os.environ["STAGE_NAME"] = "stackset" + os.environ["AWS_REGION"] = "us-east-1" + os.environ["CAPABILITIES"] = "CAPABILITY_NAMED_IAM, CAPABILITY_AUTO_EXPAND" diff --git a/source/tests/test_cfn_params_handler.py b/source/tests/test_cfn_params_handler.py index 2f1e60c..112889b 100644 --- a/source/tests/test_cfn_params_handler.py +++ b/source/tests/test_cfn_params_handler.py @@ -12,49 +12,49 @@ # KIND, express or implied. See the License for the specific language # # governing permissions and limitations under the License. # ############################################################################## -from moto import mock_ssm -from cfct.utils.logger import Logger -from cfct.manifest.cfn_params_handler import CFNParamsHandler -from cfct.aws.services.ssm import SSM import pytest +from cfct.aws.services.ssm import SSM +from cfct.manifest.cfn_params_handler import CFNParamsHandler +from cfct.utils.logger import Logger +from moto import mock_ssm -log_level = 'info' +log_level = "info" logger = Logger(loglevel=log_level) + @pytest.mark.unit def test_update_alfred_ssm(): - keyword_ssm = 'alfred_ssm_not_exist_alfred_ssm' - value_ssm = 'parameter_store_value' + keyword_ssm = "alfred_ssm_not_exist_alfred_ssm" + value_ssm = "parameter_store_value" cph = CFNParamsHandler(logger) - value_ssm, param_flag = cph._update_alfred_ssm( - keyword_ssm, value_ssm, False) + value_ssm, param_flag = cph._update_alfred_ssm(keyword_ssm, value_ssm, False) assert param_flag is True + @pytest.mark.unit @mock_ssm def test_update_params(): logger.info("-- Put new parameter keys in mock environment") ssm = SSM(logger) - ssm.put_parameter('/key1', 'value1', 'Test parameter 1', 'String') - ssm.put_parameter('/key2', 'value2', 'Test parameter 2', 'String') - ssm.put_parameter('/key3', 'value3', 'Test parameter 3', 'String') + ssm.put_parameter("/key1", "value1", "Test parameter 1", "String") + ssm.put_parameter("/key2", "value2", "Test parameter 2", "String") + ssm.put_parameter("/key3", "value3", "Test parameter 3", "String") logger.info("-- Get parameter keys using alfred_ssm") - multiple_params = [{ - "ParameterKey": "Key1", - "ParameterValue": [ - "$[alfred_ssm_/key1]", - "$[alfred_ssm_/key2]", - "$[alfred_ssm_/key3]" - ] - }] + multiple_params = [ + { + "ParameterKey": "Key1", + "ParameterValue": [ + "$[alfred_ssm_/key1]", + "$[alfred_ssm_/key2]", + "$[alfred_ssm_/key3]", + ], + } + ] cph = CFNParamsHandler(logger) values = cph.update_params(multiple_params) assert values == {"Key1": ["value1", "value2", "value3"]} - single_param = [{ - "ParameterKey": "Key2", - "ParameterValue": "$[alfred_ssm_/key1]" - }] + single_param = [{"ParameterKey": "Key2", "ParameterValue": "$[alfred_ssm_/key1]"}] value = cph.update_params(single_param) assert value == {"Key2": "value1"} diff --git a/source/tests/test_datetime_encoder.py b/source/tests/test_datetime_encoder.py index e637b17..92259f2 100644 --- a/source/tests/test_datetime_encoder.py +++ b/source/tests/test_datetime_encoder.py @@ -12,27 +12,29 @@ # KIND, express or implied. See the License for the specific language # # governing permissions and limitations under the License. # ############################################################################## -from cfct.utils.datetime_encoder import DateTimeEncoder -from datetime import datetime, date import json +from datetime import date, datetime + import pytest +from cfct.utils.datetime_encoder import DateTimeEncoder + @pytest.mark.unit def test_datetime_encoder(): - datetime_str = '02/17/20 23:38:26' - datetime_object = datetime.strptime(datetime_str, '%m/%d/%y %H:%M:%S') + datetime_str = "02/17/20 23:38:26" + datetime_object = datetime.strptime(datetime_str, "%m/%d/%y %H:%M:%S") date_object = datetime_object.date() encoder = DateTimeEncoder() assert encoder.encode({"datetime": datetime_object}) == json.dumps( - {"datetime": "2020-02-17T23:38:26"}) - assert json.dumps( - {"datetime": datetime_object}, cls=DateTimeEncoder) == json.dumps( - {"datetime": "2020-02-17T23:38:26"}) - assert encoder.encode({"date": date_object}) == json.dumps( - {"date": "2020-02-17"}) - assert json.dumps( - {"date": date_object}, cls=DateTimeEncoder) == json.dumps( - {"date": "2020-02-17"}) + {"datetime": "2020-02-17T23:38:26"} + ) + assert json.dumps({"datetime": datetime_object}, cls=DateTimeEncoder) == json.dumps( + {"datetime": "2020-02-17T23:38:26"} + ) + assert encoder.encode({"date": date_object}) == json.dumps({"date": "2020-02-17"}) + assert json.dumps({"date": date_object}, cls=DateTimeEncoder) == json.dumps( + {"date": "2020-02-17"} + ) assert encoder.encode( - {"date": date_object, "datetime": datetime_object}) == json.dumps( - {"date": "2020-02-17", "datetime": "2020-02-17T23:38:26"}) + {"date": date_object, "datetime": datetime_object} + ) == json.dumps({"date": "2020-02-17", "datetime": "2020-02-17T23:38:26"}) diff --git a/source/tests/test_decimal_encoder.py b/source/tests/test_decimal_encoder.py index b7c4a4f..85e9bdc 100644 --- a/source/tests/test_decimal_encoder.py +++ b/source/tests/test_decimal_encoder.py @@ -12,25 +12,24 @@ # KIND, express or implied. See the License for the specific language # # governing permissions and limitations under the License. # ############################################################################## -from cfct.utils.decimal_encoder import DecimalEncoder -import json import decimal +import json + import pytest +from cfct.utils.decimal_encoder import DecimalEncoder + @pytest.mark.unit def test_decimal_encoder(): - assert json.dumps( - {'x': decimal.Decimal('5.5')}, cls=DecimalEncoder) == json.dumps( - {'x': 5.5}) - assert json.dumps( - {'x': decimal.Decimal('5.0')}, cls=DecimalEncoder) == json.dumps( - {'x': 5}) + assert json.dumps({"x": decimal.Decimal("5.5")}, cls=DecimalEncoder) == json.dumps( + {"x": 5.5} + ) + assert json.dumps({"x": decimal.Decimal("5.0")}, cls=DecimalEncoder) == json.dumps( + {"x": 5} + ) encoder = DecimalEncoder() - assert encoder.encode({'x': decimal.Decimal('5.65')}) == json.dumps( - {'x': 5.65}) - assert encoder.encode({'x': decimal.Decimal('5.0')}) == json.dumps( - {'x': 5}) + assert encoder.encode({"x": decimal.Decimal("5.65")}) == json.dumps({"x": 5.65}) + assert encoder.encode({"x": decimal.Decimal("5.0")}) == json.dumps({"x": 5}) assert encoder.encode( - {'x': decimal.Decimal('5.0'), - 'y': decimal.Decimal('5.5')}) == json.dumps({'x': 5, 'y': 5.5}) - + {"x": decimal.Decimal("5.0"), "y": decimal.Decimal("5.5")} + ) == json.dumps({"x": 5, "y": 5.5}) diff --git a/source/tests/test_get_partition.py b/source/tests/test_get_partition.py index 528d63a..1ad2ab9 100644 --- a/source/tests/test_get_partition.py +++ b/source/tests/test_get_partition.py @@ -13,34 +13,38 @@ # and limitations under the License. # ############################################################################### -from cfct.aws.utils.get_partition import get_partition -from cfct.utils.logger import Logger from os import environ + import pytest +from cfct.aws.utils.get_partition import get_partition +from cfct.utils.logger import Logger + +logger = Logger("info") -logger = Logger('info') +aws_regions_partition = "aws" +aws_china_regions_partition = "aws-cn" +aws_us_gov_cloud_regions_partition = "aws-us-gov" -aws_regions_partition = 'aws' -aws_china_regions_partition = 'aws-cn' -aws_us_gov_cloud_regions_partition = 'aws-us-gov' @pytest.mark.unit def test_get_partition_for_us_region(): - environ['AWS_REGION'] = 'us-east-1' + environ["AWS_REGION"] = "us-east-1" assert aws_regions_partition == get_partition() + @pytest.mark.unit def test_get_partition_for_eu_region(): - environ['AWS_REGION'] = 'eu-west-1' - assert aws_regions_partition == get_partition() + environ["AWS_REGION"] = "eu-west-1" + assert aws_regions_partition == get_partition() + @pytest.mark.unit def test_get_partition_for_cn_region(): - environ['AWS_REGION'] = 'cn-north-1' + environ["AWS_REGION"] = "cn-north-1" assert aws_china_regions_partition == get_partition() + @pytest.mark.unit def test_get_partition_for_us_gov_cloud_region(): - environ['AWS_REGION'] = 'us-gov-west-1' + environ["AWS_REGION"] = "us-gov-west-1" assert aws_us_gov_cloud_regions_partition == get_partition() - diff --git a/source/tests/test_list_comparision.py b/source/tests/test_list_comparision.py index 4afea28..6a0878a 100644 --- a/source/tests/test_list_comparision.py +++ b/source/tests/test_list_comparision.py @@ -13,35 +13,39 @@ # and limitations under the License. # ############################################################################### +import pytest from cfct.utils.list_comparision import compare_lists from cfct.utils.logger import Logger -import pytest -logger = Logger('info') +logger = Logger("info") + +list1 = ["aa", "bb", "cc"] # add value to list 2 +list2 = ["aa", "bb"] # remove value from list 1 +list3 = ["aa", "cc", "dd"] # remove and add values from list 1 +list4 = ["ee"] # single item list to test replace single account +list5 = ["ff"] -list1 = ['aa', 'bb', 'cc'] # add value to list 2 -list2 = ['aa', 'bb'] # remove value from list 1 -list3 = ['aa', 'cc', 'dd'] # remove and add values from list 1 -list4 = ['ee'] # single item list to test replace single account -list5 = ['ff'] @pytest.mark.unit def test_add_list(): assert compare_lists(list2, list1) is False + @pytest.mark.unit def test_delete_list(): assert compare_lists(list1, list2) is False + @pytest.mark.unit def test_add_delete_list(): assert compare_lists(list1, list3) is False + @pytest.mark.unit def test_single_item_replacement(): assert compare_lists(list4, list5) is False + @pytest.mark.unit def test_no_change_list(): assert compare_lists(list1, list1) is True - diff --git a/source/tests/test_list_type.py b/source/tests/test_list_type.py index cc5e701..4ceb329 100644 --- a/source/tests/test_list_type.py +++ b/source/tests/test_list_type.py @@ -16,49 +16,58 @@ import pytest from cfct.state_machine_handler import CloudFormation from cfct.utils.logger import Logger -logger = Logger('info') -string1 = 'xx' -string2 = 'yy' -list1 = ['aa', 'bb'] -list2 = ['bb', 'dd'] +logger = Logger("info") + +string1 = "xx" +string2 = "yy" +list1 = ["aa", "bb"] +list2 = ["bb", "dd"] event = {} cf = CloudFormation(event, logger) + @pytest.mark.unit def test_add_list_type(): assert isinstance(cf._add_list(list1, list2), list) + @pytest.mark.unit def test_delete_list_type(): assert isinstance(cf._delete_list(list1, list2), list) + @pytest.mark.unit def test_add_list_string_fail(): with pytest.raises(ValueError, match=r"Both variables must be list.*"): cf._add_list(list1, string1) + @pytest.mark.unit def test_add_string_list_fail(): with pytest.raises(ValueError, match=r"Both variables must be list.*"): cf._add_list(string1, list1) + @pytest.mark.unit def test_add_strings_fail(): with pytest.raises(ValueError, match=r"Both variables must be list.*"): cf._add_list(string1, string2) + @pytest.mark.unit def test_del_list_string_fail(): with pytest.raises(ValueError, match=r"Both variables must be list.*"): cf._delete_list(list1, string1) + @pytest.mark.unit def test_del_string_list_fail(): with pytest.raises(ValueError, match=r"Both variables must be list.*"): cf._delete_list(string1, list1) + @pytest.mark.unit def test_del_strings_fail(): with pytest.raises(ValueError, match=r"Both variables must be list.*"): - cf._delete_list(string1, string2) \ No newline at end of file + cf._delete_list(string1, string2) diff --git a/source/tests/test_manifest_parser.py b/source/tests/test_manifest_parser.py index a1dc89a..df65fb7 100644 --- a/source/tests/test_manifest_parser.py +++ b/source/tests/test_manifest_parser.py @@ -14,21 +14,22 @@ ############################################################################### import os -import pytest -import mock -from cfct.utils.logger import Logger + import cfct.manifest.manifest_parser as parse +import mock import pytest +from cfct.utils.logger import Logger + +TESTS_DIR = "./source/tests/" -TESTS_DIR = './source/tests/' +logger = Logger(loglevel="info") -logger = Logger(loglevel='info') +os.environ["ACCOUNT_LIST"] = "" -os.environ['ACCOUNT_LIST'] = '' @pytest.fixture def bucket_name(): - return os.getenv('STAGING_BUCKET') + return os.getenv("STAGING_BUCKET") @pytest.fixture @@ -39,26 +40,22 @@ def s3_setup(s3_client, bucket_name): @pytest.fixture def organizations_setup(org_client): - dev_map = { - "AccountName": "Developer1", - "AccountEmail": "dev@mock", - "OUName": "Dev" - } + dev_map = {"AccountName": "Developer1", "AccountEmail": "dev@mock", "OUName": "Dev"} dev_map_2 = { "AccountName": "Developer1-SuperSet", "AccountEmail": "dev-2@mock", - "OUName": "Dev" + "OUName": "Dev", } prod_map = { "AccountName": "Production1", "AccountEmail": "prod@mock", - "OUName": "Prod" + "OUName": "Prod", } test_map = { "AccountName": "Testing1", "AccountEmail": "test@mock", - "OUName": "Test" + "OUName": "Test", } # create organization org_client.create_organization(FeatureSet="ALL") @@ -66,155 +63,207 @@ def organizations_setup(org_client): # create accounts dev_account_id = org_client.create_account( - AccountName=dev_map['AccountName'], - Email=dev_map['AccountEmail'])["CreateAccountStatus"]["AccountId"] + AccountName=dev_map["AccountName"], Email=dev_map["AccountEmail"] + )["CreateAccountStatus"]["AccountId"] dev_account_id_2 = org_client.create_account( - AccountName=dev_map_2['AccountName'], - Email=dev_map_2['AccountEmail'])["CreateAccountStatus"]["AccountId"] + AccountName=dev_map_2["AccountName"], Email=dev_map_2["AccountEmail"] + )["CreateAccountStatus"]["AccountId"] test_account_id = org_client.create_account( - AccountName=test_map['AccountName'], - Email=test_map['AccountEmail'])["CreateAccountStatus"]["AccountId"] + AccountName=test_map["AccountName"], Email=test_map["AccountEmail"] + )["CreateAccountStatus"]["AccountId"] prod_account_id = org_client.create_account( - AccountName=prod_map['AccountName'], - Email=prod_map['AccountEmail'])["CreateAccountStatus"]["AccountId"] + AccountName=prod_map["AccountName"], Email=prod_map["AccountEmail"] + )["CreateAccountStatus"]["AccountId"] # create org units - dev_resp = org_client.create_organizational_unit(ParentId=root_id, - Name=dev_map['OUName']) + dev_resp = org_client.create_organizational_unit( + ParentId=root_id, Name=dev_map["OUName"] + ) dev_ou_id = dev_resp["OrganizationalUnit"]["Id"] - test_resp = org_client.create_organizational_unit(ParentId=root_id, - Name=test_map['OUName']) + test_resp = org_client.create_organizational_unit( + ParentId=root_id, Name=test_map["OUName"] + ) test_ou_id = test_resp["OrganizationalUnit"]["Id"] - prod_resp = org_client.create_organizational_unit(ParentId=root_id, - Name=prod_map['OUName']) + prod_resp = org_client.create_organizational_unit( + ParentId=root_id, Name=prod_map["OUName"] + ) prod_ou_id = prod_resp["OrganizationalUnit"]["Id"] # move accounts org_client.move_account( - AccountId=dev_account_id, SourceParentId=root_id, - DestinationParentId=dev_ou_id + AccountId=dev_account_id, SourceParentId=root_id, DestinationParentId=dev_ou_id ) org_client.move_account( - AccountId=dev_account_id_2, SourceParentId=root_id, - DestinationParentId=dev_ou_id + AccountId=dev_account_id_2, + SourceParentId=root_id, + DestinationParentId=dev_ou_id, ) org_client.move_account( - AccountId=test_account_id, SourceParentId=root_id, - DestinationParentId=test_ou_id + AccountId=test_account_id, + SourceParentId=root_id, + DestinationParentId=test_ou_id, ) org_client.move_account( - AccountId=prod_account_id, SourceParentId=root_id, - DestinationParentId=prod_ou_id + AccountId=prod_account_id, + SourceParentId=root_id, + DestinationParentId=prod_ou_id, ) # Get account list - os.environ['ACCOUNT_LIST'] = dev_account_id + ','+ dev_account_id_2 + ','+ test_account_id + ','+ prod_account_id + os.environ["ACCOUNT_LIST"] = ( + dev_account_id + + "," + + dev_account_id_2 + + "," + + test_account_id + + "," + + prod_account_id + ) yield + @pytest.mark.unit def test_version_1_manifest_scp_sm_input(s3_setup, organizations_setup, ssm_client): - manifest_name = 'manifest_version_1.yaml' + manifest_name = "manifest_version_1.yaml" file_path = TESTS_DIR + manifest_name - os.environ['MANIFEST_FILE_NAME'] = manifest_name - os.environ['MANIFEST_FILE_PATH'] = file_path - os.environ['MANIFEST_FOLDER'] = file_path[:-len(manifest_name)] - os.environ['STAGE_NAME'] = 'scp' + os.environ["MANIFEST_FILE_NAME"] = manifest_name + os.environ["MANIFEST_FILE_PATH"] = file_path + os.environ["MANIFEST_FOLDER"] = file_path[: -len(manifest_name)] + os.environ["STAGE_NAME"] = "scp" sm_input_list = parse.scp_manifest() - logger.info("[test_version_1_manifest_scp_sm_input] SCP sm_input_list for manifest_version_1.yaml:") + logger.info( + "[test_version_1_manifest_scp_sm_input] SCP sm_input_list for manifest_version_1.yaml:" + ) logger.info(sm_input_list) logger.info(sm_input_list[0]) - assert sm_input_list[0]['ResourceProperties']['PolicyDocument'][ - 'Name'] == "test-preventive-guardrails" - assert sm_input_list[1]['ResourceProperties']['PolicyDocument'][ - 'Name'] == "test-guardrails-2" + assert ( + sm_input_list[0]["ResourceProperties"]["PolicyDocument"]["Name"] + == "test-preventive-guardrails" + ) + assert ( + sm_input_list[1]["ResourceProperties"]["PolicyDocument"]["Name"] + == "test-guardrails-2" + ) + @pytest.mark.unit def test_version_2_manifest_scp_sm_input(s3_setup, organizations_setup, ssm_client): - manifest_name = 'manifest_version_2.yaml' + manifest_name = "manifest_version_2.yaml" file_path = TESTS_DIR + manifest_name - os.environ['MANIFEST_FILE_NAME'] = manifest_name - os.environ['MANIFEST_FILE_PATH'] = file_path - os.environ['MANIFEST_FOLDER'] = file_path[:-len(manifest_name)] - os.environ['STAGE_NAME'] = 'scp' + os.environ["MANIFEST_FILE_NAME"] = manifest_name + os.environ["MANIFEST_FILE_PATH"] = file_path + os.environ["MANIFEST_FOLDER"] = file_path[: -len(manifest_name)] + os.environ["STAGE_NAME"] = "scp" sm_input_list = parse.scp_manifest() - logger.info("[test_version_2_manifest_scp_sm_input] SCP sm_input_list for manifest_version_2.yaml:") + logger.info( + "[test_version_2_manifest_scp_sm_input] SCP sm_input_list for manifest_version_2.yaml:" + ) logger.info(sm_input_list) - assert sm_input_list[0]['ResourceProperties']['PolicyDocument'][ - 'Name'] == "test-preventive-guardrails" - assert sm_input_list[1]['ResourceProperties']['PolicyDocument'][ - 'Name'] == "test-guardrails-2" + assert ( + sm_input_list[0]["ResourceProperties"]["PolicyDocument"]["Name"] + == "test-preventive-guardrails" + ) + assert ( + sm_input_list[1]["ResourceProperties"]["PolicyDocument"]["Name"] + == "test-guardrails-2" + ) + @pytest.mark.unit -def test_version_1_manifest_stackset_sm_input(s3_setup, organizations_setup, - ssm_client): +def test_version_1_manifest_stackset_sm_input( + s3_setup, organizations_setup, ssm_client +): # mock API call and assign return value - with mock.patch("cfct.manifest.manifest_parser.OrganizationsData.get_accounts_in_ct_baseline_config_stack_set", mock.MagicMock(return_value=[list(os.environ['ACCOUNT_LIST'].split(',')),[]])): - manifest_name = 'manifest_version_1.yaml' + with mock.patch( + "cfct.manifest.manifest_parser.OrganizationsData.get_accounts_in_ct_baseline_config_stack_set", + mock.MagicMock(return_value=[list(os.environ["ACCOUNT_LIST"].split(",")), []]), + ): + manifest_name = "manifest_version_1.yaml" file_path = TESTS_DIR + manifest_name - os.environ['MANIFEST_FILE_NAME'] = manifest_name - os.environ['MANIFEST_FILE_PATH'] = file_path - os.environ['MANIFEST_FOLDER'] = file_path[:-len(manifest_name)] - os.environ['STAGE_NAME'] = 'stackset' + os.environ["MANIFEST_FILE_NAME"] = manifest_name + os.environ["MANIFEST_FILE_PATH"] = file_path + os.environ["MANIFEST_FOLDER"] = file_path[: -len(manifest_name)] + os.environ["STAGE_NAME"] = "stackset" sm_input_list = parse.stack_set_manifest() logger.info("Stack Set sm_input_list:") logger.info(sm_input_list) - assert sm_input_list[0]['ResourceProperties']['StackSetName'] == \ - "CustomControlTower-stackset-1" - assert sm_input_list[1]['ResourceProperties']['StackSetName'] == \ - "CustomControlTower-stackset-2" + assert ( + sm_input_list[0]["ResourceProperties"]["StackSetName"] + == "CustomControlTower-stackset-1" + ) + assert ( + sm_input_list[1]["ResourceProperties"]["StackSetName"] + == "CustomControlTower-stackset-2" + ) + @pytest.mark.unit -def test_version_2_manifest_stackset_sm_input(s3_setup, organizations_setup, - ssm_client, mocker): +def test_version_2_manifest_stackset_sm_input( + s3_setup, organizations_setup, ssm_client, mocker +): - logger.info('os.environ[ACCOUNT_LIST]: {}'.format(list(os.environ['ACCOUNT_LIST'].split(',')))) + logger.info( + "os.environ[ACCOUNT_LIST]: {}".format( + list(os.environ["ACCOUNT_LIST"].split(",")) + ) + ) # mock API call and assign return value - with mock.patch("cfct.manifest.manifest_parser.OrganizationsData.get_accounts_in_ct_baseline_config_stack_set", mock.MagicMock(return_value=[list(os.environ['ACCOUNT_LIST'].split(',')),[]])): - manifest_name = 'manifest_version_2.yaml' + with mock.patch( + "cfct.manifest.manifest_parser.OrganizationsData.get_accounts_in_ct_baseline_config_stack_set", + mock.MagicMock(return_value=[list(os.environ["ACCOUNT_LIST"].split(",")), []]), + ): + manifest_name = "manifest_version_2.yaml" file_path = TESTS_DIR + manifest_name - os.environ['MANIFEST_FILE_NAME'] = manifest_name - os.environ['MANIFEST_FILE_PATH'] = file_path - os.environ['MANIFEST_FOLDER'] = file_path[:-len(manifest_name)] - os.environ['STAGE_NAME'] = 'stackset' + os.environ["MANIFEST_FILE_NAME"] = manifest_name + os.environ["MANIFEST_FILE_PATH"] = file_path + os.environ["MANIFEST_FOLDER"] = file_path[: -len(manifest_name)] + os.environ["STAGE_NAME"] = "stackset" sm_input_list = parse.stack_set_manifest() logger.info("Stack Set sm_input_list:") logger.info(sm_input_list) # check if namespace CustomControlTower is added to the stack name - assert sm_input_list[0]['ResourceProperties']['StackSetName'] == \ - "CustomControlTower-stackset-1" + assert ( + sm_input_list[0]["ResourceProperties"]["StackSetName"] + == "CustomControlTower-stackset-1" + ) # check the account list should have 2 accounts - Developer1 only (not # Developer1-SuperSet - assert len(sm_input_list[0]['ResourceProperties']['AccountList']) == 2 + assert len(sm_input_list[0]["ResourceProperties"]["AccountList"]) == 2 # check if export_outputs is not defined then SSMParameters is set to # empty dict - assert sm_input_list[1]['ResourceProperties']['SSMParameters'] == {} + assert sm_input_list[1]["ResourceProperties"]["SSMParameters"] == {} # check the account list should have 3 accounts - Developer1 only (not # Developer1-SuperSet - assert len(sm_input_list[1]['ResourceProperties']['AccountList']) == 3 + assert len(sm_input_list[1]["ResourceProperties"]["AccountList"]) == 3 # check if empty OU, account list should be empty string - assert sm_input_list[2]['ResourceProperties']['AccountList'] == [] + assert sm_input_list[2]["ResourceProperties"]["AccountList"] == [] # parameters key has empty dict - assert sm_input_list[2]['ResourceProperties']['Parameters'] == {} + assert sm_input_list[2]["ResourceProperties"]["Parameters"] == {} + @pytest.mark.unit def test_root_ou_stackset(mocker): org = parse.OrganizationsData() - mocker.patch.object(org.stack_set, 'get_accounts_and_regions_per_stack_set') - org.stack_set.get_accounts_and_regions_per_stack_set.return_value = ['000', '111'],[] + mocker.patch.object(org.stack_set, "get_accounts_and_regions_per_stack_set") + org.stack_set.get_accounts_and_regions_per_stack_set.return_value = [ + "000", + "111", + ], [] ou_id_to_account_map = {} ou_name_to_id_map = {} - ou_list = ['Root'] + ou_list = ["Root"] resp = org.get_accounts_in_ou(ou_id_to_account_map, ou_name_to_id_map, ou_list) - assert resp == ['000', '111'] + assert resp == ["000", "111"] + @pytest.mark.unit def test_root_ou_stackset_no(): org = parse.OrganizationsData() ou_id_to_account_map = {} ou_name_to_id_map = {} - ou_list = ['Dev'] + ou_list = ["Dev"] resp = org.get_accounts_in_ou(ou_id_to_account_map, ou_name_to_id_map, ou_list) - assert resp == [] \ No newline at end of file + assert resp == [] diff --git a/source/tests/test_os_util.py b/source/tests/test_os_util.py index e082d9f..0066775 100644 --- a/source/tests/test_os_util.py +++ b/source/tests/test_os_util.py @@ -12,26 +12,28 @@ # or implied. See the License for the specific language governing permissions# # and limitations under the License. # ############################################################################### -from cfct.utils import os_util -import mock -from cfct.utils.logger import Logger import os -import pytest +import mock +import pytest +from cfct.utils import os_util +from cfct.utils.logger import Logger -log_level = 'info' +log_level = "info" logger = Logger(loglevel=log_level) + @pytest.mark.unit -@mock.patch('cfct.utils.os_util.os') +@mock.patch("cfct.utils.os_util.os") def test_make_dir(self, tmpdir): os_util.make_dir(tmpdir) assert os.path.isdir(tmpdir) is True os_util.make_dir(tmpdir, logger) assert os.path.isdir(tmpdir) is True + @pytest.mark.unit -@mock.patch('cfct.utils.os_util.os') +@mock.patch("cfct.utils.os_util.os") def test_remove_dir(self, tmpdir): os_util.remove_dir(tmpdir) assert os.path.isdir(tmpdir) is False diff --git a/source/tests/test_parameter_manipulation.py b/source/tests/test_parameter_manipulation.py index e5a2aa7..ac5f7cc 100644 --- a/source/tests/test_parameter_manipulation.py +++ b/source/tests/test_parameter_manipulation.py @@ -12,32 +12,23 @@ # or implied. See the License for the specific language governing permissions# # and limitations under the License. # ############################################################################### -from cfct.utils import parameter_manipulation import pytest +from cfct.utils import parameter_manipulation -param = { - "key": "value", - "key1": "value1" -} +param = {"key": "value", "key1": "value1"} trans_params = [ - { - "ParameterKey": "key", - "ParameterValue": "value" - }, - { - "ParameterKey": "key1", - "ParameterValue": "value1" - } + {"ParameterKey": "key", "ParameterValue": "value"}, + {"ParameterKey": "key1", "ParameterValue": "value1"}, ] + @pytest.mark.unit def test_transform_params(): out_params = parameter_manipulation.transform_params(param) for idx in range(len(out_params)): - assert out_params[idx]['ParameterKey']\ - == trans_params[idx]['ParameterKey'] - assert out_params[idx]['ParameterValue']\ - == trans_params[idx]['ParameterValue'] + assert out_params[idx]["ParameterKey"] == trans_params[idx]["ParameterKey"] + assert out_params[idx]["ParameterValue"] == trans_params[idx]["ParameterValue"] + @pytest.mark.unit def test_reverse_transform_params(): diff --git a/source/tests/test_password_generator.py b/source/tests/test_password_generator.py index b8ce9a4..00538f4 100644 --- a/source/tests/test_password_generator.py +++ b/source/tests/test_password_generator.py @@ -12,14 +12,15 @@ # KIND, express or implied. See the License for the specific language # # governing permissions and limitations under the License. # ############################################################################## -from cfct.utils import password_generator import re + import pytest +from cfct.utils import password_generator + @pytest.mark.unit def test_random_pwd_generator(): - random_pwd_no_additional_string = \ - password_generator.random_pwd_generator(10, 'a') - assert len(re.sub('([^0-9])','',random_pwd_no_additional_string)) >= 2 + random_pwd_no_additional_string = password_generator.random_pwd_generator(10, "a") + assert len(re.sub("([^0-9])", "", random_pwd_no_additional_string)) >= 2 assert random_pwd_no_additional_string[8:] == "aa" assert len(random_pwd_no_additional_string) == 10 diff --git a/source/tests/test_sm_input_builder.py b/source/tests/test_sm_input_builder.py index ccddfb6..3dcd24d 100644 --- a/source/tests/test_sm_input_builder.py +++ b/source/tests/test_sm_input_builder.py @@ -13,12 +13,15 @@ # and limitations under the License. # ############################################################################### -from cfct.manifest.sm_input_builder import InputBuilder, SCPResourceProperties, \ - StackSetResourceProperties -from cfct.utils.logger import Logger import pytest +from cfct.manifest.sm_input_builder import ( + InputBuilder, + SCPResourceProperties, + StackSetResourceProperties, +) +from cfct.utils.logger import Logger -logger = Logger('info') +logger = Logger("info") # declare SCP state machine input variables name = "policy_name" @@ -27,38 +30,31 @@ account_id = "account_id_1" policy_list = [] operation = "operation_id" -ou_list = [ - [ - "ou_name_1", - "Attach" - ], - [ - "ou_name_2", - "Attach" - ] -] +ou_list = [["ou_name_1", "Attach"], ["ou_name_2", "Attach"]] delimiter = ":" # declare Stack Set state machine input variables stack_set_name = "StackSetName1" template_url = "https://s3.amazonaws.com/bucket/prefix" -parameters = {"Key1": "Value1", - "Key2": "Value2"} +parameters = {"Key1": "Value1", "Key2": "Value2"} capabilities = "CAPABILITY_NAMED_IAM" -account_list = ["account_id_1", - "account_id_2"] -region_list = ["us-east-1", - "us-east-2"] -ssm_parameters = { - "/ssm/parameter/store/key": "value" -} +account_list = ["account_id_1", "account_id_2"] +region_list = ["us-east-1", "us-east-2"] +ssm_parameters = {"/ssm/parameter/store/key": "value"} def build_scp_input(): # get SCP output - resource_properties = SCPResourceProperties(name, description, policy_url, - policy_list, account_id, - operation, ou_list, delimiter) + resource_properties = SCPResourceProperties( + name, + description, + policy_url, + policy_list, + account_id, + operation, + ou_list, + delimiter, + ) scp_input = InputBuilder(resource_properties.get_scp_input_map()) return scp_input.input_map() @@ -66,42 +62,53 @@ def build_scp_input(): def build_stack_set_input(): # get stack set output resource_properties = StackSetResourceProperties( - stack_set_name, template_url, parameters, - capabilities, account_list, region_list, - ssm_parameters) + stack_set_name, + template_url, + parameters, + capabilities, + account_list, + region_list, + ssm_parameters, + ) ss_input = InputBuilder(resource_properties.get_stack_set_input_map()) return ss_input.input_map() + @pytest.mark.unit def test_scp_input_type(): # check if returned input is of type dict scp_input = build_scp_input() assert isinstance(scp_input, dict) + @pytest.mark.unit def test_scp_resource_property_type(): # check if resource property is not None scp_input = build_scp_input() assert isinstance(scp_input.get("ResourceProperties"), dict) + @pytest.mark.unit def test_request_type_value(): # check the default request type is create scp_input = build_scp_input() assert scp_input.get("RequestType") == "Create" + @pytest.mark.unit def test_stack_set_input_type(): # check if returned input is of type dict stack_set_input = build_stack_set_input() assert isinstance(stack_set_input, dict) + @pytest.mark.unit def test_ss_resource_property_type(): # check if resource property is not None stack_set_input = build_stack_set_input() assert isinstance(stack_set_input.get("ResourceProperties"), dict) + @pytest.mark.unit def test_ss_request_type_value(): # check the default request type is create diff --git a/source/tests/test_stage_to_s3.py b/source/tests/test_stage_to_s3.py index e23179d..6d3606c 100644 --- a/source/tests/test_stage_to_s3.py +++ b/source/tests/test_stage_to_s3.py @@ -13,24 +13,30 @@ # and limitations under the License. # ############################################################################### -from cfct.manifest.stage_to_s3 import StageFile -from cfct.utils.logger import Logger from os import environ + import pytest +from cfct.manifest.stage_to_s3 import StageFile +from cfct.utils.logger import Logger + +logger = Logger("info") -logger = Logger('info') @pytest.mark.unit def test_convert_url(): - bucket_name = 'my-bucket-name' - key_name = 'my-key-name' + bucket_name = "my-bucket-name" + key_name = "my-key-name" relative_path = "s3://" + bucket_name + "/" + key_name sf = StageFile(logger, relative_path) s3_url = sf.get_staged_file() logger.info(s3_url) - assert s3_url.startswith("{}{}{}{}{}{}".format('https://', - bucket_name, - '.s3.', - environ.get('AWS_REGION'), - '.amazonaws.com/', - key_name)) \ No newline at end of file + assert s3_url.startswith( + "{}{}{}{}{}{}".format( + "https://", + bucket_name, + ".s3.", + environ.get("AWS_REGION"), + ".amazonaws.com/", + key_name, + ) + ) diff --git a/source/tests/test_string_manipulation.py b/source/tests/test_string_manipulation.py index a1adcfe..0042bd7 100644 --- a/source/tests/test_string_manipulation.py +++ b/source/tests/test_string_manipulation.py @@ -12,37 +12,47 @@ # KIND, express or implied. See the License for the specific language # # governing permissions and limitations under the License. # ############################################################################## -from cfct.utils import string_manipulation import pytest +from cfct.utils import string_manipulation + @pytest.mark.unit def test_sanitize(): - non_sanitized_string = 'I s@nitize $tring exc*pt_underscore-hypen.' - sanitized_string_allow_space = 'I s_nitize _tring exc_pt_underscore-hypen.' - sanitized_string_no_space_replace_hypen = \ - 'I-s-nitize--tring-exc-pt_underscore-hypen.' - assert string_manipulation.sanitize(non_sanitized_string,True) == \ - sanitized_string_allow_space - assert string_manipulation.sanitize(non_sanitized_string, False,'-') == \ - sanitized_string_no_space_replace_hypen + non_sanitized_string = "I s@nitize $tring exc*pt_underscore-hypen." + sanitized_string_allow_space = "I s_nitize _tring exc_pt_underscore-hypen." + sanitized_string_no_space_replace_hypen = ( + "I-s-nitize--tring-exc-pt_underscore-hypen." + ) + assert ( + string_manipulation.sanitize(non_sanitized_string, True) + == sanitized_string_allow_space + ) + assert ( + string_manipulation.sanitize(non_sanitized_string, False, "-") + == sanitized_string_no_space_replace_hypen + ) + @pytest.mark.unit def test_trim_length(): actual_sting = "EighteenCharacters" eight_char_string = "Eighteen" - assert string_manipulation\ - .trim_length_from_end(actual_sting, 8) == eight_char_string - assert string_manipulation\ - .trim_length_from_end(actual_sting, 18) == actual_sting - assert string_manipulation\ - .trim_length_from_end(actual_sting, 20) == actual_sting + assert ( + string_manipulation.trim_length_from_end(actual_sting, 8) == eight_char_string + ) + assert string_manipulation.trim_length_from_end(actual_sting, 18) == actual_sting + assert string_manipulation.trim_length_from_end(actual_sting, 20) == actual_sting + @pytest.mark.unit def test_extract_string(): actual_string = "abcdefgh" extract_string = "defgh" - assert string_manipulation.trim_string_from_front(actual_string, 'abc') == \ - extract_string + assert ( + string_manipulation.trim_string_from_front(actual_string, "abc") + == extract_string + ) + @pytest.mark.unit def test_convert_list_values_to_string(): @@ -51,52 +61,70 @@ def test_convert_list_values_to_string(): for string in list_of_strings: assert isinstance(string, str) + @pytest.mark.unit def test_convert_string_to_list_default_separator(): - separator = ',' + separator = "," value = "a, b" - list_1 = value if separator not in value else \ - string_manipulation.convert_string_to_list(value, separator) + list_1 = ( + value + if separator not in value + else string_manipulation.convert_string_to_list(value, separator) + ) assert isinstance(list_1, list) - assert list_1[0] == 'a' - assert list_1[1] == 'b' + assert list_1[0] == "a" + assert list_1[1] == "b" + @pytest.mark.unit def test_convert_string_to_list_no_separator(): - separator = ',' + separator = "," value = "a" - string = value if separator not in value else \ - string_manipulation.convert_string_to_list(value, separator) + string = ( + value + if separator not in value + else string_manipulation.convert_string_to_list(value, separator) + ) assert isinstance(string, str) - assert string == 'a' + assert string == "a" + @pytest.mark.unit def test_convert_string_to_list_custom_separator(): - separator = ';' + separator = ";" value = "a; b" - list_1 = list_1 = value if separator not in value else \ - string_manipulation.convert_string_to_list(value, separator) + list_1 = list_1 = ( + value + if separator not in value + else string_manipulation.convert_string_to_list(value, separator) + ) assert isinstance(list_1, list) - assert list_1[0] == 'a' - assert list_1[1] == 'b' + assert list_1[0] == "a" + assert list_1[1] == "b" + @pytest.mark.unit def test_strip_list_items(): - arr = [" a"," b", "c "] - assert string_manipulation.strip_list_items(arr) == ['a', 'b', 'c'] + arr = [" a", " b", "c "] + assert string_manipulation.strip_list_items(arr) == ["a", "b", "c"] + @pytest.mark.unit def test_remove_empty_strings(): - arr = ["a","b", "", "c"] - assert string_manipulation.remove_empty_strings(arr) == ['a', 'b', 'c'] + arr = ["a", "b", "", "c"] + assert string_manipulation.remove_empty_strings(arr) == ["a", "b", "c"] + @pytest.mark.unit def test_list_sanitizer(): - arr = [" a","b ", "", " c"] - assert string_manipulation.list_sanitizer(arr) == ['a', 'b', 'c'] + arr = [" a", "b ", "", " c"] + assert string_manipulation.list_sanitizer(arr) == ["a", "b", "c"] + @pytest.mark.unit def test_empty_separator_handler(): delimiter = ":" nested_ou_name_list = "testou1:testou2:testou3" - assert string_manipulation.empty_separator_handler(delimiter, nested_ou_name_list) == ['testou1', 'testou2', 'testou3'] \ No newline at end of file + assert string_manipulation.empty_separator_handler( + delimiter, nested_ou_name_list + ) == ["testou1", "testou2", "testou3"] diff --git a/source/tests/test_url_conversion.py b/source/tests/test_url_conversion.py index d34aacb..6f3b253 100644 --- a/source/tests/test_url_conversion.py +++ b/source/tests/test_url_conversion.py @@ -12,39 +12,68 @@ # KIND, express or implied. See the License for the specific language # # governing permissions and limitations under the License. # ############################################################################## -from cfct.aws.utils.url_conversion import parse_bucket_key_names, \ - convert_s3_url_to_http_url, build_http_url -from cfct.utils.logger import Logger from os import getenv + import pytest +from cfct.aws.utils.url_conversion import ( + build_http_url, + convert_s3_url_to_http_url, + parse_bucket_key_names, +) +from cfct.utils.logger import Logger + +logger = Logger("info") +bucket_name = "bucket-name" +key_name = "key-name/key2/object" -logger = Logger('info') -bucket_name = 'bucket-name' -key_name = 'key-name/key2/object' @pytest.mark.unit def test_s3_url_to_http_url(): - s3_path = '%s/%s' % (bucket_name, key_name) - s3_url = 's3://' + s3_path + s3_path = "%s/%s" % (bucket_name, key_name) + s3_url = "s3://" + s3_path http_url = convert_s3_url_to_http_url(s3_url) - assert http_url == 'https://' + bucket_name + '.s3.' + getenv('AWS_REGION')\ - + '.amazonaws.com/' + key_name + assert ( + http_url + == "https://" + + bucket_name + + ".s3." + + getenv("AWS_REGION") + + ".amazonaws.com/" + + key_name + ) + @pytest.mark.unit def test_virtual_hosted_style_http_url_to_s3_url(): - http_url = 'https://' + bucket_name + '.s3.' + getenv('AWS_REGION') + '.amazonaws.com/' + key_name + http_url = ( + "https://" + + bucket_name + + ".s3." + + getenv("AWS_REGION") + + ".amazonaws.com/" + + key_name + ) bucket, key, region = parse_bucket_key_names(http_url) assert bucket_name == bucket assert key_name == key - assert getenv('AWS_REGION') == region + assert getenv("AWS_REGION") == region + @pytest.mark.unit def test_path_style_http_url_to_s3_url(): - http_url = 'https://s3.' + getenv('AWS_REGION') + '.amazonaws.com/' + bucket_name + '/' + key_name + http_url = ( + "https://s3." + + getenv("AWS_REGION") + + ".amazonaws.com/" + + bucket_name + + "/" + + key_name + ) bucket, key, region = parse_bucket_key_names(http_url) assert bucket_name == bucket assert key_name == key - assert getenv('AWS_REGION') == region + assert getenv("AWS_REGION") == region + @pytest.mark.unit def test_build_http_url():