Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allowedDomains to Microsoft connector. #3515

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions connector/microsoft/microsoft.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
DomainHint string `json:"domainHint"`

Scopes []string `json:"scopes"` // defaults to scopeUser (user.read)

AllowedDomains []string `json:"allowedDomains"`
}

// Open returns a strategy for logging in through Microsoft.
Expand All @@ -83,6 +85,7 @@
promptType: c.PromptType,
domainHint: c.DomainHint,
scopes: c.Scopes,
allowedDomains: c.AllowedDomains,
}

if m.apiURL == "" {
Expand Down Expand Up @@ -138,6 +141,7 @@
promptType string
domainHint string
scopes []string
allowedDomains []string
}

func (c *microsoftConnector) isOrgTenant() bool {
Expand Down Expand Up @@ -217,6 +221,11 @@
user.Email = strings.ToLower(user.Email)
}

// Check if the email's domain is in the allowed list
if !c.isAllowedDomain(user.Email) {
return identity, fmt.Errorf("email (%s) domain not allowed", user.Email)
}

identity = connector.Identity{
UserID: user.ID,
Username: user.Name,
Expand Down Expand Up @@ -531,3 +540,22 @@
}
return e.error + ": " + e.errorDescription
}

func (c *microsoftConnector) isAllowedDomain(email string) bool {

Check failure on line 544 in connector/microsoft/microsoft.go

View workflow job for this annotation

GitHub Actions / Lint

unnecessary leading newline (whitespace)

if len(c.allowedDomains) == 0 {
return true
}

parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := parts[1]
for _, d := range c.allowedDomains {
if d == domain {
return true
}
}
return false
}
61 changes: 61 additions & 0 deletions connector/microsoft/microsoft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package microsoft
import (
"encoding/json"
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -119,6 +120,66 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
expectEquals(t, identity.Groups, []string{"a", "b"})
}

func TestDomainNotAllowed(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {
data: user{ID: "S56767889", Name: "Jane Doe", Email: "[email protected]"},
},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, allowedDomains: []string{"dcode.tech"}}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)

assert.Error(t, err, "email ([email protected]) domain not allowed")
assert.Equal(t, connector.Identity{}, identity)
}

func TestDomainListAllowed(t *testing.T) {
testCases := []struct {
email string
allowed bool
domain string
}{
{"[email protected]", true, "dcode.tech"}, // Allowed domain
{"[email protected]", true, "example.com"}, // Allowed domain
{"[email protected]", false, "otherdomain.com"}, // Not allowed domain
}

for _, tc := range testCases {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {
data: user{ID: "S56767889", Name: "John Doe", Email: tc.email},
},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()

req, _ := http.NewRequest("GET", s.URL, nil)

// Setup the microsoftConnector with allowed domains
c := microsoftConnector{
apiURL: s.URL,
graphURL: s.URL,
tenant: tenant,
allowedDomains: []string{"dcode.tech", "example.com"},
}

identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)

if tc.allowed {
assert.NoError(t, err, "Expected no error for allowed domain: "+tc.domain)
assert.NotEqual(t, connector.Identity{}, identity, "Expected a non-empty identity struct")
} else {
assert.Error(t, err, "Expected error for non-allowed domain: "+tc.email)
assert.Equal(t, connector.Identity{}, identity, "Expected an empty identity struct")
}
}
}

func newTestServer(responses map[string]testResponse) *httptest.Server {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response, found := responses[r.RequestURI]
Expand Down
Loading