附:Lambda检查逻辑

上面我们分析了一个帐号,也可以将其扩展到多个帐户,使用Lambda扫描并分析其他帐户中的KMS密钥策略。当然要更新Dynamo DB表, 添加其他帐号

Lambda 中一些评估项如下:

def checkManageableThroughIAM(principal_service: str) -> str:
    if principal_service[-5:] == ":root":
        return "key access management manageable through IAM"
    return ""


def checkThirdPartyManaged(account_number: str, current_account_number: str) -> str:
    if account_number != current_account_number:
        return "3rd party account can use KMS!!!"
    return ""


def checkKMSPolciy(action: str) -> str:
    if "KMS:*" in action:
        return "key policy is overly permissive"
    return ""

也可以根据自己企业安全要求,添加其他的评估。比如检查是否提供对 IAM 用户的访问权限。我们希望所有访问权限都提供给 IAM 角色,而不是用户:

def checkManageableThroughKMS(principal_service: str) -> str:
    if ":user" in principal_service:
        return "Access provided to a IAM user"
    return ""

这些函数都是通过concernFiller调用的:

def concernFiller(
    principal_service: str,
    account_number: str,
    current_account_number: str,
    action: str,
):
    concern_list = []
    concern_list.append(checkManageableThroughIAM(principal_service=principal_service))
    concern_list.append(
        checkThirdPartyManaged(
            account_number=account_number, current_account_number=current_account_number
        )
    )
    concern_list.append(checkKmsPolciy(action=action))
    return ";".join([x for x in concern_list if x != ""])

所以我们最终才能做出这种图表:

image-20240622123742804

lambda完整代码:

import csv
import json
import logging
import os
from datetime import date, datetime

# Define Imports
import boto3
import botocore

logger = logging.getLogger()
logger.setLevel(logging.INFO)


def getKeys(kms):
    response = kms.list_keys()
    logger.debug(f"getKeys Response: {json.dumps(response,indent=3)}")
    return response["Keys"]


def getKeyPolicies(kms, keyid):
    try:
        response = kms.list_key_policies(KeyId=keyid)
        logger.debug(f"getKeyPolicies Response: {json.dumps(response,indent=3)}")
        return response["PolicyNames"]
    except botocore.exceptions.ClientError as error:
        if error.response["Error"]["Code"] == "AccessDeniedException":
            logger.warning(
                f"Unable to get Policies: {error.response['Error']['Message']}"
            )
        else:
            logger.error(f"ERROR RESPONSE: {error.response}")
            raise


def getKeyPoliciesTags(kms, keyid):
    try:
        response = kms.list_resource_tags(KeyId=keyid)
        logger.debug(f"getKeyPoliciesTags Response: {json.dumps(response,indent=3)}")
        return response["Tags"]
    except botocore.exceptions.ClientError as error:
        if error.response["Error"]["Code"] == "AccessDeniedException":
            logger.warning(
                f"Unable to get Policies: {error.response['Error']['Message']}"
            )
        else:
            logger.error(f"ERROR RESPONSE: {error.response}")
            raise


def whoami(session, region="us-east-1"):
    # Who am I?
    sts = session.client("sts", region_name=region)
    return sts.get_caller_identity()


def get_accounts(
    dynamo_accountID: str,
    dynamodb_role: str = "dynamodb_role",
    dynamodb_table: str = "accounts",
    region: str = "us-east-1",
):
    # Get assumed role credentials for dynamodb read
    try:
        session = getAssumedRoleSession(dynamo_accountID, dynamodb_role)
        logger.info(f"Account Number: {getAccountNumber(session, region)}")

        # Who am I?
        logger.debug(f"Who Am I? [{whoami(session, region)}]")

        dynamodb = session.resource("dynamodb", region_name=region)
        table = dynamodb.Table(dynamodb_table)

        accounts = table.scan(ProjectionExpression="accountId, accountName")
        print(f"Accounts: {accounts['Items']}")

        return accounts["Items"]

    except Exception as e:
        logger.error("Issue assuming the XA role: " + str(e))


def getAccountNumber(session, region: str) -> str:
    sts = session.client("sts", region_name=region)
    try:
        caller_id = sts.get_caller_identity()
        logger.debug(f"caller_id: {caller_id}")
        return caller_id.get("Account")
    except botocore.exceptions.ClientError as error:
        if error.response["Error"]["Code"] == "AccessDeniedException":
            logger.warning(
                f"Failed trying to retrieve STS get_caller_identity - {error.response['Error']['Message']}"
            )
        else:
            logger.error(f"ERROR RESPONSE: {error.response}")
            raise


