Skip to content

Commit

Permalink
update task management behavior
Browse files Browse the repository at this point in the history
- create task with specific wallet
- list task by the smart wallet
- return all task data in list
  • Loading branch information
v9n committed Nov 11, 2024
1 parent 6b13ec3 commit aaa7740
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 363 deletions.
9 changes: 7 additions & 2 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,21 @@ func (r *RpcServer) CreateTask(ctx context.Context, taskPayload *avsproto.Create
}, nil
}

func (r *RpcServer) ListTasks(ctx context.Context, _ *avsproto.ListTasksReq) (*avsproto.ListTasksResp, error) {
func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksReq) (*avsproto.ListTasksResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
}

r.config.Logger.Info("Process List Task",
"user", user.Address.String(),
"smart_wallet_address", payload.SmartWalletAddress,
)
tasks, err := r.engine.ListTasksByUser(user)
tasks, err := r.engine.ListTasksByUser(user, payload)

if err != nil {
return nil, err
}

return &avsproto.ListTasksResp{
Tasks: tasks,
Expand Down
58 changes: 42 additions & 16 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWal

updates := map[string][]byte{}

updates[string(WalletStorageKey(wallet))], err = wallet.ToJSON()
updates[string(WalletStorageKey(user.Address, sender.Hex()))], err = wallet.ToJSON()

if err = n.db.BatchWrite(updates); err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_StorageWriteError), "cannot update key to storage")
Expand All @@ -216,16 +216,19 @@ func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWal
}, nil
}

// CreateTask records submission data
func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskReq) (*model.Task, error) {
var err error
salt := big.NewInt(0)

user.SmartAccountAddress, err = aa.GetSenderAddress(rpcConn, user.Address, salt)
if taskPayload.SmartWalletAddress != "" {
if !ValidWalletAddress(taskPayload.SmartWalletAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}

if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_SmartWalletRpcError), "cannot get smart wallet address")
if valid, _ := ValidWalletOwner(n.db, user, common.HexToAddress(taskPayload.SmartWalletAddress)); !valid {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}
}

task, err := model.NewTaskFromProtobuf(user, taskPayload)

if err != nil {
Expand Down Expand Up @@ -343,21 +346,46 @@ func (n *Engine) AggregateChecksResult(address string, ids []string) error {
return nil
}

func (n *Engine) ListTasksByUser(user *model.User) ([]*avsproto.ListTasksResp_TaskItemResp, error) {
taskIDs, err := n.db.GetByPrefix(UserTaskStoragePrefix(user.Address))
func (n *Engine) ListTasksByUser(user *model.User, payload *avsproto.ListTasksReq) ([]*avsproto.Task, error) {
// by default show the task from the default smart wallet, if proving we look into that wallet specifically
owner := user.SmartAccountAddress
if payload.SmartWalletAddress != "" {
if !ValidWalletAddress(payload.SmartWalletAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}

if valid, _ := ValidWalletOwner(n.db, user, common.HexToAddress(payload.SmartWalletAddress)); !valid {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}

smartWallet := common.HexToAddress(payload.SmartWalletAddress)
owner = &smartWallet
}

taskIDs, err := n.db.GetByPrefix(SmartWalletTaskStoragePrefix(user.Address, *owner))

if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), "storage is not ready")
}

