-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
118 lines (106 loc) · 3.26 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package main
import (
"context"
"github.com/ReneKroon/ttlcache/v2"
"github.com/apache/arrow/go/v7/arrow/flight"
"github.com/apache/arrow/go/v7/arrow/memory"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
cloudevents "github.com/cloudevents/sdk-go/v2"
"github.com/ehenry2/avro-flight-decisioner/internal/arrowconv"
"github.com/ehenry2/avro-flight-decisioner/internal/avroutil"
"github.com/ehenry2/avro-flight-decisioner/internal/scoring"
"google.golang.org/grpc"
"log"
"time"
)
const (
codecLoaderKey = "codecLoader"
scorerKey = "flightScorer"
)
func HandleMessage(ctx context.Context, event cloudevents.Event) {
log.Println("received message")
start := time.Now()
// pull the codec loader out of the context and load the avro codec.
codecLoader := ctx.Value(codecLoaderKey)
if codecLoader == nil {
log.Fatalln("required codec loader not in context..exiting")
}
loader, ok := codecLoader.(avroutil.AvroCodecLoader)
if !ok {
log.Fatalln("codec loader in context is not a valid AvroCodecLoader")
}
codec, err := loader.LoadCodec(event.Type())
if err != nil {
log.Fatalf("error creating avro codec: %s", err)
}
// convert from avro to generic map
datum, _, err := codec.NativeFromBinary(event.Data())
if err != nil {
log.Fatalf("error decoding from binary: %s", err)
}
data, ok := datum.(map[string]interface{})
if !ok {
log.Fatalln("could not convert datum to map")
}
// pull out the flight client
flightScorer := ctx.Value(scorerKey)
scorer, ok := flightScorer.(*scoring.FlightModelScorer)
if !ok {
log.Println("could not cast to flight client")
}
// run the scoring
_, err = scorer.ScoreModel(data)
if err != nil {
log.Fatalf("error scoring: %s", err)
}
elapsed := time.Since(start)
log.Printf("took %d milliseconds", elapsed.Milliseconds())
// convert back to map
log.Println("done")
}
func initSchemaLoader() avroutil.AvroCodecLoader {
bucket := "dqhub-test"
prefix := "not-a-prefix"
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
log.Fatalf("unable to load SDK config, %v", err)
}
client := s3.NewFromConfig(cfg)
cache := ttlcache.NewCache()
err = cache.SetTTL(10 * time.Minute)
if err != nil {
log.Fatalln(err)
}
return avroutil.NewS3AvroCodecLoader(cache, client, bucket, prefix)
}
func getFlightClient() (flight.FlightServiceClient, error) {
conn, err := grpc.Dial("127.0.0.1:9998", grpc.WithInsecure(), grpc.WithBlock())
if err != nil {
return nil, err
}
return flight.NewFlightServiceClient(conn), nil
}
func getScorer() *scoring.FlightModelScorer {
flightClient, err := getFlightClient()
if err != nil {
log.Fatalf("failed to instantiate flight client: %s", err)
}
conv := arrowconv.NewArrowConverter(memory.NewGoAllocator())
return scoring.NewFlightModelScorer(flightClient, conv)
}
func main() {
scorer := getScorer()
loader := initSchemaLoader()
log.Println("starting cloud events client")
c, err := cloudevents.NewClientHTTP()
if err != nil {
log.Fatalf("error starting cloudwatch client: %s", err)
}
log.Println("starting receiver")
// initialize s3 client here and add to context
ctx := context.WithValue(context.Background(), codecLoaderKey, loader)
ctx = context.WithValue(ctx, scorerKey, scorer)
log.Fatal(
c.StartReceiver(ctx, HandleMessage))
}