def getPolicy(kms, keyId, policy):
    response = kms.get_key_policy(KeyId=keyId, PolicyName=policy)
    logger.debug(f"getPolicy Response: {response}")
    return json.loads(response["Policy"])


def getCreationDate(kms, keyId):
    response = kms.describe_key(KeyId=keyId)
    logger.debug(f"getCreationDate Response: {response}")
    return response["KeyMetadata"]["CreationDate"].strftime("%Y-%m-%d %H:%M:%S")


def getTag(kms, keyId):
    response = kms.list_resource_tags(KeyId=keyId)
    logger.debug(f"getTags Response: {response}")
    return response["Tags"]


def getAliases(kms, keyId):
    response = kms.list_aliases(KeyId=keyId)
    logger.debug(f"getAliases Response: {response}")
    return response["Aliases"]


def getEverythingJson(kms):
    kms_keys = getKeys(kms)
    logger.debug(f"Keys: {kms_keys}")

    keyMap = {"kms_keys": []}

    for kms_key in kms_keys:
        logger.debug(f"Key: {kms_key}")

        # Instantiate kms_key_object
        kms_key_object = {}
        kms_key_object["KeyId"] = kms_key["KeyId"]
        kms_key_object["Aliases"] = []
        list_of_aliases = []
        list_of_policies = []
        logger.debug(f"KMS KEY OBJECT {kms_key_object}")

        keyid = kms_key["KeyId"]
        logger.debug(f"KeyId: {keyid}")

        kms_key_policies = getKeyPolicies(kms, keyid)
        logger.debug(f"Keys Policies: {kms_key_policies}")

        # kms_key_tags = getKeyPoliciesTags(kms,keyid)
        # logger.debug(f"Keys Tags: {kms_key_tags}")

        kms_aliases = getAliases(kms, keyid)
        logger.debug(f"Aliases: {kms_aliases}")

        if kms_aliases:
            logger.debug(f"KMS KeyId: {keyid} - Aliases: {kms_aliases}")
            for ind in range(len(kms_aliases)):
                logger.debug(kms_aliases[ind]["AliasName"])
                list_of_aliases.append(kms_aliases[ind]["AliasName"])

        if kms_key_policies:
            key_tag = getTag(kms, keyid)
            creation_date = getCreationDate(kms, keyid)
            for policy in kms_key_policies:
                key_policy = getPolicy(kms, keyid, policy)
                list_of_policies.append(key_policy)
                logger.debug(f"KMS KeyId: {keyid} - [{policy}] - policy: {key_policy}")

        logger.debug(f"Aliases: {list_of_aliases}")
        logger.debug(f"Policies: {list_of_policies}")

        kms_key_object["Aliases"] = list_of_aliases
        kms_key_object["Policies"] = list_of_policies
        kms_key_object["Tags"] = key_tag
        kms_key_object["CreationDate"] = creation_date

        keyMap["kms_keys"].append(kms_key_object)

    logger.debug(json.dumps(keyMap, indent=3))
    return keyMap


# Function to assume role
def getAssumedRoleSession(aws_account, role_name="XA-KMSRead-Role"):
    role_to_assume_arn = "arn:aws:iam::" + aws_account + ":role/" + role_name
    sts_client = boto3.client("sts")

    logged_on_arn = sts_client.get_caller_identity()["Arn"]
    logger.info(
        f"Logged on user: '{logged_on_arn}' assuming role '{role_to_assume_arn}"
    )

    try:
        response = sts_client.assume_role(
            RoleArn=role_to_assume_arn, RoleSessionName=role_name
        )
        creds = response["Credentials"]
        session = boto3.session.Session(
            aws_access_key_id=creds["AccessKeyId"],
            aws_secret_access_key=creds["SecretAccessKey"],
            aws_session_token=creds["SessionToken"],
        )
        return session
    except Exception as e:
        logger.error(
            f"Unable to assume role '{role_name}' in account  in '{aws_account}': {e}"
        )


# Function to assume role
def getAssumedRoleKMSSession(aws_account, role_name="XA-KMSRead-Role"):
    role_to_assume_arn = "arn:aws:iam::" + aws_account + ":role/" + role_name
    sts_connection = boto3.client("sts")

    try:
        acct_b = sts_connection.assume_role(
            RoleArn=role_to_assume_arn, RoleSessionName="cross_acct_lambda"
        )

        ACCESS_KEY = acct_b["Credentials"]["AccessKeyId"]
        SECRET_KEY = acct_b["Credentials"]["SecretAccessKey"]
        SESSION_TOKEN = acct_b["Credentials"]["SessionToken"]

        # create service client using the assumed role credentials, e.g. KMS
        kms = boto3.client(
            "kms",
            aws_access_key_id=ACCESS_KEY,
            aws_secret_access_key=SECRET_KEY,
            aws_session_token=SESSION_TOKEN,
        )
        return kms

    except Exception as e:
        logger.error(f"Unable to assume role in account {aws_account}: {e}")


