Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

August-W: Made ServiceProvider more configurable without needing to extend it #20

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
13 changes: 3 additions & 10 deletions examples/sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,9 @@
from tests.idp.base import CERTIFICATE as IDP_CERTIFICATE
from tests.sp.base import CERTIFICATE, PRIVATE_KEY


class ExampleServiceProvider(ServiceProvider):
def get_logout_return_url(self):
return url_for('index', _external=True)

def get_default_login_return_url(self):
return url_for('index', _external=True)


sp = ExampleServiceProvider()
sp = ServiceProvider()
sp.default_login_return_endpoint = 'index'
sp.logout_return_endpoint = 'index'

app = Flask(__name__)
app.debug = True
Expand Down
51 changes: 43 additions & 8 deletions flask_saml2/sp/sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ class ServiceProvider:
#: The name of the blueprint to generate.
blueprint_name = 'flask_saml2_sp'

#: Set this to http or https
scheme = 'http'

#: Set these to your desired endpoints
logout_return_endpoint = None
default_login_return_endpoint = None
acs_redirect_endpoint = None

"""
Set this value to override the default metadata return value
of :meth: `get_sp_entity_id`. By setting this, you can return
only the entity_id value, rather than the url to the full metadata xml.
"""
entity_id = None

def login_successful(
self,
auth_data: AuthData,
Expand Down Expand Up @@ -83,7 +98,10 @@ def get_sp_entity_id(self) -> str:

See :func:`get_metadata_url`.
"""
return self.get_metadata_url()
if self.entity_id is None:
return self.get_metadata_url()
else:
return self.entity_id

def get_sp_certificate(self) -> Optional[X509]:
"""Get the public certificate for this SP."""
Expand Down Expand Up @@ -156,6 +174,8 @@ def get_metadata_url(self) -> str:
def get_default_login_return_url(self) -> Optional[str]:
"""The default URL to redirect users to once the have logged in.
"""
if self.default_login_return_endpoint is not None:
return url_for(self.default_login_return_endpoint, _external=True)
return None

def get_login_return_url(self) -> Optional[str]:
Expand All @@ -177,6 +197,8 @@ def get_login_return_url(self) -> Optional[str]:
def get_logout_return_url(self) -> Optional[str]:
"""The URL to redirect users to once they have logged out.
"""
if self.logout_return_endpoint is not None:
return url_for(self.logout_return_endpoint, _external=True)
return None

def is_valid_redirect_url(self, url: str) -> str:
Expand Down Expand Up @@ -306,22 +328,35 @@ def get_metadata_context(self) -> dict:
'contacts': [],
}

def create_blueprint(self) -> Blueprint:
def get_scheme(self) -> str:
return self.scheme

def get_acs_redirect_endpoint(self) -> str:
return self.acs_redirect_endpoint

# With acs_redirect_url, you can set the url that the Access Consumer Service redirects to upon successful login
# This is unnecessary if you expect a "relay_state" parameter in the SAML request to the ACS
def create_blueprint(self, login_endpoint='/login/', login_idp_endpoint='/login/idp/',
logout_endpoint='/logout/', acs_endpoint='/acs/', sls_endpoint='/sls/',
metadata_endpoint='/metadata.xml', scheme='http') -> Blueprint:
Comment on lines +339 to +341
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m far from sure about this, but shouldn’t the new entity_id not also be passed as an argument here, same as the scheme? (And then of course also something at line 345)

Suggested change
def create_blueprint(self, login_endpoint='/login/', login_idp_endpoint='/login/idp/',
logout_endpoint='/logout/', acs_endpoint='/acs/', sls_endpoint='/sls/',
metadata_endpoint='/metadata.xml', scheme='http') -> Blueprint:
def create_blueprint(self, login_endpoint='/login/', login_idp_endpoint='/login/idp/',
logout_endpoint='/logout/', acs_endpoint='/acs/', sls_endpoint='/sls/',
metadata_endpoint='/metadata.xml', scheme='http', entity_id=None) -> Blueprint:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that. The main reason I didn't pass entity_id there is that it seemed to be a little outside the scope of what this create_blueprint function should do. To me, it makes sense to set the scheme together with the various endpoints, as they are related. But you could make a case that even scheme doesn't belong here, as it is setting a value in sp rather than in idp_bp.

I'm open to including entity_id or removing scheme. Maybe creating a separate config_and_create_blueprint function which sets the entity_id and scheme and then calls create_blueprint. Or other ideas? Not sure what the cleanest solution is.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah, yeah, got it. In that case I’m out of my depth, I wouldn’t know what the cleanest solution would be I’m afraid.


"""Create a Flask :class:`flask.Blueprint` for this Service Provider.
"""
self.scheme = scheme

idp_bp = Blueprint(self.blueprint_name, 'flask_saml2.sp', template_folder='templates')

idp_bp.add_url_rule('/login/', view_func=Login.as_view(
idp_bp.add_url_rule(login_endpoint, view_func=Login.as_view(
'login', sp=self))
idp_bp.add_url_rule('/login/idp/', view_func=LoginIdP.as_view(
idp_bp.add_url_rule(login_idp_endpoint, view_func=LoginIdP.as_view(
'login_idp', sp=self))
idp_bp.add_url_rule('/logout/', view_func=Logout.as_view(
idp_bp.add_url_rule(logout_endpoint, view_func=Logout.as_view(
'logout', sp=self))
idp_bp.add_url_rule('/acs/', view_func=AssertionConsumer.as_view(
idp_bp.add_url_rule(acs_endpoint, view_func=AssertionConsumer.as_view(
'acs', sp=self))
idp_bp.add_url_rule('/sls/', view_func=SingleLogout.as_view(
idp_bp.add_url_rule(sls_endpoint, view_func=SingleLogout.as_view(
'sls', sp=self))
idp_bp.add_url_rule('/metadata.xml', view_func=Metadata.as_view(
idp_bp.add_url_rule(metadata_endpoint, view_func=Metadata.as_view(
'metadata', sp=self))

idp_bp.register_error_handler(CannotHandleAssertion, CannotHandleAssertionView.as_view(
Expand Down
8 changes: 6 additions & 2 deletions flask_saml2/sp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def get(self):
handler = self.sp.get_default_idp_handler()
login_next = self.sp.get_login_return_url()
if handler:
return redirect(url_for('.login_idp', entity_id=handler.entity_id, next=login_next))
return redirect(url_for('.login_idp', entity_id=handler.entity_id, next=login_next,
_scheme=self.sp.get_scheme(), _external=True))
return self.sp.render_template(
'flask_saml2_sp/choose_idp.html',
login_next=login_next,
Expand Down Expand Up @@ -79,7 +80,10 @@ def do_logout(self, handler):
class AssertionConsumer(SAML2View):
def post(self):
saml_request = request.form['SAMLResponse']
relay_state = request.form['RelayState']
if self.sp.get_acs_redirect_endpoint() is None:
relay_state = request.form['RelayState']
else:
relay_state = self.sp.make_absolute_url(self.sp.get_acs_redirect_endpoint())

for handler in self.sp.get_idp_handlers():
try:
Expand Down