diff --git a/pkg/api/graphql_client.go b/pkg/api/graphql_client.go index 405b10f..3317eef 100644 --- a/pkg/api/graphql_client.go +++ b/pkg/api/graphql_client.go @@ -10,6 +10,7 @@ import ( "net/http" "strings" + "github.com/cli/go-gh/v2/pkg/auth" graphql "github.com/cli/shurcooL-graphql" ) @@ -171,8 +172,8 @@ func graphQLEndpoint(host string) string { if isGarage(host) { return fmt.Sprintf("https://%s/api/graphql", host) } - host = normalizeHostname(host) - if isEnterprise(host) { + host = auth.NormalizeHostname(host) + if auth.IsEnterprise(host) { return fmt.Sprintf("https://%s/api/graphql", host) } if strings.EqualFold(host, localhost) { diff --git a/pkg/api/http_client.go b/pkg/api/http_client.go index d20cb37..1d7c517 100644 --- a/pkg/api/http_client.go +++ b/pkg/api/http_client.go @@ -134,36 +134,6 @@ func isGarage(host string) bool { return strings.EqualFold(host, "garage.github.com") } -func isEnterprise(host string) bool { - return host != github && host != localhost && !isTenancy(host) -} - -// tenancyHost is the domain name of a tenancy GitHub instance. -const tenancyHost = "ghe.com" - -func isTenancy(host string) bool { - return strings.HasSuffix(host, "."+tenancyHost) -} - -func normalizeHostname(hostname string) string { - hostname = strings.ToLower(hostname) - if strings.HasSuffix(hostname, "."+github) { - return github - } - if strings.HasSuffix(hostname, "."+localhost) { - return localhost - } - // This has been copied over from the cli/cli NormalizeHostname function - // to ensure compatible behaviour but we don't fully understand when or - // why it would be useful here. We can't see what harm will come of - // duplicating the logic. - if before, found := strings.CutSuffix(hostname, "."+tenancyHost); found { - idx := strings.LastIndex(before, ".") - return fmt.Sprintf("%s.%s", before[idx+1:], tenancyHost) - } - return hostname -} - type headerRoundTripper struct { headers map[string]string host string diff --git a/pkg/api/http_client_test.go b/pkg/api/http_client_test.go index dee6b02..9fe0934 100644 --- a/pkg/api/http_client_test.go +++ b/pkg/api/http_client_test.go @@ -157,73 +157,6 @@ func TestNewHTTPClient(t *testing.T) { } } -func TestIsEnterprise(t *testing.T) { - tests := []struct { - name string - host string - wantOut bool - }{ - { - name: "github", - host: "github.com", - wantOut: false, - }, - { - name: "localhost", - host: "github.localhost", - wantOut: false, - }, - { - name: "enterprise", - host: "mygithub.com", - wantOut: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out := isEnterprise(tt.host) - assert.Equal(t, tt.wantOut, out) - }) - } -} - -func TestNormalizeHostname(t *testing.T) { - tests := []struct { - name string - host string - wantHost string - }{ - { - name: "github domain", - host: "test.github.com", - wantHost: "github.com", - }, - { - name: "capitalized", - host: "GitHub.com", - wantHost: "github.com", - }, - { - name: "localhost domain", - host: "test.github.localhost", - wantHost: "github.localhost", - }, - { - name: "enterprise domain", - host: "mygithub.com", - wantHost: "mygithub.com", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - normalized := normalizeHostname(tt.host) - assert.Equal(t, tt.wantHost, normalized) - }) - } -} - type tripper struct { roundTrip func(*http.Request) (*http.Response, error) } diff --git a/pkg/api/rest_client.go b/pkg/api/rest_client.go index 2d91f70..ccd86cd 100644 --- a/pkg/api/rest_client.go +++ b/pkg/api/rest_client.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "strings" + + "github.com/cli/go-gh/v2/pkg/auth" ) // RESTClient wraps methods for the different types of @@ -159,8 +161,8 @@ func restPrefix(hostname string) string { if isGarage(hostname) { return fmt.Sprintf("https://%s/api/v3/", hostname) } - hostname = normalizeHostname(hostname) - if isEnterprise(hostname) { + hostname = auth.NormalizeHostname(hostname) + if auth.IsEnterprise(hostname) { return fmt.Sprintf("https://%s/api/v3/", hostname) } if strings.EqualFold(hostname, localhost) { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 6ab996f..a903736 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -26,6 +26,7 @@ const ( hostsKey = "hosts" localhost = "github.localhost" oauthToken = "oauth_token" + tenancyHost = "ghe.com" // TenancyHost is the domain suffix of a tenancy GitHub instance. ) // TokenForHost retrieves an authentication token and the source of that token for the specified @@ -61,7 +62,7 @@ func TokenFromEnvOrConfig(host string) (string, string) { } func tokenForHost(cfg *config.Config, host string) (string, string) { - host = normalizeHostname(host) + host = NormalizeHostname(host) if IsEnterprise(host) { if token := os.Getenv(ghEnterpriseToken); token != "" { return token, ghEnterpriseToken @@ -149,24 +150,24 @@ func defaultHost(cfg *config.Config) (string, string) { return github, defaultSource } -// TenancyHost is the domain name of a tenancy GitHub instance. -const tenancyHost = "ghe.com" - // IsEnterprise determines if a provided host is a GitHub Enterprise Server instance, // rather than GitHub.com or a tenancy GitHub instance. func IsEnterprise(host string) bool { - normalizedHost := normalizeHostname(host) + normalizedHost := NormalizeHostname(host) return normalizedHost != github && normalizedHost != localhost && !IsTenancy(normalizedHost) } // IsTenancy determines if a provided host is a tenancy GitHub instance, // rather than GitHub.com or a GitHub Enterprise Server instance. func IsTenancy(host string) bool { - normalizedHost := normalizeHostname(host) + normalizedHost := NormalizeHostname(host) return strings.HasSuffix(normalizedHost, "."+tenancyHost) } -func normalizeHostname(host string) string { +// NormalizeHostname ensures the host matches the values used throughout +// the rest of the codebase with respect to hostnames. These are github, +// localhost, and tenancyHost. +func NormalizeHostname(host string) string { hostname := strings.ToLower(host) if strings.HasSuffix(hostname, "."+github) { return github diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 0588a1c..e6b7b9a 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -368,7 +368,7 @@ func TestNormalizeHostname(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - normalized := normalizeHostname(tt.host) + normalized := NormalizeHostname(tt.host) assert.Equal(t, tt.wantHost, normalized) }) }