def getEverythingToCSV(accountNumber: str, filename: str, keyMap: dict, region: str):
    logger.info(f"{len(keyMap['kms_keys'])} KMS keys found in [{region}]")
    filename = "/tmp/" + filename
    logger.debug(f"Filename: {filename}")

    kmsKeysWithPolicies = []

    for key in keyMap["kms_keys"]:
        logger.debug(f"Key: {key}")
        if key["Policies"]:
            logger.debug(f"Policy: {key['Policies']}")

            # Grab the KeyId
            keyId = key["KeyId"]
            logger.debug(f"KeyId: {keyId}")

            # Grab the 1st alias
            keyAlias = None
            if key["Aliases"]:
                keyAlias = key["Aliases"][0]
            logger.debug(f"KeyAlias: {key['Aliases']}")

            # Process the policies
            for x in range(
                len(key["Policies"])
            ):  # Probably always one Policy - so x will almost always be 0
                for line in grabPolicyStatementDetailsList(
                    accountNumber,
                    region,
                    keyId,
                    keyAlias,
                    key["Policies"][x]["Statement"],
                    key["Tags"],
                    key["CreationDate"],
                ):
                    kmsKeysWithPolicies.append(line)
                logger.debug(
                    f"KeyId: {keyId} - keyAlias: {keyAlias} - Policy: {kmsKeysWithPolicies}"
                )

        # Write to CSV
        header = [
            "Date",
            "AccountNumber",
            "Region",
            "KeyId",
            "Alias",
            "Sid",
            "Effect",
            "Principal",
            "Principal Service",
            "Action",
            "Condition",
            "Concern",
            "Resource",
            "Tags",
            "CreationDate",
        ]
        logger.debug(f"kmsKeysWithPolicies: {kmsKeysWithPolicies}")
        with open(filename, "w") as file:
            writer = csv.DictWriter(file, fieldnames=header)
            writer.writeheader()
            writer.writerows(kmsKeysWithPolicies)


def grabPolicyStatementDetailsList(
    accountNumber: str,
    region: str,
    keyId: str,
    keyAlias: str,
    policyStatements: json,
    Tags: list,
    CreationDate: date,
) -> list:
    keyListOfLists = []

    try:
        logger.debug(f"Statement: {policyStatements}")
        if policyStatements:
            logger.debug(f"Number of Policy Statements: {len(policyStatements)}")
            for each in range(len(policyStatements)):
                logger.debug(f"policyStatements{each}: {policyStatements[each]}")
                # Add the variables we know exist
                keyList = {
                    "AccountNumber": accountNumber,
                    "Region": region,
                    "KeyId": keyId,
                    "Alias": keyAlias,
                }
                # Add the others that might exist
                if "Sid" in policyStatements[each]:
                    keyList["Sid"] = policyStatements[each]["Sid"]
                if "Effect" in policyStatements[each]:
                    keyList["Effect"] = policyStatements[each]["Effect"]
                if "Principal" in policyStatements[each]:
                    keyList["Principal"] = list(
                        policyStatements[each]["Principal"].keys()
                    )[0]
                    keyList["Principal Service"] = list(
                        policyStatements[each]["Principal"].values()
                    )[0]
                    keyList["Principal Service"] = str(
                        keyList["Principal Service"]
                    ).replace(",", ";")
                if "Action" in policyStatements[each]:
                    keyList["Action"] = str(policyStatements[each]["Action"]).replace(
                        ",", ";"
                    )
                if "Resource" in policyStatements[each]:
                    keyList["Resource"] = policyStatements[each]["Resource"]
                if "Condition" in policyStatements[each]:
                    condition = str(policyStatements[each]["Condition"])
                    if isinstance(condition, str):
                        keyList["Condition"] = condition.replace(",", ";")
                keyList["Tags"] = Tags
                keyList["Tags"] = str(keyList["Tags"]).replace(",", ";")
                keyList["CreationDate"] = CreationDate
                keyList["Date"] = str(datetime.now().strftime("%Y-%m-%d"))

                # This function processes the concerns at the bottom of the file
                keyList["Concern"] = concernFiller(
                    principal_service=keyList["Principal Service"],
                    account_number=os.environ["DYNAMODBACCOUNT"],
                    current_account_number=accountNumber,
                    action=keyList["Action"],
                )
                keyListOfLists.append(keyList)

    except Exception as e:
        logger.error(f"Error: {e}")

    return keyListOfLists


