Skip to content

Commit

Permalink
wallet management api
Browse files Browse the repository at this point in the history
  • Loading branch information
v9n committed Nov 11, 2024
1 parent fcd0dc3 commit 4a816fe
Show file tree
Hide file tree
Showing 14 changed files with 766 additions and 296 deletions.
26 changes: 17 additions & 9 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,16 @@ type RpcServer struct {
}

// Get nonce of an existing smart wallet of a given owner
func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest) (*avsproto.NonceResp, error) {
func (r *RpcServer) CreateWallet(ctx context.Context, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
}
return r.engine.CreateSmartWallet(user, payload)
}

// Get nonce of an existing smart wallet of a given owner
func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest) (*avsproto.NonceResp, error) {
ownerAddress := common.HexToAddress(payload.Owner)

nonce, err := aa.GetNonce(r.smartWalletRpc, ownerAddress, big.NewInt(0))
Expand All @@ -55,17 +63,15 @@ func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest

// GetAddress returns smart account address of the given owner in the auth key
func (r *RpcServer) GetSmartAccountAddress(ctx context.Context, payload *avsproto.AddressRequest) (*avsproto.AddressResp, error) {
ownerAddress := common.HexToAddress(payload.Owner)
salt := big.NewInt(0)
sender, err := aa.GetSenderAddress(r.smartWalletRpc, ownerAddress, salt)

user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
}

wallets, err := r.engine.GetSmartWallets(user.Address)

return &avsproto.AddressResp{
SmartAccountAddress: sender.String(),
// TODO: return the right salt
Salt: big.NewInt(0).String(),
Wallets: wallets,
}, nil
}

Expand Down Expand Up @@ -160,12 +166,14 @@ func (r *RpcServer) GetTask(ctx context.Context, taskID *avsproto.UUID) (*avspro
return task.ToProtoBuf()
}

// Operator action
func (r *RpcServer) SyncTasks(payload *avsproto.SyncTasksReq, srv avsproto.Aggregator_SyncTasksServer) error {
err := r.engine.StreamCheckToOperator(payload, srv)

return err
}

// Operator action
func (r *RpcServer) UpdateChecks(ctx context.Context, payload *avsproto.UpdateChecksReq) (*avsproto.UpdateChecksResp, error) {
if err := r.engine.AggregateChecksResult(payload.Address, payload.Id); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
agg.queue,
agg.logger,
)
agg.engine.Start()
agg.engine.MustStart()

agg.queue.MustStart()
agg.worker.MustStart()
Expand Down
11 changes: 11 additions & 0 deletions core/chainio/aa/aa.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ func GetSenderAddress(conn *ethclient.Client, ownerAddress common.Address, salt
return &sender, nil
}

// Compute smart wallet address for a particular factory
func GetSenderAddressForFactory(conn *ethclient.Client, ownerAddress common.Address, customFactoryAddress common.Address, salt *big.Int) (*common.Address, error) {
simpleFactory, err := NewSimpleFactory(customFactoryAddress, conn)
if err != nil {
return nil, err
}

sender, err := simpleFactory.GetAddress(nil, ownerAddress, salt)
return &sender, nil
}

