Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spl-token: Add token_program_id parameter to associated token endpoints #456

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/spl/token/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from solana.utils.validate import validate_instruction_keys, validate_instruction_type
from spl.token._layouts import INSTRUCTIONS_LAYOUT, InstructionType
from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID
from spl.token.constants import ASSOCIATED_TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID, TOKEN_2022_PROGRAM_ID


class AuthorityType(IntEnum):
Expand Down Expand Up @@ -1211,34 +1211,59 @@ def sync_native(params: SyncNativeParams) -> Instruction:
return __sync_native_instruction(params, data)


def get_associated_token_address(owner: Pubkey, mint: Pubkey) -> Pubkey:
def get_associated_token_address(owner: Pubkey, mint: Pubkey, token_program_id: Pubkey = TOKEN_PROGRAM_ID) -> Pubkey:
"""Derives the associated token address for the given wallet address and token mint.

Args:
owner (Pubkey): Owner's wallet address.
mint (Pubkey): The token mint address.
token_program_id (Pubkey, optional): The token program ID. Must be either `spl.token.constants.TOKEN_PROGRAM_ID`
or `spl.token.constants.TOKEN_2022_PROGRAM_ID` (default is `TOKEN_PROGRAM_ID`).

Returns:
The public key of the derived associated token address.

Raises:
ValueError: If an invalid `token_program_id` is provided.
"""
if token_program_id not in [TOKEN_PROGRAM_ID, TOKEN_2022_PROGRAM_ID]:
raise ValueError("token_program_id must be one of TOKEN_PROGRAM_ID or TOKEN_2022_PROGRAM_ID.")
key, _ = Pubkey.find_program_address(
seeds=[bytes(owner), bytes(TOKEN_PROGRAM_ID), bytes(mint)],
seeds=[bytes(owner), bytes(token_program_id), bytes(mint)],
program_id=ASSOCIATED_TOKEN_PROGRAM_ID,
)
return key


def create_associated_token_account(payer: Pubkey, owner: Pubkey, mint: Pubkey) -> Instruction:
def create_associated_token_account(
payer: Pubkey, owner: Pubkey, mint: Pubkey, token_program_id: Pubkey = TOKEN_PROGRAM_ID
) -> Instruction:
"""Creates a transaction instruction to create an associated token account.

Args:
payer (Pubkey): Payer's wallet address.
owner (Pubkey): Owner's wallet address.
mint (Pubkey): The token mint address.
token_program_id (Pubkey, optional): The token program ID. Must be either `spl.token.constants.TOKEN_PROGRAM_ID`
or `spl.token.constants.TOKEN_2022_PROGRAM_ID` (default is `TOKEN_PROGRAM_ID`).

Returns:
The instruction to create the associated token account.

Raises:
ValueError: If an invalid `token_program_id` is provided.
"""
associated_token_address = get_associated_token_address(owner, mint)
if token_program_id not in [TOKEN_PROGRAM_ID, TOKEN_2022_PROGRAM_ID]:
raise ValueError("token_program_id must be one of TOKEN_PROGRAM_ID or TOKEN_2022_PROGRAM_ID.")
associated_token_address = get_associated_token_address(owner, mint, token_program_id)
return Instruction(
accounts=[
AccountMeta(pubkey=payer, is_signer=True, is_writable=True),
AccountMeta(pubkey=associated_token_address, is_signer=False, is_writable=True),
AccountMeta(pubkey=owner, is_signer=False, is_writable=False),
AccountMeta(pubkey=mint, is_signer=False, is_writable=False),
AccountMeta(pubkey=SYS_PROGRAM_ID, is_signer=False, is_writable=False),
AccountMeta(pubkey=TOKEN_PROGRAM_ID, is_signer=False, is_writable=False),
AccountMeta(pubkey=token_program_id, is_signer=False, is_writable=False),
AccountMeta(pubkey=RENT, is_signer=False, is_writable=False),
],
program_id=ASSOCIATED_TOKEN_PROGRAM_ID,
Expand Down