Skip to content

Commit

Permalink
Support verifying webhook payloads using secret key (#1)
Browse files Browse the repository at this point in the history
Signed-off-by: Tamal Saha <[email protected]>
  • Loading branch information
tamalsaha authored Jul 4, 2020
1 parent 07a51f7 commit 0448c18
Showing 1 changed file with 40 additions and 68 deletions.
108 changes: 40 additions & 68 deletions cmds/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
Expand All @@ -48,6 +47,7 @@ const (
)

var (
secretKey = ""
certDir = "certs"
email = "[email protected]"
hosts = []string{"gh-ci-webhook.appscode.ninja"}
Expand All @@ -68,6 +68,7 @@ func NewCmdRun() *cobra.Command {
},
}

cmd.Flags().StringVar(&secretKey, "secret-key", secretKey, "Secret key to verify webhook payloads")
cmd.Flags().StringVar(&certDir, "cert-dir", certDir, "Directory where certs are stored")
cmd.Flags().StringVar(&email, "email", email, "Email used by Let's Encrypt to notify about problems with issued certificates")
cmd.Flags().StringSliceVar(&hosts, "hosts", hosts, "Hosts for which certificate will be issued")
Expand Down Expand Up @@ -132,9 +133,10 @@ func runServer() error {
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "Hello, TLS user! Your config: %+v", r.TLS)
}).Methods(http.MethodGet)
r.HandleFunc("/check-ci-runs", CheckRunsHandler).Methods(http.MethodPost)
r.HandleFunc("/check-pr-runs", CheckPrRepoRunsHandler).Methods(http.MethodPost)
r.HandleFunc("/pr", PullRequestsHandler).Methods(http.MethodPost)
r.HandleFunc("/check-ci-runs", serveHTTP).Methods(http.MethodPost)
r.HandleFunc("/check-pr-runs", serveHTTP).Methods(http.MethodPost)
r.HandleFunc("/pr", serveHTTP).Methods(http.MethodPost)
r.Use()

if !enableSSL {
addr := fmt.Sprintf(":%d", port)
Expand Down Expand Up @@ -282,25 +284,41 @@ func openPR(gh *github.Client, sh *shell.Session, event PREvent) error {
return err
}

func CheckPrRepoRunsHandler(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1048576)

dec := json.NewDecoder(r.Body)
// dec.DisallowUnknownFields() // Read all

var event github.CheckRunEvent
err := dec.Decode(&event)
func serveHTTP(w http.ResponseWriter, r *http.Request) {
payload, err := github.ValidatePayload(r, []byte(secretKey))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
event, err := github.ParseWebHook(github.WebHookType(r), payload)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

err = dec.Decode(&struct{}{})
if err != io.EOF {
http.Error(w, msgSingleJSON, http.StatusBadRequest)
query := r.URL.Query()
switch event := event.(type) {
case *github.CheckRunEvent:
if _, ok := query["pr-repo"]; ok {
handleCIRepoEvent(event, query)
return
}
if _, ok := query["ci-repo"]; ok {
handlePRRepoEvent(event, query)
return
}
http.Error(w, "unsupported event", http.StatusOK)
return
case *github.PullRequestEvent:
handlePREvent(event, query)
return
default:
http.Error(w, "unsupported event", http.StatusOK)
return
}
}

func handlePRRepoEvent(event *github.CheckRunEvent, query url.Values) {
if event.GetCheckRun().GetApp().GetSlug() == "github-actions" &&
event.GetCheckRun().GetName() == "Build" &&
event.GetCheckRun().GetStatus() == "completed" &&
Expand All @@ -318,7 +336,7 @@ func CheckPrRepoRunsHandler(w http.ResponseWriter, r *http.Request) {

prs <- PREvent{
PRRepoURL: strings.TrimPrefix(event.GetRepo().GetHTMLURL(), "https://"),
TestRepoURL: strings.TrimPrefix(r.URL.Query().Get("test-repo"), "https://"),
TestRepoURL: strings.TrimPrefix(query.Get("ci-repo"), "https://"),
PRNumber: event.GetCheckRun().PullRequests[0].GetNumber(),
PRTitle: pr.GetTitle(),
PRState: pr.GetState(),
Expand All @@ -329,27 +347,9 @@ func CheckPrRepoRunsHandler(w http.ResponseWriter, r *http.Request) {
}
}

func CheckRunsHandler(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1048576)

dec := json.NewDecoder(r.Body)
// dec.DisallowUnknownFields() // Read all

var event github.CheckRunEvent
err := dec.Decode(&event)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

err = dec.Decode(&struct{}{})
if err != io.EOF {
http.Error(w, msgSingleJSON, http.StatusBadRequest)
return
}

func handleCIRepoEvent(event *github.CheckRunEvent, query url.Values) {
if event.GetCheckRun().GetApp().GetSlug() == "github-actions" {
owner, repo := lib.ParseRepoURL(r.URL.Query().Get("pr-repo"))
owner, repo := lib.ParseRepoURL(query.Get("pr-repo"))
ref := strings.Split(event.GetCheckRun().PullRequests[0].GetHead().GetRef(), "@")[0] // branch name matches pr repo's sha

var state string
Expand Down Expand Up @@ -379,51 +379,23 @@ func CheckRunsHandler(w http.ResponseWriter, r *http.Request) {
}
fmt.Println(sr)
}

encoder := json.NewEncoder(w)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")
err = encoder.Encode(event)
if err != nil {
panic(err)
}
}

func PullRequestsHandler(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1048576)

dec := json.NewDecoder(r.Body)
// dec.DisallowUnknownFields() // Read all

var event github.PullRequestEvent
err := dec.Decode(&event)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

err = dec.Decode(&struct{}{})
if err != io.EOF {
http.Error(w, msgSingleJSON, http.StatusBadRequest)
return
}

actions := lib.GetQueryParameter(r.URL.Query(), "actions")
func handlePREvent(event *github.PullRequestEvent, query url.Values) {
actions := lib.GetQueryParameter(query, "actions")
if actions.Len() == 0 {
actions = sets.NewString("opened", "synchronize", "closed", "reopened")
}
if actions.Has(event.GetAction()) {
prs <- PREvent{
PRRepoURL: strings.TrimPrefix(event.GetRepo().GetHTMLURL(), "https://"),
TestRepoURL: strings.TrimPrefix(r.URL.Query().Get("test-repo"), "https://"),
TestRepoURL: strings.TrimPrefix(query.Get("ci-repo"), "https://"),
PRNumber: event.GetPullRequest().GetNumber(),
PRTitle: event.GetPullRequest().GetTitle(),
PRState: event.GetPullRequest().GetState(),
PRMerged: event.GetPullRequest().GetMerged(),
HeadRef: event.GetPullRequest().GetHead().GetRef(),
HeadSHA: event.GetPullRequest().GetHead().GetSHA(),
}

_, _ = w.Write([]byte("queued"))
}
}

0 comments on commit 0448c18

Please sign in to comment.