func GetNonce(conn *ethclient.Client, ownerAddress common.Address, salt *big.Int) (*big.Int, error) {
if salt == nil {
salt = defaultSalt
Expand Down
31 changes: 21 additions & 10 deletions core/taskengine/doc.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
/*
package trigger monitor the condition on when to fire a task
there are 3 trigger types:
Task Engine handles task storage and execution. We use badgerdb for all of our task storage. We like to make sure of Go cross compiling extensively and want to leverage pure-go as much as possible. badgerdb sastify that requirement.
Interval: Repeated at certain interval
Cron: trigger on time on cron
Onchain Event: when an event is emiited from a contract
Ev
**Wallet Info**
# Storage Layout
w:<eoa>:<smart-wallet-address> = {factory_address: address, salt: salt}
Task is store into 2 storage
t:a:<task-id>: the raw json of task data
u:<task-id>: the task status
**Task Storage**
w:<eoa>:<smart-wallet-address> -> {factory, salt}
t:<task-status>:<task-id> -> task payload, the source of truth of task information
u:<eoa>:<smart-wallet-address>:<task-id> -> task status
h:<smart-wallet-address>:<task-id>:<execution-id> -> an execution history
The task storage was designed for fast retrieve time at the cost of extra storage.
The storage can also be easily back-up, sync due to simplicity of supported write operation.
**Data console**
Storage can also be inspect with telnet:
telnet /tmp/ap.sock
Then issue `get <ket>` or `list <prefix>` or `list *` to inspect current keys in the storage.
*/
package taskengine
160 changes: 106 additions & 54 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (
"github.com/AvaProtocol/ap-avs/model"
"github.com/AvaProtocol/ap-avs/storage"
sdklogging "github.com/Layr-Labs/eigensdk-go/logging"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/ethclient"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
grpcstatus "google.golang.org/grpc/status"

avsproto "github.com/AvaProtocol/ap-avs/protobuf"
Expand Down Expand Up @@ -50,6 +53,7 @@ type Engine struct {
lock *sync.Mutex
trackSyncedTasks map[string]*operatorState

smartWalletConfig *config.SmartWalletConfig
// when shutdown is true, our engine will perform the shutdown
// pending execution will be pushed out before the shutdown completely
// to force shutdown, one can type ctrl+c twice
Expand Down Expand Up @@ -91,10 +95,11 @@ func New(db storage.Storage, config *config.Config, queue *apqueue.Queue, logger
db: db,
queue: queue,

lock: &sync.Mutex{},
tasks: make(map[string]*model.Task),
trackSyncedTasks: make(map[string]*operatorState),
shutdown: false,
lock: &sync.Mutex{},
tasks: make(map[string]*model.Task),
trackSyncedTasks: make(map[string]*operatorState),
smartWalletConfig: config.SmartWallet,
shutdown: false,

logger: logger,
}
Expand All @@ -110,14 +115,15 @@ func (n *Engine) Stop() {
n.shutdown = true
}

func (n *Engine) Start() {
func (n *Engine) MustStart() {
var err error
n.seq, err = n.db.GetSequence([]byte("t:seq"), 1000)
if err != nil {
panic(err)
}

kvs, e := n.db.GetByPrefix([]byte(fmt.Sprintf("t:%s:", TaskStatusToStorageKey(avsproto.TaskStatus_Active))))
// Upon booting we will get all the active tasks to sync to operator
kvs, e := n.db.GetByPrefix(TaskByStatusStoragePrefix(avsproto.TaskStatus_Active))
if e != nil {
panic(e)
}
Expand All @@ -127,7 +133,87 @@ func (n *Engine) Start() {
n.tasks[string(item.Key)] = &task
}
}
}

func (n *Engine) GetSmartWallets(owner common.Address) ([]*avsproto.SmartWallet, error) {
// This is the default wallet with our own factory
salt := big.NewInt(0)
sender, err := aa.GetSenderAddress(rpcConn, owner, salt)
if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
}

// now load the customize wallet with different salt or factory that was initialed and store in our db
wallets := []*avsproto.SmartWallet{
&avsproto.SmartWallet{
Address: sender.String(),
Factory: n.smartWalletConfig.FactoryAddress.String(),
Salt: salt.String(),
},
}

items, err := n.db.GetByPrefix(WalletByOwnerPrefix(owner))

if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
}

for _, item := range items {
w := &model.SmartWallet{}
w.FromStorageData(item.Value)

wallets = append(wallets, &avsproto.SmartWallet{
Address: w.Address.String(),
Factory: w.Factory.String(),
Salt: w.Salt.String(),
})
}

return wallets, nil
}

func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
// Verify data
// when user passing a custom factory address, we want to validate it
if payload.FactoryAddress != "" && !common.IsHexAddress(payload.FactoryAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid factory address")
}

salt := big.NewInt(0)
if payload.Salt != "" {
var ok bool
salt, ok = math.ParseBig256(payload.Salt)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "invalid salt value")
}
}

factoryAddress := n.smartWalletConfig.FactoryAddress
if payload.FactoryAddress != "" {
factoryAddress = common.HexToAddress(payload.FactoryAddress)

}

sender, err := aa.GetSenderAddressForFactory(rpcConn, user.Address, factoryAddress, salt)

wallet := &model.SmartWallet{
Owner: &user.Address,
Address: sender,
Factory: &factoryAddress,
Salt: salt,
}

updates := map[string][]byte{}

updates[string(WalletStorageKey(wallet))], err = wallet.ToJSON()

if err = n.db.BatchWrite(updates); err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_StorageWriteError), "cannot update key to storage")
}

return &avsproto.CreateWalletResp{
Address: sender.String(),
}, nil
}

func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskReq) (*model.Task, error) {
Expand All @@ -140,21 +226,16 @@ func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskRe
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_SmartWalletRpcError), "cannot get smart wallet address")
}

taskID, err := n.NewTaskID()
if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), "cannot create task right now. storage unavailable")
}

task, err := model.NewTaskFromProtobuf(taskID, user, taskPayload)
task, err := model.NewTaskFromProtobuf(user, taskPayload)

if err != nil {
return nil, err
}