tasks := make([]*avsproto.ListTasksResp_TaskItemResp, len(taskIDs))
tasks := make([]*avsproto.Task, len(taskIDs))
for i, kv := range taskIDs {

status, _ := strconv.Atoi(string(kv.Value))
tasks[i] = &avsproto.ListTasksResp_TaskItemResp{
Id: string(model.TaskKeyToId(kv.Key[2:])),
Status: avsproto.TaskStatus(status),
taskID := string(model.TaskKeyToId(kv.Key[2:]))
taskRawByte, err := n.db.GetKey(TaskStorageKey(taskID, avsproto.TaskStatus(status)))
if err != nil {
continue
}

task := &model.Task{
ID: taskID,
Owner: user.Address.Hex(),
}
if err := task.FromStorageData(taskRawByte); err != nil {
continue
}

tasks[i], _ = task.ToProtoBuf()
}

return tasks, nil
Expand All @@ -376,9 +404,7 @@ func (n *Engine) GetTaskByUser(user *model.User, taskID string) (*model.Task, er
}
status, _ := strconv.Atoi(string(rawStatus))

taskRawByte, err := n.db.GetKey([]byte(
TaskStorageKey(taskID, avsproto.TaskStatus(status)),
))
taskRawByte, err := n.db.GetKey(TaskStorageKey(taskID, avsproto.TaskStatus(status)))

if err != nil {
taskRawByte, err = n.db.GetKey([]byte(
Expand Down
15 changes: 10 additions & 5 deletions core/taskengine/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import (
"github.com/ethereum/go-ethereum/common"
)

// Prefix
func UserTaskStoragePrefix(address common.Address) []byte {
return []byte(fmt.Sprintf("u:%s", strings.ToLower(address.String())))
}

func SmartWalletTaskStoragePrefix(owner common.Address, smartWalletAddress common.Address) []byte {
return []byte(fmt.Sprintf("u:%s:%s", strings.ToLower(owner.Hex()), strings.ToLower(smartWalletAddress.Hex())))
}

func TaskByStatusStoragePrefix(status avsproto.TaskStatus) []byte {
return []byte(fmt.Sprintf("t:%s:", TaskStatusToStorageKey(status)))
}
Expand All @@ -25,11 +28,11 @@ func WalletByOwnerPrefix(owner common.Address) []byte {
))
}

func WalletStorageKey(w *model.SmartWallet) string {
func WalletStorageKey(owner common.Address, smartWalletAddress string) string {
return fmt.Sprintf(
"w:%s:%s",
strings.ToLower(w.Owner.String()),
strings.ToLower(w.Address.String()),
strings.ToLower(owner.Hex()),
strings.ToLower(smartWalletAddress),
)
}

Expand All @@ -43,7 +46,9 @@ func TaskStorageKey(id string, status avsproto.TaskStatus) []byte {

func TaskUserKey(t *model.Task) []byte {
return []byte(fmt.Sprintf(
"u:%s",
"u:%s:%s:%s",
strings.ToLower(t.Owner),
strings.ToLower(t.SmartWalletAddress),
t.Key(),
))
}
Expand Down
26 changes: 26 additions & 0 deletions core/taskengine/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package taskengine

import (
"github.com/AvaProtocol/ap-avs/model"
"github.com/AvaProtocol/ap-avs/storage"
"github.com/ethereum/go-ethereum/common"
)

func ValidWalletAddress(address string) bool {
return common.IsHexAddress(address)
}

func ValidWalletOwner(db storage.Storage, u *model.User, smartWalletAddress common.Address) (bool, error) {
// the smart wallet adress is the default one
if u.Address.Hex() == smartWalletAddress.Hex() {
return true, nil
}

// not default, look up in our storage
exists, err := db.Exist([]byte(WalletStorageKey(u.Address, smartWalletAddress.Hex())))
if exists {
return true, nil
}

return false, err
}
9 changes: 7 additions & 2 deletions examples/example.js
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ async function listTask(owner, token) {
const metadata = new grpc.Metadata();
metadata.add("authkey", token);

const result = await asyncRPC(client, "ListTasks", {}, metadata);
const result = await asyncRPC(client, "ListTasks", {
smart_wallet_address: process.argv[3]
}, metadata);

console.log("Tasks that has created by", owner, "\n", result);
console.log("Tasks that has created by", process.argv[3], "\n", result);
}

async function getTask(owner, token, taskId) {
Expand Down Expand Up @@ -353,6 +355,9 @@ async function scheduleERC20TransferJob(owner, token, taskCondition) {
client,
'CreateTask',
{
// salt = 0
//smart_wallet_address: "0x5Df343de7d99fd64b2479189692C1dAb8f46184a",
smart_wallet_address: "0xdD85693fd14b522a819CC669D6bA388B4FCd158d",
actions: [{
task_type: TaskType.CONTRACTEXECUTIONTASK,
// id need to be unique
Expand Down
27 changes: 19 additions & 8 deletions model/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"fmt"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/oklog/ulid/v2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)
Expand All @@ -19,7 +22,7 @@ type Task struct {

// The smartwallet that deploy this, it is important to store this because
// there are maybe more than one AA per owner
SmartAccountAddress string `json:"smart_account_address"`
SmartWalletAddress string `json:"smart_wallet_address"`

// trigger defined whether the task can be executed
// trigger can be time based, price based, or contract call based
Expand Down Expand Up @@ -53,16 +56,23 @@ func NewTaskFromProtobuf(user *User, body *avsproto.CreateTaskReq) (*Task, error
}

owner := user.Address
aaAddress := user.SmartAccountAddress
aaAddress := *user.SmartAccountAddress

if body.SmartWalletAddress != "" {
if !common.IsHexAddress(body.SmartWalletAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}
aaAddress = common.HexToAddress(body.SmartWalletAddress)
}

taskID := GenerateTaskID()

t := &Task{
ID: taskID,

// convert back to string with EIP55-compliant
Owner: owner.Hex(),
SmartAccountAddress: aaAddress.Hex(),
Owner: owner.Hex(),
SmartWalletAddress: aaAddress.Hex(),

Trigger: body.Trigger,
Nodes: body.Actions,
Expand Down Expand Up @@ -96,8 +106,8 @@ func (t *Task) Validate() bool {
// Convert to protobuf
func (t *Task) ToProtoBuf() (*avsproto.Task, error) {
protoTask := avsproto.Task{
Owner: t.Owner,
SmartAccountAddress: t.SmartAccountAddress,
Owner: t.Owner,
SmartWalletAddress: t.SmartWalletAddress,

Id: &avsproto.UUID{
Bytes: t.ID,
Expand Down Expand Up @@ -125,7 +135,7 @@ func (t *Task) FromStorageData(body []byte) error {

// Generate a global unique key for the task in our system
func (t *Task) Key() []byte {
return []byte(fmt.Sprintf("%s:%s", t.Owner, t.ID))
return []byte(t.ID)
}

func (t *Task) SetCompleted() {
Expand Down Expand Up @@ -158,6 +168,7 @@ func (t *Task) AppendExecution(epoch int64, userOpHash string, err error) {

// Given a task key generated from Key(), extract the ID part
func TaskKeyToId(key []byte) []byte {
// <43-byte>:<43-byte>:
// the first 43 bytes is owner address
return key[43:]
return key[86:]
}
Loading

0 comments on commit aaa7740

Please sign in to comment.