diff --git a/src/spl/token/instructions.py b/src/spl/token/instructions.py index 69e48052..efd3c521 100644 --- a/src/spl/token/instructions.py +++ b/src/spl/token/instructions.py @@ -10,7 +10,7 @@ from solana.utils.validate import validate_instruction_keys, validate_instruction_type from spl.token._layouts import INSTRUCTIONS_LAYOUT, InstructionType -from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID +from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID, TOKEN_2022_PROGRAM_ID class AuthorityType(IntEnum): @@ -1211,26 +1211,51 @@ def sync_native(params: SyncNativeParams) -> Instruction: return __sync_native_instruction(params, data) -def get_associated_token_address(owner: Pubkey, mint: Pubkey) -> Pubkey: +def get_associated_token_address(owner: Pubkey, mint: Pubkey, token_program_id: Pubkey = TOKEN_PROGRAM_ID) -> Pubkey: """Derives the associated token address for the given wallet address and token mint. + Args: + owner (Pubkey): Owner's wallet address. + mint (Pubkey): The token mint address. + token_program_id (Pubkey, optional): The token program ID. Must be either `spl.token.constants.TOKEN_PROGRAM_ID` + or `spl.token.constants.TOKEN_2022_PROGRAM_ID` (default is `TOKEN_PROGRAM_ID`). + Returns: The public key of the derived associated token address. + + Raises: + ValueError: If an invalid `token_program_id` is provided. """ + if token_program_id not in [TOKEN_PROGRAM_ID, TOKEN_2022_PROGRAM_ID]: + raise ValueError("token_program_id must be one of TOKEN_PROGRAM_ID or TOKEN_2022_PROGRAM_ID.") key, _ = Pubkey.find_program_address( - seeds=[bytes(owner), bytes(TOKEN_PROGRAM_ID), bytes(mint)], + seeds=[bytes(owner), bytes(token_program_id), bytes(mint)], program_id=ASSOCIATED_TOKEN_PROGRAM_ID, ) return key -def create_associated_token_account(payer: Pubkey, owner: Pubkey, mint: Pubkey) -> Instruction: +def create_associated_token_account( + payer: Pubkey, owner: Pubkey, mint: Pubkey, token_program_id: Pubkey = TOKEN_PROGRAM_ID +) -> Instruction: """Creates a transaction instruction to create an associated token account. + Args: + payer (Pubkey): Payer's wallet address. + owner (Pubkey): Owner's wallet address. + mint (Pubkey): The token mint address. + token_program_id (Pubkey, optional): The token program ID. Must be either `spl.token.constants.TOKEN_PROGRAM_ID` + or `spl.token.constants.TOKEN_2022_PROGRAM_ID` (default is `TOKEN_PROGRAM_ID`). + Returns: The instruction to create the associated token account. + + Raises: + ValueError: If an invalid `token_program_id` is provided. """ - associated_token_address = get_associated_token_address(owner, mint) + if token_program_id not in [TOKEN_PROGRAM_ID, TOKEN_2022_PROGRAM_ID]: + raise ValueError("token_program_id must be one of TOKEN_PROGRAM_ID or TOKEN_2022_PROGRAM_ID.") + associated_token_address = get_associated_token_address(owner, mint, token_program_id) return Instruction( accounts=[ AccountMeta(pubkey=payer, is_signer=True, is_writable=True), @@ -1238,7 +1263,7 @@ def create_associated_token_account(payer: Pubkey, owner: Pubkey, mint: Pubkey) AccountMeta(pubkey=owner, is_signer=False, is_writable=False), AccountMeta(pubkey=mint, is_signer=False, is_writable=False), AccountMeta(pubkey=SYS_PROGRAM_ID, is_signer=False, is_writable=False), - AccountMeta(pubkey=TOKEN_PROGRAM_ID, is_signer=False, is_writable=False), + AccountMeta(pubkey=token_program_id, is_signer=False, is_writable=False), AccountMeta(pubkey=RENT, is_signer=False, is_writable=False), ], program_id=ASSOCIATED_TOKEN_PROGRAM_ID,