updates := map[string][]byte{}

updates[TaskStorageKey(task.ID, task.Status)], err = task.ToJSON()
updates[TaskUserKey(task)] = []byte(fmt.Sprintf("%d", avsproto.TaskStatus_Active))
updates[string(TaskStorageKey(task.ID, task.Status))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", avsproto.TaskStatus_Active))

if err = n.db.BatchWrite(updates); err != nil {
return nil, err
Expand Down Expand Up @@ -249,8 +330,8 @@ func (n *Engine) AggregateChecksResult(address string, ids []string) error {
n.logger.Debug("mark task in executing status", "task_id", id)

if err := n.db.Move(
[]byte(fmt.Sprintf("t:%s:%s", TaskStatusToStorageKey(avsproto.TaskStatus_Active), id)),
[]byte(fmt.Sprintf("t:%s:%s", TaskStatusToStorageKey(avsproto.TaskStatus_Executing), id)),
[]byte(TaskStorageKey(id, avsproto.TaskStatus_Active)),
[]byte(TaskStorageKey(id, avsproto.TaskStatus_Executing)),
); err != nil {
n.logger.Error("error moving the task storage from active to executing", "task", id, "error", err)
}
Expand All @@ -263,7 +344,7 @@ func (n *Engine) AggregateChecksResult(address string, ids []string) error {
}

func (n *Engine) ListTasksByUser(user *model.User) ([]*avsproto.ListTasksResp_TaskItemResp, error) {
taskIDs, err := n.db.GetByPrefix([]byte(fmt.Sprintf("u:%s", user.Address.String())))
taskIDs, err := n.db.GetByPrefix(UserTaskStoragePrefix(user.Address))

if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), "storage is not ready")
Expand Down Expand Up @@ -326,8 +407,8 @@ func (n *Engine) DeleteTaskByUser(user *model.User, taskID string) (bool, error)
return false, fmt.Errorf("Only non executing task can be deleted")
}

n.db.Delete([]byte(TaskStorageKey(task.ID, task.Status)))
n.db.Delete([]byte(TaskUserKey(task)))
n.db.Delete(TaskStorageKey(task.ID, task.Status))
n.db.Delete(TaskUserKey(task))

return true, nil
}
Expand All @@ -346,13 +427,13 @@ func (n *Engine) CancelTaskByUser(user *model.User, taskID string) (bool, error)
updates := map[string][]byte{}
oldStatus := task.Status
task.SetCanceled()
updates[TaskStorageKey(task.ID, oldStatus)], err = task.ToJSON()
updates[TaskUserKey(task)] = []byte(fmt.Sprintf("%d", task.Status))
updates[string(TaskStorageKey(task.ID, oldStatus))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", task.Status))

if err = n.db.BatchWrite(updates); err == nil {
n.db.Move(
[]byte(TaskStorageKey(task.ID, oldStatus)),
[]byte(TaskStorageKey(task.ID, task.Status)),
TaskStorageKey(task.ID, oldStatus),
TaskStorageKey(task.ID, task.Status),
)

delete(n.tasks, task.ID)
Expand All @@ -363,37 +444,8 @@ func (n *Engine) CancelTaskByUser(user *model.User, taskID string) (bool, error)
return true, nil
}

func TaskStorageKey(id string, status avsproto.TaskStatus) string {
return fmt.Sprintf(
"t:%s:%s",
TaskStatusToStorageKey(status),
id,
)
}

func TaskUserKey(t *model.Task) string {
return fmt.Sprintf(
"u:%s",
t.Key(),
)
}

func TaskStatusToStorageKey(v avsproto.TaskStatus) string {
switch v {
case 1:
return "c"
case 2:
return "f"
case 3:
return "l"
case 4:
return "x"
}

return "a"
}

func (n *Engine) NewTaskID() (string, error) {
// A global counter for the task engine
func (n *Engine) NewSeqID() (string, error) {
num := uint64(0)
var err error

Expand Down
1 change: 1 addition & 0 deletions core/taskengine/engine_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package taskengine
4 changes: 2 additions & 2 deletions core/taskengine/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (c *ContractProcessor) Perform(job *apqueue.Job) error {

defer func() {
updates := map[string][]byte{}
updates[TaskStorageKey(task.ID, avsproto.TaskStatus_Executing)], err = task.ToJSON()
updates[TaskUserKey(task)] = []byte(fmt.Sprintf("%d", task.Status))
updates[string(TaskStorageKey(task.ID, avsproto.TaskStatus_Executing))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", task.Status))

if err = c.db.BatchWrite(updates); err == nil {
c.db.Move(
Expand Down
Loading

0 comments on commit 4a816fe

Please sign in to comment.