diff --git a/cmds/webhook.go b/cmds/webhook.go index ab406a28..fa341d6e 100644 --- a/cmds/webhook.go +++ b/cmds/webhook.go @@ -21,11 +21,10 @@ import ( "bytes" "context" "crypto/tls" - "encoding/json" "fmt" - "io" "io/ioutil" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -48,6 +47,7 @@ const ( ) var ( + secretKey = "" certDir = "certs" email = "tamal@appscode.com" hosts = []string{"gh-ci-webhook.appscode.ninja"} @@ -68,6 +68,7 @@ func NewCmdRun() *cobra.Command { }, } + cmd.Flags().StringVar(&secretKey, "secret-key", secretKey, "Secret key to verify webhook payloads") cmd.Flags().StringVar(&certDir, "cert-dir", certDir, "Directory where certs are stored") cmd.Flags().StringVar(&email, "email", email, "Email used by Let's Encrypt to notify about problems with issued certificates") cmd.Flags().StringSliceVar(&hosts, "hosts", hosts, "Hosts for which certificate will be issued") @@ -132,9 +133,10 @@ func runServer() error { r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprintf(w, "Hello, TLS user! Your config: %+v", r.TLS) }).Methods(http.MethodGet) - r.HandleFunc("/check-ci-runs", CheckRunsHandler).Methods(http.MethodPost) - r.HandleFunc("/check-pr-runs", CheckPrRepoRunsHandler).Methods(http.MethodPost) - r.HandleFunc("/pr", PullRequestsHandler).Methods(http.MethodPost) + r.HandleFunc("/check-ci-runs", serveHTTP).Methods(http.MethodPost) + r.HandleFunc("/check-pr-runs", serveHTTP).Methods(http.MethodPost) + r.HandleFunc("/pr", serveHTTP).Methods(http.MethodPost) + r.Use() if !enableSSL { addr := fmt.Sprintf(":%d", port) @@ -282,25 +284,41 @@ func openPR(gh *github.Client, sh *shell.Session, event PREvent) error { return err } -func CheckPrRepoRunsHandler(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 1048576) - - dec := json.NewDecoder(r.Body) - // dec.DisallowUnknownFields() // Read all - - var event github.CheckRunEvent - err := dec.Decode(&event) +func serveHTTP(w http.ResponseWriter, r *http.Request) { + payload, err := github.ValidatePayload(r, []byte(secretKey)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + event, err := github.ParseWebHook(github.WebHookType(r), payload) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - err = dec.Decode(&struct{}{}) - if err != io.EOF { - http.Error(w, msgSingleJSON, http.StatusBadRequest) + query := r.URL.Query() + switch event := event.(type) { + case *github.CheckRunEvent: + if _, ok := query["pr-repo"]; ok { + handleCIRepoEvent(event, query) + return + } + if _, ok := query["ci-repo"]; ok { + handlePRRepoEvent(event, query) + return + } + http.Error(w, "unsupported event", http.StatusOK) + return + case *github.PullRequestEvent: + handlePREvent(event, query) + return + default: + http.Error(w, "unsupported event", http.StatusOK) return } +} +func handlePRRepoEvent(event *github.CheckRunEvent, query url.Values) { if event.GetCheckRun().GetApp().GetSlug() == "github-actions" && event.GetCheckRun().GetName() == "Build" && event.GetCheckRun().GetStatus() == "completed" && @@ -318,7 +336,7 @@ func CheckPrRepoRunsHandler(w http.ResponseWriter, r *http.Request) { prs <- PREvent{ PRRepoURL: strings.TrimPrefix(event.GetRepo().GetHTMLURL(), "https://"), - TestRepoURL: strings.TrimPrefix(r.URL.Query().Get("test-repo"), "https://"), + TestRepoURL: strings.TrimPrefix(query.Get("ci-repo"), "https://"), PRNumber: event.GetCheckRun().PullRequests[0].GetNumber(), PRTitle: pr.GetTitle(), PRState: pr.GetState(), @@ -329,27 +347,9 @@ func CheckPrRepoRunsHandler(w http.ResponseWriter, r *http.Request) { } } -func CheckRunsHandler(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 1048576) - - dec := json.NewDecoder(r.Body) - // dec.DisallowUnknownFields() // Read all - - var event github.CheckRunEvent - err := dec.Decode(&event) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - err = dec.Decode(&struct{}{}) - if err != io.EOF { - http.Error(w, msgSingleJSON, http.StatusBadRequest) - return - } - +func handleCIRepoEvent(event *github.CheckRunEvent, query url.Values) { if event.GetCheckRun().GetApp().GetSlug() == "github-actions" { - owner, repo := lib.ParseRepoURL(r.URL.Query().Get("pr-repo")) + owner, repo := lib.ParseRepoURL(query.Get("pr-repo")) ref := strings.Split(event.GetCheckRun().PullRequests[0].GetHead().GetRef(), "@")[0] // branch name matches pr repo's sha var state string @@ -379,43 +379,17 @@ func CheckRunsHandler(w http.ResponseWriter, r *http.Request) { } fmt.Println(sr) } - - encoder := json.NewEncoder(w) - encoder.SetEscapeHTML(false) - encoder.SetIndent("", " ") - err = encoder.Encode(event) - if err != nil { - panic(err) - } } -func PullRequestsHandler(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 1048576) - - dec := json.NewDecoder(r.Body) - // dec.DisallowUnknownFields() // Read all - - var event github.PullRequestEvent - err := dec.Decode(&event) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - err = dec.Decode(&struct{}{}) - if err != io.EOF { - http.Error(w, msgSingleJSON, http.StatusBadRequest) - return - } - - actions := lib.GetQueryParameter(r.URL.Query(), "actions") +func handlePREvent(event *github.PullRequestEvent, query url.Values) { + actions := lib.GetQueryParameter(query, "actions") if actions.Len() == 0 { actions = sets.NewString("opened", "synchronize", "closed", "reopened") } if actions.Has(event.GetAction()) { prs <- PREvent{ PRRepoURL: strings.TrimPrefix(event.GetRepo().GetHTMLURL(), "https://"), - TestRepoURL: strings.TrimPrefix(r.URL.Query().Get("test-repo"), "https://"), + TestRepoURL: strings.TrimPrefix(query.Get("ci-repo"), "https://"), PRNumber: event.GetPullRequest().GetNumber(), PRTitle: event.GetPullRequest().GetTitle(), PRState: event.GetPullRequest().GetState(), @@ -423,7 +397,5 @@ func PullRequestsHandler(w http.ResponseWriter, r *http.Request) { HeadRef: event.GetPullRequest().GetHead().GetRef(), HeadSHA: event.GetPullRequest().GetHead().GetSHA(), } - - _, _ = w.Write([]byte("queued")) } }