def pushToS3(filename: str, bucketName: str):
    # file to check
    file_path = "/tmp/" + filename

    flag = os.path.isfile(file_path)
    if flag:
        logger.debug(f"The file {file_path} exists")
        s3 = boto3.resource("s3")
        bucket = s3.Bucket(bucketName)
        key = filename
        logger.debug(f"About to try and push [{filename}] to s3://{bucketName}")
        try:
            bucket.upload_file("/tmp/" + filename, key)
            logger.info(f"Successfully uploaded [{filename}] to s3://{bucketName}")
        except FileNotFoundError as error:
            logger.error(f"{filename} not found!")
        except botocore.exceptions.ClientError as error:
            if error.response["Error"]["Code"] == "AccessDeniedException":
                logger.error(
                    f"Access Denied PUTting {filename} to s3://{bucket}: {error.response['Error']['Message']}"
                )
            elif error.response["Error"]["Code"] == "FileNotFoundError":
                logger.error(
                    f"{filename} not found!: {error.response['Error']['Message']}"
                )
            else:
                logger.error(f"ERROR RESPONSE: {error.response}")
                raise

    else:
        logger.debug(f"{file_path} not found. Probably because no keys found")


def lambda_handler(event, context):
    """
    The Lambda handler invoked from the console by the user

    Parameters:
        event (dict): This is assumed to be empty/test data as we're evoking manually
        context
    """

    # Read S3 bucket name to store the CSV files
    if "DEST_BUCKET" in os.environ:
        destination_s3_bucket = os.environ["DEST_BUCKET"]
    else:
        destination_s3_bucket = None

    # Read active regions from environment variable
    if "REGIONS" in os.environ:
        regions_string = os.environ["REGIONS"]
        regions = [x.strip() for x in regions_string.strip("[]").split(",")]
    else:
        regions = None

    # Read Dynamo Settings
    if "DYNAMODBROLE" in os.environ:
        dynamodb_role = os.environ["DYNAMODBROLE"]
    else:
        dynamodb_role = None

    if "DYNAMODBACCOUNT" in os.environ:
        dynamo_accountID = os.environ["DYNAMODBACCOUNT"]
    else:
        dynamo_accountID = None

    if "DYNAMODBTABLE" in os.environ:
        dynamodb_table = os.environ["DYNAMODBTABLE"]
    else:
        dynamodb_table = None

    if "DYNAMODBREGION" in os.environ:
        dynamodb_region = os.environ["DYNAMODBREGION"]
    else:
        dynamodb_region = None

    account_ids = get_accounts(
        dynamo_accountID, dynamodb_role, dynamodb_table, dynamodb_region
    )
    logging.info(f"Accounts: {account_ids}")

    if account_ids:  # Did we find some accounts in the DynamoDB table?
        for account in account_ids:  # For each account we found do the following
            account_number = account["accountId"]
            logger.info(f"Processing Account Number: {account_number}")
            # Assume role into the target account
            session = getAssumedRoleSession(account_number, role_name="XA-KMSRead-Role")

            # If we successfully Assumed a role in the target account now loop round each region and get the KMS key data
            if session:
                for region in regions:
                    # TODO: CODE or PASTE the CODE IN HERE:
                    # 1. Create KMS client
                    KMS = session.client("kms",region_name=region)
                    # 2. Using the KMS client, query the KMS keys, aliases and policies and build into a single JSON object
                    keyMap = getEverythingJson(KMS)
                    # 3. Define a unique filename for the account number and region
                    filename = account_number + "-" + region + "-KMS-details.csv"
                    # 4. Convert the JSON object to a CSV and store in the local Lambda filesystem
                    getEverythingToCSV(account_number, filename, keyMap, region)
                    # 5. Push the file to S3
                    pushToS3(filename, destination_s3_bucket)
                    
            else:
                logger.error(f"Unable to create session for {account_number}")
    else:
        logger.error(f"No Account IDs found: account_ids={account_ids}")


def concernFiller(
    principal_service: str,
    account_number: str,
    current_account_number: str,
    action: str,
):
    concern_list = []
    concern_list.append(checkManageableThroughIAM(principal_service=principal_service))
    concern_list.append(
        checkThirdPartyManaged(
            account_number=account_number, current_account_number=current_account_number
        )
    )
    concern_list.append(checkKmsPolciy(action=action))
    return ";".join([x for x in concern_list if x != ""])


def checkManageableThroughIAM(principal_service: str) -> str:
    if principal_service[-5:] == ":root":
        return "key access management manageable through IAM"
    return ""


def checkThirdPartyManaged(account_number: str, current_account_number: str) -> str:
    if account_number != current_account_number:
        return "3rd party account can use kms!!!"
    return ""


def checkKmsPolciy(action: str) -> str:
    if "kms:*" in action:
        return "key policy is overly permissive"
    return ""