diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..de543dd45 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Ruff format +59b2965dc8f5f487ba24d148911ac2999d7bbb57 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 2b0789738..4e5c8ab76 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -36,11 +36,6 @@ jobs: # a pull request then we can checkout the head. fetch-depth: 2 - # If this run was triggered by a pull request event, then checkout - # the head of the pull request instead of the merge commit. - - run: git checkout HEAD^2 - if: ${{ github.event_name == 'pull_request' }} - # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL uses: github/codeql-action/init@v1 @@ -63,4 +58,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index eeadcadac..088b15b2a 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -15,7 +15,7 @@ jobs: strategy: max-parallel: 4 matrix: - python: [3.8, 3.9, "3.10", "3.11"] + python: [3.9, "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 diff --git a/MANIFEST.in b/MANIFEST.in index 1ca593a16..bdf2a252f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ include src/oic/py.typed +include pyproject.toml diff --git a/Makefile b/Makefile index 6ecc1dce5..d20f0c7a5 100644 --- a/Makefile +++ b/Makefile @@ -44,28 +44,10 @@ test: @pipenv run pytest $(TESTDIR) .PHONY: test -isort: - @pipenv run isort $(OICDIR) $(TESTDIR) $(OAUTH_EXAMPLE) - -check-isort: - @pipenv run isort --diff --check-only $(OICDIR) $(TESTDIR) $(OAUTH_EXAMPLE) -.PHONY: isort check-isort - -blacken: - @pipenv run black src/ tests/ oauth_example/ - -check-black: - @pipenv run black src/ tests/ oauth_example/ --check -.PHONY: blacken check-black - bandit: @pipenv run bandit -a file -r src/ oauth_example/ oidc_example/ .PHONY: bandit -check-pylama: - @pipenv run pylama $(OICDIR) $(TESTDIR) $(OAUTH_EXAMPLE) -.PHONY: check-pylama - release: @pipenv run python setup.py sdist upload -r pypi .PHONY: release diff --git a/appveyor.yml b/appveyor.yml index bb37f7cbc..c6e0fe5b2 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -3,9 +3,11 @@ image: environment: matrix: - - TOXENV: py38 - TOXENV: py39 - TOXENV: py310 + - TOXENV: py311 + - TOXENV: py312 + - TOXENV: py313 build: off diff --git a/doc/conf.py b/doc/conf.py index 7351e8368..bda2d8d6b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -2,77 +2,78 @@ import os import sys -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinxcontrib.autodoc_pydantic', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinxcontrib.autodoc_pydantic", ] -autoclass_content = 'both' # Merge the __init__ docstring into the class docstring. -autodoc_member_order = 'bysource' # Order by source ordering +autoclass_content = "both" # Merge the __init__ docstring into the class docstring. +autodoc_member_order = "bysource" # Order by source ordering autodoc_pydantic_model_show_config = True autodoc_pydantic_settings_show_json = False -templates_path = ['_templates'] +templates_path = ["_templates"] -source_suffix = '.rst' +source_suffix = ".rst" -master_doc = 'index' +master_doc = "index" -project = u'pyoidc' +project = "pyoidc" -copyright = u'2014, Roland Hedberg' +copyright = "2014, Roland Hedberg" -version = '0.1' +version = "0.1" -release = '0.1' +release = "0.1" -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] -pygments_style = 'sphinx' +pygments_style = "sphinx" html_theme_path = [alabaster.get_path()] -html_theme = 'alabaster' +html_theme = "alabaster" -html_static_path = ['_static'] +html_static_path = ["_static"] -htmlhelp_basename = 'pyoidcdoc' +htmlhelp_basename = "pyoidcdoc" html_theme_options = { - 'description': '', - 'github_button': False, - 'github_user': 'its-dirg', - 'github_repo': 'saml2testGui', - 'github_banner': False, - + "description": "", + "github_button": False, + "github_user": "its-dirg", + "github_repo": "saml2testGui", + "github_banner": False, } html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'searchbox.html', - 'donate.html', - ] + "**": [ + "about.html", + "navigation.html", + "searchbox.html", + "donate.html", + ] } -man_pages = [ - ('index', 'pyoidc', u'pyoidc Documentation', - [u'Roland Hedberg'], 1) -] +man_pages = [("index", "pyoidc", "pyoidc Documentation", ["Roland Hedberg"], 1)] latex_elements = {} latex_documents = [ - ('index', 'pyoidc.tex', u'pyoidc Documentation', - u'Roland Hedberg', 'manual'), + ("index", "pyoidc.tex", "pyoidc Documentation", "Roland Hedberg", "manual"), ] texinfo_documents = [ - ('index', 'pyoidc', u'pyoidc Documentation', - u'Roland Hedberg', 'pyoidc', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "pyoidc", + "pyoidc Documentation", + "Roland Hedberg", + "pyoidc", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/docker/oidc_op/scripts/make_test_site.py b/docker/oidc_op/scripts/make_test_site.py index f76e2e453..4b03f5286 100755 --- a/docker/oidc_op/scripts/make_test_site.py +++ b/docker/oidc_op/scripts/make_test_site.py @@ -5,7 +5,7 @@ _distroot = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) -_root = 'test_site' +_root = "test_site" if os.path.isdir(_root) is False: os.makedirs(_root) diff --git a/docker/op_test/config.py b/docker/op_test/config.py index 0bf1f758e..3164292af 100644 --- a/docker/op_test/config.py +++ b/docker/op_test/config.py @@ -11,13 +11,13 @@ PORT_MIN = 60001 PORT_MAX = 61000 -BASE_URL = 'https://op-test' +BASE_URL = "https://op-test" # The variables below are all passed on to the test tool instance -ENT_PATH = 'entities' -ENT_INFO = 'entity_info' +ENT_PATH = "entities" +ENT_INFO = "entity_info" -FLOWDIR = 'flows' +FLOWDIR = "flows" -PATH2PORT = 'path2port.csv' -TEST_SCRIPT = './op_test_tool.py' +PATH2PORT = "path2port.csv" +TEST_SCRIPT = "./op_test_tool.py" diff --git a/docker/op_test/tt_config.py b/docker/op_test/tt_config.py index 8b9c3e32a..0d0c2d339 100644 --- a/docker/op_test/tt_config.py +++ b/docker/op_test/tt_config.py @@ -9,17 +9,17 @@ VERIFY_SSL = False # Make sure BASE starts with https if TLS = True -BASE = 'https://op-test' +BASE = "https://op-test" -ENT_PATH = 'entities' -ENT_INFO = 'entity_info' -PRE_HTML = 'html/tt' +ENT_PATH = "entities" +ENT_INFO = "entity_info" +PRE_HTML = "html/tt" KEYS = [ {"key": "keys/enc.key", "type": "RSA", "use": ["enc"]}, {"key": "keys/sig.key", "type": "RSA", "use": ["sig"]}, {"crv": "P-256", "type": "EC", "use": ["sig"]}, - {"crv": "P-256", "type": "EC", "use": ["enc"]} + {"crv": "P-256", "type": "EC", "use": ["enc"]}, ] SESSION_CHANGE_URL = "{}session_change" diff --git a/docker/rp_test/conf.py b/docker/rp_test/conf.py index 219bb2f4f..ea09093db 100644 --- a/docker/rp_test/conf.py +++ b/docker/rp_test/conf.py @@ -6,7 +6,7 @@ {"type": "RSA", "key": "keys/pyoidc_enc", "use": ["enc"]}, {"type": "RSA", "key": "keys/pyoidc_sig", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} + {"type": "EC", "crv": "P-256", "use": ["enc"]}, ] multi_keys = [ @@ -17,36 +17,41 @@ {"type": "EC", "crv": "P-256", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["enc"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} + {"type": "EC", "crv": "P-256", "use": ["enc"]}, ] -FOS = ['https://swamid.sunet.se/oidc', - 'https://surfnet.nl/oidc'] +FOS = ["https://swamid.sunet.se/oidc", "https://surfnet.nl/oidc"] -KEYDEFS = [ - {"type": "RSA", "key": '', "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]} -] +KEYDEFS = [{"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}] GRPS = [ - "Discovery", "Dynamic Client Registration", - "Response Type and Response Mode", "claims Request Parameter", - "request_uri Request Parameter", "scope Request Parameter", - "nonce Request Parameter", "Client Authentication", - "ID Token", "Key Rotation", "Claims Types", "UserInfo Endpoint", - "3rd-Party Init SSO", "RP Initiated BackChannel Logout", - "RP Initiated FrontChannel Logout", "RP Initiated Logout", + "Discovery", + "Dynamic Client Registration", + "Response Type and Response Mode", + "claims Request Parameter", + "request_uri Request Parameter", + "scope Request Parameter", + "nonce Request Parameter", + "Client Authentication", + "ID Token", + "Key Rotation", + "Claims Types", + "UserInfo Endpoint", + "3rd-Party Init SSO", + "RP Initiated BackChannel Logout", + "RP Initiated FrontChannel Logout", + "RP Initiated Logout", ] -#Only Username and password. +# Only Username and password. AUTHENTICATION = { - #"UserPassword": {"ACR": "PASSWORD", "WEIGHT": 1, "URL": SERVICE_URL} + # "UserPassword": {"ACR": "PASSWORD", "WEIGHT": 1, "URL": SERVICE_URL} "NoAuthn": {"ACR": "PASSWORD", "WEIGHT": 1, "user": "diana"} } COOKIE = { - 'name': 'pyoic', - 'ttl': 4 * 60 # 4 hours + "name": "pyoic", + "ttl": 4 * 60, # 4 hours } SYM_KEY = "SoLittleTime,Got" @@ -56,7 +61,7 @@ CLIENT_DB = "client_db" -LOGOUT_PATH = 'logout' +LOGOUT_PATH = "logout" CHECK_SESSION_IFRAME = "{}:{{}}//check_session_iframe".format(baseurl) # ======= SIMPLE DATABASE ============== @@ -77,7 +82,7 @@ "street_address": "Umeå Universitet", "locality": "Umeå", "postal_code": "SE-90187", - "country": "Sweden" + "country": "Sweden", }, }, "babs": { @@ -103,5 +108,5 @@ "family_name": "Crust", "email": "uc@example.com", "email_verified": True, - } + }, } diff --git a/oauth_example/as/as.py b/oauth_example/as/as.py index ec2f61aa2..61ab94776 100755 --- a/oauth_example/as/as.py +++ b/oauth_example/as/as.py @@ -2,6 +2,7 @@ """ A very simple OAuth2 AS """ + import json import logging import os @@ -257,7 +258,7 @@ def application(self, environ, start_response): hostname=config.HOST, capabilities=capabilities, behavior=config.BEHAVIOR, - **kwargs + **kwargs, ) try: diff --git a/oauth_example/as/authn_setup.py b/oauth_example/as/authn_setup.py index 9bb7ca9d0..bc197121c 100644 --- a/oauth_example/as/authn_setup.py +++ b/oauth_example/as/authn_setup.py @@ -24,18 +24,14 @@ def cas_setup(item): _func = VALIDATOR[v_cnf["type"].upper()](item) _cnf = item["config"] - return CasAuthnMethod( - None, _cnf["cas_server"], item["URL"], _cnf["return_to"], _func - ) + return CasAuthnMethod(None, _cnf["cas_server"], item["URL"], _cnf["return_to"], _func) def userpwd_setup(item): from oic.utils.authn.user import UsernamePasswordMako _conf = item["config"] - return UsernamePasswordMako( - None, "login.mako", _conf["lookup"], _conf["passwd"], _conf["return_to"] - ) + return UsernamePasswordMako(None, "login.mako", _conf["lookup"], _conf["passwd"], _conf["return_to"]) AUTH_METHOD = { diff --git a/oauth_example/rp/rp.py b/oauth_example/rp/rp.py index e43a0c36f..e2c52a060 100644 --- a/oauth_example/rp/rp.py +++ b/oauth_example/rp/rp.py @@ -39,9 +39,7 @@ def as_choice(environ, start_response): - resp = Response( - mako_template="as_choice.mako", template_lookup=RP_CONF.LOOKUP, headers=[] - ) + resp = Response(mako_template="as_choice.mako", template_lookup=RP_CONF.LOOKUP, headers=[]) argv = {"as_list": RP_CONF.AS_CONF.keys(), "action": "as", "method": "POST"} return resp(environ, start_response, **argv) diff --git a/oidc_example/op1/claims_provider.py b/oidc_example/op1/claims_provider.py index 606b6bbc6..02bcee864 100755 --- a/oidc_example/op1/claims_provider.py +++ b/oidc_example/op1/claims_provider.py @@ -2,23 +2,21 @@ import logging import re -from oic.oic.claims_provider import UserClaimsEndpoint -from oic.oic.claims_provider import UserClaimsInfoEndpoint +from oic.oic.claims_provider import UserClaimsEndpoint, UserClaimsInfoEndpoint from oic.oic.message import OpenIDSchema -#from oic.oic.provider import CheckIDEndpoint -from oic.oic.provider import RegistrationEndpoint -from oic.oic.provider import UserinfoEndpoint -from oic.utils.http_util import * + +# from oic.oic.provider import CheckIDEndpoint +from oic.oic.provider import RegistrationEndpoint, UserinfoEndpoint +from oic.utils.http_util import NotFound, Response, parse_cookie from oic.utils.keyio import keybundle_from_local_file from oic.utils.userinfo import UserInfo -__author__ = 'rohe0002' - +__author__ = "rohe0002" LOGGER = logging.getLogger("") -hdlr = logging.FileHandler('oc3cp.log') -formatter = logging.Formatter('%(asctime)s %(name)s:%(levelname)s %(message)s') +hdlr = logging.FileHandler("oc3cp.log") +formatter = logging.Formatter("%(asctime)s %(name)s:%(levelname)s %(message)s") hdlr.setFormatter(formatter) LOGGER.addHandler(hdlr) LOGGER.setLevel(logging.DEBUG) @@ -26,7 +24,7 @@ # ---------------------------------------------------------------------------- -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def verify_client(environ, req, cdb): identity = req["client_id"] secret = req["client_secret"] @@ -38,9 +36,9 @@ def verify_client(environ, req, cdb): return False -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def user_info(oicsrv, userdb, sub, client_id="", user_info_claims=None): - #print >> sys.stderr, "claims: %s" % user_info_claims + # print >> sys.stderr, "claims: %s" % user_info_claims identity = userdb[sub] if user_info_claims: @@ -58,49 +56,44 @@ def user_info(oicsrv, userdb, sub, client_id="", user_info_claims=None): return OpenIDSchema(**result) -USER2MODE = {"diana": "aggregate", - "upper": "distribute", - "babs": "aggregate"} +USER2MODE = {"diana": "aggregate", "upper": "distribute", "babs": "aggregate"} -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def claims_mode(info, uid): if USER2MODE[uid] == "aggregate": return True else: return False -FUNCTIONS = { - "verify_client": verify_client, - "userinfo": user_info, - "claims_mode": claims_mode -} + +FUNCTIONS = {"verify_client": verify_client, "userinfo": user_info, "claims_mode": claims_mode} # ---------------------------------------------------------------------------- -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def userinfo(environ, start_response, handle): _oas = environ["oic.oas"] return _oas.userinfo_endpoint(environ, start_response, LOGGER) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def check_id(environ, start_response, handle): _oas = environ["oic.oas"] return _oas.check_id_endpoint(environ, start_response, LOGGER) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def op_info(environ, start_response, handle): _oas = environ["oic.oas"] return _oas.providerinfo_endpoint(environ, start_response, LOGGER) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def userclaims(environ, start_response, handle): _oas = environ["oic.oas"] @@ -108,25 +101,25 @@ def userclaims(environ, start_response, handle): return _oas.claims_endpoint(environ, start_response, LOGGER) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def registration(environ, start_response, handle): _oas = environ["oic.oas"] return _oas.registration_endpoint(environ, start_response) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def userclaimsinfo(environ, start_response, handle): _oas = environ["oic.oas"] LOGGER.info("claims_info_endpoint") return _oas.claims_info_endpoint(environ, start_response, LOGGER) + # ---------------------------------------------------------------------------- def static(environ, start_response, path): - _txt = open(path).read() if "x509" in path: content = "text/xml" @@ -136,20 +129,19 @@ def static(environ, start_response, path): resp = Response(_txt, content=content) return resp(environ, start_response) + # ---------------------------------------------------------------------------- ENDPOINTS = [ UserinfoEndpoint(userinfo), - #CheckIDEndpoint(check_id), + # CheckIDEndpoint(check_id), RegistrationEndpoint(registration), UserClaimsEndpoint(userclaims), - UserClaimsInfoEndpoint(userclaimsinfo) + UserClaimsInfoEndpoint(userclaimsinfo), ] -URLS = [ - (r'^.well-known/openid-configuration', op_info) -] +URLS = [(r"^.well-known/openid-configuration", op_info)] for endp in ENDPOINTS: URLS.append(("^%s$" % endp.etype, endp)) @@ -171,9 +163,9 @@ def application(environ, start_response): """ global OAS - #user = environ.get("REMOTE_USER", "") - path = environ.get('PATH_INFO', '').lstrip('/') - kaka = environ.get("HTTP_COOKIE", '') + # user = environ.get("REMOTE_USER", "") + path = environ.get("PATH_INFO", "").lstrip("/") + kaka = environ.get("HTTP_COOKIE", "") if kaka: handle = parse_cookie(OAS.name, OAS.seed, kaka) @@ -191,9 +183,9 @@ def application(environ, start_response): match = re.search(regex, path) if match is not None: try: - environ['oic.url_args'] = match.groups()[0] + environ["oic.url_args"] = match.groups()[0] except IndexError: - environ['oic.url_args'] = path + environ["oic.url_args"] = path return callback(environ, start_response, handle) resp = NotFound("Couldn't find the side you asked for!") @@ -211,13 +203,13 @@ def application(environ, start_response): }, "babs": { "geolocation": {"longitude": 4.8890, "latitude": 52.3673}, - } + }, } SERVER_DB = {} -if __name__ == '__main__': +if __name__ == "__main__": import argparse import json @@ -229,9 +221,9 @@ def application(environ, start_response): from oic.utils.sdb import create_session_db parser = argparse.ArgumentParser() - parser.add_argument('-v', dest='verbose', action='store_true') - parser.add_argument('-d', dest='debug', action='store_true') - parser.add_argument('-p', dest='port', default=8093, type=int) + parser.add_argument("-v", dest="verbose", action="store_true") + parser.add_argument("-d", dest="debug", action="store_true") + parser.add_argument("-p", dest="port", default=8093, type=int) parser.add_argument(dest="config") args = parser.parse_args() @@ -241,22 +233,18 @@ def application(environ, start_response): # in memory session storage config = json.loads(open(args.config).read()) - sdb = create_session_db(config["issuer"], - config["SESSION_KEY"], - password=rndstr(16)) - OAS = ClaimsServer(config["issuer"], sdb, cdb, userinfo, - verify_client) + sdb = create_session_db(config["issuer"], config["SESSION_KEY"], password=rndstr(16)) + OAS = ClaimsServer(config["issuer"], sdb, cdb, userinfo, verify_client) if "keys" in config: for typ, info in config["keys"].items(): - OAS.keyjar.add_kb("", keybundle_from_local_file(info["key"], "rsa", - ["ver", "sig"])) + OAS.keyjar.add_kb("", keybundle_from_local_file(info["key"], "rsa", ["ver", "sig"])) try: OAS.jwks_uri.append(info["jwk"]) except KeyError: pass - #print URLS + # print URLS if args.debug: OAS.debug = True @@ -271,12 +259,10 @@ def application(environ, start_response): if not OAS.baseurl.endswith("/"): OAS.baseurl += "/" - OAS.claims_userinfo_endpoint = "%s%s" % ( - OAS.baseurl, UserClaimsInfoEndpoint.etype) + OAS.claims_userinfo_endpoint = "%s%s" % (OAS.baseurl, UserClaimsInfoEndpoint.etype) - SRV = wsgiserver.CherryPyWSGIServer(('0.0.0.0', args.port), application) # nosec - SRV.ssl_adapter = ssl_builtin.BuiltinSSLAdapter("certs/server.crt", - "certs/server.key") + SRV = wsgiserver.CherryPyWSGIServer(("0.0.0.0", args.port), application) # nosec + SRV.ssl_adapter = ssl_builtin.BuiltinSSLAdapter("certs/server.crt", "certs/server.key") LOGGER.info("Starting server") try: diff --git a/oidc_example/op1/create_jwk_from_cert.py b/oidc_example/op1/create_jwk_from_cert.py index c1b7efba3..e36cfb66d 100755 --- a/oidc_example/op1/create_jwk_from_cert.py +++ b/oidc_example/op1/create_jwk_from_cert.py @@ -4,7 +4,8 @@ from oic.utils.keystore import KeyStore from oic.utils.keystore import x509_rsa_loads -__author__ = 'rohe0002' +__author__ = "rohe0002" + def main(x509_file, out="keys.jwk"): pb = PBase() @@ -18,6 +19,8 @@ def main(x509_file, out="keys.jwk"): f.write(txt) f.close() + if __name__ == "__main__": import sys + main(*sys.argv[1:2]) diff --git a/oidc_example/op1/oc_server.py b/oidc_example/op1/oc_server.py index 67462189d..163c0425f 100755 --- a/oidc_example/op1/oc_server.py +++ b/oidc_example/op1/oc_server.py @@ -1,52 +1,44 @@ #!/usr/bin/env python import json +import logging import os import re import sys import traceback -from exceptions import AttributeError -from exceptions import Exception -from exceptions import IndexError -from exceptions import KeyboardInterrupt -from exceptions import KeyError -from exceptions import OSError from logging.handlers import BufferingHandler -from urlparse import parse_qs +from exceptions import AttributeError, Exception, IndexError, KeyboardInterrupt, KeyError, OSError from mako.lookup import TemplateLookup +from urlparse import parse_qs from oic.oic.message import ProviderConfigurationResponse -#from oic.oic.provider import CheckIDEndpoint -from oic.oic.provider import AuthorizationEndpoint -from oic.oic.provider import EndSessionEndpoint -from oic.oic.provider import Provider -from oic.oic.provider import RegistrationEndpoint -from oic.oic.provider import TokenEndpoint -from oic.oic.provider import UserinfoEndpoint + +# from oic.oic.provider import CheckIDEndpoint +from oic.oic.provider import ( + AuthorizationEndpoint, + EndSessionEndpoint, + Provider, + RegistrationEndpoint, + TokenEndpoint, + UserinfoEndpoint, +) from oic.utils.authn.authn_context import AuthnBroker from oic.utils.authn.client import verify_client from oic.utils.authz import AuthzHandling -from oic.utils.http_util import * +from oic.utils.http_util import BadRequest, NotFound, Response, ServiceError, Unauthorized, wsgi_wrapper from oic.utils.keyio import keyjar_init from oic.utils.userinfo import UserInfo -from oic.utils.webfinger import OIC_ISSUER -from oic.utils.webfinger import WebFinger - -__author__ = 'rohe0002' - - - +from oic.utils.webfinger import OIC_ISSUER, WebFinger +__author__ = "rohe0002" LOGGER = logging.getLogger("") -LOGFILE_NAME = 'oc.log' +LOGFILE_NAME = "oc.log" hdlr = logging.FileHandler(LOGFILE_NAME) -base_formatter = logging.Formatter( - "%(asctime)s %(name)s:%(levelname)s %(message)s") +base_formatter = logging.Formatter("%(asctime)s %(name)s:%(levelname)s %(message)s") -CPC = ('%(asctime)s %(name)s:%(levelname)s ' - '[%(client)s,%(path)s,%(cid)s] %(message)s') +CPC = "%(asctime)s %(name)s:%(levelname)s " "[%(client)s,%(path)s,%(cid)s] %(message)s" cpc_formatter = logging.Formatter(CPC) hdlr.setFormatter(base_formatter) @@ -68,14 +60,16 @@ OAS = None -PASSWD = {"diana": "krall", - "babs": "howes", - "upper": "crust", - "rohe0002": "StevieRay", - "haho0032": "qwerty"} #haho0032@hashog.umdc.umu.se +PASSWD = { + "diana": "krall", + "babs": "howes", + "upper": "crust", + "rohe0002": "StevieRay", + "haho0032": "qwerty", +} # haho0032@hashog.umdc.umu.se -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def devnull(txt): pass @@ -123,6 +117,7 @@ def replace_format_handler(logger, log_format="CPC"): ACTIVE_HANDLER = format return logger + # #noinspection PyUnusedLocal # def simple_user_info(oicsrv, userdb, sub, client_id="", # user_info_claims=None): @@ -132,15 +127,15 @@ def replace_format_handler(logger, log_format="CPC"): # ---------------------------------------------------------------------------- -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def safe(environ, start_response, logger): _oas = environ["oic.oas"] _srv = _oas.server _log_info = _oas.logger.info _log_info("- safe -") - #_log_info("env: %s" % environ) - #_log_info("handle: %s" % (handle,)) + # _log_info("env: %s" % environ) + # _log_info("handle: %s" % (handle,)) try: authz = environ["HTTP_AUTHORIZATION"] @@ -152,7 +147,7 @@ def safe(environ, start_response, logger): if typ != "Bearer": resp = BadRequest("Unsupported authorization method") return resp(environ, start_response) - + try: _sinfo = _srv.sdb[code] except KeyError: @@ -164,7 +159,7 @@ def safe(environ, start_response, logger): return resp(environ, start_response) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def css(environ, start_response, logger): try: info = open(environ["PATH_INFO"]).read() @@ -174,88 +169,80 @@ def css(environ, start_response, logger): return resp(environ, start_response) + # ---------------------------------------------------------------------------- -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def token(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.token_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.token_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def authorization(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.authorization_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.authorization_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def userinfo(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.userinfo_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.userinfo_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def op_info(environ, start_response, logger): _oas = environ["oic.oas"] LOGGER.info("op_info") - return wsgi_wrapper(environ, start_response, _oas.providerinfo_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.providerinfo_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def registration(environ, start_response, logger): _oas = environ["oic.oas"] if environ["REQUEST_METHOD"] == "POST": - return wsgi_wrapper(environ, start_response, _oas.registration_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.registration_endpoint, logger=logger) elif environ["REQUEST_METHOD"] == "GET": - return wsgi_wrapper(environ, start_response, _oas.read_registration, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.read_registration, logger=logger) else: resp = ServiceError("Method not supported") return resp(environ, start_response) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def check_id(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.check_id_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.check_id_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def swd_info(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.discovery_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.discovery_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def trace_log(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.tracelog_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.tracelog_endpoint, logger=logger) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def endsession(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.endsession_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.endsession_endpoint, logger=logger) -#noinspection PyUnusedLocal + +# noinspection PyUnusedLocal def meta_info(environ, start_response, logger): """ Returns something like this @@ -285,11 +272,10 @@ def webfinger(environ, start_response, _): return resp(environ, start_response) -#noinspection PyUnusedLocal +# noinspection PyUnusedLocal def verify(environ, start_response, logger): _oas = environ["oic.oas"] - return wsgi_wrapper(environ, start_response, _oas.verify_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, _oas.verify_endpoint, logger=logger) def static_file(path): @@ -300,24 +286,24 @@ def static_file(path): return False -#noinspection PyUnresolvedReferences +# noinspection PyUnresolvedReferences def static(environ, start_response, logger, path): logger.info("[static]sending: %s" % (path,)) try: - data = open(path, 'rb').read() + data = open(path, "rb").read() if path.endswith(".ico"): - start_response('200 OK', [('Content-Type', "image/x-icon")]) + start_response("200 OK", [("Content-Type", "image/x-icon")]) elif path.endswith(".html"): - start_response('200 OK', [('Content-Type', 'text/html')]) + start_response("200 OK", [("Content-Type", "text/html")]) elif path.endswith(".json"): - start_response('200 OK', [('Content-Type', 'application/json')]) + start_response("200 OK", [("Content-Type", "application/json")]) elif path.endswith(".txt"): - start_response('200 OK', [('Content-Type', 'text/plain')]) + start_response("200 OK", [("Content-Type", "text/plain")]) elif path.endswith(".css"): - start_response('200 OK', [('Content-Type', 'text/css')]) + start_response("200 OK", [("Content-Type", "text/css")]) else: - start_response('200 OK', [('Content-Type', "text/xml")]) + start_response("200 OK", [("Content-Type", "text/xml")]) return [data] except IOError: resp = NotFound() @@ -328,21 +314,21 @@ def static(environ, start_response, logger, path): AuthorizationEndpoint(authorization), TokenEndpoint(token), UserinfoEndpoint(userinfo), - #CheckIDEndpoint(check_id), + # CheckIDEndpoint(check_id), RegistrationEndpoint(registration), EndSessionEndpoint(endsession), ] URLS = [ - (r'^verify', verify), - (r'^.well-known/openid-configuration', op_info), - (r'^.well-known/simple-web-discovery', swd_info), - (r'^.well-known/host-meta.json', meta_info), - (r'^.well-known/webfinger', webfinger), -# (r'^.well-known/webfinger', webfinger), - (r'.+\.css$', css), - (r'safe', safe), -# (r'tracelog', trace_log), + (r"^verify", verify), + (r"^.well-known/openid-configuration", op_info), + (r"^.well-known/simple-web-discovery", swd_info), + (r"^.well-known/host-meta.json", meta_info), + (r"^.well-known/webfinger", webfinger), + # (r'^.well-known/webfinger', webfinger), + (r".+\.css$", css), + (r"safe", safe), + # (r'tracelog', trace_log), ] @@ -352,13 +338,17 @@ def add_endpoints(extra): for endp in extra: URLS.append(("^%s" % endp.etype, endp)) + # ---------------------------------------------------------------------------- -ROOT = './' +ROOT = "./" -LOOKUP = TemplateLookup(directories=[ROOT + 'templates', ROOT + 'htdocs'], - module_directory=ROOT + 'modules', - input_encoding='utf-8', output_encoding='utf-8') +LOOKUP = TemplateLookup( + directories=[ROOT + "templates", ROOT + "htdocs"], + module_directory=ROOT + "modules", + input_encoding="utf-8", + output_encoding="utf-8", +) # ---------------------------------------------------------------------------- @@ -380,38 +370,38 @@ def application(environ, start_response): global OAS - #user = environ.get("REMOTE_USER", "") - path = environ.get('PATH_INFO', '').lstrip('/') + # user = environ.get("REMOTE_USER", "") + path = environ.get("PATH_INFO", "").lstrip("/") - logger = logging.getLogger('oicServer') + logger = logging.getLogger("oicServer") if path == "robots.txt": return static(environ, start_response, logger, "static/robots.txt") environ["oic.oas"] = OAS - #remote = environ.get("REMOTE_ADDR") - #kaka = environ.get("HTTP_COOKIE", '') + # remote = environ.get("REMOTE_ADDR") + # kaka = environ.get("HTTP_COOKIE", '') if path.startswith("static/"): return static(environ, start_response, logger, path) -# elif path.startswith("oc_keys/"): -# return static(environ, start_response, logger, path) + # elif path.startswith("oc_keys/"): + # return static(environ, start_response, logger, path) for regex, callback in URLS: match = re.search(regex, path) if match is not None: try: - environ['oic.url_args'] = match.groups()[0] + environ["oic.url_args"] = match.groups()[0] except IndexError: - environ['oic.url_args'] = path + environ["oic.url_args"] = path logger.info("callback: %s" % callback) try: return callback(environ, start_response, logger) except Exception as err: - print >> sys.stderr, "%s" % err + sys.stderr.write("%s" % err) message = traceback.format_exception(*sys.exc_info()) - print >> sys.stderr, message + sys.stderr.write(message) logger.exception("%s" % err) resp = ServiceError("%s" % err) return resp(environ, start_response) @@ -423,12 +413,11 @@ def application(environ, start_response): # ---------------------------------------------------------------------------- + class TestProvider(Provider): - #noinspection PyUnusedLocal - def __init__(self, name, sdb, cdb, function, userdb, urlmap=None, - debug=0, jwt_keys=None): - Provider.__init__(self, name, sdb, cdb, function, userdb, urlmap, - jwt_keys) + # noinspection PyUnusedLocal + def __init__(self, name, sdb, cdb, function, userdb, urlmap=None, debug=0, jwt_keys=None): + Provider.__init__(self, name, sdb, cdb, function, userdb, urlmap, jwt_keys) self.test_mode = True self.trace_log = {} self.sessions = [] @@ -445,7 +434,7 @@ def dump_tracelog(self, key): return "\n".join(arr) return "" - #noinspection PyUnusedLocal + # noinspection PyUnusedLocal def tracelog_endpoint(self, environ, start_response, logger, **kwargs): handle = kwargs["handle"] tlog = self.trace_log[handle[0]] @@ -477,27 +466,28 @@ def new_trace_log(self, key): return _log -if __name__ == '__main__': +if __name__ == "__main__": import argparse - import shelve # nosec import importlib + import shelve # nosec from cherrypy import wsgiserver - #from cherrypy.wsgiserver import ssl_builtin + + # from cherrypy.wsgiserver import ssl_builtin from cherrypy.wsgiserver import ssl_pyopenssl from oic import rndstr from oic.utils.sdb import create_session_db parser = argparse.ArgumentParser() - parser.add_argument('-v', dest='verbose', action='store_true') - parser.add_argument('-d', dest='debug', action='store_true') - parser.add_argument('-p', dest='port', default=80, type=int) - parser.add_argument('-t', dest='test', action='store_true') - parser.add_argument('-X', dest='XpressConnect', action='store_true') - parser.add_argument('-A', dest='authn_as', default="") - parser.add_argument('-P', dest='provider_conf') - parser.add_argument('-k', dest='insecure', action='store_true') + parser.add_argument("-v", dest="verbose", action="store_true") + parser.add_argument("-d", dest="debug", action="store_true") + parser.add_argument("-p", dest="port", default=80, type=int) + parser.add_argument("-t", dest="test", action="store_true") + parser.add_argument("-X", dest="XpressConnect", action="store_true") + parser.add_argument("-A", dest="authn_as", default="") + parser.add_argument("-P", dest="provider_conf") + parser.add_argument("-k", dest="insecure", action="store_true") parser.add_argument(dest="config") args = parser.parse_args() @@ -514,20 +504,28 @@ def new_trace_log(self, key): for authkey, value in config.AUTHORIZATION.items(): authn = None if "CAS" == authkey: - from oic.utils.authn.user_cas import CasAuthnMethod - from oic.utils.authn.ldap_member import UserLDAPMemberValidation - config.LDAP_EXTRAVALIDATION.update(config.LDAP) - authn = CasAuthnMethod(None, config.CAS_SERVER, config.SERVICE_URL,"%s/authorization" % config.issuer, - UserLDAPMemberValidation(**config.LDAP_EXTRAVALIDATION)) + from oic.utils.authn.ldap_member import UserLDAPMemberValidation + from oic.utils.authn.user_cas import CasAuthnMethod + + config.LDAP_EXTRAVALIDATION.update(config.LDAP) + authn = CasAuthnMethod( + None, + config.CAS_SERVER, + config.SERVICE_URL, + "%s/authorization" % config.issuer, + UserLDAPMemberValidation(**config.LDAP_EXTRAVALIDATION), + ) if "UserPassword" == authkey: from oic.utils.authn.user import UsernamePasswordMako - authn = UsernamePasswordMako(None, "login.mako", LOOKUP, PASSWD, - "%s/authorization" % config.issuer) + + authn = UsernamePasswordMako(None, "login.mako", LOOKUP, PASSWD, "%s/authorization" % config.issuer) if authn is not None: - ac.add(config.AUTHORIZATION[authkey]["ACR"], - authn, - config.AUTHORIZATION[authkey]["WEIGHT"], - config.AUTHORIZATION[authkey]["URL"]) + ac.add( + config.AUTHORIZATION[authkey]["ACR"], + authn, + config.AUTHORIZATION[authkey]["WEIGHT"], + config.AUTHORIZATION[authkey]["URL"], + ) # dealing with authorization authz = AuthzHandling() @@ -539,24 +537,17 @@ def new_trace_log(self, key): kwargs = {"verify_ssl": True} # In-Memory SessionDB issuing DefaultTokens - sdb = create_session_db(config.baseurl, - secret=rndstr(32), - password=rndstr(32)) + sdb = create_session_db(config.baseurl, secret=rndstr(32), password=rndstr(32)) if args.test: - URLS.append((r'tracelog', trace_log)) - OAS = TestProvider(config.issuer, sdb, cdb, ac, - None, authz, config.SYM_KEY) + URLS.append((r"tracelog", trace_log)) + OAS = TestProvider(config.issuer, sdb, cdb, ac, None, authz, config.SYM_KEY) elif args.XpressConnect: from XpressConnect import XpressConnectProvider - OAS = XpressConnectProvider(config.issuer, sdb, - cdb, ac, None, authz, verify_client, - config.SYM_KEY) + OAS = XpressConnectProvider(config.issuer, sdb, cdb, ac, None, authz, verify_client, config.SYM_KEY) else: - OAS = Provider(config.issuer, sdb, cdb, ac, None, - authz, verify_client, config.SYM_KEY, **kwargs) - + OAS = Provider(config.issuer, sdb, cdb, ac, None, authz, verify_client, config.SYM_KEY, **kwargs) try: OAS.cookie_ttl = config.COOKIETTL @@ -568,7 +559,7 @@ def new_trace_log(self, key): except AttributeError: pass - #print URLS + # print URLS if args.debug: OAS.debug = True if args.test: @@ -580,8 +571,7 @@ def new_trace_log(self, key): OAS.authn_as = args.authn_as if args.provider_conf: - prc = ProviderConfigurationResponse().from_json( - open(args.provider_conf).read()) + prc = ProviderConfigurationResponse().from_json(open(args.provider_conf).read()) endpoints = [] for key in prc.keys(): if key.endswith("_endpoint"): @@ -619,24 +609,23 @@ def new_trace_log(self, key): if config.USERINFO == "LDAP": from oic.utils.userinfo.ldap_info import UserInfoLDAP + OAS.userinfo = UserInfoLDAP(**config.LDAP) elif config.USERINFO == "SIMPLE": OAS.userinfo = UserInfo(config.USERDB) elif config.USERINFO == "DISTRIBUTED": from oic.utils.userinfo.distaggr import DistributedAggregatedUserInfo - OAS.userinfo = DistributedAggregatedUserInfo(config.USERDB, OAS, - config.CLIENT_INFO) + + OAS.userinfo = DistributedAggregatedUserInfo(config.USERDB, OAS, config.CLIENT_INFO) LOGGER.debug("URLS: '%s" % (URLS,)) # Add the claims providers keys - SRV = wsgiserver.CherryPyWSGIServer(('0.0.0.0', args.port), application) # nosec + SRV = wsgiserver.CherryPyWSGIServer(("0.0.0.0", args.port), application) # nosec - SRV.ssl_adapter = ssl_pyopenssl.pyOpenSSLAdapter(config.SERVER_CERT, - config.SERVER_KEY, - config.CERT_CHAIN) + SRV.ssl_adapter = ssl_pyopenssl.pyOpenSSLAdapter(config.SERVER_CERT, config.SERVER_KEY, config.CERT_CHAIN) LOGGER.info("OC server starting listening on port:%s" % args.port) - print ("OC server starting listening on port:%s" % args.port) + print("OC server starting listening on port:%s" % args.port) try: SRV.start() except KeyboardInterrupt: diff --git a/oidc_example/op2/client_mgr.py b/oidc_example/op2/client_mgr.py index b401dd68c..1b1c1e1e5 100755 --- a/oidc_example/op2/client_mgr.py +++ b/oidc_example/op2/client_mgr.py @@ -3,13 +3,13 @@ from oic.utils.client_management import CDB -if __name__ == '__main__': +if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument('-l', dest='list', action='store_true') - parser.add_argument('-a', dest='add') - parser.add_argument('-d', dest='delete') + parser.add_argument("-l", dest="list", action="store_true") + parser.add_argument("-a", dest="add") + parser.add_argument("-d", dest="delete") parser.add_argument(dest="config") args = parser.parse_args() @@ -18,7 +18,7 @@ if args.list: for key, val in cdb.items(): - print('{}:{}'.format(key, val['redirect_uris'])) + print("{}:{}".format(key, val["redirect_uris"])) if args.add: fp = open(args.add) diff --git a/oidc_example/op2/config_full.py b/oidc_example/op2/config_full.py index e637b1d16..3b004e953 100644 --- a/oidc_example/op2/config_full.py +++ b/oidc_example/op2/config_full.py @@ -1,29 +1,26 @@ keys = [ {"type": "RSA", "key": "cp_keys/key.pem", "use": ["enc", "sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} + {"type": "EC", "crv": "P-256", "use": ["enc"]}, ] -ISSUER = 'http://localhost' +ISSUER = "http://localhost" SERVICE_URL = "{issuer}/verify" -USER_PASSWORD_END_POINTS = ["user_password", "multi_user_password_saml_verify", - "multi_user_password_js_verify"] -SAML_END_POINTS = ['saml', "multi_saml_pass"] -JAVASCRIPT_END_POINTS = ['javascript_login', "multi_javascript_login"] +USER_PASSWORD_END_POINTS = ["user_password", "multi_user_password_saml_verify", "multi_user_password_js_verify"] +SAML_END_POINTS = ["saml", "multi_saml_pass"] +JAVASCRIPT_END_POINTS = ["javascript_login", "multi_javascript_login"] AUTHENTICATION = { "SAML": {"ACR": "SAML", "WEIGHT": 1, "END_POINTS": SAML_END_POINTS}, - "UserPassword": {"ACR": "PASSWORD", "WEIGHT": 2, - "END_POINTS": USER_PASSWORD_END_POINTS}, + "UserPassword": {"ACR": "PASSWORD", "WEIGHT": 2, "END_POINTS": USER_PASSWORD_END_POINTS}, "SamlPass": {"ACR": "SAML_PASS", "WEIGHT": 3}, - "JavascriptLogin": {"ACR": "JAVASCRIPT_LOGIN", "WEIGHT": 4, - "END_POINTS": JAVASCRIPT_END_POINTS}, + "JavascriptLogin": {"ACR": "JAVASCRIPT_LOGIN", "WEIGHT": 4, "END_POINTS": JAVASCRIPT_END_POINTS}, "JavascriptPass": {"ACR": "JAVASCRIPT_PASS", "WEIGHT": 5}, } -COOKIENAME= 'pyoic' -COOKIETTL = 4*60 # 4 hours +COOKIENAME = "pyoic" +COOKIETTL = 4 * 60 # 4 hours SYM_KEY = "SoLittleTime,Got" SERVER_CERT = "certs/server.crt" @@ -37,7 +34,7 @@ # User information is collected with a SAML attribute authority # USERINFO = "AA" # Name of the Service Provider configuration file. -SP_CONFIG="sp_conf" +SP_CONFIG = "sp_conf" # Dictionary with user information for the SAML users. Must be empty. SAML = {} @@ -59,7 +56,7 @@ "street_address": "Umeå Universitet", "locality": "Umeå", "postal_code": "SE-90187", - "country": "Sweden" + "country": "Sweden", }, }, "babs": { @@ -85,6 +82,5 @@ "family_name": "Crust", "email": "uc@example.com", "email_verified": True, - } + }, } - diff --git a/oidc_example/op2/config_simple.py b/oidc_example/op2/config_simple.py index f7f08219a..f14b115c8 100644 --- a/oidc_example/op2/config_simple.py +++ b/oidc_example/op2/config_simple.py @@ -1,19 +1,16 @@ keys = [ {"type": "RSA", "key": "cp_keys/key.pem", "use": ["enc", "sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} + {"type": "EC", "crv": "P-256", "use": ["enc"]}, ] -ISSUER = 'http://localhost' +ISSUER = "http://localhost" SERVICE_URL = "{issuer}/verify" # Only Username and password. -AUTHENTICATION = { - "UserPassword": {"ACR": "PASSWORD", "WEIGHT": 1, "URL": SERVICE_URL, - "END_POINTS": ["verify"]} -} +AUTHENTICATION = {"UserPassword": {"ACR": "PASSWORD", "WEIGHT": 1, "URL": SERVICE_URL, "END_POINTS": ["verify"]}} -COOKIENAME = 'pyoic' +COOKIENAME = "pyoic" COOKIETTL = 4 * 60 # 4 hours SYM_KEY = "SoLittleTime,Got" @@ -40,7 +37,7 @@ "street_address": "Umeå Universitet", "locality": "Umeå", "postal_code": "SE-90187", - "country": "Sweden" + "country": "Sweden", }, }, "babs": { @@ -66,5 +63,5 @@ "family_name": "Crust", "email": "uc@example.com", "email_verified": True, - } + }, } diff --git a/oidc_example/op2/server.py b/oidc_example/op2/server.py index a4ad57e75..d20e49d68 100755 --- a/oidc_example/op2/server.py +++ b/oidc_example/op2/server.py @@ -1,37 +1,47 @@ #!/usr/bin/env python -from urllib.parse import parse_qs - +import importlib.util import json +import logging +import os import re import sys +import time import traceback -import importlib.util +from urllib.parse import parse_qs from mako.lookup import TemplateLookup -from oic.oic.provider import AuthorizationEndpoint -from oic.oic.provider import EndSessionEndpoint -from oic.oic.provider import Provider -from oic.oic.provider import RegistrationEndpoint -from oic.oic.provider import TokenEndpoint -from oic.oic.provider import UserinfoEndpoint +from oic.oic.provider import ( + AuthorizationEndpoint, + EndSessionEndpoint, + Provider, + RegistrationEndpoint, + TokenEndpoint, + UserinfoEndpoint, +) from oic.utils import shelve_wrapper -from oic.utils.authn.authn_context import AuthnBroker -from oic.utils.authn.authn_context import make_auth_verify +from oic.utils.authn.authn_context import AuthnBroker, make_auth_verify from oic.utils.authn.client import verify_client from oic.utils.authn.javascript_login import JavascriptFormMako -from oic.utils.authn.multi_auth import AuthnIndexedEndpointWrapper -from oic.utils.authn.multi_auth import setup_multi_auth +from oic.utils.authn.multi_auth import AuthnIndexedEndpointWrapper, setup_multi_auth from oic.utils.authn.user import UsernamePasswordMako from oic.utils.authz import AuthzHandling -from oic.utils.http_util import * +from oic.utils.http_util import ( + BadRequest, + NotFound, + Response, + ServiceError, + Unauthorized, + as_unicode, + get_post, + wsgi_wrapper, +) from oic.utils.keyio import keyjar_init from oic.utils.userinfo import UserInfo from oic.utils.userinfo.aa_info import AaUserInfo -from oic.utils.webfinger import OIC_ISSUER -from oic.utils.webfinger import WebFinger +from oic.utils.webfinger import OIC_ISSUER, WebFinger -__author__ = 'rohe0002' +__author__ = "rohe0002" # This is *NOT* good practice !! try: @@ -42,30 +52,24 @@ urllib3.disable_warnings() LOGGER = logging.getLogger("") -LOGFILE_NAME = 'oc.log' +LOGFILE_NAME = "oc.log" hdlr = logging.FileHandler(LOGFILE_NAME) -base_formatter = logging.Formatter( - "%(asctime)s %(name)s:%(levelname)s %(message)s") +base_formatter = logging.Formatter("%(asctime)s %(name)s:%(levelname)s %(message)s") -CPC = ('%(asctime)s %(name)s:%(levelname)s ' - '[%(client)s,%(path)s,%(cid)s] %(message)s') +CPC = "%(asctime)s %(name)s:%(levelname)s " "[%(client)s,%(path)s,%(cid)s] %(message)s" cpc_formatter = logging.Formatter(CPC) hdlr.setFormatter(base_formatter) LOGGER.addHandler(hdlr) LOGGER.setLevel(logging.DEBUG) -logger = logging.getLogger('oicServer') +logger = logging.getLogger("oicServer") URLMAP = {} NAME = "pyoic" OAS = None -PASSWD = { - "diana": "krall", - "babs": "howes", - "upper": "crust" -} +PASSWD = {"diana": "krall", "babs": "howes", "upper": "crust"} JWKS_FILE_NAME = "static/jwks.json" @@ -85,19 +89,19 @@ def static(self, environ, start_response, path): logger.info("[static]sending: %s" % (path,)) try: - data = open(path, 'rb').read() + data = open(path, "rb").read() if path.endswith(".ico"): - start_response('200 OK', [('Content-Type', "image/x-icon")]) + start_response("200 OK", [("Content-Type", "image/x-icon")]) elif path.endswith(".html"): - start_response('200 OK', [('Content-Type', 'text/html')]) + start_response("200 OK", [("Content-Type", "text/html")]) elif path.endswith(".json"): - start_response('200 OK', [('Content-Type', 'application/json')]) + start_response("200 OK", [("Content-Type", "application/json")]) elif path.endswith(".txt"): - start_response('200 OK', [('Content-Type', 'text/plain')]) + start_response("200 OK", [("Content-Type", "text/plain")]) elif path.endswith(".css"): - start_response('200 OK', [('Content-Type', 'text/css')]) + start_response("200 OK", [("Content-Type", "text/css")]) else: - start_response('200 OK', [('Content-Type', "text/xml")]) + start_response("200 OK", [("Content-Type", "text/xml")]) return [data] except IOError: resp = NotFound() @@ -105,8 +109,7 @@ def static(self, environ, start_response, path): def check_session_iframe(self, environ, start_response, logger): - return static(self, environ, start_response, - "htdocs/op_session_iframe.html") + return static(self, environ, start_response, "htdocs/op_session_iframe.html") # ---------------------------------------------------------------------------- @@ -135,11 +138,14 @@ def clear_keys(self, environ, start_response, _): # ---------------------------------------------------------------------------- -ROOT = './' +ROOT = "./" -LOOKUP = TemplateLookup(directories=[ROOT + 'templates', ROOT + 'htdocs'], - module_directory=ROOT + 'modules', - input_encoding='utf-8', output_encoding='utf-8') +LOOKUP = TemplateLookup( + directories=[ROOT + "templates", ROOT + "htdocs"], + module_directory=ROOT + "modules", + input_encoding="utf-8", + output_encoding="utf-8", +) def mako_renderer(template_name, context): @@ -164,19 +170,21 @@ def __init__(self, oas, urls): self.oas.endp = self.endpoints self.urls = urls - self.urls.extend([ - (r'^.well-known/openid-configuration', self.op_info), - (r'^.well-known/simple-web-discovery', self.swd_info), - (r'^.well-known/host-meta.json', self.meta_info), - (r'^.well-known/webfinger', self.webfinger), - # (r'^.well-known/webfinger', webfinger), - (r'.+\.css$', self.css), - (r'safe', self.safe), - (r'^keyrollover', key_rollover), - (r'^clearkeys', clear_keys), - (r'^check_session', check_session_iframe) - # (r'tracelog', trace_log), - ]) + self.urls.extend( + [ + (r"^.well-known/openid-configuration", self.op_info), + (r"^.well-known/simple-web-discovery", self.swd_info), + (r"^.well-known/host-meta.json", self.meta_info), + (r"^.well-known/webfinger", self.webfinger), + # (r'^.well-known/webfinger', webfinger), + (r".+\.css$", self.css), + (r"safe", self.safe), + (r"^keyrollover", key_rollover), + (r"^clearkeys", clear_keys), + (r"^check_session", check_session_iframe), + # (r'tracelog', trace_log), + ] + ) self.add_endpoints(self.endpoints) @@ -227,49 +235,37 @@ def css(self, environ, start_response): # ------------------------------------------------------------------------ def token(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.oas.token_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.token_endpoint, logger=logger) def authorization(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.oas.authorization_endpoint, logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.authorization_endpoint, logger=logger) def userinfo(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.oas.userinfo_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.userinfo_endpoint, logger=logger) def op_info(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.oas.providerinfo_endpoint, logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.providerinfo_endpoint, logger=logger) def registration(self, environ, start_response): if environ["REQUEST_METHOD"] == "POST": - return wsgi_wrapper(environ, start_response, - self.oas.registration_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.registration_endpoint, logger=logger) elif environ["REQUEST_METHOD"] == "GET": - return wsgi_wrapper(environ, start_response, - self.oas.read_registration, logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.read_registration, logger=logger) else: resp = ServiceError("Method not supported") return resp(environ, start_response) def check_id(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.oas.check_id_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.check_id_endpoint, logger=logger) def swd_info(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.oas.discovery_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.discovery_endpoint, logger=logger) def trace_log(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.oas.tracelog_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.tracelog_endpoint, logger=logger) def endsession(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.oas.endsession_endpoint, logger=logger) + return wsgi_wrapper(environ, start_response, self.oas.endsession_endpoint, logger=logger) # noinspection PyUnusedLocal def meta_info(self, environ, start_response): @@ -298,8 +294,7 @@ def webfinger(self, environ, start_response): resp = BadRequest("Bad issuer in request") else: wf = WebFinger() - resp = Response(wf.response(subject=resource, - base=self.oas.baseurl)) + resp = Response(wf.response(subject=resource, base=self.oas.baseurl)) return resp(environ, start_response) def application(self, environ, start_response): @@ -317,7 +312,7 @@ def application(self, environ, start_response): :return: The response as a list of lines """ # user = environ.get("REMOTE_USER", "") - path = environ.get('PATH_INFO', '').lstrip('/') + path = environ.get("PATH_INFO", "").lstrip("/") if path == "robots.txt": return static(self, environ, start_response, "static/robots.txt") @@ -333,9 +328,9 @@ def application(self, environ, start_response): match = re.search(regex, path) if match is not None: try: - environ['oic.url_args'] = match.groups()[0] + environ["oic.url_args"] = match.groups()[0] except IndexError: - environ['oic.url_args'] = path + environ["oic.url_args"] = path logger.info("callback: %s" % callback) try: @@ -355,14 +350,15 @@ def application(self, environ, start_response): # ---------------------------------------------------------------------------- + def _import_config(config_path): - import_spec = importlib.util.spec_from_file_location('config', config_path) + import_spec = importlib.util.spec_from_file_location("config", config_path) config_module = importlib.util.module_from_spec(import_spec) import_spec.loader.exec_module(config_module) return config_module -if __name__ == '__main__': +if __name__ == "__main__": import argparse from cherrypy import wsgiserver @@ -372,24 +368,23 @@ def _import_config(config_path): from oic.utils.sdb import create_session_db parser = argparse.ArgumentParser() - parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', - help='More verbose output') - parser.add_argument('-d', '--debug', dest='debug', action='store_true', - help="Enable debug output (doesn't do much)") - parser.add_argument('-p', '--port', dest='port', default=80, type=int, - help='TCP listen port') - parser.add_argument('-t', '--tls', dest='tls', action='store_true', - help='Use HTTPS') - parser.add_argument('-k', '--insecure', dest='insecure', - action='store_true', - help='Disable verification of SSL certs') - parser.add_argument('-c', '--capabilities', dest='capabilities', - help='A file containing a JSON representation of ' - 'server capabilities') - parser.add_argument('-i', '--issuer', dest='issuer', - help='Issuer ID of the OpenID Connect Provider [OP]', - nargs=1) - parser.add_argument(dest='config', help='Python config file (see examples)') + parser.add_argument("-v", "--verbose", dest="verbose", action="store_true", help="More verbose output") + parser.add_argument( + "-d", "--debug", dest="debug", action="store_true", help="Enable debug output (doesn't do much)" + ) + parser.add_argument("-p", "--port", dest="port", default=80, type=int, help="TCP listen port") + parser.add_argument("-t", "--tls", dest="tls", action="store_true", help="Use HTTPS") + parser.add_argument( + "-k", "--insecure", dest="insecure", action="store_true", help="Disable verification of SSL certs" + ) + parser.add_argument( + "-c", + "--capabilities", + dest="capabilities", + help="A file containing a JSON representation of " "server capabilities", + ) + parser.add_argument("-i", "--issuer", dest="issuer", help="Issuer ID of the OpenID Connect Provider [OP]", nargs=1) + parser.add_argument(dest="config", help="Python config file (see examples)") args = parser.parse_args() # Client data base @@ -404,12 +399,12 @@ def _import_config(config_path): _issuer = args.issuer[0] else: if args.port not in [80, 443]: - _issuer = config.ISSUER + ':{}'.format(args.port) + _issuer = config.ISSUER + ":{}".format(args.port) else: _issuer = config.ISSUER - if _issuer[-1] != '/': - _issuer += '/' + if _issuer[-1] != "/": + _issuer += "/" config.SERVICE_URL = config.SERVICE_URL.format(issuer=_issuer) @@ -420,8 +415,8 @@ def _import_config(config_path): end_points = config.AUTHENTICATION["UserPassword"]["END_POINTS"] full_end_point_paths = ["%s%s" % (_issuer, ep) for ep in end_points] username_password_authn = UsernamePasswordMako( - None, "login.mako", LOOKUP, PASSWD, "%sauthorization" % _issuer, - None, full_end_point_paths) + None, "login.mako", LOOKUP, PASSWD, "%sauthorization" % _issuer, None, full_end_point_paths + ) _urls = [] for authkey, value in config.AUTHENTICATION.items(): @@ -429,11 +424,9 @@ def _import_config(config_path): if "UserPassword" == authkey: PASSWORD_END_POINT_INDEX = 0 - end_point = config.AUTHENTICATION[authkey]["END_POINTS"][ - PASSWORD_END_POINT_INDEX] - authn = AuthnIndexedEndpointWrapper(username_password_authn, - PASSWORD_END_POINT_INDEX) - _urls.append((r'^' + end_point, make_auth_verify(authn.verify))) + end_point = config.AUTHENTICATION[authkey]["END_POINTS"][PASSWORD_END_POINT_INDEX] + authn = AuthnIndexedEndpointWrapper(username_password_authn, PASSWORD_END_POINT_INDEX) + _urls.append((r"^" + end_point, make_auth_verify(authn.verify))) # Ensure javascript_login_authn to be defined try: @@ -443,99 +436,102 @@ def _import_config(config_path): if "JavascriptLogin" == authkey: if not javascript_login_authn: - end_points = config.AUTHENTICATION[ - "JavascriptLogin"]["END_POINTS"] - full_end_point_paths = [ - "{}{}".format(_issuer, ep) for ep in end_points] + end_points = config.AUTHENTICATION["JavascriptLogin"]["END_POINTS"] + full_end_point_paths = ["{}{}".format(_issuer, ep) for ep in end_points] javascript_login_authn = JavascriptFormMako( - None, "javascript_login.mako", LOOKUP, PASSWD, - "{}authorization".format(_issuer), None, - full_end_point_paths) + None, + "javascript_login.mako", + LOOKUP, + PASSWD, + "{}authorization".format(_issuer), + None, + full_end_point_paths, + ) ac.add("", javascript_login_authn, "", "") JAVASCRIPT_END_POINT_INDEX = 0 - end_point = config.AUTHENTICATION[authkey]["END_POINTS"][ - JAVASCRIPT_END_POINT_INDEX] - authn = AuthnIndexedEndpointWrapper(javascript_login_authn, - JAVASCRIPT_END_POINT_INDEX) - _urls.append((r'^' + end_point, make_auth_verify(authn.verify))) + end_point = config.AUTHENTICATION[authkey]["END_POINTS"][JAVASCRIPT_END_POINT_INDEX] + authn = AuthnIndexedEndpointWrapper(javascript_login_authn, JAVASCRIPT_END_POINT_INDEX) + _urls.append((r"^" + end_point, make_auth_verify(authn.verify))) if authkey in {"SAML", "SamlPass"}: # https://github.com/CZ-NIC/pyoidc/issues/33 # noinspection PyUnresolvedReferences - from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST + from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT + from oic.utils.authn.saml import SAMLAuthnMethod if "SAML" == authkey: if not saml_authn: saml_authn = SAMLAuthnMethod( - None, LOOKUP, config.SAML, config.SP_CONFIG, _issuer, + None, + LOOKUP, + config.SAML, + config.SP_CONFIG, + _issuer, "{}authorization".format(_issuer), - userinfo=config.USERINFO) + userinfo=config.USERINFO, + ) ac.add("", saml_authn, "", "") SAML_END_POINT_INDEX = 0 - end_point = config.AUTHENTICATION[authkey]["END_POINTS"][ - SAML_END_POINT_INDEX] - end_point_indexes = {BINDING_HTTP_REDIRECT: 0, BINDING_HTTP_POST: 0, - "disco_end_point_index": 0} + end_point = config.AUTHENTICATION[authkey]["END_POINTS"][SAML_END_POINT_INDEX] + end_point_indexes = {BINDING_HTTP_REDIRECT: 0, BINDING_HTTP_POST: 0, "disco_end_point_index": 0} authn = AuthnIndexedEndpointWrapper(saml_authn, end_point_indexes) - _urls.append((r'^' + end_point, make_auth_verify(authn.verify))) + _urls.append((r"^" + end_point, make_auth_verify(authn.verify))) if "SamlPass" == authkey: if not saml_authn: saml_authn = SAMLAuthnMethod( - None, LOOKUP, config.SAML, config.SP_CONFIG, _issuer, + None, + LOOKUP, + config.SAML, + config.SP_CONFIG, + _issuer, "{}authorization".format(_issuer), - userinfo=config.USERINFO) + userinfo=config.USERINFO, + ) PASSWORD_END_POINT_INDEX = 1 SAML_END_POINT_INDEX = 1 - password_end_point = config.AUTHENTICATION["UserPassword"][ - "END_POINTS"][PASSWORD_END_POINT_INDEX] - saml_endpoint = config.AUTHENTICATION["SAML"]["END_POINTS"][ - SAML_END_POINT_INDEX] - - end_point_indexes = {BINDING_HTTP_REDIRECT: 1, BINDING_HTTP_POST: 1, - "disco_end_point_index": 1} - multi_saml = AuthnIndexedEndpointWrapper(saml_authn, - end_point_indexes) - multi_password = AuthnIndexedEndpointWrapper( - username_password_authn, PASSWORD_END_POINT_INDEX) - - auth_modules = [(multi_saml, r'^' + saml_endpoint), - (multi_password, r'^' + password_end_point)] + password_end_point = config.AUTHENTICATION["UserPassword"]["END_POINTS"][PASSWORD_END_POINT_INDEX] + saml_endpoint = config.AUTHENTICATION["SAML"]["END_POINTS"][SAML_END_POINT_INDEX] + + end_point_indexes = {BINDING_HTTP_REDIRECT: 1, BINDING_HTTP_POST: 1, "disco_end_point_index": 1} + multi_saml = AuthnIndexedEndpointWrapper(saml_authn, end_point_indexes) + multi_password = AuthnIndexedEndpointWrapper(username_password_authn, PASSWORD_END_POINT_INDEX) + + auth_modules = [(multi_saml, r"^" + saml_endpoint), (multi_password, r"^" + password_end_point)] authn = setup_multi_auth(ac, _urls, auth_modules) if "JavascriptPass" == authkey: if not javascript_login_authn: - end_points = config.AUTHENTICATION[ - "JavascriptLogin"]["END_POINTS"] - full_end_point_paths = [ - "{}{}".format(_issuer, ep) for ep in end_points] + end_points = config.AUTHENTICATION["JavascriptLogin"]["END_POINTS"] + full_end_point_paths = ["{}{}".format(_issuer, ep) for ep in end_points] javascript_login_authn = JavascriptFormMako( - None, "javascript_login.mako", LOOKUP, PASSWD, - "{}authorization".format(_issuer), None, - full_end_point_paths) + None, + "javascript_login.mako", + LOOKUP, + PASSWD, + "{}authorization".format(_issuer), + None, + full_end_point_paths, + ) PASSWORD_END_POINT_INDEX = 2 JAVASCRIPT_POINT_INDEX = 1 - password_end_point = config.AUTHENTICATION["UserPassword"][ - "END_POINTS"][PASSWORD_END_POINT_INDEX] - javascript_end_point = config.AUTHENTICATION["JavascriptLogin"][ - "END_POINTS"][JAVASCRIPT_POINT_INDEX] + password_end_point = config.AUTHENTICATION["UserPassword"]["END_POINTS"][PASSWORD_END_POINT_INDEX] + javascript_end_point = config.AUTHENTICATION["JavascriptLogin"]["END_POINTS"][JAVASCRIPT_POINT_INDEX] - multi_password = AuthnIndexedEndpointWrapper( - username_password_authn, PASSWORD_END_POINT_INDEX) - multi_javascript = AuthnIndexedEndpointWrapper( - javascript_login_authn, JAVASCRIPT_POINT_INDEX) + multi_password = AuthnIndexedEndpointWrapper(username_password_authn, PASSWORD_END_POINT_INDEX) + multi_javascript = AuthnIndexedEndpointWrapper(javascript_login_authn, JAVASCRIPT_POINT_INDEX) - auth_modules = [(multi_password, r'^' + password_end_point), - (multi_javascript, r'^' + javascript_end_point)] + auth_modules = [ + (multi_password, r"^" + password_end_point), + (multi_javascript, r"^" + javascript_end_point), + ] authn = setup_multi_auth(ac, _urls, auth_modules) if authn is not None: - ac.add(config.AUTHENTICATION[authkey]["ACR"], authn, - config.AUTHENTICATION[authkey]["WEIGHT"], - "") + ac.add(config.AUTHENTICATION[authkey]["ACR"], authn, config.AUTHENTICATION[authkey]["WEIGHT"], "") # dealing with authorization authz = AuthzHandling() @@ -557,12 +553,9 @@ def _import_config(config_path): pass # In-Memory non persistent SessionDB - sdb = create_session_db(_issuer, - secret=rndstr(32), - password=rndstr(32)) + sdb = create_session_db(_issuer, secret=rndstr(32), password=rndstr(32)) - OAS = Provider(_issuer, sdb, cdb, ac, None, - authz, verify_client, config.SYM_KEY, **kwargs) + OAS = Provider(_issuer, sdb, cdb, ac, None, authz, verify_client, config.SYM_KEY, **kwargs) OAS.baseurl = _issuer for authn in ac: @@ -619,21 +612,20 @@ def _import_config(config_path): _app = Application(OAS, _urls) # Setup the web server - SRV = wsgiserver.CherryPyWSGIServer(('0.0.0.0', args.port), # nosec - _app.application) + SRV = wsgiserver.CherryPyWSGIServer( + ("0.0.0.0", args.port), # nosec + _app.application, + ) https = "" if args.tls: https = "using TLS" # SRV.ssl_adapter = ssl_pyopenssl.pyOpenSSLAdapter( # config.SERVER_CERT, config.SERVER_KEY, config.CERT_CHAIN) - SRV.ssl_adapter = BuiltinSSLAdapter(config.SERVER_CERT, - config.SERVER_KEY) + SRV.ssl_adapter = BuiltinSSLAdapter(config.SERVER_CERT, config.SERVER_KEY) - LOGGER.info( - "OC server started (iss={}, port={})".format(_issuer, args.port)) - print("OC server started (iss={}, port={}) {}".format(_issuer, args.port, - https)) + LOGGER.info("OC server started (iss={}, port={})".format(_issuer, args.port)) + print("OC server started (iss={}, port={}) {}".format(_issuer, args.port, https)) try: SRV.start() except KeyboardInterrupt: diff --git a/oidc_example/op3/config.py b/oidc_example/op3/config.py index 2f9ce887a..6bacc6dab 100644 --- a/oidc_example/op3/config.py +++ b/oidc_example/op3/config.py @@ -1,23 +1,22 @@ PORT = 8040 -ISSUER = 'https://localhost' # do not include the port, it will be added in the code. +ISSUER = "https://localhost" # do not include the port, it will be added in the code. SERVICEURL = "{issuer}verify" # do not manually add issuer or port number, these will be added in the code. SERVER_CERT = "certification/server.crt" SERVER_KEY = "certification/server.key" CERT_CHAIN = None AUTHENTICATION = { - "UserPassword": - { - "ACR": "PASSWORD", - "WEIGHT": 1, - "URL": SERVICEURL, - "EndPoints": ["verify"], - } + "UserPassword": { + "ACR": "PASSWORD", + "WEIGHT": 1, + "URL": SERVICEURL, + "EndPoints": ["verify"], + } } -CLIENTDB = 'ClientDB' -SYM_KEY = "SoLittleTime,Got" # used for Symmetric key authentication only. -COOKIENAME = 'pyoic' +CLIENTDB = "ClientDB" +SYM_KEY = "SoLittleTime,Got" # used for Symmetric key authentication only. +COOKIENAME = "pyoic" COOKIETTL = 4 * 60 # 4 hours USERINFO = "SIMPLE" @@ -36,7 +35,7 @@ "street_address": "address1", "locality": "locality1", "postal_code": "5719800000", - "country": "Iran" + "country": "Iran", }, }, "user2": { @@ -54,7 +53,7 @@ "postal_code": "5719899999", "country": "Iran", }, - } + }, } # This is a JSON Web Key (JWK) object, and its members represent @@ -62,8 +61,7 @@ keys = [ {"type": "RSA", "key": "cryptography_keys/key.pem", "use": ["enc", "sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} - + {"type": "EC", "crv": "P-256", "use": ["enc"]}, # "type" or "kty" identifies the cryptographic algorithm family used with the key. # The kty values are case sensitive. The kty values should either be registered # in the IANA "JSON Web Key Types" registery or be a value that contains a diff --git a/oidc_example/op3/modules/login.mako.py b/oidc_example/op3/modules/login.mako.py index 6317e7fc2..6d6e87c4a 100644 --- a/oidc_example/op3/modules/login.mako.py +++ b/oidc_example/op3/modules/login.mako.py @@ -1,4 +1,5 @@ -from mako import runtime, filters, cache +from mako import runtime + UNDEFINED = runtime.UNDEFINED STOP_RENDERING = runtime.STOP_RENDERING __M_dict_builtin = dict @@ -6,10 +7,10 @@ _magic_number = 10 _modified_time = 1490897004.899657 _enable_loop = True -_template_filename = 'htdocs/login.mako' -_template_uri = 'login.mako' -_source_encoding = 'utf-8' -_exports = ['add_js'] +_template_filename = "htdocs/login.mako" +_template_uri = "login.mako" +_source_encoding = "utf-8" +_exports = ["add_js"] def _mako_get_namespace(context, name): @@ -18,62 +19,70 @@ def _mako_get_namespace(context, name): except KeyError: _mako_generate_namespaces(context) return context.namespaces[(__name__, name)] + + def _mako_generate_namespaces(context): pass + + def _mako_inherit(template, context): _mako_generate_namespaces(context) - return runtime._inherit_from(context, u'root.mako', _template_uri) -def render_body(context,**pageargs): + return runtime._inherit_from(context, "root.mako", _template_uri) + + +def render_body(context, **pageargs): __M_caller = context.caller_stack._push_frame() try: __M_locals = __M_dict_builtin(pageargs=pageargs) - submit_text = context.get('submit_text', UNDEFINED) - acr = context.get('acr', UNDEFINED) - title = context.get('title', UNDEFINED) - login_title = context.get('login_title', UNDEFINED) - passwd_title = context.get('passwd_title', UNDEFINED) - tos_uri = context.get('tos_uri', UNDEFINED) - policy_uri = context.get('policy_uri', UNDEFINED) - action = context.get('action', UNDEFINED) - query = context.get('query', UNDEFINED) - login = context.get('login', UNDEFINED) - password = context.get('password', UNDEFINED) - logo_uri = context.get('logo_uri', UNDEFINED) + submit_text = context.get("submit_text", UNDEFINED) + acr = context.get("acr", UNDEFINED) + title = context.get("title", UNDEFINED) + login_title = context.get("login_title", UNDEFINED) + passwd_title = context.get("passwd_title", UNDEFINED) + tos_uri = context.get("tos_uri", UNDEFINED) + policy_uri = context.get("policy_uri", UNDEFINED) + action = context.get("action", UNDEFINED) + query = context.get("query", UNDEFINED) + login = context.get("login", UNDEFINED) + password = context.get("password", UNDEFINED) + logo_uri = context.get("logo_uri", UNDEFINED) __M_writer = context.writer() - __M_writer(u'\n
\n

') - __M_writer(unicode(title)) - __M_writer(u'

\n
\n
\n \n') + __M_writer('\n\n\n\n') - __M_writer(u'\n') - return '' + __M_writer(' Client's Terms of Service\n') + __M_writer("
\n\n") + __M_writer("\n") + return "" finally: context.caller_stack._pop_frame() @@ -82,8 +91,10 @@ def render_add_js(context): __M_caller = context.caller_stack._push_frame() try: __M_writer = context.writer() - __M_writer(u'\n \n') - return '' + __M_writer( + '\n \n' + ) + return "" finally: context.caller_stack._pop_frame() diff --git a/oidc_example/op3/modules/root.mako.py b/oidc_example/op3/modules/root.mako.py index 5fc5d0770..6680f01f2 100644 --- a/oidc_example/op3/modules/root.mako.py +++ b/oidc_example/op3/modules/root.mako.py @@ -1,4 +1,5 @@ -from mako import runtime, filters, cache +from mako import runtime, filters + UNDEFINED = runtime.UNDEFINED STOP_RENDERING = runtime.STOP_RENDERING __M_dict_builtin = dict @@ -6,67 +7,71 @@ _magic_number = 10 _modified_time = 1490897004.910343 _enable_loop = True -_template_filename = u'Templates/root.mako' -_template_uri = u'root.mako' -_source_encoding = 'utf-8' -_exports = ['css_link', 'pre', 'post', 'css'] +_template_filename = "Templates/root.mako" +_template_uri = "root.mako" +_source_encoding = "utf-8" +_exports = ["css_link", "pre", "post", "css"] -def render_body(context,**pageargs): +def render_body(context, **pageargs): __M_caller = context.caller_stack._push_frame() try: __M_locals = __M_dict_builtin(pageargs=pageargs) + def pre(): return render_pre(context._locals(__M_locals)) - self = context.get('self', UNDEFINED) - set = context.get('set', UNDEFINED) + + self = context.get("self", UNDEFINED) + set = context.get("set", UNDEFINED) + def post(): return render_post(context._locals(__M_locals)) - next = context.get('next', UNDEFINED) + + next = context.get("next", UNDEFINED) __M_writer = context.writer() - self.seen_css = set() - - __M_writer(u'\n') - __M_writer(u'\n') - __M_writer(u'\n') - __M_writer(u'\n') - __M_writer(u'\n') - __M_writer(u'\nOpenID Connect provider example\n') - __M_writer(unicode(self.css())) - __M_writer(u'\n\n\n\n') - __M_writer(unicode(pre())) - __M_writer(u'\n') - __M_writer(unicode(next.body())) - __M_writer(u'\n') - __M_writer(unicode(post())) - __M_writer(u'\n\n\n') - return '' + self.seen_css = set() + + __M_writer("\n") + __M_writer("\n") + __M_writer("\n") + __M_writer("\n") + __M_writer("\n") + __M_writer("\nOpenID Connect provider example\n") + __M_writer(self.css()) + __M_writer('\n\n\n\n') + __M_writer(pre()) + __M_writer("\n") + __M_writer(next.body()) + __M_writer("\n") + __M_writer(post()) + __M_writer("\n\n\n") + return "" finally: context.caller_stack._pop_frame() -def render_css_link(context,path,media=''): +def render_css_link(context, path, media=""): __M_caller = context.caller_stack._push_frame() try: context._push_buffer() - self = context.get('self', UNDEFINED) + self = context.get("self", UNDEFINED) __M_writer = context.writer() - __M_writer(u'\n') + __M_writer("\n") if path not in self.seen_css: - __M_writer(u' \n') - __M_writer(u' ') - self.seen_css.add(path) - - __M_writer(u'\n') + __M_writer(' \n') + __M_writer(" ") + self.seen_css.add(path) + + __M_writer("\n") finally: __M_buf, __M_writer = context._pop_buffer_and_writer() context.caller_stack._pop_frame() __M_writer(filters.trim(__M_buf.getvalue())) - return '' + return "" def render_pre(context): @@ -74,12 +79,12 @@ def render_pre(context): try: context._push_buffer() __M_writer = context.writer() - __M_writer(u'\n') + __M_writer("\n") finally: __M_buf, __M_writer = context._pop_buffer_and_writer() context.caller_stack._pop_frame() __M_writer(filters.trim(__M_buf.getvalue())) - return '' + return "" def render_post(context): @@ -87,29 +92,33 @@ def render_post(context): try: context._push_buffer() __M_writer = context.writer() - __M_writer(u'\n
\n \n
\n') + __M_writer( + '\n
\n \n
\n' + ) finally: __M_buf, __M_writer = context._pop_buffer_and_writer() context.caller_stack._pop_frame() __M_writer(filters.trim(__M_buf.getvalue())) - return '' + return "" def render_css(context): __M_caller = context.caller_stack._push_frame() try: context._push_buffer() - def css_link(path,media=''): - return render_css_link(context,path,media) + + def css_link(path, media=""): + return render_css_link(context, path, media) + __M_writer = context.writer() - __M_writer(u'\n ') - __M_writer(unicode(css_link('/css/main.css', 'screen'))) - __M_writer(u'\n') + __M_writer("\n ") + __M_writer(css_link("/css/main.css", "screen")) + __M_writer("\n") finally: __M_buf, __M_writer = context._pop_buffer_and_writer() context.caller_stack._pop_frame() __M_writer(filters.trim(__M_buf.getvalue())) - return '' + return "" """ diff --git a/oidc_example/op3/server.py b/oidc_example/op3/server.py index 8a4eee6bb..0035982f8 100755 --- a/oidc_example/op3/server.py +++ b/oidc_example/op3/server.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -__author__ = 'Vahid Jalili' +__author__ = "Vahid Jalili" from urllib.parse import parse_qs @@ -10,6 +10,7 @@ import traceback import argparse import importlib +import time import logging from mako.lookup import TemplateLookup @@ -29,7 +30,8 @@ from oic.utils.authn.multi_auth import AuthnIndexedEndpointWrapper from oic.utils.authn.user import UsernamePasswordMako from oic.utils.authz import AuthzHandling -from oic.utils.http_util import * +from oic.utils.http_util import NotFound, ServiceError, Response, BadRequest, wsgi_wrapper, get_post, Unauthorized +from jwkest import as_unicode from oic.utils.keyio import keyjar_init from oic.utils.userinfo import UserInfo from oic.utils.webfinger import OIC_ISSUER @@ -42,22 +44,19 @@ from oic.utils.sdb import create_session_db - LOGGER = logging.getLogger("") -LOGFILE_NAME = 'oc.log' +LOGFILE_NAME = "oc.log" hdlr = logging.FileHandler(LOGFILE_NAME) -base_formatter = logging.Formatter( - "%(asctime)s %(name)s:%(levelname)s %(message)s") +base_formatter = logging.Formatter("%(asctime)s %(name)s:%(levelname)s %(message)s") -CPC = ('%(asctime)s %(name)s:%(levelname)s ' - '[%(client)s,%(path)s,%(cid)s] %(message)s') +CPC = "%(asctime)s %(name)s:%(levelname)s " "[%(client)s,%(path)s,%(cid)s] %(message)s" cpc_formatter = logging.Formatter(CPC) hdlr.setFormatter(base_formatter) LOGGER.addHandler(hdlr) LOGGER.setLevel(logging.DEBUG) -logger = logging.getLogger('oicServer') +logger = logging.getLogger("oicServer") def static_file(path): @@ -73,19 +72,19 @@ def static(self, environ, start_response, path): logger.info("[static]sending: %s" % (path,)) try: - data = open(path, 'rb').read() + data = open(path, "rb").read() if path.endswith(".ico"): - start_response('200 OK', [('Content-Type', "image/x-icon")]) + start_response("200 OK", [("Content-Type", "image/x-icon")]) elif path.endswith(".html"): - start_response('200 OK', [('Content-Type', 'text/html')]) + start_response("200 OK", [("Content-Type", "text/html")]) elif path.endswith(".json"): - start_response('200 OK', [('Content-Type', 'application/json')]) + start_response("200 OK", [("Content-Type", "application/json")]) elif path.endswith(".txt"): - start_response('200 OK', [('Content-Type', 'text/plain')]) + start_response("200 OK", [("Content-Type", "text/plain")]) elif path.endswith(".css"): - start_response('200 OK', [('Content-Type', 'text/css')]) + start_response("200 OK", [("Content-Type", "text/css")]) else: - start_response('200 OK', [('Content-Type', "text/xml")]) + start_response("200 OK", [("Content-Type", "text/xml")]) return [data] except IOError: resp = NotFound() @@ -130,17 +129,19 @@ def __init__(self, provider, urls): self.provider.endp = self.endpoints self.urls = urls - self.urls.extend([ - (r'^.well-known/openid-configuration', self.op_info), - (r'^.well-known/simple-web-discovery', self.swd_info), - (r'^.well-known/host-meta.json', self.meta_info), - (r'^.well-known/webfinger', self.webfinger), - (r'.+\.css$', self.css), - (r'safe', self.safe), - (r'^keyrollover', key_rollover), - (r'^clearkeys', clear_keys), - (r'^check_session', check_session_iframe) - ]) + self.urls.extend( + [ + (r"^.well-known/openid-configuration", self.op_info), + (r"^.well-known/simple-web-discovery", self.swd_info), + (r"^.well-known/host-meta.json", self.meta_info), + (r"^.well-known/webfinger", self.webfinger), + (r".+\.css$", self.css), + (r"safe", self.safe), + (r"^keyrollover", key_rollover), + (r"^clearkeys", clear_keys), + (r"^check_session", check_session_iframe), + ] + ) for endp in self.endpoints: self.urls.append(("^%s" % endp.etype, endp.func)) @@ -158,9 +159,9 @@ def safe(self, environ, start_response): resp = BadRequest("Missing authorization information") return resp(environ, start_response) else: - if typ != "Bearer": - resp = BadRequest("Unsupported authorization method") - return resp(environ, start_response) + if typ != "Bearer": + resp = BadRequest("Unsupported authorization method") + return resp(environ, start_response) try: _sinfo = _srv.sdb[code] @@ -184,56 +185,47 @@ def css(self, environ, start_response): # noinspection PyUnusedLocal def token(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.provider.token_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.token_endpoint, logger=logger) # noinspection PyUnusedLocal def authorization(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.provider.authorization_endpoint, logger=logger) # cookies required. + return wsgi_wrapper( + environ, start_response, self.provider.authorization_endpoint, logger=logger + ) # cookies required. # noinspection PyUnusedLocal def userinfo(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.provider.userinfo_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.userinfo_endpoint, logger=logger) # noinspection PyUnusedLocal def op_info(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.provider.providerinfo_endpoint, logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.providerinfo_endpoint, logger=logger) # noinspection PyUnusedLocal def registration(self, environ, start_response): if environ["REQUEST_METHOD"] == "POST": - return wsgi_wrapper(environ, start_response, - self.provider.registration_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.registration_endpoint, logger=logger) elif environ["REQUEST_METHOD"] == "GET": - return wsgi_wrapper(environ, start_response, - self.provider.read_registration, logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.read_registration, logger=logger) else: resp = ServiceError("Method not supported") return resp(environ, start_response) # noinspection PyUnusedLocal def check_id(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.provider.check_id_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.check_id_endpoint, logger=logger) # noinspection PyUnusedLocal def swd_info(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.provider.discovery_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.discovery_endpoint, logger=logger) # noinspection PyUnusedLocal def trace_log(self, environ, start_response): - return wsgi_wrapper(environ, start_response, self.provider.tracelog_endpoint, - logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.tracelog_endpoint, logger=logger) # noinspection PyUnusedLocal def endsession(self, environ, start_response): - return wsgi_wrapper(environ, start_response, - self.provider.endsession_endpoint, logger=logger) + return wsgi_wrapper(environ, start_response, self.provider.endsession_endpoint, logger=logger) # noinspection PyUnusedLocal def meta_info(self, environ, start_response): @@ -248,7 +240,7 @@ def meta_info(self, environ, start_response): ]} """ - print('\n in meta-info') + print("\n in meta-info") pass def webfinger(self, environ, start_response): @@ -263,8 +255,7 @@ def webfinger(self, environ, start_response): resp = BadRequest("Bad issuer in request") else: wf = WebFinger() - resp = Response(wf.response(subject=resource, - base=self.provider.baseurl)) + resp = Response(wf.response(subject=resource, base=self.provider.baseurl)) return resp(environ, start_response) def application(self, environ, start_response): @@ -281,9 +272,9 @@ def application(self, environ, start_response): request is done :return: The response as a list of lines """ - path = environ.get('PATH_INFO', '').lstrip('/') + path = environ.get("PATH_INFO", "").lstrip("/") - print('start_response: ', start_response) + print("start_response: ", start_response) if path == "robots.txt": return static(self, environ, start_response, "static/robots.txt") @@ -296,9 +287,9 @@ def application(self, environ, start_response): match = re.search(regex, path) if match is not None: try: - environ['oic.url_args'] = match.groups()[0] + environ["oic.url_args"] = match.groups()[0] except IndexError: - environ['oic.url_args'] = path + environ["oic.url_args"] = path try: return callback(environ, start_response) except Exception as err: @@ -314,21 +305,20 @@ def application(self, environ, start_response): return resp(environ, start_response) -if __name__ == '__main__': - - root = './' - lookup = TemplateLookup(directories=[root + 'Templates', root + 'htdocs'], - module_directory=root + 'modules', - input_encoding='utf-8', output_encoding='utf-8') +if __name__ == "__main__": + root = "./" + lookup = TemplateLookup( + directories=[root + "Templates", root + "htdocs"], + module_directory=root + "modules", + input_encoding="utf-8", + output_encoding="utf-8", + ) def mako_renderer(template_name, context): mte = lookup.get_template(template_name) return mte.render(**context) - usernamePasswords = { - "user1": "1", - "user2": "2" - } + usernamePasswords = {"user1": "1", "user2": "2"} passwordEndPointIndex = 0 # what is this, and what does its value mean? @@ -337,18 +327,18 @@ def mako_renderer(template_name, context): # parse the parameters parser = argparse.ArgumentParser() - parser.add_argument('-c', dest='config') - parser.add_argument('-d', dest='debug', action='store_true') + parser.add_argument("-c", dest="config") + parser.add_argument("-d", dest="debug", action="store_true") args = parser.parse_args() # parse and setup configuration config = importlib.import_module(args.config) - config.ISSUER = config.ISSUER + ':{}/'.format(config.PORT) + config.ISSUER = config.ISSUER + ":{}/".format(config.PORT) config.SERVICEURL = config.SERVICEURL.format(issuer=config.ISSUER) endPoints = config.AUTHENTICATION["UserPassword"]["EndPoints"] fullEndPointsPath = ["%s%s" % (config.ISSUER, ep) for ep in endPoints] -# TODO: why this instantiation happens so early? can I move it later? + # TODO: why this instantiation happens so early? can I move it later? # An OIDC Authorization/Authentication server is designed to # allow more than one authentication method to be used by the server. # And that is what the AuthBroker is for. @@ -362,44 +352,45 @@ def mako_renderer(template_name, context): # UsernamePasswordMako: authenticas a user using the username/password form in a # WSGI environment using Mako as template system usernamePasswordAuthn = UsernamePasswordMako( - None, # server instance - "login.mako", # a mako template - lookup, # lookup template - usernamePasswords, # username/password dictionary-like database + None, # server instance + "login.mako", # a mako template + lookup, # lookup template + usernamePasswords, # username/password dictionary-like database "%sauthorization" % config.ISSUER, # where to send the user after authentication - None, # templ_arg_func ??!! - fullEndPointsPath) # verification endpoints + None, # templ_arg_func ??!! + fullEndPointsPath, + ) # verification endpoints # AuthnIndexedEndpointWrapper is a wrapper class for using an authentication module with multiple endpoints. authnIndexedEndPointWrapper = AuthnIndexedEndpointWrapper(usernamePasswordAuthn, passwordEndPointIndex) - authnBroker.add(config.AUTHENTICATION["UserPassword"]["ACR"], # (?!) - authnIndexedEndPointWrapper, # (?!) method: an identifier of the authentication method. - config.AUTHENTICATION["UserPassword"]["WEIGHT"], # security level - "") # (?!) authentication authority + authnBroker.add( + config.AUTHENTICATION["UserPassword"]["ACR"], # (?!) + authnIndexedEndPointWrapper, # (?!) method: an identifier of the authentication method. + config.AUTHENTICATION["UserPassword"]["WEIGHT"], # security level + "", + ) # (?!) authentication authority # ?! authz = AuthzHandling() clientDB = shelve_wrapper.open(config.CLIENTDB) # In-Memory non-persistent SessionDB issuing DefaultTokens - sessionDB = create_session_db(config.ISSUER, - secret=rndstr(32), - password=rndstr(32)) + sessionDB = create_session_db(config.ISSUER, secret=rndstr(32), password=rndstr(32)) provider = Provider( - name=config.ISSUER, # name - sdb=sessionDB, # session database. - cdb=clientDB, # client database - authn_broker=authnBroker, # authn broker - userinfo=None, # user information - authz=authz, # authz - client_authn=verify_client, # client authentication - symkey=config.SYM_KEY, # Used for Symmetric key authentication + name=config.ISSUER, # name + sdb=sessionDB, # session database. + cdb=clientDB, # client database + authn_broker=authnBroker, # authn broker + userinfo=None, # user information + authz=authz, # authz + client_authn=verify_client, # client authentication + symkey=config.SYM_KEY, # Used for Symmetric key authentication # urlmap = None, # ? # keyjar = None, # ? # hostname = "", # ? - template_renderer=mako_renderer, # Rendering custom templates + template_renderer=mako_renderer, # Rendering custom templates # verify_ssl = True, # Enable SSL certs # capabilities = None, # ? # schema = OpenIDSchema, # ? @@ -407,7 +398,7 @@ def mako_renderer(template_name, context): # jwks_name = '', # ? baseurl=config.ISSUER, # client_cert = None # ? - ) + ) # SessionDB: # This is database where the provider keeps information about @@ -443,10 +434,11 @@ def mako_renderer(template_name, context): # keyjar_init configures cryptographic key # based on the provided configuration "keys". jwks = keyjar_init( - provider, # server/client instance - config.keys, # key configuration - kid_template="op%d") # template by which to build the kids (key ID parameter) - except Exception as err: + provider, # server/client instance + config.keys, # key configuration + kid_template="op%d", + ) # template by which to build the kids (key ID parameter) + except Exception: # LOGGER.error("Key setup failed: %s" % err) provider.key_setup("static", sig={"format": "jwk", "alg": "rsa"}) else: @@ -470,12 +462,12 @@ def mako_renderer(template_name, context): endPoint = config.AUTHENTICATION["UserPassword"]["EndPoints"][passwordEndPointIndex] _urls = [] - _urls.append((r'^' + endPoint, make_auth_verify(authnIndexedEndPointWrapper.verify))) + _urls.append((r"^" + endPoint, make_auth_verify(authnIndexedEndPointWrapper.verify))) _app = Application(provider, _urls) # Setup the web server - server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', config.PORT), _app.application) # nosec + server = wsgiserver.CherryPyWSGIServer(("0.0.0.0", config.PORT), _app.application) # nosec server.ssl_adapter = BuiltinSSLAdapter(config.SERVER_CERT, config.SERVER_KEY) print("OIDC Provider server started (issuer={}, port={})".format(config.ISSUER, config.PORT)) diff --git a/oidc_example/rp2/oidc.py b/oidc_example/rp2/oidc.py index 0526cd278..b3a14d416 100644 --- a/oidc_example/rp2/oidc.py +++ b/oidc_example/rp2/oidc.py @@ -15,7 +15,7 @@ from oic.utils.http_util import Response from oic.utils.webfinger import WebFinger -__author__ = 'rolandh' +__author__ = "rolandh" logger = logging.getLogger(__name__) @@ -31,8 +31,7 @@ def token_secret_key(sid): class OpenIDConnect(object): - def __init__(self, attribute_map=None, authenticating_authority=None, - name="", registration_info=None, **kwargs): + def __init__(self, attribute_map=None, authenticating_authority=None, name="", registration_info=None, **kwargs): self.attribute_map = attribute_map self.authenticating_authority = authenticating_authority self.name = name @@ -70,10 +69,9 @@ def dynamic(self, server_env, callback, logout_callback, session, key): provider_conf = client.provider_config(self.srv_discovery_url) logger.debug("Got provider config: %s", provider_conf) - session['provider'] = provider_conf["issuer"] + session["provider"] = provider_conf["issuer"] logger.debug("Registering RP") - reg_info = client.register(provider_conf["registration_endpoint"], - **_me) + reg_info = client.register(provider_conf["registration_endpoint"], **_me) logger.debug("Registration response: %s", reg_info) for prop in ["client_id", "client_secret"]: try: @@ -122,12 +120,12 @@ def begin(self, environ, server_env, start_response, session, key): try: logger.debug("FLOW type: %s", self.flow_type) logger.debug("begin environ: %s", server_env) - client = session['client'] + client = session["client"] if client is not None and self.srv_discovery_url: data = {"client_id": client.client_id} - resp = requests.get(self.srv_discovery_url + "verifyClientId", - params=data, verify=self.extra["ca_bundle"], - timeout=10) + resp = requests.get( + self.srv_discovery_url + "verifyClientId", params=data, verify=self.extra["ca_bundle"], timeout=10 + ) if not resp.ok and resp.status_code == 400: client = None server_env["OIC_CLIENT"].pop(key, None) @@ -137,42 +135,37 @@ def begin(self, environ, server_env, start_response, session, key): callback = server_env["base_url"] + key logout_callback = server_env["base_url"] if self.srv_discovery_url: - client = self.dynamic(server_env, callback, logout_callback, - session, key) + client = self.dynamic(server_env, callback, logout_callback, session, key) else: - client = self.static(server_env, callback, logout_callback, - key) - _state = session['state'] - session['client'] = client + client = self.static(server_env, callback, logout_callback, key) + _state = session["state"] + session["client"] = client acr_value = session.get_acr_value(client.authorization_endpoint) try: acr_values = client.provider_info["acr_values_supported"] - session['acr_values'] = acr_values + session["acr_values"] = acr_values except KeyError: acr_values = None - if acr_value is None and acr_values is not None and \ - len(acr_values) > 1: + if acr_value is None and acr_values is not None and len(acr_values) > 1: resp_headers = [("Location", str("/rpAcr"))] start_response("302 Found", resp_headers) return [] elif acr_values is not None and len(acr_values) == 1: acr_value = acr_values[0] - return self.create_authnrequest(environ, server_env, start_response, - session, acr_value, _state) + return self.create_authnrequest(environ, server_env, start_response, session, acr_value, _state) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) return self.result( - environ, start_response, server_env, - (False, "Cannot find the OP! Please view your configuration.")) + environ, start_response, server_env, (False, "Cannot find the OP! Please view your configuration.") + ) # noinspection PyUnusedLocal - def create_authnrequest(self, environ, server_env, start_response, session, - acr_value, state): + def create_authnrequest(self, environ, server_env, start_response, session, acr_value, state): try: - client = session['client'] + client = session["client"] session.set_acr_value(client.authorization_endpoint, acr_value) request_args = { "response_type": self.flow_type, @@ -185,12 +178,12 @@ def create_authnrequest(self, environ, server_env, start_response, session, if self.flow_type == "token": request_args["nonce"] = rndstr(16) - session['nonce'] = request_args["nonce"] + session["nonce"] = request_args["nonce"] else: use_nonce = getattr(self, "use_nonce", None) if use_nonce: request_args["nonce"] = rndstr(16) - session['nonce'] = request_args["nonce"] + session["nonce"] = request_args["nonce"] logger.info("client args: %s", list(client.__dict__.items())) logger.info("request_args: %s", request_args) @@ -199,28 +192,28 @@ def create_authnrequest(self, environ, server_env, start_response, session, message = traceback.format_exception(*sys.exc_info()) logger.error(message) return self.result( - environ, start_response, server_env, - (False, "Cannot find the OP! Please view your configuration.")) + environ, start_response, server_env, (False, "Cannot find the OP! Please view your configuration.") + ) try: - cis = client.construct_AuthorizationRequest( - request_args=request_args) + cis = client.construct_AuthorizationRequest(request_args=request_args) logger.debug("request: %s", cis) url, body, ht_args, cis = client.uri_and_body( - AuthorizationRequest, cis, method="GET", - request_args=request_args) + AuthorizationRequest, cis, method="GET", request_args=request_args + ) logger.debug("body: %s", body) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) - return self.result(environ, start_response, server_env, ( - False, "Authorization request can not be performed!")) + return self.result( + environ, start_response, server_env, (False, "Authorization request can not be performed!") + ) logger.info("URL: %s", url) logger.debug("ht_args: %s", ht_args) - session['client'] = client + session["client"] = client resp_headers = [("Location", str(url))] if ht_args: resp_headers.extend([(a, b) for a, b in ht_args.items()]) @@ -243,8 +236,8 @@ def get_accesstoken(self, client, authresp): # get the access token return client.do_access_token_request( - state=authresp["state"], response_cls=self.access_token_response, - **kwargs) + state=authresp["state"], response_cls=self.access_token_response, **kwargs + ) # noinspection PyUnusedLocal def verify_token(self, client, access_token): @@ -252,10 +245,9 @@ def verify_token(self, client, access_token): def get_userinfo(self, client, authresp, access_token, **kwargs): # use the access token to get some userinfo - return client.do_user_info_request(state=authresp["state"], - schema="openid", - access_token=access_token, - **kwargs) + return client.do_user_info_request( + state=authresp["state"], schema="openid", access_token=access_token, **kwargs + ) # noinspection PyUnusedLocal def phaseN(self, environ, query, server_env, session): @@ -263,12 +255,11 @@ def phaseN(self, environ, query, server_env, session): callback URL you can request the access token the user has approved.""" - client = session['client'] + client = session["client"] logger.debug("info: %s", query) logger.debug("keyjar: %s", client.keyjar) - authresp = client.parse_response(AuthorizationResponse, query, - sformat="dict", keyjar=client.keyjar) + authresp = client.parse_response(AuthorizationResponse, query, sformat="dict", keyjar=client.keyjar) if isinstance(authresp, ErrorResponse): return False, "Access denied" @@ -302,7 +293,7 @@ def phaseN(self, environ, query, server_env, session): if isinstance(inforesp, ErrorResponse): return False, "Invalid response %s." % inforesp["error"], session - tot_info = userinfo.update(inforesp.to_dict()) + userinfo.update(inforesp.to_dict()) logger.debug("UserInfo: %s", inforesp) @@ -328,7 +319,7 @@ def callback(self, environ, server_env, start_response, query, session): try: result = self.phaseN(environ, query, server_env, session) - session['login'] = True + session["login"] = True logger.debug("[do_%s] response: %s", _service, result) except Exception: message = traceback.format_exception(*sys.exc_info()) @@ -338,12 +329,8 @@ def callback(self, environ, server_env, start_response, query, session): return self.result(environ, start_response, server_env, result) def result(self, environ, start_response, server_env, result): - resp = Response(mako_template="opresult.mako", - template_lookup=server_env["template_lookup"], - headers=[]) - argv = { - "result": result - } + resp = Response(mako_template="opresult.mako", template_lookup=server_env["template_lookup"], headers=[]) + argv = {"result": result} return resp(environ, start_response, **argv) def find_srv_discovery_url(self, resource): diff --git a/oidc_example/rp2/rp2.py b/oidc_example/rp2/rp2.py index 71f83e2fc..59784efec 100755 --- a/oidc_example/rp2/rp2.py +++ b/oidc_example/rp2/rp2.py @@ -19,23 +19,20 @@ from oic.utils.http_util import ServiceError LOGGER = logging.getLogger("") -LOGFILE_NAME = 'rp.log' +LOGFILE_NAME = "rp.log" hdlr = logging.FileHandler(LOGFILE_NAME) -base_formatter = logging.Formatter( - "%(asctime)s %(name)s:%(levelname)s %(message)s") +base_formatter = logging.Formatter("%(asctime)s %(name)s:%(levelname)s %(message)s") -CPC = ('%(asctime)s %(name)s:%(levelname)s ' - '[%(client)s,%(path)s,%(cid)s] %(message)s') +CPC = "%(asctime)s %(name)s:%(levelname)s " "[%(client)s,%(path)s,%(cid)s] %(message)s" cpc_formatter = logging.Formatter(CPC) hdlr.setFormatter(base_formatter) LOGGER.addHandler(hdlr) LOGGER.setLevel(logging.DEBUG) -LOOKUP = TemplateLookup(directories=['templates', 'htdocs'], - module_directory='modules', - input_encoding='utf-8', - output_encoding='utf-8') +LOOKUP = TemplateLookup( + directories=["templates", "htdocs"], module_directory="modules", input_encoding="utf-8", output_encoding="utf-8" +) SERVER_ENV = {} RP = None @@ -44,18 +41,17 @@ def setup_server_env(conf): global SERVER_ENV - SERVER_ENV = dict([(k, v) for k, v in conf.__dict__.items() - if not k.startswith("__")]) + SERVER_ENV = dict([(k, v) for k, v in conf.__dict__.items() if not k.startswith("__")]) SERVER_ENV["template_lookup"] = LOOKUP SERVER_ENV["base_url"] = conf.BASE - #SERVER_ENV["CACHE"] = {} + # SERVER_ENV["CACHE"] = {} SERVER_ENV["OIC_CLIENT"] = {} class Httpd(object): def http_request(self, url): # ignore cert validation for the example... - return requests.get(url, verify=False) # nosec + return requests.get(url, verify=False) # nosec class Session(object): @@ -63,7 +59,7 @@ def __init__(self, session): self.session = session def __getitem__(self, item): - if item == 'state': + if item == "state": return uuid.uuid4().urn try: @@ -86,29 +82,29 @@ def get_acr_value(self, key): def set_acr_value(self, key, val): try: - self.session['acr_value'][key] = val + self.session["acr_value"][key] = val except KeyError: - self.session['acr_value'] = {key: val} + self.session["acr_value"] = {key: val} -#noinspection PyUnresolvedReferences +# noinspection PyUnresolvedReferences def static(environ, start_response, logger, path): logger.info("[static]sending: %s" % (path,)) try: - data = open(path, 'rb').read() + data = open(path, "rb").read() if path.endswith(".ico"): - start_response('200 OK', [('Content-Type', "image/x-icon")]) + start_response("200 OK", [("Content-Type", "image/x-icon")]) elif path.endswith(".html"): - start_response('200 OK', [('Content-Type', 'text/html')]) + start_response("200 OK", [("Content-Type", "text/html")]) elif path.endswith(".json"): - start_response('200 OK', [('Content-Type', 'application/json')]) + start_response("200 OK", [("Content-Type", "application/json")]) elif path.endswith(".txt"): - start_response('200 OK', [('Content-Type', 'text/plain')]) + start_response("200 OK", [("Content-Type", "text/plain")]) elif path.endswith(".css"): - start_response('200 OK', [('Content-Type', 'text/css')]) + start_response("200 OK", [("Content-Type", "text/css")]) else: - start_response('200 OK', [('Content-Type', "text/xml")]) + start_response("200 OK", [("Content-Type", "text/xml")]) return [data] except IOError: resp = NotFound() @@ -116,29 +112,20 @@ def static(environ, start_response, logger, path): def opbyuid(environ, start_response): - resp = Response(mako_template="opbyuid.mako", - template_lookup=LOOKUP, - headers=[]) - argv = { - } + resp = Response(mako_template="opbyuid.mako", template_lookup=LOOKUP, headers=[]) + argv = {} return resp(environ, start_response, **argv) def post_logout(environ, start_response): - resp = Response(mako_template="post_logout.mako", - template_lookup=LOOKUP, - headers=[]) + resp = Response(mako_template="post_logout.mako", template_lookup=LOOKUP, headers=[]) argv = {} return resp(environ, start_response, **argv) def choose_acr_value(environ, start_response, session): - resp = Response(mako_template="acrvalue.mako", - template_lookup=LOOKUP, - headers=[]) - argv = { - "acrvalues": session['acr_values'] - } + resp = Response(mako_template="acrvalue.mako", template_lookup=LOOKUP, headers=[]) + argv = {"acrvalues": session["acr_values"]} return resp(environ, start_response, **argv) @@ -152,9 +139,9 @@ def id_token_as_signed_jwt(client, alg="RS256"): def application(environ, start_response): - session = Session(environ['beaker.session']) + session = Session(environ["beaker.session"]) - path = environ.get('PATH_INFO', '').lstrip('/') + path = environ.get("PATH_INFO", "").lstrip("/") if path == "robots.txt": return static(environ, start_response, LOGGER, "static/robots.txt") @@ -165,29 +152,27 @@ def application(environ, start_response): if path == "logout": try: - logoutUrl = session['client'].end_session_endpoint + logoutUrl = session["client"].end_session_endpoint plru = "{}post_logout".format(SERVER_ENV["base_url"]) logoutUrl += "?" + urlencode({"post_logout_redirect_uri": plru}) try: - logoutUrl += "&" + urlencode({ - "id_token_hint": id_token_as_signed_jwt( - session['client'], "HS256")}) - except AttributeError as err: + logoutUrl += "&" + urlencode({"id_token_hint": id_token_as_signed_jwt(session["client"], "HS256")}) + except AttributeError: pass session.clear() resp = SeeOther(str(logoutUrl)) return resp(environ, start_response) - except Exception as err: + except Exception: LOGGER.exception("Failed to handle logout") if path == "post_logout": return post_logout(environ, start_response) - if session['callback']: + if session["callback"]: _uri = "%s%s" % (conf.BASE, path) for _cli in SERVER_ENV["OIC_CLIENT"].values(): if _uri in _cli.redirect_uris: - session['callback'] = False + session["callback"] = False func = getattr(RP, "callback") return func(environ, SERVER_ENV, start_response, query, session) @@ -195,14 +180,13 @@ def application(environ, start_response): return choose_acr_value(environ, start_response, session) if path == "rpAuth": - # Only called if multiple arc_values (that is authentications) exists. - if "acr" in query and query["acr"][0] in session['acr_values']: + # Only called if multiple arc_values (that is authentications) exists. + if "acr" in query and query["acr"][0] in session["acr_values"]: func = getattr(RP, "create_authnrequest") - return func(environ, SERVER_ENV, start_response, session, - query["acr"][0]) + return func(environ, SERVER_ENV, start_response, session, query["acr"][0]) if session["client"] is not None: - session['callback'] = True + session["callback"] = True func = getattr(RP, "begin") return func(environ, SERVER_ENV, start_response, session, "") @@ -215,36 +199,36 @@ def application(environ, start_response): return resp(environ, start_response) RP.srv_discovery_url = link - h = hashlib.new('sha256') + h = hashlib.new("sha256") h.update(link.encode("utf-8")) opkey = base64.b16encode(h.digest()).decode("utf-8") - session['callback'] = True + session["callback"] = True func = getattr(RP, "begin") return func(environ, SERVER_ENV, start_response, session, opkey) return opbyuid(environ, start_response) -if __name__ == '__main__': +if __name__ == "__main__": from oidc import OpenIDConnect import conf setup_server_env(conf) session_opts = { - 'session.type': 'memory', - 'session.cookie_expires': True, + "session.type": "memory", + "session.cookie_expires": True, #'session.data_dir': './data', - 'session.auto': True, - 'session.timeout': 900 + "session.auto": True, + "session.timeout": 900, } - RP = OpenIDConnect(registration_info=conf.ME, - ca_bundle=conf.CA_BUNDLE) + RP = OpenIDConnect(registration_info=conf.ME, ca_bundle=conf.CA_BUNDLE) - SRV = wsgiserver.CherryPyWSGIServer(('0.0.0.0', conf.PORT), # nosec - SessionMiddleware(application, - session_opts)) + SRV = wsgiserver.CherryPyWSGIServer( + ("0.0.0.0", conf.PORT), # nosec + SessionMiddleware(application, session_opts), + ) if conf.BASE.startswith("https"): from cherrypy.wsgiserver.ssl_builtin import BuiltinSSLAdapter @@ -252,7 +236,7 @@ def application(environ, start_response): SRV.ssl_adapter = BuiltinSSLAdapter(conf.SERVER_CERT, conf.SERVER_KEY, conf.CA_BUNDLE) LOGGER.info("RP server starting listening on port:%s" % conf.PORT) - print ("RP server starting listening on port:%s" % conf.PORT) + print("RP server starting listening on port:%s" % conf.PORT) try: SRV.start() except KeyboardInterrupt: diff --git a/oidc_example/rp3/conf_heart.py b/oidc_example/rp3/conf_heart.py index 0f7096dca..8057f4276 100644 --- a/oidc_example/rp3/conf_heart.py +++ b/oidc_example/rp3/conf_heart.py @@ -18,7 +18,7 @@ "redirect_uris": ["{base}authz_cb"], "post_logout_redirect_uris": ["{base}logout_success"], "response_types": ["code"], - 'token_endpoint_auth_method': ['private_key_jwt'] + "token_endpoint_auth_method": ["private_key_jwt"], } BEHAVIOUR = { @@ -35,18 +35,15 @@ # The ones that support webfinger, OP discovery and client registration # This is the default, any client that is not listed here is expected to # support dynamic discovery and registration. - "": { - "client_info": ME, - "behaviour": BEHAVIOUR - }, + "": {"client_info": ME, "behaviour": BEHAVIOUR}, } KEY_SPECIFICATION = [ {"type": "RSA", "key": "keys/pyoidc_enc", "use": ["enc"]}, {"type": "RSA", "key": "keys/pyoidc_sig", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} + {"type": "EC", "crv": "P-256", "use": ["enc"]}, ] -CLIENT_TYPE = 'OAUTH2' # one of OIDC/OAUTH2 +CLIENT_TYPE = "OAUTH2" # one of OIDC/OAUTH2 USERINFO = False diff --git a/oidc_example/rp3/rp3.py b/oidc_example/rp3/rp3.py index 7c5f1878e..06b456e3a 100755 --- a/oidc_example/rp3/rp3.py +++ b/oidc_example/rp3/rp3.py @@ -23,23 +23,20 @@ urllib3.disable_warnings() LOGGER = logging.getLogger("") -LOGFILE_NAME = 'rp.log' +LOGFILE_NAME = "rp.log" hdlr = logging.FileHandler(LOGFILE_NAME) -base_formatter = logging.Formatter( - "%(asctime)s %(name)s:%(levelname)s %(message)s") +base_formatter = logging.Formatter("%(asctime)s %(name)s:%(levelname)s %(message)s") -CPC = ('%(asctime)s %(name)s:%(levelname)s ' - '[%(client)s,%(path)s,%(cid)s] %(message)s') +CPC = "%(asctime)s %(name)s:%(levelname)s " "[%(client)s,%(path)s,%(cid)s] %(message)s" cpc_formatter = logging.Formatter(CPC) hdlr.setFormatter(base_formatter) LOGGER.addHandler(hdlr) LOGGER.setLevel(logging.DEBUG) -LOOKUP = TemplateLookup(directories=['templates', 'htdocs'], - module_directory='modules', - input_encoding='utf-8', - output_encoding='utf-8') +LOOKUP = TemplateLookup( + directories=["templates", "htdocs"], module_directory="modules", input_encoding="utf-8", output_encoding="utf-8" +) SERVER_ENV = {} @@ -50,17 +47,17 @@ def __init__(self, logger, sid): self.id = sid def info(self, info): - _dict = {'id': self.id} + _dict = {"id": self.id} _dict.update(info) self.logger.info(json.dumps(_dict)) def error(self, info): - _dict = {'id': self.id} + _dict = {"id": self.id} _dict.update(info) self.logger.error(json.dumps(_dict)) def warning(self, info): - _dict = {'id': self.id} + _dict = {"id": self.id} _dict.update(info) self.logger.warning(json.dumps(_dict)) @@ -70,19 +67,19 @@ def static(environ, start_response, logger, path): logger.info("[static]sending: %s" % (path,)) try: - data = open(path, 'rb').read() + data = open(path, "rb").read() if path.endswith(".ico"): - start_response('200 OK', [('Content-Type', "image/x-icon")]) + start_response("200 OK", [("Content-Type", "image/x-icon")]) elif path.endswith(".html"): - start_response('200 OK', [('Content-Type', 'text/html')]) + start_response("200 OK", [("Content-Type", "text/html")]) elif path.endswith(".json"): - start_response('200 OK', [('Content-Type', 'application/json')]) + start_response("200 OK", [("Content-Type", "application/json")]) elif path.endswith(".txt"): - start_response('200 OK', [('Content-Type', 'text/plain')]) + start_response("200 OK", [("Content-Type", "text/plain")]) elif path.endswith(".css"): - start_response('200 OK', [('Content-Type', 'text/css')]) + start_response("200 OK", [("Content-Type", "text/css")]) else: - start_response('200 OK', [('Content-Type', "text/xml")]) + start_response("200 OK", [("Content-Type", "text/xml")]) return [data] except IOError: resp = NotFound() @@ -90,22 +87,16 @@ def static(environ, start_response, logger, path): def opchoice(environ, start_response, clients): - resp = Response(mako_template="opchoice.mako", - template_lookup=LOOKUP, - headers=[]) - argv = { - "op_list": list(clients.keys()) - } + resp = Response(mako_template="opchoice.mako", template_lookup=LOOKUP, headers=[]) + argv = {"op_list": list(clients.keys())} return resp(environ, start_response, **argv) def opresult(environ, start_response, **kwargs): - resp = Response(mako_template="opresult.mako", - template_lookup=LOOKUP, - headers=[]) + resp = Response(mako_template="opresult.mako", template_lookup=LOOKUP, headers=[]) _args = {} - for param in ['userinfo', 'userid', 'id_token']: + for param in ["userinfo", "userid", "id_token"]: try: _args[param] = kwargs[param] except KeyError: @@ -115,29 +106,20 @@ def opresult(environ, start_response, **kwargs): def operror(environ, start_response, error=None): - resp = Response(mako_template="operror.mako", - template_lookup=LOOKUP, - headers=[]) - argv = { - "error": error - } + resp = Response(mako_template="operror.mako", template_lookup=LOOKUP, headers=[]) + argv = {"error": error} return resp(environ, start_response, **argv) def opresult_fragment(environ, start_response): - resp = Response(mako_template="opresult_repost.mako", - template_lookup=LOOKUP, - headers=[]) + resp = Response(mako_template="opresult_repost.mako", template_lookup=LOOKUP, headers=[]) argv = {} return resp(environ, start_response, **argv) def sorry_response(environ, start_response, homepage, err): - resp = Response(mako_template="sorry.mako", - template_lookup=LOOKUP, - headers=[]) - argv = {"htmlpage": homepage, - "error": str(err)} + resp = Response(mako_template="sorry.mako", template_lookup=LOOKUP, headers=[]) + argv = {"htmlpage": homepage, "error": str(err)} return resp(environ, start_response, **argv) @@ -153,19 +135,19 @@ def id_token_as_signed_jwt(client, id_token, alg="RS256"): def url_eq(a, b): - if a.endswith('/'): - if b.endswith('/'): + if a.endswith("/"): + if b.endswith("/"): return a == b else: return a[:-1] == b else: - if b.endswith('/'): + if b.endswith("/"): return a == b[:-1] else: return a == b -KEY_MAP = {'state': 'state', 'iss': 'op'} +KEY_MAP = {"state": "state", "iss": "op"} class Application(object): @@ -200,14 +182,14 @@ def find_session(self, **kwargs): def init_client(self, client, session, query, environ, start_response): client.get_userinfo = self.userinfo try: - client.resource_server = session['resource_server'] + client.resource_server = session["resource_server"] except KeyError: pass try: - session['response_format'] = query["response_format"][0] + session["response_format"] = query["response_format"][0] except KeyError: - session['response_format'] = 'html' + session["response_format"] = "html" session["op"] = client.provider_info["issuer"] @@ -220,40 +202,39 @@ def init_client(self, client, session, query, environ, start_response): return resp(environ, start_response) def application(self, environ, start_response): - b_session = environ['beaker.session'] + b_session = environ["beaker.session"] jlog = JLog(LOGGER, b_session.id) - path = environ.get('PATH_INFO', '').lstrip('/') + path = environ.get("PATH_INFO", "").lstrip("/") try: - jlog.info({'cookie': environ['HTTP_COOKIE'].split(';'), - 'path': path}) + jlog.info({"cookie": environ["HTTP_COOKIE"].split(";"), "path": path}) except KeyError: - jlog.info({'path': path}) + jlog.info({"path": path}) if path == "robots.txt": return static(environ, start_response, LOGGER, "static/robots.txt") elif path.startswith("static/"): return static(environ, start_response, LOGGER, path) - elif '/static/' in path: - pre, post = path.split('static') - return static(environ, start_response, LOGGER, 'static' + post) + elif "/static/" in path: + pre, post = path.split("static") + return static(environ, start_response, LOGGER, "static" + post) query = parse_qs(environ["QUERY_STRING"]) try: - session = b_session['session_info'] + session = b_session["session_info"] except KeyError: session = self.find_session(**query) if session: - b_session['session_info'] = session + b_session["session_info"] = session else: session = {} - b_session['session_info'] = session + b_session["session_info"] = session self.session[b_session.id] = session - if path == '': - if 'access_token' not in session: + if path == "": + if "access_token" not in session: return opchoice(environ, start_response, self.clients) else: client = self.clients[session["op"]] @@ -265,15 +246,12 @@ def application(self, environ, start_response): session["session_management"] = { "session_state": query["session_state"][0], "client_id": client.client_id, - "issuer": client.provider_info["issuer"] + "issuer": client.provider_info["issuer"], } except KeyError: pass - kwargs = dict( - [(p, session[p]) for p in - ['id_token', 'userinfo', 'user_id'] if - p in session]) + kwargs = dict([(p, session[p]) for p in ["id_token", "userinfo", "user_id"] if p in session]) return opresult(environ, start_response, **kwargs) elif path == "rp": # After having chosen which OP to authenticate at @@ -281,23 +259,21 @@ def application(self, environ, start_response): try: client = self.clients.dynamic_client(userid=query["uid"][0]) except (ConnectionError, OIDCError) as err: - return operror(environ, start_response, '{}'.format(err)) - elif 'issuer' in query: + return operror(environ, start_response, "{}".format(err)) + elif "issuer" in query: try: client = self.clients[query["issuer"][0]] except (ConnectionError, OIDCError) as err: - return operror(environ, start_response, '{}'.format(err)) + return operror(environ, start_response, "{}".format(err)) else: client = self.clients[query["op"][0]] - return self.init_client(client, session, query, environ, - start_response) - elif path.endswith('authz_post'): + return self.init_client(client, session, query, environ, start_response) + elif path.endswith("authz_post"): try: - _iss = session['op'] + _iss = session["op"] except KeyError: - jlog.error({'reason': 'No active session', - 'remote_addr': environ['REMOTE_ADDR']}) + jlog.error({"reason": "No active session", "remote_addr": environ["REMOTE_ADDR"]}) return opchoice(environ, start_response, self.clients) else: @@ -307,75 +283,66 @@ def application(self, environ, start_response): try: info = query["fragment"][0] except KeyError: - return sorry_response(environ, start_response, self.base, - "missing fragment ?!") - if info == ['x']: - return sorry_response(environ, start_response, self.base, - "Expected fragment didn't get one ?!") + return sorry_response(environ, start_response, self.base, "missing fragment ?!") + if info == ["x"]: + return sorry_response(environ, start_response, self.base, "Expected fragment didn't get one ?!") - jlog.info({'fragment': info}) + jlog.info({"fragment": info}) try: - result = client.callback(info, session, 'urlencoded') + result = client.callback(info, session, "urlencoded") if isinstance(result, SeeOther): return result(environ, start_response) except OIDCError as err: return operror(environ, start_response, "%s" % err) - except Exception as err: + except Exception: raise else: session.update(result) - res = SeeOther(self.conf['base_url']) + res = SeeOther(self.conf["base_url"]) return res(environ, start_response) elif path in self.clients.return_paths(): # After having # authenticated at the OP - jlog.info({'query': query}) + jlog.info({"query": query}) _client = None for cli in self.clients.client.values(): - if query['state'][0] in cli.authz_req: + if query["state"][0] in cli.authz_req: _client = cli break if not _client: - jlog.error({ - 'reason': 'No active session', - 'remote_addr': environ['REMOTE_ADDR'], - 'state': query['state'][0] - }) + jlog.error( + {"reason": "No active session", "remote_addr": environ["REMOTE_ADDR"], "state": query["state"][0]} + ) return opchoice(environ, start_response, self.clients) - if 'error' in query: # something amiss - if query['error'][0] == 'access_denied': # Try reregistering - _iss = _client.provider_info['issuer'] + if "error" in query: # something amiss + if query["error"][0] == "access_denied": # Try reregistering + _iss = _client.provider_info["issuer"] del self.clients[_iss] try: client = self.clients[_iss] except (ConnectionError, OIDCError) as err: - return operror(environ, start_response, - '{}'.format(err)) - return self.init_client(client, session, query, environ, - start_response) + return operror(environ, start_response, "{}".format(err)) + return self.init_client(client, session, query, environ, start_response) try: - _iss = query['iss'][0] + _iss = query["iss"][0] except KeyError: pass else: - if _iss != _client.provider_info['issuer']: - jlog.error({'reason': 'Got response from wrong OP'}) + if _iss != _client.provider_info["issuer"]: + jlog.error({"reason": "Got response from wrong OP"}) return opchoice(environ, start_response, self.clients) _response_type = _client.behaviour["response_type"] try: - _response_mode = _client.authz_req[session['state']][ - 'response_mode'] + _response_mode = _client.authz_req[session["state"]]["response_mode"] except KeyError: - _response_mode = '' + _response_mode = "" - jlog.info({ - "response_type": _response_type, - "response_mode": _response_mode}) + jlog.info({"response_type": _response_type, "response_mode": _response_mode}) if _response_type and _response_type != "code": # Fall through if it's a query response anyway @@ -397,19 +364,17 @@ def application(self, environ, start_response): raise else: session.update(result) - res = SeeOther(self.conf['base_url']) + res = SeeOther(self.conf["base_url"]) return res(environ, start_response) elif path == "logout": # After the user has pressed the logout button try: - _iss = session['op'] + _iss = session["op"] except KeyError: - jlog.error( - {'reason': 'No active session', - 'remote_addr': environ['REMOTE_ADDR']}) + jlog.error({"reason": "No active session", "remote_addr": environ["REMOTE_ADDR"]}) return opchoice(environ, start_response, self.clients) client = self.clients[_iss] try: - del client.authz_req[session['state']] + del client.authz_req[session["state"]] except KeyError: pass @@ -419,21 +384,17 @@ def application(self, environ, start_response): # log out. That URL must be registered with the OP at client # registration. logout_url += "?" + urlencode( - {"post_logout_redirect_uri": client.registration_response[ - "post_logout_redirect_uris"][0]}) + {"post_logout_redirect_uri": client.registration_response["post_logout_redirect_uris"][0]} + ) except KeyError: pass else: # If there is an ID token send it along as a id_token_hint _idtoken = get_id_token(client, session) if _idtoken: - logout_url += "&" + urlencode({ - "id_token_hint": id_token_as_signed_jwt(client, - _idtoken, - "HS256")}) + logout_url += "&" + urlencode({"id_token_hint": id_token_as_signed_jwt(client, _idtoken, "HS256")}) # Also append the ACR values - logout_url += "&" + urlencode({"acr_values": self.acr_values}, - True) + logout_url += "&" + urlencode({"acr_values": self.acr_values}, True) session.delete() resp = SeeOther(str(logout_url)) @@ -442,19 +403,15 @@ def application(self, environ, start_response): return Response("Logout successful!")(environ, start_response) elif path == "session_iframe": # session management kwargs = session["session_management"] - resp = Response(mako_template="rp_session_iframe.mako", - template_lookup=LOOKUP) - return resp(environ, start_response, - session_change_url="{}session_change".format( - self.conf["base_url"]), - **kwargs) + resp = Response(mako_template="rp_session_iframe.mako", template_lookup=LOOKUP) + return resp( + environ, start_response, session_change_url="{}session_change".format(self.conf["base_url"]), **kwargs + ) elif path == "session_change": try: - _iss = session['op'] + _iss = session["op"] except KeyError: - jlog.error({ - 'reason': 'No active session', - 'remote_addr': environ['REMOTE_ADDR']}) + jlog.error({"reason": "No active session", "remote_addr": environ["REMOTE_ADDR"]}) return opchoice(environ, start_response, self.clients) try: @@ -466,16 +423,14 @@ def application(self, environ, start_response): # If there is an ID token send it along as a id_token_hint idt = get_id_token(client, session) if idt: - kwargs["id_token_hint"] = id_token_as_signed_jwt(client, idt, - "HS256") - resp = client.create_authn_request(session, self.acr_values, - **kwargs) + kwargs["id_token_hint"] = id_token_as_signed_jwt(client, idt, "HS256") + resp = client.create_authn_request(session, self.acr_values, **kwargs) return resp(environ, start_response) return opchoice(environ, start_response, self.clients) -if __name__ == '__main__': +if __name__ == "__main__": from oic.utils.rp import OIDCClients from oic.utils.rp import OIDCError from beaker.middleware import SessionMiddleware @@ -485,7 +440,7 @@ def application(self, environ, start_response): parser.add_argument(dest="config") parser.add_argument("-p", default=8666, dest="port", help="port of the RP") parser.add_argument("-b", dest="base_url", help="base url of the RP") - parser.add_argument('-k', dest='verify_ssl', action='store_false') + parser.add_argument("-k", dest="verify_ssl", action="store_false") args = parser.parse_args() _conf = importlib.import_module(args.config) @@ -496,16 +451,13 @@ def application(self, environ, start_response): for _client, client_conf in _conf.CLIENTS.items(): if "client_registration" in client_conf: client_reg = client_conf["client_registration"] - client_reg["redirect_uris"] = [ - url.format(base=_conf.BASE) for url in - client_reg["redirect_uris"]] + client_reg["redirect_uris"] = [url.format(base=_conf.BASE) for url in client_reg["redirect_uris"]] session_opts = { - 'session.type': 'memory', - 'session.cookie_expires': True, - 'session.auto': True, - 'session.key': "{}.beaker.session.id".format( - urlparse(_conf.BASE).netloc.replace(":", ".")) + "session.type": "memory", + "session.cookie_expires": True, + "session.auto": True, + "session.key": "{}.beaker.session.id".format(urlparse(_conf.BASE).netloc.replace(":", ".")), } try: @@ -514,50 +466,46 @@ def application(self, environ, start_response): jwks_info = {} else: jwks, keyjar, kidd = build_keyjar(key_spec) - jwks_info = { - 'jwks_uri': '{}static/jwks_uri.json'.format(_base), - 'keyjar': keyjar, - 'kid': kidd - } - f = open('static/jwks_uri.json', 'w') + jwks_info = {"jwks_uri": "{}static/jwks_uri.json".format(_base), "keyjar": keyjar, "kid": kidd} + f = open("static/jwks_uri.json", "w") f.write(json.dumps(jwks)) f.close() try: ctype = _conf.CLIENT_TYPE except KeyError: - ctype = 'OIDC' + ctype = "OIDC" - if ctype == 'OIDC': - _clients = OIDCClients(_conf, _base, jwks_info=jwks_info, - verify_ssl=args.verify_ssl) + if ctype == "OIDC": + _clients = OIDCClients(_conf, _base, jwks_info=jwks_info, verify_ssl=args.verify_ssl) else: - _clients = OAuthClients(_conf, _base, jwks_info=jwks_info, - verify_ssl=args.verify_ssl) + _clients = OAuthClients(_conf, _base, jwks_info=jwks_info, verify_ssl=args.verify_ssl) SERVER_ENV.update({"template_lookup": LOOKUP, "base_url": _base}) - app_args = {'clients': _clients, - 'acrs': _conf.ACR_VALUES, - 'conf': SERVER_ENV, - 'userinfo': _conf.USERINFO, - 'base': _conf.BASE} + app_args = { + "clients": _clients, + "acrs": _conf.ACR_VALUES, + "conf": SERVER_ENV, + "userinfo": _conf.USERINFO, + "base": _conf.BASE, + } try: - app_args['resource_server'] = _conf.RESOURCE_SERVER + app_args["resource_server"] = _conf.RESOURCE_SERVER except AttributeError: pass _app = Application(**app_args) SRV = wsgiserver.CherryPyWSGIServer( - ('0.0.0.0', int(args.port)), # nosec - SessionMiddleware(_app.application, session_opts)) + ("0.0.0.0", int(args.port)), # nosec + SessionMiddleware(_app.application, session_opts), + ) if _conf.BASE.startswith("https"): from cherrypy.wsgiserver.ssl_builtin import BuiltinSSLAdapter - SRV.ssl_adapter = BuiltinSSLAdapter(_conf.SERVER_CERT, _conf.SERVER_KEY, - _conf.CERT_CHAIN) + SRV.ssl_adapter = BuiltinSSLAdapter(_conf.SERVER_CERT, _conf.SERVER_KEY, _conf.CERT_CHAIN) extra = " using SSL/TLS" else: extra = "" diff --git a/oidc_example/simple_op/src/provider/__init__.py b/oidc_example/simple_op/src/provider/__init__.py index c207ec6e0..725d839d8 100644 --- a/oidc_example/simple_op/src/provider/__init__.py +++ b/oidc_example/simple_op/src/provider/__init__.py @@ -1 +1 @@ -__author__ = 'regu0004' +__author__ = "regu0004" diff --git a/oidc_example/simple_op/src/provider/authn/__init__.py b/oidc_example/simple_op/src/provider/authn/__init__.py index ac8929110..f7c427f3f 100644 --- a/oidc_example/simple_op/src/provider/authn/__init__.py +++ b/oidc_example/simple_op/src/provider/authn/__init__.py @@ -1,10 +1,10 @@ import importlib from oic.utils.authn.user import UserAuthnMethod -__author__ = 'regu0004' +__author__ = "regu0004" -class AuthnModule(UserAuthnMethod): +class AuthnModule(UserAuthnMethod): # override in subclass specifying suitable url endpoint to POST user input url_endpoint = "/verify" FAILED_AUTHN = (None, True) diff --git a/oidc_example/simple_op/src/provider/authn/two_factor.py b/oidc_example/simple_op/src/provider/authn/two_factor.py index c1e83774e..1a79b17c2 100644 --- a/oidc_example/simple_op/src/provider/authn/two_factor.py +++ b/oidc_example/simple_op/src/provider/authn/two_factor.py @@ -16,9 +16,17 @@ class MailTwoFactor(AuthnModule): url_endpoint = "/two_factor/verify" - def __init__(self, user_db, passwd_db, smtp_server, outgoing_sender, - template_env, code_ttl=2, template="mail_two_factor.jinja2", - **kwargs): + def __init__( + self, + user_db, + passwd_db, + smtp_server, + outgoing_sender, + template_env, + code_ttl=2, + template="mail_two_factor.jinja2", + **kwargs, + ): """ :param user_db: @@ -81,15 +89,14 @@ def verify(self, *args, **kwargs): # Generate code and send it now = time.time() secret = "%d%s" % (now, rndstr(16)) - code = hashlib.sha256(secret.encode('utf-8')).hexdigest() + code = hashlib.sha256(secret.encode("utf-8")).hexdigest() self.codes[code] = {"username": username, "time": now} self._send_mail(code, receiver) template = self.template_env.get_template(self.template) - response = Response(template.render(mail=receiver, - action=self.url_endpoint, - state=json.dumps( - kwargs["state"]))) + response = Response( + template.render(mail=receiver, action=self.url_endpoint, state=json.dumps(kwargs["state"])) + ) return response, False def _send_mail(self, code, receiver): diff --git a/oidc_example/simple_op/src/provider/authn/user_pass.py b/oidc_example/simple_op/src/provider/authn/user_pass.py index 65decd4ad..0c9042f25 100644 --- a/oidc_example/simple_op/src/provider/authn/user_pass.py +++ b/oidc_example/simple_op/src/provider/authn/user_pass.py @@ -25,14 +25,11 @@ def __init__(self, db, template_env, template="user_pass.jinja2", **kwargs): def __call__(self, *args, **kwargs): template = self.template_env.get_template(self.template) - return Response(template.render(action=self.url_endpoint, - state=json.dumps(kwargs), - **self.kwargs)) + return Response(template.render(action=self.url_endpoint, state=json.dumps(kwargs), **self.kwargs)) def verify(self, *args, **kwargs): username = kwargs["username"] - if username in self.user_db and self.user_db[username] == kwargs[ - "password"]: + if username in self.user_db and self.user_db[username] == kwargs["password"]: return username, True else: return self.FAILED_AUTHN diff --git a/oidc_example/simple_op/src/provider/authn/yubikey.py b/oidc_example/simple_op/src/provider/authn/yubikey.py index d8430e62f..41588e091 100644 --- a/oidc_example/simple_op/src/provider/authn/yubikey.py +++ b/oidc_example/simple_op/src/provider/authn/yubikey.py @@ -14,9 +14,17 @@ class YubicoOTP(AuthnModule): url_endpoint = "/yubi_otp/verify" - def __init__(self, yubikey_db, validation_server, client_id, template_env, - secret_key=None, verify_ssl=True, template="yubico_otp.jinja2", - **kwargs): + def __init__( + self, + yubikey_db, + validation_server, + client_id, + template_env, + secret_key=None, + verify_ssl=True, + template="yubico_otp.jinja2", + **kwargs, + ): super(YubicoOTP, self).__init__(None) self.template_env = template_env self.template = template @@ -24,33 +32,27 @@ def __init__(self, yubikey_db, validation_server, client_id, template_env, cls = make_cls_from_name(yubikey_db["class"]) self.yubikey_db = cls(**yubikey_db["kwargs"]) - self.client = Yubico(client_id, secret_key, - api_urls=[validation_server], - verify_cert=verify_ssl) + self.client = Yubico(client_id, secret_key, api_urls=[validation_server], verify_cert=verify_ssl) if not verify_ssl: # patch yubico-client to not find any ca bundle self.client._get_ca_bundle_path = lambda: None def __call__(self, *args, **kwargs): template = self.template_env.get_template(self.template) - return Response(template.render(action=self.url_endpoint, - state=json.dumps(kwargs))) + return Response(template.render(action=self.url_endpoint, state=json.dumps(kwargs))) def verify(self, *args, **kwargs): otp = kwargs["otp"] try: status = self.client.verify(otp, return_response=True) except yubico_exceptions.InvalidClientIdError as e: - logger.error( - "Client with id {} does not exist".format(e.client_id)) + logger.error("Client with id {} does not exist".format(e.client_id)) return self.FAILED_AUTHN except yubico_exceptions.SignatureVerificationError: logger.error("Signature verification failed") return self.FAILED_AUTHN except yubico_exceptions.StatusCodeError as e: - logger.error( - "Negative status code was returned: {}".format( - e.status_code)) + logger.error("Negative status code was returned: {}".format(e.status_code)) return self.FAILED_AUTHN if status: @@ -59,5 +61,4 @@ def verify(self, *args, **kwargs): return self.yubikey_db[yubikey_public_id], True else: - logger.error( - "No response from the servers or received other negative status code") + logger.error("No response from the servers or received other negative status code") diff --git a/oidc_example/simple_op/src/provider/server/server.py b/oidc_example/simple_op/src/provider/server/server.py index e01917957..4a02df1f8 100644 --- a/oidc_example/simple_op/src/provider/server/server.py +++ b/oidc_example/simple_op/src/provider/server/server.py @@ -45,8 +45,6 @@ from cherrypy.wsgiserver.wsgiserver2 import WSGIPathInfoDispatcher - - def VerifierMiddleware(verifier): """Common wrapper for the authentication modules. * Parses the request before passing it on to the authentication module. @@ -68,15 +66,11 @@ def wrapper(environ, start_response): set_cookie, cookie_value = verifier.create_cookie(val, "auth") cookie_value += "; path=/" - url = "{base_url}?{query_string}".format( - base_url="/authorization", - query_string=kwargs["state"]["query"]) + url = "{base_url}?{query_string}".format(base_url="/authorization", query_string=kwargs["state"]["query"]) response = SeeOther(url, headers=[(set_cookie, cookie_value)]) return response(environ, start_response) else: # Unsuccessful authentication - url = "{base_url}?{query_string}".format( - base_url="/authorization", - query_string=kwargs["state"]["query"]) + url = "{base_url}?{query_string}".format(base_url="/authorization", query_string=kwargs["state"]["query"]) response = SeeOther(url) return response(environ, start_response) @@ -124,16 +118,11 @@ def setup_endpoints(provider): """Setup the OpenID Connect Provider endpoints.""" app_routing = {} endpoints = [ - AuthorizationEndpoint( - pyoidcMiddleware(provider.authorization_endpoint)), - TokenEndpoint( - pyoidcMiddleware(provider.token_endpoint)), - UserinfoEndpoint( - pyoidcMiddleware(provider.userinfo_endpoint)), - RegistrationEndpoint( - pyoidcMiddleware(provider.registration_endpoint)), - EndSessionEndpoint( - pyoidcMiddleware(provider.endsession_endpoint)) + AuthorizationEndpoint(pyoidcMiddleware(provider.authorization_endpoint)), + TokenEndpoint(pyoidcMiddleware(provider.token_endpoint)), + UserinfoEndpoint(pyoidcMiddleware(provider.userinfo_endpoint)), + RegistrationEndpoint(pyoidcMiddleware(provider.registration_endpoint)), + EndSessionEndpoint(pyoidcMiddleware(provider.endsession_endpoint)), ] provider.endp = endpoints @@ -148,35 +137,35 @@ def _webfinger(provider, request, **kwargs): params = urlparse.parse_qs(request) if params["rel"][0] == OIC_ISSUER: wf = WebFinger() - return Response(wf.response(params["resource"][0], provider.baseurl), - headers=[("Content-Type", "application/jrd+json")]) + return Response( + wf.response(params["resource"][0], provider.baseurl), headers=[("Content-Type", "application/jrd+json")] + ) else: return BadRequest("Incorrect webfinger.") def make_static_handler(static_dir): def static(environ, start_response): - path = environ['PATH_INFO'] + path = environ["PATH_INFO"] full_path = os.path.join(static_dir, os.path.normpath(path).lstrip("/")) if os.path.exists(full_path): - with open(full_path, 'rb') as f: + with open(full_path, "rb") as f: content = f.read() content_type, encoding = mimetypes.guess_type(full_path) - headers = [('Content-Type', content_type)] + headers = [("Content-Type", content_type)] start_response("200 OK", headers) return [content] else: - response = NotFound( - "File '{}' not found.".format(environ['PATH_INFO'])) + response = NotFound("File '{}' not found.".format(environ["PATH_INFO"])) return response(environ, start_response) return static def main(): - parser = argparse.ArgumentParser(description='Example OIDC Provider.') + parser = argparse.ArgumentParser(description="Example OIDC Provider.") parser.add_argument("-p", "--port", default=80, type=int) parser.add_argument("-b", "--base", default="https://localhost", type=str) parser.add_argument("-d", "--debug", action="store_true") @@ -190,10 +179,8 @@ def main(): issuer = args.base.rstrip("/") template_dirs = settings["server"].get("template_dirs", "templates") - jinja_env = Environment(loader=FileSystemLoader(template_dirs), - autoescape=select_autoescape(['html'])) - authn_broker, auth_routing = setup_authentication_methods(settings["authn"], - jinja_env) + jinja_env = Environment(loader=FileSystemLoader(template_dirs), autoescape=select_autoescape(["html"])) + authn_broker, auth_routing = setup_authentication_methods(settings["authn"], jinja_env) # Setup userinfo userinfo_conf = settings["userinfo"] @@ -202,10 +189,8 @@ def main(): userinfo = UserInfo(i) client_db = {} - session_db = create_session_db(issuer, - secret=rndstr(32), password=rndstr(32)) - provider = Provider(issuer, session_db, client_db, authn_broker, - userinfo, AuthzHandling(), verify_client, None) + session_db = create_session_db(issuer, secret=rndstr(32), password=rndstr(32)) + provider = Provider(issuer, session_db, client_db, authn_broker, userinfo, AuthzHandling(), verify_client, None) provider.baseurl = issuer provider.symkey = rndstr(16) @@ -222,26 +207,24 @@ def main(): with open(os.path.join(path, name), "w") as f: f.write(json.dumps(jwks)) - #TODO: I take this out and it still works, what was this for? - #provider.jwks_uri.append( + # TODO: I take this out and it still works, what was this for? + # provider.jwks_uri.append( # "{}/static/{}".format(provider.baseurl, name)) # Mount the WSGI callable object (app) on the root directory app_routing = setup_endpoints(provider) - app_routing["/.well-known/openid-configuration"] = pyoidcMiddleware( - provider.providerinfo_endpoint) - app_routing["/.well-known/webfinger"] = pyoidcMiddleware( - partial(_webfinger, provider)) + app_routing["/.well-known/openid-configuration"] = pyoidcMiddleware(provider.providerinfo_endpoint) + app_routing["/.well-known/webfinger"] = pyoidcMiddleware(partial(_webfinger, provider)) routing = dict(list(auth_routing.items()) + list(app_routing.items())) routing["/static"] = make_static_handler(path) dispatcher = WSGIPathInfoDispatcher(routing) - server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', args.port), dispatcher) # nosec + server = wsgiserver.CherryPyWSGIServer(("0.0.0.0", args.port), dispatcher) # nosec # Setup SSL if provider.baseurl.startswith("https://"): server.ssl_adapter = BuiltinSSLAdapter( - settings["server"]["cert"], settings["server"]["key"], - settings["server"]["cert_chain"]) + settings["server"]["cert"], settings["server"]["key"], settings["server"]["cert_chain"] + ) # Start the CherryPy WSGI web server try: diff --git a/oidc_example/simple_rp/src/rp.py b/oidc_example/simple_rp/src/rp.py index 66d33fc33..50143ee4b 100644 --- a/oidc_example/simple_rp/src/rp.py +++ b/oidc_example/simple_rp/src/rp.py @@ -9,7 +9,7 @@ from oic.utils.keyio import build_keyjar from oic.oic.message import AuthorizationResponse -__author__ = 'regu0004' +__author__ = "regu0004" class OIDCExampleRP(object): @@ -24,8 +24,7 @@ def __init__(self, client_metadata, behaviour): def register_with_dynamic_provider(self, session, uid): issuer_url = session["client"].wf.discovery_query(uid) provider_info = session["client"].provider_config(issuer_url) - session["client"].register(provider_info["registration_endpoint"], - **self.client_metadata) + session["client"].register(provider_info["registration_endpoint"], **self.client_metadata) def make_authentication_request(self, session): session["state"] = rndstr() @@ -34,27 +33,23 @@ def make_authentication_request(self, session): "response_type": self.response_type, "state": session["state"], "nonce": session["nonce"], - "redirect_uri": self.redirect_uri + "redirect_uri": self.redirect_uri, } request_args.update(self.behaviour) - auth_req = session["client"].construct_AuthorizationRequest( - request_args=request_args) + auth_req = session["client"].construct_AuthorizationRequest(request_args=request_args) login_url = auth_req.request(session["client"].authorization_endpoint) raise cherrypy.HTTPRedirect(login_url, 303) def parse_authentication_response(self, session, query_string): - auth_response = session["client"].parse_response(AuthorizationResponse, - info=query_string, - sformat="urlencoded") + auth_response = session["client"].parse_response(AuthorizationResponse, info=query_string, sformat="urlencoded") if auth_response["state"] != session["state"]: raise "The OIDC state does not match." - if "id_token" in auth_response and \ - auth_response["id_token"]["nonce"] != session["nonce"]: + if "id_token" in auth_response and auth_response["id_token"]["nonce"] != session["nonce"]: raise "The OIDC nonce does not match." return auth_response @@ -64,20 +59,17 @@ def make_token_request(self, session, auth_code): "code": auth_code, "redirect_uri": self.redirect_uri, "client_id": session["client"].client_id, - "client_secret": session["client"].client_secret + "client_secret": session["client"].client_secret, } token_response = session["client"].do_access_token_request( - scope="openid", - state=session[ - "state"], - request_args=args) + scope="openid", state=session["state"], request_args=args + ) return token_response def make_userinfo_request(self, session, access_token): - userinfo_response = session["client"].do_user_info_request( - access_token=access_token) + userinfo_response = session["client"].do_user_info_request(access_token=access_token) return userinfo_response @@ -92,7 +84,7 @@ def index(self): @cherrypy.expose def authenticate(self, uid): - #TODO: Why did I have to do this? I am not sure this is correct + # TODO: Why did I have to do this? I am not sure this is correct keys = [ {"type": "RSA", "key": "../simple_op/keys/key.pem", "use": ["enc", "sig"]}, ] @@ -108,8 +100,7 @@ def authenticate(self, uid): @cherrypy.expose def repost_fragment(self, **kwargs): - response = self.rp.parse_authentication_response(cherrypy.session, - kwargs["url_fragment"]) + response = self.rp.parse_authentication_response(cherrypy.session, kwargs["url_fragment"]) html_page = self._load_HTML_page_from_file("htdocs/success_page.html") @@ -123,38 +114,31 @@ def repost_fragment(self, **kwargs): access_token = None try: access_token = response["access_token"] - userinfo = self.rp.make_userinfo_request(cherrypy.session, - access_token) + userinfo = self.rp.make_userinfo_request(cherrypy.session, access_token) except KeyError: pass - return html_page.format(authz_code, access_token, - response["id_token"], userinfo) + return html_page.format(authz_code, access_token, response["id_token"], userinfo) @cherrypy.expose def code_flow(self, **kwargs): if "error" in kwargs: - raise cherrypy.HTTPError(500, "{}: {}".format(kwargs["error"], - kwargs[ - "error_description"])) + raise cherrypy.HTTPError(500, "{}: {}".format(kwargs["error"], kwargs["error_description"])) qs = cherrypy.request.query_string - auth_response = self.rp.parse_authentication_response(cherrypy.session, - qs) + auth_response = self.rp.parse_authentication_response(cherrypy.session, qs) auth_code = auth_response["code"] token_response = self.rp.make_token_request(cherrypy.session, auth_code) - userinfo = self.rp.make_userinfo_request(cherrypy.session, - token_response["access_token"]) + userinfo = self.rp.make_userinfo_request(cherrypy.session, token_response["access_token"]) html_page = self._load_HTML_page_from_file("htdocs/success_page.html") - return html_page.format(auth_code, token_response["access_token"], - token_response["id_token"], userinfo) + return html_page.format(auth_code, token_response["access_token"], token_response["id_token"], userinfo) @cherrypy.expose def implicit_hybrid_flow(self, **kwargs): return self._load_HTML_page_from_file("htdocs/repost_fragment.html") def _load_HTML_page_from_file(self, path): - if not path.startswith("/"): # relative path + if not path.startswith("/"): # relative path # prepend the root package dir path = os.path.join(os.path.dirname(__file__), path) @@ -163,7 +147,7 @@ def _load_HTML_page_from_file(self, path): def main(): - parser = argparse.ArgumentParser(description='Example OIDC Client.') + parser = argparse.ArgumentParser(description="Example OIDC Client.") parser.add_argument("-p", "--port", default=80, type=int) parser.add_argument("-b", "--base", default="https://localhost", type=str) parser.add_argument("settings") @@ -175,31 +159,33 @@ def main(): baseurl = args.base.rstrip("/") # strip trailing slash if it exists registration_info = settings["registration_info"] # patch redirect_uris with proper base url - registration_info["redirect_uris"] = [url.format(base=baseurl, - port=args.port) - for url in - registration_info["redirect_uris"]] + registration_info["redirect_uris"] = [ + url.format(base=baseurl, port=args.port) for url in registration_info["redirect_uris"] + ] - rp_server = RPServer(registration_info, settings["behaviour"], - settings["server"]["verify_ssl"]) + rp_server = RPServer(registration_info, settings["behaviour"], settings["server"]["verify_ssl"]) # Mount the WSGI callable object (app) on the root directory cherrypy.tree.mount(rp_server, "/") # Set the configuration of the web server - cherrypy.config.update({ - 'tools.sessions.on': True, - 'server.socket_port': args.port, - 'server.socket_host': '0.0.0.0' # nosec - }) + cherrypy.config.update( + { + "tools.sessions.on": True, + "server.socket_port": args.port, + "server.socket_host": "0.0.0.0", # nosec + } + ) if baseurl.startswith("https://"): - cherrypy.config.update({ - 'server.ssl_module': 'builtin', - 'server.ssl_certificate': settings["server"]["cert"], - 'server.ssl_private_key': settings["server"]["key"], - 'server.ssl_certificate_chain': settings["server"]["cert_chain"] - }) + cherrypy.config.update( + { + "server.ssl_module": "builtin", + "server.ssl_certificate": settings["server"]["cert"], + "server.ssl_private_key": settings["server"]["key"], + "server.ssl_certificate_chain": settings["server"]["cert_chain"], + } + ) # Start the CherryPy WSGI web server cherrypy.engine.start() diff --git a/pylama.ini b/pylama.ini deleted file mode 100644 index 1399d0509..000000000 --- a/pylama.ini +++ /dev/null @@ -1,11 +0,0 @@ -[pylama] -linters = pyflakes,eradicate,pycodestyle,mccabe -# D203/D204 and D212/D213 are mutually exclusive, pick one -# E203 is not PEP8 compliant in pycodestyle -ignore = D203,D212,E203,C901 - -[pylama:pycodestyle] -max_line_length = 120 - -[pylama:mccabe] -complexity = 30 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..588bdfe1e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.ruff] +target-version = "py39" +line-length = 120 + +[tool.ruff.lint] +ignore = ["D203", "D212", "E203", "C901", "D413"] + +[tool.ruff.lint.pycodestyle] +max-doc-length = 120 + +[tool.ruff.lint.mccabe] +max-complexity = 15 + +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/script/webfinger.py b/script/webfinger.py index 9be5210fa..e084076da 100644 --- a/script/webfinger.py +++ b/script/webfinger.py @@ -4,8 +4,8 @@ from oic.utils.webfinger import OIC_ISSUER from oic.utils.webfinger import WebFinger -__author__ = 'roland' +__author__ = "roland" wf = WebFinger(OIC_ISSUER) wf.httpd = PBase() -print (wf.discovery_query(sys.argv[1])) +print(wf.discovery_query(sys.argv[1])) diff --git a/setup.cfg b/setup.cfg index 6629bd357..e69de29bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +0,0 @@ -[pydocstyle] -convention = google -add-ignore = - D1, # Ignore missing docstrings - D212, # Multiline docstring not on first line - D413, # No blank line after last section -add-select = - D213 # Multiline docstring should start on first line diff --git a/setup.py b/setup.py index 8ee8f8757..e78803f26 100755 --- a/setup.py +++ b/setup.py @@ -15,58 +15,60 @@ # limitations under the License. # import re -import sys from io import open from setuptools import setup -__author__ = 'rohe0002' +__author__ = "rohe0002" -tests_requires = ['responses', 'testfixtures', 'pytest', 'freezegun'] +tests_requires = ["responses", "testfixtures", "pytest", "freezegun"] -with open('src/oic/__init__.py', 'r') as fd: - version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', - fd.read(), re.MULTILINE).group(1) +with open("src/oic/__init__.py", "r") as fd: + version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) setup( name="oic", version=version, description="Python implementation of OAuth2 and OpenID Connect", - long_description=open('README.rst', encoding='utf-8').read(), + long_description=open("README.rst", encoding="utf-8").read(), author="Roland Hedberg", author_email="roland@catalogix.se", license="Apache 2.0", - url='https://github.com/CZ-NIC/pyoidc/', + url="https://github.com/CZ-NIC/pyoidc/", packages=[ - "oic", "oic/oauth2", "oic/oic", "oic/utils", "oic/utils/authn", - "oic/utils/userinfo", 'oic/utils/rp', 'oic/extension' + "oic", + "oic/oauth2", + "oic/oic", + "oic/utils", + "oic/utils/authn", + "oic/utils/userinfo", + "oic/utils/rp", + "oic/extension", ], - entry_points={ - 'console_scripts': [ - 'oic-client-management = oic.utils.client_management:run' - ] - }, + entry_points={"console_scripts": ["oic-client-management = oic.utils.client_management:run"]}, package_dir={"": "src"}, package_data={"oic": ["py.typed"]}, include_package_data=True, classifiers=[ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Topic :: Software Development :: Libraries :: Python Modules"], - python_requires='~=3.8', + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + python_requires="~=3.9", extras_require={ - 'develop': ["cherrypy==3.2.4", "pyOpenSSL"], - 'testing': tests_requires, - 'docs': ['Sphinx', 'sphinx-autobuild', 'alabaster', 'autodoc_pydantic>=2.0.0'], - 'quality': ['pylama', 'isort', 'eradicate', 'mypy', 'black', 'bandit', 'readme_renderer[md]'], - 'types': ['types-requests'], - 'ldap_authn': ['python-ldap'], - 'examples': ['beaker'], + "develop": ["cherrypy==3.2.4", "pyOpenSSL"], + "testing": tests_requires, + "docs": ["Sphinx", "sphinx-autobuild", "alabaster", "autodoc_pydantic>=2.0.0"], + "quality": ["mypy", "ruff", "bandit", "readme_renderer[md]", "build"], + "types": ["types-requests"], + "ldap_authn": ["python-ldap"], + "examples": ["beaker"], }, install_requires=[ "requests", @@ -76,7 +78,6 @@ "mako", "cryptography", "defusedxml", - 'typing_extensions; python_version<"3.8"', ], long_description_content_type="text/x-rst", zip_safe=False, diff --git a/src/oic/extension/client.py b/src/oic/extension/client.py index a001d7a7b..87d2802e1 100644 --- a/src/oic/extension/client.py +++ b/src/oic/extension/client.py @@ -70,9 +70,7 @@ def __init__( ) self.registration_response = None - def construct_RegistrationRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_RegistrationRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("registration_endpoint") if request_args is None: @@ -80,9 +78,7 @@ def construct_RegistrationRequest( return self.construct_request(request, request_args, extra_args) - def construct_ClientUpdateRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_ClientUpdateRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("update_endpoint") if request_args is None: @@ -110,17 +106,13 @@ def _token_interaction_setup(self, request_args=None, **kwargs): return request_args - def construct_TokenIntrospectionRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_TokenIntrospectionRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("introspection_endpoint") request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) - def construct_TokenRevocationRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_TokenRevocationRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("revocation_endpoint") request_args = self._token_interaction_setup(request_args, **kwargs) @@ -138,18 +130,14 @@ def do_op( response_cls=None, **kwargs, ): - url, body, ht_args, _ = self.request_info( - request, method, request_args, extra_args, **kwargs - ) + url, body, ht_args, _ = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(ht_args) - resp = self.request_and_return( - url, response_cls, method, body, body_type, http_args=http_args - ) + resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp @@ -276,9 +264,7 @@ def do_token_revocation( # There is no expected response, only the status code is important, # so do not use do_op(). - url, body, ht_args, _ = self.request_info( - request, method, request_args, extra_args, **kwargs - ) + url, body, ht_args, _ = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args @@ -372,9 +358,7 @@ def store_registration_info(self, reginfo): def handle_registration_info(self, response): if response.status_code in SUCCESSFUL: - resp = self.message_factory.get_response_type( - "registration_endpoint" - )().deserialize(response.text, "json") + resp = self.message_factory.get_response_type("registration_endpoint")().deserialize(response.text, "json") self.store_response(resp, response.text) self.store_registration_info(resp) else: diff --git a/src/oic/extension/device_flow.py b/src/oic/extension/device_flow.py index 121a95de4..b900f6b83 100644 --- a/src/oic/extension/device_flow.py +++ b/src/oic/extension/device_flow.py @@ -56,9 +56,7 @@ def device_endpoint(self, request, authn=None): self.device2user[device_code] = user_code self.user_auth[user_code] = False self.client_id2device[_req["client_id"]] = device_code - self.device_code_expire_at[device_code] = ( - time_sans_frac() + self.device_code_life_time - ) + self.device_code_expire_at[device_code] = time_sans_frac() + self.device_code_life_time def token_endpoint(self, request, authn=None): _req = TokenRequest(**request) @@ -87,19 +85,13 @@ def __init__(self, host): } def authorization_request(self, scope=""): - req = AuthorizationRequest( - client_id=self.host.client_id, response_type="device_code" - ) + req = AuthorizationRequest(client_id=self.host.client_id, response_type="device_code") if scope: req["scope"] = scope - http_response = self.host.http_request( - self.host.provider_info["device_endpoint"], "POST", req.to_urlencoded() - ) + http_response = self.host.http_request(self.host.provider_info["device_endpoint"], "POST", req.to_urlencoded()) - response = self.host.parse_request_response( - AuthorizationResponse, http_response, "json" - ) + response = self.host.parse_request_response(AuthorizationResponse, http_response, "json") return response @@ -110,12 +102,8 @@ def token_request(self, device_code=""): client_id=self.host.client_id, ) - http_response = self.host.http_request( - self.host.provider_info["token_endpoint"], "POST", req.to_urlencoded() - ) + http_response = self.host.http_request(self.host.provider_info["token_endpoint"], "POST", req.to_urlencoded()) - response = self.host.parse_request_response( - AccessTokenResponse, http_response, "json" - ) + response = self.host.parse_request_response(AccessTokenResponse, http_response, "json") return response diff --git a/src/oic/extension/heart.py b/src/oic/extension/heart.py index cd2483d0a..b35b244d1 100644 --- a/src/oic/extension/heart.py +++ b/src/oic/extension/heart.py @@ -16,7 +16,6 @@ class PrivateKeyJWT(JasonWebToken): "aud": SINGLE_REQUIRED_STRING, "iss": SINGLE_REQUIRED_STRING, "sub": SINGLE_REQUIRED_STRING, - "aud": SINGLE_REQUIRED_STRING, "exp": SINGLE_REQUIRED_INT, "iat": SINGLE_REQUIRED_INT, "jti": SINGLE_REQUIRED_STRING, diff --git a/src/oic/extension/message.py b/src/oic/extension/message.py index 759d658f9..5ac710539 100644 --- a/src/oic/extension/message.py +++ b/src/oic/extension/message.py @@ -114,9 +114,7 @@ class RegistrationRequest(Message): } def verify(self, **kwargs): - if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith( - "https:" - ): + if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith("https:"): raise RegistrationError("initiate_login_uri is not https") if "redirect_uris" in self: @@ -127,9 +125,7 @@ def verify(self, **kwargs): for uri in ["client_uri", "logo_uri", "tos_uri", "policy_uri"]: if uri in self: try: - resp = requests.request( - "GET", str(self[uri]), allow_redirects=True, verify=False - ) + resp = requests.request("GET", str(self[uri]), allow_redirects=True, verify=False) except requests.ConnectionError: raise MissingPage(self[uri]) @@ -244,9 +240,7 @@ def unpack_software_statement(software_statement, iss, keyjar): class ExtensionMessageFactory(OauthMessageFactory): """Message factory for Extension code.""" - introspection_endpoint = MessageTuple( - TokenIntrospectionRequest, TokenIntrospectionResponse - ) + introspection_endpoint = MessageTuple(TokenIntrospectionRequest, TokenIntrospectionResponse) revocation_endpoint = MessageTuple(TokenRevocationRequest, Message) registration_endpoint = MessageTuple(RegistrationRequest, ClientInfoResponse) update_endpoint = MessageTuple(ClientUpdateRequest, ClientInfoResponse) diff --git a/src/oic/extension/pop.py b/src/oic/extension/pop.py index ccc097a1f..37b211f06 100644 --- a/src/oic/extension/pop.py +++ b/src/oic/extension/pop.py @@ -129,9 +129,7 @@ def store_key(self, access_token, tir): key = load_jwks(json.dumps({"keys": [json.loads(tir["key"])]})) self.token2key[access_token] = key - def eval_signed_http_request( - self, pop_token, access_token, method, url, headers, body="" - ): + def eval_signed_http_request(self, pop_token, access_token, method, url, headers, body=""): kwargs = sign_http_args(method, url, headers, body) shr = SignedHttpRequest(self.token2key[access_token][0]) diff --git a/src/oic/extension/popjwt.py b/src/oic/extension/popjwt.py index 6863eaa2f..f5ff16278 100644 --- a/src/oic/extension/popjwt.py +++ b/src/oic/extension/popjwt.py @@ -13,9 +13,7 @@ class PJWT(JasonWebToken): class PopJWT(object): - def __init__( - self, iss="", aud="", lifetime=3600, in_a_while=0, sub="", jwe=None, keys=None - ): + def __init__(self, iss="", aud="", lifetime=3600, in_a_while=0, sub="", jwe=None, keys=None): """ Initialize the class. diff --git a/src/oic/extension/proof_of_possesion.py b/src/oic/extension/proof_of_possesion.py index 24fecd8a5..8cf2b1dec 100644 --- a/src/oic/extension/proof_of_possesion.py +++ b/src/oic/extension/proof_of_possesion.py @@ -39,9 +39,7 @@ def token_endpoint(self, request="", authn="", dtype="urlencoded", **kwargs): if "token_type" not in atr or atr["token_type"] != "pop": return resp - client_public_key = base64.urlsafe_b64decode(atr["key"].encode("utf-8")).decode( - "utf-8" - ) + client_public_key = base64.urlsafe_b64decode(atr["key"].encode("utf-8")).decode("utf-8") pop_key = json.loads(client_public_key) atr = AccessTokenResponse().deserialize(resp.message, method="json") data = self.sdb.read(atr["access_token"]) @@ -77,24 +75,18 @@ def userinfo_endpoint(self, request="", **kwargs): strict_headers_verification=False, ) except ValidationError: - return error_response( - "access_denied", descr="Could not verify proof of " "possession" - ) + return error_response("access_denied", descr="Could not verify proof of " "possession") return self._do_user_info(self.access_tokens[access_token], **kwargs) def _get_client_public_key(self, access_token): _jws = jws.factory(access_token) if _jws: - data = _jws.verify_compact( - access_token, self.keyjar.get_verify_key(owner="") - ) + data = _jws.verify_compact(access_token, self.keyjar.get_verify_key(owner="")) try: return keyrep(data["cnf"]["jwk"]) except KeyError: - raise NonPoPTokenError( - "Could not extract public key as JWK from access token" - ) + raise NonPoPTokenError("Could not extract public key as JWK from access token") raise NonPoPTokenError("Unsigned access token, maybe not PoP?") @@ -133,10 +125,7 @@ def _parse_access_token(self, request, **kwargs): return request["query"]["access_token"] elif "access_token" in request["body"]: return parse_qs(request["body"])["access_token"][0] - elif ( - "Authorization" in request["headers"] - and request["headers"]["Authorization"] - ): + elif "Authorization" in request["headers"] and request["headers"]["Authorization"]: auth_header = request["headers"]["Authorization"] if auth_header.startswith("pop "): return auth_header[len("pop ") :] diff --git a/src/oic/extension/provider.py b/src/oic/extension/provider.py index 0c4550159..54bac410c 100644 --- a/src/oic/extension/provider.py +++ b/src/oic/extension/provider.py @@ -204,9 +204,7 @@ def __init__( else: self.lifetime_policy = lifetime_policy - self.token_handler = TokenHandler( - self.baseurl, self.token_policy, keyjar=self.keyjar - ) + self.token_handler = TokenHandler(self.baseurl, self.token_policy, keyjar=self.keyjar) @staticmethod def _uris_to_tuples(uris): @@ -244,9 +242,7 @@ def load_keys(self, request, client_id, client_secret): msg = "Failed to load client keys: {}" logger.error(msg.format(sanitize(request.to_dict()))) logger.error("%s", err) - error = ClientRegistrationError( - error="invalid_configuration_parameter", error_description="%s" % err - ) + error = ClientRegistrationError(error="invalid_configuration_parameter", error_description="%s" % err) return Response( error.to_json(), content="application/json", @@ -357,17 +353,13 @@ def match_client_request(self, request): if request[_pref] not in self.capabilities[_prov]: raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) else: - if not set(request[_pref]).issubset( - set(self.capabilities[_prov]) - ): + if not set(request[_pref]).issubset(set(self.capabilities[_prov])): raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) def client_info(self, client_id): _cinfo = self.cdb[client_id].copy() if not valid_client_info(_cinfo): - err = ErrorResponse( - error="invalid_client", error_description="Invalid client secret" - ) + err = ErrorResponse(error="invalid_client", error_description="Invalid client secret") return BadRequest(err.to_json(), content="application/json") try: @@ -439,28 +431,22 @@ def registration_endpoint(self, **kwargs): :param kwargs: extra keyword arguments :return: A Response instance """ - _request = self.server.message_factory.get_request_type( - "registration_endpoint" - )().deserialize(kwargs["request"], "json") + _request = self.server.message_factory.get_request_type("registration_endpoint")().deserialize( + kwargs["request"], "json" + ) try: _request.verify(keyjar=self.keyjar) except InvalidRedirectUri as err: - msg = ClientRegistrationError( - error="invalid_redirect_uri", error_description="%s" % err - ) + msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: - msg = ClientRegistrationError( - error="invalid_client_metadata", error_description="%s" % err - ) + msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") # If authentication is necessary at registration if self.authn_at_registration: try: - self.verify_client( - kwargs["environ"], _request, self.authn_at_registration - ) + self.verify_client(kwargs["environ"], _request, self.authn_at_registration) except (AuthnFailure, UnknownAssertionType): return Unauthorized() @@ -474,14 +460,10 @@ def registration_endpoint(self, **kwargs): try: client_id = self.create_new_client(_request, client_restrictions) except CapabilitiesMisMatch as err: - msg = ClientRegistrationError( - error="invalid_client_metadata", error_description="%s" % err - ) + msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except RestrictionError as err: - msg = ClientRegistrationError( - error="invalid_client_metadata", error_description="%s" % err - ) + msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") return self.client_info(client_id) @@ -505,9 +487,7 @@ def client_info_endpoint(self, method="GET", **kwargs): # authenticated client try: - self.verify_client( - kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id - ) + self.verify_client(kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id) except (AuthnFailure, UnknownAssertionType): return Unauthorized() @@ -515,23 +495,19 @@ def client_info_endpoint(self, method="GET", **kwargs): return self.client_info(_id) elif method == "PUT": try: - _request = self.server.message_factory.get_request_type( - "update_endpoint" - )().from_json(kwargs["request"]) + _request = self.server.message_factory.get_request_type("update_endpoint")().from_json( + kwargs["request"] + ) except ValueError as err: return BadRequest(str(err)) try: _request.verify() except InvalidRedirectUri as err: - msg = ClientRegistrationError( - error="invalid_redirect_uri", error_description="%s" % err - ) + msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: - msg = ClientRegistrationError( - error="invalid_client_metadata", error_description="%s" % err - ) + msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") try: @@ -548,9 +524,7 @@ def client_info_endpoint(self, method="GET", **kwargs): return NoContent() @staticmethod - def verify_code_challenge( - code_verifier, code_challenge, code_challenge_method="S256" - ): + def verify_code_challenge(code_verifier, code_challenge, code_challenge_method="S256"): """ Verify a PKCE (RFC7636) code challenge. @@ -562,9 +536,7 @@ def verify_code_challenge( _cc = b64e(_h) if _cc.decode("ascii") != code_challenge: logger.error("PCKE Code Challenge check failed") - err = TokenErrorResponse( - error="invalid_request", error_description="PCKE check failed" - ) + err = TokenErrorResponse(error="invalid_request", error_description="PCKE check failed") return Response(err.to_json(), content="application/json", status_code=401) return True @@ -591,12 +563,8 @@ def code_grant_type(self, areq): try: _info = self.sdb[areq["code"]] except KeyError: - err = TokenErrorResponse( - error="invalid_grant", error_description="Unknown access grant" - ) - return Response( - err.to_json(), content="application/json", status="401 Unauthorized" - ) + err = TokenErrorResponse(error="invalid_grant", error_description="Unknown access grant") + return Response(err.to_json(), content="application/json", status="401 Unauthorized") authzreq = json.loads(_info["authzreq"]) if "code_verifier" in areq: @@ -605,9 +573,7 @@ def code_grant_type(self, areq): except KeyError: _method = "S256" - resp = self.verify_code_challenge( - areq["code_verifier"], authzreq["code_challenge"], _method - ) + resp = self.verify_code_challenge(areq["code_verifier"], authzreq["code_challenge"], _method) if resp: return resp @@ -634,16 +600,10 @@ def code_grant_type(self, areq): issue_refresh = True try: - _tinfo = self.sdb.upgrade_to_token( - areq["code"], issue_refresh=issue_refresh - ) + _tinfo = self.sdb.upgrade_to_token(areq["code"], issue_refresh=issue_refresh) except AccessCodeUsed: - err = TokenErrorResponse( - error="invalid_grant", error_description="Access grant used" - ) - return Response( - err.to_json(), content="application/json", status="401 Unauthorized" - ) + err = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") + return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug("_tinfo: %s" % _tinfo) @@ -652,9 +612,7 @@ def code_grant_type(self, areq): logger.debug("AccessTokenResponse: %s" % atr) - return Response( - atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS - ) + return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) def client_credentials_grant_type(self, areq): _at = self.token_handler.get_access_token( @@ -662,9 +620,7 @@ def client_credentials_grant_type(self, areq): ) _info = self.token_handler.token_factory.get_info(_at) try: - _rt = self.token_handler.get_refresh_token( - self.baseurl, _info["access_token"], "client_credentials" - ) + _rt = self.token_handler.get_refresh_token(self.baseurl, _info["access_token"], "client_credentials") except NotAllowed: atr = self.do_access_token_response(_at, _info, areq["state"]) else: @@ -684,9 +640,7 @@ def password_grant_type(self, areq): except IndexError: err = TokenErrorResponse(error="invalid_grant") return Unauthorized(err.to_json(), content="application/json") - identity, _ts = authn.authenticated_as( - username=areq["username"], password=areq["password"] - ) + identity, _ts = authn.authenticated_as(username=areq["username"], password=areq["password"]) if identity is None: err = TokenErrorResponse(error="invalid_grant") return Unauthorized(err.to_json(), content="application/json") @@ -702,14 +656,10 @@ def password_grant_type(self, areq): _at = self.sdb.upgrade_to_token(self.sdb[sid]["code"], issue_refresh=True) atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **_at)) - return Response( - atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS - ) + return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) def refresh_token_grant_type(self, areq): - at = self.token_handler.refresh_access_token( - self.baseurl, areq["access_token"], "refresh_token" - ) + at = self.token_handler.refresh_access_token(self.baseurl, areq["access_token"], "refresh_token") atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **at)) @@ -745,12 +695,8 @@ def get_token_info(self, authn, req, endpoint): client_id = self.client_authn(self, req, authn) except FailedAuthentication as err: logger.error("%s", err) - error = TokenErrorResponse( - error="unauthorized_client", error_description="%s" % err - ) - return Response( - error.to_json(), content="application/json", status="401 Unauthorized" - ) + error = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) + return Response(error.to_json(), content="application/json", status="401 Unauthorized") logger.debug("{}: {} requesting {}".format(endpoint, client_id, req.to_dict())) @@ -761,9 +707,7 @@ def get_token_info(self, authn, req, endpoint): _info = self.sdb.token_factory["access_token"].get_info(req["token"]) except Exception: try: - _info = self.sdb.token_factory["refresh_token"].get_info( - req["token"] - ) + _info = self.sdb.token_factory["refresh_token"].get_info(req["token"]) except Exception: return self._return_inactive() else: @@ -782,9 +726,7 @@ def get_token_info(self, authn, req, endpoint): return client_id, token_type, _info def _return_inactive(self): - ir = self.server.message_factory.get_response_type("introspection_endpoint")( - active=False - ) + ir = self.server.message_factory.get_response_type("introspection_endpoint")(active=False) return Response(ir.to_json(), content="application/json") def revocation_endpoint(self, authn="", request=None, **kwargs): @@ -796,9 +738,7 @@ def revocation_endpoint(self, authn="", request=None, **kwargs): :param kwargs: :return: """ - trr = self.server.message_factory.get_request_type( - "revocation_endpoint" - )().deserialize(request, "urlencoded") + trr = self.server.message_factory.get_request_type("revocation_endpoint")().deserialize(request, "urlencoded") resp = self.get_token_info(authn, trr, "revocation_endpoint") @@ -825,9 +765,9 @@ def introspection_endpoint(self, authn="", request=None, **kwargs): :param kwargs: :return: """ - tir = self.server.message_factory.get_request_type( - "introspection_endpoint" - )().deserialize(request, "urlencoded") + tir = self.server.message_factory.get_request_type("introspection_endpoint")().deserialize( + request, "urlencoded" + ) resp = self.get_token_info(authn, tir, "introspection_endpoint") diff --git a/src/oic/extension/signed_http_req.py b/src/oic/extension/signed_http_req.py index b1e4ecc6f..420e70577 100644 --- a/src/oic/extension/signed_http_req.py +++ b/src/oic/extension/signed_http_req.py @@ -172,9 +172,7 @@ def verify(self, signature, **kwargs): if "b" not in unpacked_req and "body" not in kwargs: pass elif "b" in unpacked_req and "body" in kwargs: - _equals( - b64_hash(kwargs.get("body", ""), hash_size), unpacked_req.get("b", "") - ) + _equals(b64_hash(kwargs.get("body", ""), hash_size), unpacked_req.get("b", "")) else: if "b" in unpacked_req: raise ValidationError("Body sent but not received!!") diff --git a/src/oic/oauth2/__init__.py b/src/oic/oauth2/__init__.py index 25959c4c2..d1b116cd7 100644 --- a/src/oic/oauth2/__init__.py +++ b/src/oic/oauth2/__init__.py @@ -100,9 +100,7 @@ class ExpiredToken(PyoidcError): def error_response(error, descr=None, status_code=400): logger.error("%s" % sanitize(error)) response = ErrorResponse(error=error, error_description=descr) - return Response( - response.to_json(), content="application/json", status_code=status_code - ) + return Response(response.to_json(), content="application/json", status_code=status_code) def none_response(**kwargs): @@ -389,9 +387,7 @@ def clean_tokens(self) -> None: if token.replaced or not token.is_valid(): grant.delete_token(token) - def construct_request( - self, request: Type[Message], request_args=None, extra_args=None - ): + def construct_request(self, request: Type[Message], request_args=None, extra_args=None): if request_args is None: request_args = {} @@ -467,8 +463,7 @@ def construct_AccessTokenRequest( if not grant.is_valid(): raise GrantExpired( - "Authorization Code to old %s > %s" - % (utc_time_sans_frac(), grant.grant_expiration_time) + "Authorization Code to old %s > %s" % (utc_time_sans_frac(), grant.grant_expiration_time) ) request_args["code"] = grant.code @@ -575,9 +570,7 @@ def request_info( cis.lax = lax if "authn_method" in kwargs: - h_arg = self.init_authentication_method( - cis, request_args=request_args, **kwargs - ) + h_arg = self.init_authentication_method(cis, request_args=request_args, **kwargs) else: h_arg = None @@ -713,18 +706,14 @@ def parse_response( session_update(self.sso_db, _state, "smid", resp["id_token"]["sid"]) return resp - def init_authentication_method( - self, cis, authn_method, request_args=None, http_args=None, **kwargs - ): + def init_authentication_method(self, cis, authn_method, request_args=None, http_args=None, **kwargs): if http_args is None: http_args = {} if request_args is None: request_args = {} if authn_method: - return self.client_authn_method[authn_method](self).construct( - cis, request_args, http_args, **kwargs - ) + return self.client_authn_method[authn_method](self).construct(cis, request_args, http_args, **kwargs) else: return http_args @@ -743,17 +732,12 @@ def parse_request_response( logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise ParseError("ERROR: Something went wrong: %s" % reqresp.text) - if reqresp.status_code in SUCCESSFUL or ( - reqresp.status_code in [400, 401] and response - ): + if reqresp.status_code in SUCCESSFUL or (reqresp.status_code in [400, 401] and response): verified_body_type = verify_header(reqresp, body_type) else: # Any other error logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) - raise HttpError( - "HTTP ERROR: %s [%s] on %s" - % (reqresp.text, reqresp.status_code, reqresp.url) - ) + raise HttpError("HTTP ERROR: %s [%s] on %s" % (reqresp.text, reqresp.status_code, reqresp.url)) # we expect some specific response message type, try to parse it if response: @@ -761,9 +745,7 @@ def parse_request_response( if verified_body_type is None: verified_body_type = "urlencoded" - return self.parse_response( - response, reqresp.text, verified_body_type, state, **kwargs - ) + return self.parse_response(response, reqresp.text, verified_body_type, state, **kwargs) # No one told us what to expect, so try to decode an error response if reqresp.status_code in [200, 400, 401]: @@ -842,9 +824,7 @@ def do_authorization_request( request_args = {"state": state} kwargs["authn_endpoint"] = "authorization" - url, body, ht_args, csi = self.request_info( - request, method, request_args, extra_args, **kwargs - ) + url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) try: self.authz_req[request_args["state"]] = csi @@ -1011,13 +991,9 @@ def do_any( else: http_args.update(ht_args) - return self.request_and_return( - url, response, method, body, body_type, state=state, http_args=http_args - ) + return self.request_and_return(url, response, method, body, body_type, state=state, http_args=http_args) - def fetch_protected_resource( - self, uri, method="GET", headers=None, state="", **kwargs - ): + def fetch_protected_resource(self, uri, method="GET", headers=None, state="", **kwargs): if "token" in kwargs and kwargs["token"]: token = kwargs["token"] request_args = {"access_token": token} @@ -1034,14 +1010,10 @@ def fetch_protected_resource( headers = {} if "authn_method" in kwargs: - http_args = self.init_authentication_method( - request_args=request_args, **kwargs - ) + http_args = self.init_authentication_method(request_args=request_args, **kwargs) else: # If nothing defined this is the default - http_args = self.client_authn_method["bearer_header"](self).construct( - request_args=request_args - ) + http_args = self.client_authn_method["bearer_header"](self).construct(request_args=request_args) headers.update(http_args["headers"]) @@ -1109,10 +1081,7 @@ def handle_provider_config( _issuer = issuer if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer: - raise PyoidcError( - "provider info issuer mismatch '%s' != '%s'" - % (_issuer, _pcr_issuer) - ) + raise PyoidcError("provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer)) self.provider_info = pcr else: @@ -1244,9 +1213,7 @@ def parse_jwt_request( areq.verify() return areq - def parse_body_request( - self, request: Type[Message] = AccessTokenRequest, body: Optional[str] = None - ): + def parse_body_request(self, request: Type[Message] = AccessTokenRequest, body: Optional[str] = None): req = request().deserialize(body, "urlencoded") req.verify() return req @@ -1255,8 +1222,6 @@ def parse_token_request(self, body: Optional[str] = None) -> AccessTokenRequest: request = self.message_factory.get_request_type("token_endpoint") return self.parse_body_request(request, body) - def parse_refresh_token_request( - self, body: Optional[str] = None - ) -> RefreshAccessTokenRequest: + def parse_refresh_token_request(self, body: Optional[str] = None) -> RefreshAccessTokenRequest: request = self.message_factory.get_request_type("refresh_endpoint") return self.parse_body_request(request, body) diff --git a/src/oic/oauth2/base.py b/src/oic/oauth2/base.py index b3446cdcf..bc9f136d1 100644 --- a/src/oic/oauth2/base.py +++ b/src/oic/oauth2/base.py @@ -136,8 +136,7 @@ def http_request(self, url: str, method="GET", **kwargs) -> requests.Response: ) except Exception as err: logger.error( - "http_request failed: %s, url: %s, htargs: %s, method: %s" - % (err, url, sanitize(_kwargs), method) + "http_request failed: %s, url: %s, htargs: %s, method: %s" % (err, url, sanitize(_kwargs), method) ) raise @@ -160,12 +159,8 @@ def http_request(self, url: str, method="GET", **kwargs) -> requests.Response: def send(self, url, method="GET", **kwargs): return self.http_request(url, method, **kwargs) - def load_cookies_from_file( - self, filename, ignore_discard=False, ignore_expires=False - ): + def load_cookies_from_file(self, filename, ignore_discard=False, ignore_expires=False): self.cookiejar.load(filename, ignore_discard, ignore_expires) - def save_cookies_to_file( - self, filename, ignore_discard=False, ignore_expires=False - ): + def save_cookies_to_file(self, filename, ignore_discard=False, ignore_expires=False): self.cookiejar.save(filename, ignore_discard, ignore_expires) diff --git a/src/oic/oauth2/consumer.py b/src/oic/oauth2/consumer.py index dbd99418f..068c56e71 100644 --- a/src/oic/oauth2/consumer.py +++ b/src/oic/oauth2/consumer.py @@ -285,9 +285,7 @@ def handle_authorization_response(self, query="", **kwargs): if "code" in self.response_type: # Might be an error response try: - aresp = self.parse_response( - AuthorizationResponse, info=query, sformat="urlencoded" - ) + aresp = self.parse_response(AuthorizationResponse, info=query, sformat="urlencoded") except Exception as err: logger.error("%s", err) raise @@ -305,9 +303,7 @@ def handle_authorization_response(self, query="", **kwargs): return aresp else: # implicit flow - atr = self.parse_response( - AccessTokenResponse, info=query, sformat="urlencoded", extended=True - ) + atr = self.parse_response(AccessTokenResponse, info=query, sformat="urlencoded", extended=True) if isinstance(atr, Message): if atr.type().endswith("ErrorResponse"): diff --git a/src/oic/oauth2/message.py b/src/oic/oauth2/message.py index 35c63d2b4..074422573 100644 --- a/src/oic/oauth2/message.py +++ b/src/oic/oauth2/message.py @@ -318,9 +318,7 @@ def to_dict(self, lev=0): if isinstance(val, Message): _res[key] = val.to_dict(lev + 1) - elif isinstance(val, list) and isinstance( - next(iter(val or []), None), Message - ): + elif isinstance(val, list) and isinstance(next(iter(val or []), None), Message): _res[key] = [v.to_dict(lev) for v in val] else: _res[key] = val @@ -341,9 +339,7 @@ def from_dict(self, dictionary, **kwargs): continue cparam = self._extract_cparam(key, _spec) if cparam is not None: - self._add_value( - key, cparam.type, key, val, cparam.deserializer, cparam.null_allowed - ) + self._add_value(key, cparam.type, key, val, cparam.deserializer, cparam.null_allowed) else: self._dict[key] = val return self @@ -362,9 +358,7 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): if vtyp is bool: self._dict[skey] = val else: - raise ParameterError( - '"{}", wrong type of value for "{}"'.format(val, skey) - ) + raise ParameterError('"{}", wrong type of value for "{}"'.format(val, skey)) elif isinstance(val, vtyp): # Not necessary to do anything self._dict[skey] = val else: @@ -377,15 +371,11 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): try: self._dict[skey] = int(val) except (ValueError, TypeError): - raise ParameterError( - '"{}", wrong type of value for "{}"'.format(val, skey) - ) + raise ParameterError('"{}", wrong type of value for "{}"'.format(val, skey)) else: return elif vtyp is bool: - raise ParameterError( - '"{}", wrong type of value for "{}"'.format(val, skey) - ) + raise ParameterError('"{}", wrong type of value for "{}"'.format(val, skey)) if isinstance(val, str): self._dict[skey] = val @@ -439,9 +429,7 @@ def _add_value_list(self, skey, vtype, key, val, _deser, null_allowed): else: for v in val: if not isinstance(v, vtype): - raise DecodeError( - ERRTXT % (key, "type != %s (%s)" % (vtype, type(v))) - ) + raise DecodeError(ERRTXT % (key, "type != %s (%s)" % (vtype, type(v)))) self._dict[skey] = val return if isinstance(val, dict): @@ -487,9 +475,7 @@ def _add_key(self, keyjar, issuer, key, key_type="", kid="", no_kid_issuer=None) logger.error('Issuer "{}" not in keyjar'.format(issuer)) return - logger.debug( - "Key set summary for {}: {}".format(issuer, key_summary(keyjar, issuer)) - ) + logger.debug("Key set summary for {}: {}".format(issuer, key_summary(keyjar, issuer))) if kid: _key = keyjar.get_key_by_kid(kid, issuer) @@ -608,13 +594,9 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): if "algs" in kwargs and "encalg" in kwargs["algs"]: if kwargs["algs"]["encalg"] != _jw["alg"]: - raise WrongEncryptionAlgorithm( - "%s != %s" % (_jw["alg"], kwargs["algs"]["encalg"]) - ) + raise WrongEncryptionAlgorithm("%s != %s" % (_jw["alg"], kwargs["algs"]["encalg"])) if kwargs["algs"]["encenc"] != _jw["enc"]: - raise WrongEncryptionAlgorithm( - "%s != %s" % (_jw["enc"], kwargs["algs"]["encenc"]) - ) + raise WrongEncryptionAlgorithm("%s != %s" % (_jw["enc"], kwargs["algs"]["encenc"])) if keyjar: dkeys = keyjar.get_decrypt_key(owner="") if "sender" in kwargs: @@ -640,9 +622,7 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): if "algs" in kwargs and "sign" in kwargs["algs"]: _alg = _jw.jwt.headers["alg"] if kwargs["algs"]["sign"] != _alg: - raise WrongSigningAlgorithm( - "%s != %s" % (_alg, kwargs["algs"]["sign"]) - ) + raise WrongSigningAlgorithm("%s != %s" % (_alg, kwargs["algs"]["sign"])) try: _jwt = JWT().unpack(txt) jso = _jwt.payload() @@ -662,9 +642,7 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): pass elif verify: if keyjar: - key = self.get_verify_keys( - keyjar, key, jso, _header, _jw, **kwargs - ) + key = self.get_verify_keys(keyjar, key, jso, _header, _jw, **kwargs) if "alg" in _header and _header["alg"] != "none": if not key: @@ -676,9 +654,7 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): except NoSuitableSigningKeys: if keyjar: update_keyjar(keyjar) - key = self.get_verify_keys( - keyjar, key, jso, _header, _jw, **kwargs - ) + key = self.get_verify_keys(keyjar, key, jso, _header, _jw, **kwargs) _jw.verify_compact(txt, key) except Exception: raise @@ -725,7 +701,7 @@ def verify(self, **kwargs): if cparam.required: raise MissingRequiredAttribute("%s" % attribute) continue - if cparam.type != bool and not val: + if cparam.type is not bool and not val: if cparam.required: raise MissingRequiredAttribute("%s" % attribute) continue @@ -743,9 +719,7 @@ def verify(self, **kwargs): else: raise NotAllowedValue(val) else: - self._type_check( - cparam.type, _allowed[attribute], val, cparam.null_allowed - ) + self._type_check(cparam.type, _allowed[attribute], val, cparam.null_allowed) return True @@ -818,9 +792,7 @@ def __len__(self): return len(self._dict) def extra(self): - return dict( - [(key, val) for key, val in self._dict.items() if key not in self.c_param] - ) + return dict([(key, val) for key, val in self._dict.items() if key not in self.c_param]) def only_extras(self): extras = [key for key in self._dict.keys() if key in self.c_param] @@ -957,21 +929,11 @@ def json_deserializer(txt, sformat="urlencoded"): SINGLE_REQUIRED_STRING = ParamDefinition(str, True, None, None, False) SINGLE_OPTIONAL_STRING = ParamDefinition(str, False, None, None, False) SINGLE_OPTIONAL_INT = ParamDefinition(int, False, None, None, False) -OPTIONAL_LIST_OF_STRINGS = ParamDefinition( - [str], False, list_serializer, list_deserializer, False -) -REQUIRED_LIST_OF_STRINGS = ParamDefinition( - [str], True, list_serializer, list_deserializer, False -) -OPTIONAL_LIST_OF_SP_SEP_STRINGS = ParamDefinition( - [str], False, sp_sep_list_serializer, sp_sep_list_deserializer, False -) -REQUIRED_LIST_OF_SP_SEP_STRINGS = ParamDefinition( - [str], True, sp_sep_list_serializer, sp_sep_list_deserializer, False -) -SINGLE_OPTIONAL_JSON = ParamDefinition( - str, False, json_serializer, json_deserializer, False -) +OPTIONAL_LIST_OF_STRINGS = ParamDefinition([str], False, list_serializer, list_deserializer, False) +REQUIRED_LIST_OF_STRINGS = ParamDefinition([str], True, list_serializer, list_deserializer, False) +OPTIONAL_LIST_OF_SP_SEP_STRINGS = ParamDefinition([str], False, sp_sep_list_serializer, sp_sep_list_deserializer, False) +REQUIRED_LIST_OF_SP_SEP_STRINGS = ParamDefinition([str], True, sp_sep_list_serializer, sp_sep_list_deserializer, False) +SINGLE_OPTIONAL_JSON = ParamDefinition(str, False, json_serializer, json_deserializer, False) REQUIRED = [ SINGLE_REQUIRED_STRING, diff --git a/src/oic/oauth2/provider.py b/src/oic/oauth2/provider.py index 78a104ed2..bb73f873f 100644 --- a/src/oic/oauth2/provider.py +++ b/src/oic/oauth2/provider.py @@ -226,8 +226,7 @@ def __init__( self.sdb = sdb if not isinstance(cdb, BaseClientDatabase): warnings.warn( - "ClientDatabase should be an instance of " - "oic.utils.clientdb.BaseClientDatabase to ensure proper API." + "ClientDatabase should be an instance of " "oic.utils.clientdb.BaseClientDatabase to ensure proper API." ) self.cdb = cdb self.server = server_cls( @@ -280,9 +279,7 @@ def __init__( if capabilities: self.verify_capabilities(capabilities) - self.capabilities = message_factory.get_response_type( - "configuration_endpoint" - )(**capabilities) + self.capabilities = message_factory.get_response_type("configuration_endpoint")(**capabilities) else: self.capabilities = self.provider_features() self.capabilities["issuer"] = self.name @@ -405,9 +402,7 @@ def verify_capabilities(self, capabilities) -> bool: if unsup: not_supported[key] = unsup if not_supported: - logger.error( - "Server does not support the following features: %s", not_supported - ) + logger.error("Server does not support the following features: %s", not_supported) return False return True @@ -417,9 +412,7 @@ def provider_features(self, provider_config=None): :return: ProviderConfigurationResponse instance """ - pcr_class = self.server.message_factory.get_response_type( - "configuration_endpoint" - ) + pcr_class = self.server.message_factory.get_response_type("configuration_endpoint") _provider_info = pcr_class(**self.default_capabilities) _provider_info["scopes_supported"] = self.scopes @@ -446,9 +439,7 @@ def create_providerinfo(self, setup=None): :param setup: :return: """ - pcr_class = self.server.message_factory.get_response_type( - "configuration_endpoint" - ) + pcr_class = self.server.message_factory.get_response_type("configuration_endpoint") _provider_info = copy.deepcopy(self.capabilities.to_dict()) if self.jwks_uri and self.keyjar: @@ -459,9 +450,7 @@ def create_providerinfo(self, setup=None): baseurl = self.baseurl + "/" else: baseurl = self.baseurl - _provider_info["{}_endpoint".format(endp.etype)] = urljoin( - baseurl, endp.url - ) + _provider_info["{}_endpoint".format(endp.etype)] = urljoin(baseurl, endp.url) if setup and isinstance(setup, dict): for key in pcr_class.c_param.keys(): @@ -488,14 +477,10 @@ def providerinfo_endpoint(self, handle="", **kwargs): if handle: (key, timestamp) = handle if key.startswith(STR) and key.endswith(STR): - cookie = self.cookie_func( - key, self.cookie_name, "pinfo", self.sso_ttl - ) + cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl) headers.append(cookie) - resp = Response( - _response.to_json(), content="application/json", headers=headers - ) + resp = Response(_response.to_json(), content="application/json", headers=headers) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) @@ -515,9 +500,7 @@ def get_redirect_uri(self, areq): self._verify_redirect_uri(areq) uri = areq["redirect_uri"] else: - raise ParameterError( - "Missing redirect_uri and more than one or none registered" - ) + raise ParameterError("Missing redirect_uri and more than one or none registered") return uri @@ -544,9 +527,7 @@ def pick_auth(self, areq, comparision_type=""): for acr in areq["acr_values"]: res = self.authn_broker.pick(acr, comparision_type) - logger.debug( - "Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)) - ) + logger.debug("Picked AuthN broker for ACR %s: %s" % (str(acr), str(res))) if res: # Return the best guess by pick. return res[0] @@ -558,17 +539,13 @@ def pick_auth(self, areq, comparision_type=""): else: for acr in acrs: res = self.authn_broker.pick(acr, comparision_type) - logger.debug( - "Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)) - ) + logger.debug("Picked AuthN broker for ACR %s: %s" % (str(acr), str(res))) if res: # Return the best guess by pick. return res[0] except KeyError as exc: - logger.debug( - "An error occured while picking the authN broker: %s" % str(exc) - ) + logger.debug("An error occured while picking the authN broker: %s" % str(exc)) # return the best I have return None, None @@ -583,9 +560,7 @@ def auth_init(self, request): :param request: The AuthorizationRequest :return: """ - request_class = self.server.message_factory.get_request_type( - "authorization_endpoint" - ) + request_class = self.server.message_factory.get_request_type("authorization_endpoint") logger.debug("Request: '%s'" % sanitize(request)) # Same serialization used for GET and POST @@ -612,9 +587,7 @@ def auth_init(self, request): except KeyError: _state = "" - return redirect_authz_error( - "invalid_request", redirect_uri, "%s" % err, _state, _rtype - ) + return redirect_authz_error("invalid_request", redirect_uri, "%s" % err, _state, _rtype) except KeyError: areq = request_class().deserialize(request, "urlencoded") # verify the redirect_uri @@ -644,9 +617,7 @@ def auth_init(self, request): try: _cinfo = self.cdb[areq["client_id"]] except KeyError: - logger.error( - "Client ID ({}) not in client database".format(areq["client_id"]) - ) + logger.error("Client ID ({}) not in client database".format(areq["client_id"])) return error_response("unauthorized_client", "unknown client") else: try: @@ -658,17 +629,13 @@ def auth_init(self, request): _wanted = set(areq["response_type"]) if _wanted not in _registered: - return error_response( - "invalid_request", "Trying to use unregistered response_typ" - ) + return error_response("invalid_request", "Trying to use unregistered response_typ") logger.debug("AuthzRequest: %s" % (sanitize(areq.to_dict()),)) try: redirect_uri = self.get_redirect_uri(areq) except (RedirectURIError, ParameterError, UnknownClient) as err: - return error_response( - "invalid_request", "{}:{}".format(err.__class__.__name__, err) - ) + return error_response("invalid_request", "{}:{}".format(err.__class__.__name__, err)) try: keyjar = self.keyjar @@ -721,9 +688,7 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): tup = (None, None) for acr in acrs: res = self.authn_broker.pick(acr, "exact") - logger.debug( - "Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)) - ) + logger.debug("Picked AuthN broker for ACR %s: %s" % (str(acr), str(res))) if res: # Return the best guess by pick. tup = res[0] break @@ -736,9 +701,7 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): authn, authn_class_ref = self.pick_auth(areq, "any") if authn is None: - return redirect_authz_error( - "access_denied", redirect_uri, return_type=areq["response_type"] - ) + return redirect_authz_error("access_denied", redirect_uri, return_type=areq["response_type"]) try: try: @@ -751,9 +714,7 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): else: _max_age = max_age(areq) - identity, _ts = authn.authenticated_as( - cookie, authorization=_auth_info, max_age=_max_age - ) + identity, _ts = authn.authenticated_as(cookie, authorization=_auth_info, max_age=_max_age) except (NoSuchAuthentication, TamperAllert): identity = None _ts = 0 @@ -793,9 +754,7 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): if identity is None: # No! if "prompt" in areq and "none" in areq["prompt"]: # Need to authenticate but not allowed - return redirect_authz_error( - "login_required", redirect_uri, return_type=areq["response_type"] - ) + return redirect_authz_error("login_required", redirect_uri, return_type=areq["response_type"]) else: return authn(**authn_args) else: @@ -807,11 +766,7 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): user = identity["uid"] if "req_user" in kwargs: sids_for_sub = self.sdb.get_by_sub(kwargs["req_user"]) - if ( - sids_for_sub - and user - != self.sdb.get_authentication_event(sids_for_sub[-1]).uid - ): + if sids_for_sub and user != self.sdb.get_authentication_event(sids_for_sub[-1]).uid: logger.debug("Wanted to be someone else!") if "prompt" in areq and "none" in areq["prompt"]: # Need to authenticate but not allowed @@ -846,9 +801,7 @@ def authorization_endpoint(self, request="", cookie="", **kwargs): _cid = info["areq"]["client_id"] cinfo = self.cdb[_cid] - authnres = self.do_auth( - info["areq"], info["redirect_uri"], cinfo, request, cookie, **kwargs - ) + authnres = self.do_auth(info["areq"], info["redirect_uri"], cinfo, request, cookie, **kwargs) if isinstance(authnres, Response): return authnres @@ -866,9 +819,7 @@ def aresp_check(self, aresp, areq): def create_authn_response(self, areq, sid): rtype = areq["response_type"][0] _func = self.response_type_map[rtype] - aresp = _func( - areq=areq, scode=self.sdb[sid]["code"], sdb=self.sdb, myself=self.baseurl - ) + aresp = _func(areq=areq, scode=self.sdb[sid]["code"], sdb=self.sdb, myself=self.baseurl) if rtype == "code": fragment_enc = False @@ -961,13 +912,9 @@ def _complete_authz(self, user, areq, sid, **kwargs): cookie_header = None if _kaka is not None: if self.cookie_name not in _kaka: # Don't overwrite - cookie_header = self.cookie_func( - c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl - ) + cookie_header = self.cookie_func(c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl) else: - cookie_header = self.cookie_func( - c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl - ) + cookie_header = self.cookie_func(c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl) if cookie_header is not None: headers.append(cookie_header) @@ -1007,18 +954,14 @@ def token_endpoint(self, request="", authn="", dtype="urlencoded", **kwargs): logger.debug("- token -") logger.debug("token_request: %s" % sanitize(request)) - areq = self.server.message_factory.get_request_type( - "token_endpoint" - )().deserialize(request, dtype) + areq = self.server.message_factory.get_request_type("token_endpoint")().deserialize(request, dtype) # Verify client authentication try: client_id = self.client_authn(self, areq, authn) except (FailedAuthentication, AuthnFailure) as err: logger.error("%s", err) - error = TokenErrorResponse( - error="unauthorized_client", error_description="%s" % err - ) + error = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) return Unauthorized(error.to_json(), content="application/json") logger.debug("AccessTokenRequest: %s" % sanitize(areq)) @@ -1029,19 +972,14 @@ def token_endpoint(self, request="", authn="", dtype="urlencoded", **kwargs): _info = self.sdb[areq["code"]] except KeyError: logger.error("Code not present in SessionDB") - error = TokenErrorResponse( - error="unauthorized_client", error_description="Invalid code." - ) + error = TokenErrorResponse(error="unauthorized_client", error_description="Invalid code.") return Unauthorized(error.to_json(), content="application/json") resp = self.token_scope_check(areq, _info) if resp: return resp # If redirect_uri was in the initial authorization request verify that they match - if ( - "redirect_uri" in _info - and areq["redirect_uri"] != _info["redirect_uri"] - ): + if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: logger.error("Redirect_uri mismatch") error = TokenErrorResponse( error="unauthorized_client", @@ -1080,9 +1018,7 @@ def code_grant_type(self, areq): try: _tinfo = self.sdb.upgrade_to_token(areq["code"], issue_refresh=True) except AccessCodeUsed: - error = TokenErrorResponse( - error="invalid_grant", error_description="Access grant used" - ) + error = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") return Unauthorized(error.to_json(), content="application/json") logger.debug("_tinfo: %s" % sanitize(_tinfo)) @@ -1091,9 +1027,7 @@ def code_grant_type(self, areq): logger.debug("AccessTokenResponse: %s" % sanitize(atr)) - return Response( - atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS - ) + return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) def refresh_token_grant_type(self, areq): """ diff --git a/src/oic/oauth2/util.py b/src/oic/oauth2/util.py index 0eb4e430d..5c45524c7 100644 --- a/src/oic/oauth2/util.py +++ b/src/oic/oauth2/util.py @@ -57,9 +57,7 @@ ENCODINGS = Literal["json", "urlencoded", "dict", "jwt", "jwe"] -def get_or_post( - uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, accept=None, **kwargs -): +def get_or_post(uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, accept=None, **kwargs): """ Construct HTTP request. @@ -79,9 +77,7 @@ def get_or_post( _req.update(parse_qs(comp.query)) _query = str(_req.to_urlencoded()) - path = urlunsplit( - (comp.scheme, comp.netloc, comp.path, _query, comp.fragment) - ) + path = urlunsplit((comp.scheme, comp.netloc, comp.path, _query, comp.fragment)) else: path = uri body = None @@ -144,10 +140,7 @@ def set_cookie(cookiejar, kaka): std_attr["expires"] = http2time(morsel[attr]) except TimeFormatError: # Ignore cookie - logger.info( - "Time format error on %s parameter in received cookie" - % (sanitize(attr),) - ) + logger.info("Time format error on %s parameter in received cookie" % (sanitize(attr),)) continue for att, spec in PAIRS.items(): @@ -223,14 +216,13 @@ def verify_header(reqresp, body_type: Optional[ENCODINGS]) -> Optional[ENCODINGS if match_to_("application/jwt", reqresp.headers["content-type"]): body_type = "jwt" else: - raise ValueError( - "content-type: %s" % (reqresp.headers["content-type"],) - ) + raise ValueError("content-type: %s" % (reqresp.headers["content-type"],)) elif body_type == "jwt": if not match_to_("application/jwt", reqresp.headers["content-type"]): raise ValueError( - "Wrong content-type in header, got: {} expected " - "'application/jwt'".format(reqresp.headers["content-type"]) + "Wrong content-type in header, got: {} expected " "'application/jwt'".format( + reqresp.headers["content-type"] + ) ) elif body_type == "urlencoded": if not match_to_(DEFAULT_POST_CONTENT_TYPE, reqresp.headers["content-type"]): diff --git a/src/oic/oic/__init__.py b/src/oic/oic/__init__.py index 55cbb8800..e6438ac9d 100644 --- a/src/oic/oic/__init__.py +++ b/src/oic/oic/__init__.py @@ -451,9 +451,7 @@ def request_object_encryption(self, msg, **kwargs): try: encenc = self.behaviour["request_object_encryption_enc"] except KeyError: - raise MissingRequiredAttribute( - "No request_object_encryption_enc specified" - ) + raise MissingRequiredAttribute("No request_object_encryption_enc specified") _jwe = JWE(msg, alg=encalg, enc=encenc) _kty = jwe.alg2keytype(encalg) @@ -498,9 +496,7 @@ def filename_from_webname(self, webname): else: raise ValueError("Invalid webname, must start with base_url") - def construct_AuthorizationRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_AuthorizationRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request_args is not None: if "nonce" not in request_args: _rt = request_args["response_type"] @@ -570,9 +566,7 @@ def construct_AuthorizationRequest( return areq - def construct_UserInfoRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_UserInfoRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("userinfo_endpoint") if request_args is None: @@ -591,16 +585,12 @@ def construct_UserInfoRequest( return self.construct_request(request, request_args, extra_args) - def construct_RegistrationRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_RegistrationRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("registration_endpoint") return self.construct_request(request, request_args, extra_args) - def construct_RefreshSessionRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_RefreshSessionRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("refreshsession_endpoint") return self.construct_request(request, request_args, extra_args) @@ -625,23 +615,17 @@ def _id_token_based(self, request, request_args=None, extra_args=None, **kwargs) return self.construct_request(request, request_args, extra_args) - def construct_CheckSessionRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_CheckSessionRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("checksession_endpoint") return self._id_token_based(request, request_args, extra_args, **kwargs) - def construct_CheckIDRequest( - self, request=None, request_args=None, extra_args=None, **kwargs - ): + def construct_CheckIDRequest(self, request=None, request_args=None, extra_args=None, **kwargs): if request is None: request = self.message_factory.get_request_type("checkid_endpoint") # access_token is where the id_token will be placed - return self._id_token_based( - request, request_args, extra_args, prop="access_token", **kwargs - ) + return self._id_token_based(request, request_args, extra_args, prop="access_token", **kwargs) def construct_EndSessionRequest( self, @@ -659,9 +643,7 @@ def construct_EndSessionRequest( if "state" in request_args and "state" not in kwargs: kwargs["state"] = request_args["state"] - return self._id_token_based( - request, request_args, extra_args, prop=prop, **kwargs - ) + return self._id_token_based(request, request_args, extra_args, prop=prop, **kwargs) def do_authorization_request( self, @@ -750,9 +732,7 @@ def do_registration_request( http_args.update(http_args) response_cls = self.message_factory.get_response_type("registration_endpoint") - response = self.request_and_return( - url, response_cls, method, body, body_type, state=state, http_args=http_args - ) + response = self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) return response def do_check_session_request( @@ -782,9 +762,7 @@ def do_check_session_request( else: http_args.update(http_args) - return self.request_and_return( - url, response_cls, method, body, body_type, state=state, http_args=http_args - ) + return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def do_check_id_request( self, @@ -813,9 +791,7 @@ def do_check_id_request( else: http_args.update(http_args) - return self.request_and_return( - url, response_cls, method, body, body_type, state=state, http_args=http_args - ) + return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def do_end_session_request( self, @@ -845,9 +821,7 @@ def do_end_session_request( else: http_args.update(http_args) - return self.request_and_return( - url, response_cls, method, body, body_type, state=state, http_args=http_args - ) + return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def user_info_request(self, method="GET", state="", scope="", **kwargs): uir = self.message_factory.get_request_type("userinfo_endpoint")() @@ -872,11 +846,7 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): raise AccessDenied("invalid_token") if token.is_valid(): uir["access_token"] = token.access_token - if ( - token.token_type - and token.token_type.lower() == "bearer" - and method == "GET" - ): + if token.token_type and token.token_type.lower() == "bearer" and method == "GET": kwargs["behavior"] = "use_authorization_header" else: # raise oauth2.OldAccessToken @@ -918,9 +888,7 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): elif token: # use_authorization_header, token_in_message_body if "use_authorization_header" in _behav: - token_header = "{type} {token}".format( - type=_ttype.capitalize(), token=_token - ) + token_header = "{type} {token}".format(type=_ttype.capitalize(), token=_token) if "headers" in kwargs: kwargs["headers"].update({"Authorization": token_header}) else: @@ -936,17 +904,12 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): return path, body, method, h_args - def do_user_info_request( - self, method="POST", state="", scope="openid", request="openid", **kwargs - ): + def do_user_info_request(self, method="POST", state="", scope="openid", request="openid", **kwargs): kwargs["request"] = request - path, body, method, h_args = self.user_info_request( - method, state, scope, **kwargs - ) + path, body, method, h_args = self.user_info_request(method, state, scope, **kwargs) logger.debug( - "[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" - % (sanitize(path), sanitize(body), sanitize(h_args)) + "[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" % (sanitize(path), sanitize(body), sanitize(h_args)) ) if self.events: @@ -965,17 +928,13 @@ def do_user_info_request( elif "application/jwt" in resp.headers["content-type"]: sformat = "jwt" else: - raise PyoidcError( - "ERROR: Unexpected content-type: %s" % resp.headers["content-type"] - ) + raise PyoidcError("ERROR: Unexpected content-type: %s" % resp.headers["content-type"]) elif resp.status_code == 500: raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) elif resp.status_code == 405: # Method not allowed error allowed_methods = [x.strip() for x in resp.headers["allow"].split(",")] - raise CommunicationError( - "Server responded with HTTP Error Code 405", "", allowed_methods - ) + raise CommunicationError("Server responded with HTTP Error Code 405", "", allowed_methods) elif 400 <= resp.status_code < 500: # the response text might be a OIDC message try: @@ -986,9 +945,7 @@ def do_user_info_request( self.store_response(res, resp.text) return res else: - raise PyoidcError( - "ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text) - ) + raise PyoidcError("ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text)) try: _schema = kwargs["user_info_schema"] @@ -1017,17 +974,13 @@ def do_user_info_request( idt = self.get_grant(state).get_id_token() if idt: if idt["sub"] != res["sub"]: - raise SubMismatch( - "Sub identifier not the same in userinfo and Id Token" - ) + raise SubMismatch("Sub identifier not the same in userinfo and Id Token") self.store_response(res, _txt) return res - def get_userinfo_claims( - self, access_token, endpoint, method="POST", schema_class=OpenIDSchema, **kwargs - ): + def get_userinfo_claims(self, access_token, endpoint, method="POST", schema_class=OpenIDSchema, **kwargs): uir = UserInfoRequest(access_token=access_token) h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS]) @@ -1050,16 +1003,11 @@ def get_userinfo_claims( # FIXME: Could this also encounter application/jwt for encrypted userinfo # the do_userinfo_request method already handles it if "application/json" not in resp.headers["content-type"]: - raise PyoidcError( - "ERROR: content-type in response unexpected: %s" - % resp.headers["content-type"] - ) + raise PyoidcError("ERROR: content-type in response unexpected: %s" % resp.headers["content-type"]) elif resp.status_code == 500: raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) else: - raise PyoidcError( - "ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text) - ) + raise PyoidcError("ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text)) res = schema_class().from_json(txt=resp.text) self.store_response(res, resp.text) @@ -1069,20 +1017,11 @@ def unpack_aggregated_claims(self, userinfo): if userinfo["_claim_sources"]: for csrc, spec in userinfo["_claim_sources"].items(): if "JWT" in spec: - aggregated_claims = Message().from_jwt( - spec["JWT"].encode("utf-8"), keyjar=self.keyjar, sender=csrc - ) - claims = [ - value - for value, src in userinfo["_claim_names"].items() - if src == csrc - ] + aggregated_claims = Message().from_jwt(spec["JWT"].encode("utf-8"), keyjar=self.keyjar, sender=csrc) + claims = [value for value, src in userinfo["_claim_names"].items() if src == csrc] if set(claims) != set(list(aggregated_claims.keys())): - logger.warning( - "Claims from claim source doesn't match what's in " - "the userinfo" - ) + logger.warning("Claims from claim source doesn't match what's in " "the userinfo") for key, vals in aggregated_claims.items(): userinfo[key] = vals @@ -1119,17 +1058,10 @@ def fetch_distributed_claims(self, userinfo, callback=None): verify=False, ) - claims = [ - value - for value, src in userinfo["_claim_names"].items() - if src == csrc - ] + claims = [value for value, src in userinfo["_claim_names"].items() if src == csrc] if set(claims) != set(list(_uinfo.keys())): - logger.warning( - "Claims from claim source doesn't match what's in " - "the userinfo" - ) + logger.warning("Claims from claim source doesn't match what's in " "the userinfo") for key, vals in _uinfo.items(): userinfo[key] = vals @@ -1262,9 +1194,7 @@ def handle_registration_info(self, response): err_msg = "Got error response: {}" unk_msg = "Unknown response: {}" if response.status_code in [200, 201]: - resp = self.message_factory.get_response_type( - "registration_endpoint" - )().deserialize(response.text, "json") + resp = self.message_factory.get_response_type("registration_endpoint")().deserialize(response.text, "json") # Some implementations sends back a 200 with an error message inside try: resp.verify() @@ -1371,9 +1301,7 @@ def create_registration_request(self, **kwargs): pass if "response_types" in req: - req["grant_types"] = response_types_to_grant_types( - req["response_types"], **kwargs - ) + req["grant_types"] = response_types_to_grant_types(req["response_types"], **kwargs) return req @@ -1476,14 +1404,9 @@ def _verify_id_token( raise OtherError("Passed best before date") if response_type != ["code"] and id_token.jws_header["alg"] == "none": - raise WrongSigningAlgorithm( - "none is not allowed outside Authorization Flow." - ) + raise WrongSigningAlgorithm("none is not allowed outside Authorization Flow.") - if ( - self.id_token_max_age - and _now > int(id_token["iat"]) + self.id_token_max_age - ): + if self.id_token_max_age and _now > int(id_token["iat"]) + self.id_token_max_age: raise OtherError("I think this ID token is to old") if nonce and nonce != id_token["nonce"]: @@ -1593,15 +1516,9 @@ def handle_request_uri(self, request_uri, verify=True, sender=""): # http_req.text is a signed JWT try: logger.debug("request txt: {}".format(http_req.text)) - req = self.parse_jwt_request( - txt=http_req.text, verify=verify, sender=sender - ) + req = self.parse_jwt_request(txt=http_req.text, verify=verify, sender=sender) except Exception as err: - logger.error( - "{}:{} encountered while parsing fetched request".format( - err.__class__, err - ) - ) + logger.error("{}:{} encountered while parsing fetched request".format(err.__class__, err)) raise AuthzError("invalid_openid_request_object") logger.debug("Fetched request: {}".format(req)) @@ -1632,21 +1549,15 @@ def parse_authorization_request(self, url=None, query=None, keys=None): except KeyError: pass else: - _req_req = self.handle_request_uri( - _url, verify=False, sender=_req["client_id"] - ) + _req_req = self.handle_request_uri(_url, verify=False, sender=_req["client_id"]) else: if isinstance(_request, Message): _req_req = _request else: try: - _req_req = self.parse_jwt_request( - request, txt=_request, verify=False - ) + _req_req = self.parse_jwt_request(request, txt=_request, verify=False) except Exception: - _req_req = self._parse_request( - request, _request, "urlencoded", verify=False - ) + _req_req = self._parse_request(request, _request, "urlencoded", verify=False) else: # remove JWT attributes for attr in JasonWebToken.c_param: try: @@ -1698,9 +1609,7 @@ def parse_jwt_request( DeprecationWarning, stacklevel=2, ) - return super().parse_jwt_request( - request=request, txt=txt, keyjar=keyjar, verify=verify, sender=sender - ) + return super().parse_jwt_request(request=request, txt=txt, keyjar=keyjar, verify=verify, sender=sender) def parse_check_session_request(self, url=None, query=None): param = self._parse_urlencoded(url, query) @@ -1727,9 +1636,7 @@ def _parse_request(self, request_cls, data, sformat, client_id=None, verify=True elif sformat == "dict": request = request_cls(**data) else: - raise ParseError( - "Unknown package format: '{}'".format(sformat), request_cls - ) + raise ParseError("Unknown package format: '{}'".format(sformat), request_cls) # get the verification keys if client_id: diff --git a/src/oic/oic/claims_provider.py b/src/oic/oic/claims_provider.py index adc504df2..ef78eabd7 100644 --- a/src/oic/oic/claims_provider.py +++ b/src/oic/oic/claims_provider.py @@ -155,9 +155,7 @@ def claims_endpoint(self, request, http_authz, *args): _log_info("User info claims: %s" % sanitize(uic)) # oicsrv, userdb, subject, client_id="", user_info_claims=None - info = self.userinfo( - ucreq["sub"], user_info_claims=uic, client_id=ucreq["client_id"] - ) + info = self.userinfo(ucreq["sub"], user_info_claims=uic, client_id=ucreq["client_id"]) _log_info("User info: %s" % sanitize(info)) @@ -204,9 +202,7 @@ def __init__(self, client_id=None, verify_ssl=None, settings=None): self.response2error = RESPONSE2ERROR.copy() self.response2error["UserClaimsResponse"] = ["ErrorResponse"] - def construct_UserClaimsRequest( - self, request=UserClaimsRequest, request_args=None, extra_args=None, **kwargs - ): + def construct_UserClaimsRequest(self, request=UserClaimsRequest, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) def do_claims_request( diff --git a/src/oic/oic/consumer.py b/src/oic/oic/consumer.py index e0a856ab2..fff8bea25 100644 --- a/src/oic/oic/consumer.py +++ b/src/oic/oic/consumer.py @@ -293,24 +293,18 @@ def begin(self, scope="", response_type="", use_nonce=False, path="", **kwargs): _claims = None if "user_info" in self.consumer_config: - _claims = ClaimsRequest( - userinfo=Claims(**self.consumer_config["user_info"]) - ) + _claims = ClaimsRequest(userinfo=Claims(**self.consumer_config["user_info"])) if "id_token" in self.consumer_config: if _claims: _claims["id_token"] = Claims(**self.consumer_config["id_token"]) else: - _claims = ClaimsRequest( - id_token=Claims(**self.consumer_config["id_token"]) - ) + _claims = ClaimsRequest(id_token=Claims(**self.consumer_config["id_token"])) if _claims: args["claims"] = _claims if "request_method" in self.consumer_config: - areq = self.construct_AuthorizationRequest( - request_args=args, extra_args=None, request_param="request" - ) + areq = self.construct_AuthorizationRequest(request_args=args, extra_args=None, request_param="request") if self.consumer_config["request_method"] == "file": id_request = areq["request"] @@ -333,9 +327,7 @@ def begin(self, scope="", response_type="", use_nonce=False, path="", **kwargs): if "userinfo_claims" in args: # can only be carried in an IDRequest raise PyoidcError("Need a request method") - areq = self.construct_AuthorizationRequest( - AuthorizationRequest, request_args=args - ) + areq = self.construct_AuthorizationRequest(AuthorizationRequest, request_args=args) location = areq.request(self.authorization_endpoint) @@ -349,9 +341,7 @@ def _parse_authz(self, query="", **kwargs): _log_info = logger.info # Might be an error response _log_info("Expect Authorization Response") - aresp = self.parse_response( - AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar - ) + aresp = self.parse_response(AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar) if isinstance(aresp, ErrorResponse): _log_info("ErrorResponse: %s" % sanitize(aresp)) raise AuthzError(aresp.get("error"), aresp) @@ -367,7 +357,9 @@ def _parse_authz(self, query="", **kwargs): self.redirect_uris = [self.sdb[_state]["redirect_uris"]] return aresp, _state - def parse_authz(self, query="", **kwargs) -> Union[ + def parse_authz( + self, query="", **kwargs + ) -> Union[ http_util.BadRequest, Tuple[ Optional[AuthorizationResponse], @@ -554,9 +546,7 @@ def end_session(self): # LOGOUT related - def backchannel_logout( - self, request: Optional[str] = None, request_args: Optional[Dict] = None - ) -> str: + def backchannel_logout(self, request: Optional[str] = None, request_args: Optional[Dict] = None) -> str: """ Receives a back channel logout request. @@ -587,8 +577,6 @@ def backchannel_logout( sm_id = req["logout_token"]["sid"] _sid = session_get(self.sso_db, "smid", sm_id) else: - _sid = session_extended_get( - self.sso_db, sub, "issuer", req["logout_token"]["iss"] - ) + _sid = session_extended_get(self.sso_db, sub, "issuer", req["logout_token"]["iss"]) return _sid diff --git a/src/oic/oic/message.py b/src/oic/oic/message.py index e4cb2de54..f9806358a 100644 --- a/src/oic/oic/message.py +++ b/src/oic/oic/message.py @@ -236,18 +236,12 @@ def claims_request_deser(val, sformat="json"): OPTIONAL_ADDRESS = ParamDefinition(Message, False, msg_ser, address_deser, False) OPTIONAL_LOGICAL = ParamDefinition(bool, False, None, None, False) -OPTIONAL_MULTIPLE_Claims = ParamDefinition( - Message, False, claims_ser, claims_deser, False -) +OPTIONAL_MULTIPLE_Claims = ParamDefinition(Message, False, claims_ser, claims_deser, False) SINGLE_OPTIONAL_IDTOKEN = ParamDefinition(str, False, msg_ser, None, False) -SINGLE_OPTIONAL_REGISTRATION_REQUEST = ParamDefinition( - Message, False, msg_ser, registration_request_deser, False -) -SINGLE_OPTIONAL_CLAIMSREQ = ParamDefinition( - Message, False, msg_ser_json, claims_request_deser, False -) +SINGLE_OPTIONAL_REGISTRATION_REQUEST = ParamDefinition(Message, False, msg_ser, registration_request_deser, False) +SINGLE_OPTIONAL_CLAIMSREQ = ParamDefinition(Message, False, msg_ser_json, claims_request_deser, False) OPTIONAL_MESSAGE = ParamDefinition(Message, False, msg_ser, message_deser, False) REQUIRED_MESSAGE = ParamDefinition(Message, True, msg_ser, message_deser, False) @@ -256,7 +250,7 @@ def claims_request_deser(val, sformat="json"): SCOPE_CHARSET = [] -for char in ["\x21", ("\x23", "\x5b"), ("\x5d", "\x7E")]: +for char in ["\x21", ("\x23", "\x5b"), ("\x5d", "\x7e")]: if isinstance(char, tuple): c = char[0] while c <= char[1]: @@ -517,11 +511,7 @@ class AccessTokenRequest(message.AccessTokenRequest): } ) c_default = {"grant_type": "authorization_code"} - c_allowed_values = { - "client_assertion_type": [ - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" - ] - } + c_allowed_values = {"client_assertion_type": ["urn:ietf:params:oauth:client-assertion-type:jwt-bearer"]} class AddressClaim(Message): @@ -565,9 +555,7 @@ def from_dict(self, dictionary, **kwargs): result = super().from_dict(dictionary, **kwargs) # The spec allows empty fields in the UserInfo/IdToken response, but suggests # the OP should omit those. So lets drop them here. - for key_ in [ - key_ for key_, val in self._dict.items() if val is None or val == "" - ]: + for key_ in [key_ for key_, val in self._dict.items() if val is None or val == ""]: del self[key_] return result @@ -640,9 +628,7 @@ class RegistrationRequest(Message): def verify(self, **kwargs): super().verify(**kwargs) - if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith( - "https:" - ): + if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith("https:"): raise AssertionError() for param in [ @@ -660,10 +646,7 @@ def verify(self, **kwargs): if enc_param in self and alg_param not in self: raise AssertionError() - if ( - "token_endpoint_auth_signing_alg" in self - and self["token_endpoint_auth_signing_alg"] == "none" - ): + if "token_endpoint_auth_signing_alg" in self and self["token_endpoint_auth_signing_alg"] == "none": raise AssertionError() return True @@ -697,10 +680,7 @@ def verify(self, **kwargs): has_reg_at = "registration_access_token" in self if has_reg_uri != has_reg_at: raise VerificationError( - ( - "Only one of registration_client_uri" - " and registration_access_token present" - ), + ("Only one of registration_client_uri" " and registration_access_token present"), self, ) @@ -765,9 +745,7 @@ def verify(self, **kwargs): if "azp" in self: if "client_id" in kwargs: if kwargs["client_id"] != self["azp"]: - raise NotForMe( - "{} != azp:{}".format(kwargs["client_id"], self["azp"]), self - ) + raise NotForMe("{} != azp:{}".format(kwargs["client_id"], self["azp"]), self) _now = time_util.utc_time_sans_frac() @@ -811,9 +789,7 @@ class StateFullMessage(Message): class RefreshSessionRequest(StateFullMessage): c_param = StateFullMessage.c_param.copy() - c_param.update( - {"id_token": SINGLE_REQUIRED_STRING, "redirect_url": SINGLE_REQUIRED_STRING} - ) + c_param.update({"id_token": SINGLE_REQUIRED_STRING, "redirect_url": SINGLE_REQUIRED_STRING}) def verify(self, **kwargs): super(RefreshSessionRequest, self).verify(**kwargs) @@ -945,10 +921,7 @@ def verify(self, **kwargs): if parts.query or parts.fragment: raise AssertionError() - if ( - any("code" in rt for rt in self["response_types_supported"]) - and "token_endpoint" not in self - ): + if any("code" in rt for rt in self["response_types_supported"]) and "token_endpoint" not in self: raise MissingRequiredAttribute("token_endpoint") return True @@ -1063,9 +1036,7 @@ def verify(self, **kwargs): super().verify(**kwargs) if "nonce" in self: - raise MessageException( - '"nonce" is prohibited from appearing in a LogoutToken.' - ) + raise MessageException('"nonce" is prohibited from appearing in a LogoutToken.') # Check the 'events' JSON _keys = list(self["events"].keys()) @@ -1181,9 +1152,7 @@ class FrontChannelLogoutRequest(Message): def factory(msgtype): - warnings.warn( - "`factory` is deprecated. Use `OIDCMessageFactory` instead.", DeprecationWarning - ) + warnings.warn("`factory` is deprecated. Use `OIDCMessageFactory` instead.", DeprecationWarning) for _, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and issubclass(obj, Message): try: @@ -1210,7 +1179,5 @@ class OIDCMessageFactory(MessageFactory): endsession_endpoint = MessageTuple(EndSessionRequest, EndSessionResponse) checkid_endpoint = MessageTuple(CheckIDRequest, IdToken) checksession_endpoint = MessageTuple(CheckSessionRequest, IdToken) - refreshsession_endpoint = MessageTuple( - RefreshSessionRequest, RefreshSessionResponse - ) + refreshsession_endpoint = MessageTuple(RefreshSessionRequest, RefreshSessionResponse) discovery_endpoint = MessageTuple(DiscoveryRequest, DiscoveryResponse) diff --git a/src/oic/oic/provider.py b/src/oic/oic/provider.py index f6a557eb8..8856da070 100644 --- a/src/oic/oic/provider.py +++ b/src/oic/oic/provider.py @@ -289,9 +289,7 @@ def __init__( settings=self.settings, ) # Should be a OIC Server not an OAuth2 server - self.server = Server( - keyjar=keyjar, message_factory=message_factory, settings=self.settings - ) + self.server = Server(keyjar=keyjar, message_factory=message_factory, settings=self.settings) # Same keyjar self.keyjar: KeyJar = self.server.keyjar @@ -618,9 +616,7 @@ def filter_request(self, req): raise InvalidRequest("Contains unsupported response mode") if "response_type" in req: - if not self.match_sp_sep( - [" ".join(req["response_type"])], _cap["response_types_supported"] - ): + if not self.match_sp_sep([" ".join(req["response_type"])], _cap["response_types_supported"]): raise InvalidRequest("Contains unsupported response type") if before != req.to_dict(): @@ -667,9 +663,7 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs): kwargs["req_user"] = req_user - authnres = self.do_auth( - info["areq"], info["redirect_uri"], cinfo, request, cookie, **kwargs - ) + authnres = self.do_auth(info["areq"], info["redirect_uri"], cinfo, request, cookie, **kwargs) if isinstance(authnres, Response): return authnres @@ -691,12 +685,8 @@ def authz_part2(self, user, areq, sid, **kwargs): salt = rndstr() authn_event = self.sdb.get_authentication_event(sid) # use the last session state = str(authn_event.authn_time) - aresp["session_state"] = self._compute_session_state( - state, salt, areq["client_id"], redirect_uri - ) - headers.append( - self.write_session_cookie(state, http_only=False, same_site="None") - ) + aresp["session_state"] = self._compute_session_state(state, salt, areq["client_id"], redirect_uri) + headers.append(self.write_session_cookie(state, http_only=False, same_site="None")) # as per the mix-up draft don't add iss and client_id if they are # already in the id_token. @@ -742,9 +732,7 @@ def recuperate_keys(self, cid: str, client_info: Dict[str, str]) -> None: self.keyjar.issuer_keys[cid] = [] # Add client secret as a symmetric key - self.keyjar.add_symmetric( - cid, client_info["client_secret"], usage=["enc", "sig"] - ) + self.keyjar.add_symmetric(cid, client_info["client_secret"], usage=["enc", "sig"]) # Try to renew from jwks or jwks_uri if client_info.get("jwks_uri") is not None: self.keyjar.add(cid, client_info["jwks_uri"]) @@ -789,9 +777,7 @@ def encrypt(self, payload, client_info, cid, val_type="id_token", cty=""): _jwe = JWE(payload, **kwargs) return _jwe.encrypt(keys, context="public") - def sign_encrypt_id_token( - self, sinfo, client_info, areq, code=None, access_token=None, user_info=None - ): + def sign_encrypt_id_token(self, sinfo, client_info, areq, code=None, access_token=None, user_info=None): """ Sign and or encrypt a IDToken. @@ -827,9 +813,7 @@ def sign_encrypt_id_token( # Then encrypt if "id_token_encrypted_response_alg" in client_info: - id_token = self.encrypt( - id_token, client_info, areq["client_id"], "id_token", "JWT" - ) + id_token = self.encrypt(id_token, client_info, areq["client_id"], "id_token", "JWT") return id_token @@ -882,14 +866,10 @@ def code_grant_type(self, areq): if "openid" in _info["scope"]: userinfo = self.userinfo_in_id_token_claims(_info) try: - _idtoken = self.sign_encrypt_id_token( - _info, client_info, areq, user_info=userinfo - ) + _idtoken = self.sign_encrypt_id_token(_info, client_info, areq, user_info=userinfo) except (JWEException, NoSuitableSigningKeys) as err: logger.warning("%s", err) - return error_response( - "invalid_request", descr="Could not sign/encrypt id_token" - ) + return error_response("invalid_request", descr="Could not sign/encrypt id_token") _sdb.update_by_token(_access_code, "id_token", _idtoken) @@ -903,9 +883,7 @@ def code_grant_type(self, areq): logger.info("access_token_response: %s" % sanitize(atr.to_dict())) - return Response( - atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS - ) + return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) def refresh_token_grant_type(self, areq): """ @@ -930,14 +908,10 @@ def refresh_token_grant_type(self, areq): if "openid" in _info["scope"] and "authn_event" in _info: userinfo = self.userinfo_in_id_token_claims(_info) try: - _idtoken = self.sign_encrypt_id_token( - _info, client_info, areq, user_info=userinfo - ) + _idtoken = self.sign_encrypt_id_token(_info, client_info, areq, user_info=userinfo) except (JWEException, NoSuitableSigningKeys) as err: logger.warning("%s", err) - return error_response( - "invalid_request", descr="Could not sign/encrypt id_token" - ) + return error_response("invalid_request", descr="Could not sign/encrypt id_token") sid = _sdb.access_token.get_key(_info["access_token"]) _sdb.update(sid, "id_token", _idtoken) @@ -949,9 +923,7 @@ def refresh_token_grant_type(self, areq): logger.info("access_token_response: %s" % sanitize(atr.to_dict())) - return Response( - atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS - ) + return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) def client_credentials_grant_type(self, areq): """ @@ -1038,9 +1010,7 @@ def signed_userinfo(self, client_info, userinfo, session): key: List[KEYS] = [] else: if algo.startswith("HS"): - key = self.keyjar.get_signing_key( - alg2keytype(algo), client_info["client_id"], alg=algo - ) + key = self.keyjar.get_signing_key(alg2keytype(algo), client_info["client_id"], alg=algo) else: # Use my key for signing key = self.keyjar.get_signing_key(alg2keytype(algo), "", alg=algo) @@ -1050,9 +1020,7 @@ def signed_userinfo(self, client_info, userinfo, session): jinfo = userinfo.to_jwt(key, algo) if "userinfo_encrypted_response_alg" in client_info: # encrypt with clients public key - jinfo = self.encrypt( - jinfo, client_info, session["client_id"], "userinfo", "JWT" - ) + jinfo = self.encrypt(jinfo, client_info, session["client_id"], "userinfo", "JWT") return jinfo def userinfo_endpoint(self, request="", **kwargs): @@ -1097,9 +1065,7 @@ def _do_user_info(self, token, **kwargs): try: typ, key = _sdb.access_token.type_and_key(token) except Exception: - return error_response( - "invalid_token", descr="Invalid Token", status_code=401 - ) + return error_response("invalid_token", descr="Invalid Token", status_code=401) _log_debug("access_token type: '%s'" % (typ,)) @@ -1108,14 +1074,10 @@ def _do_user_info(self, token, **kwargs): raise FailedAuthentication("Wrong type of token") if _sdb.access_token.is_expired(token): - return error_response( - "invalid_token", descr="Token is expired", status_code=401 - ) + return error_response("invalid_token", descr="Token is expired", status_code=401) if _sdb.is_revoked(key): - return error_response( - "invalid_token", descr="Token is revoked", status_code=401 - ) + return error_response("invalid_token", descr="Token is revoked", status_code=401) session = _sdb[key] # Scope can translate to userinfo_claims @@ -1133,9 +1095,7 @@ def _do_user_info(self, token, **kwargs): content_type = "application/jwt" elif "userinfo_encrypted_response_alg" in _cinfo: jinfo = info.to_json() - jinfo = self.encrypt( - jinfo, _cinfo, session["client_id"], "userinfo", "" - ) + jinfo = self.encrypt(jinfo, _cinfo, session["client_id"], "userinfo", "") content_type = "application/jwt" else: jinfo = info.to_json() @@ -1191,9 +1151,7 @@ def match_client_request(self, request): if request[_pref] not in self.capabilities[_prov]: raise CapabilitiesMisMatch(_pref) else: - if not set(request[_pref]).issubset( - set(self.capabilities[_prov]) - ): + if not set(request[_pref]).issubset(set(self.capabilities[_prov])): raise CapabilitiesMisMatch(_pref) def do_client_registration(self, request, client_id, ignore=None): @@ -1226,12 +1184,8 @@ def do_client_registration(self, request, client_id, ignore=None): ruri = self.verify_redirect_uris(request) _cinfo["redirect_uris"] = ruri except InvalidRedirectURIError as e: - error = ClientRegistrationErrorResponse( - error="invalid_redirect_uri", error_description=str(e) - ) - return Response( - error.to_json(), content="application/json", status_code=400 - ) + error = ClientRegistrationErrorResponse(error="invalid_redirect_uri", error_description=str(e)) + return Response(error.to_json(), content="application/json", status_code=400) if "sector_identifier_uri" in request: try: @@ -1292,9 +1246,7 @@ def do_client_registration(self, request, client_id, ignore=None): error = ClientRegistrationErrorResponse( error="invalid_configuration_parameter", error_description="%s" % err ) - return Response( - error.to_json(), content="application/json", status="400 Bad Request" - ) + return Response(error.to_json(), content="application/json", status="400 Bad Request") return _cinfo @@ -1330,9 +1282,7 @@ def verify_redirect_uris(registration_request): p.scheme, p.hostname, ) - raise InvalidRedirectURIError( - "Redirect_uri must use custom scheme or http and localhost" - ) + raise InvalidRedirectURIError("Redirect_uri must use custom scheme or http and localhost") elif must_https and p.scheme != "https": raise InvalidRedirectURIError("None https redirect_uri not allowed") elif p.fragment: @@ -1353,9 +1303,7 @@ def _verify_post_logout_uri(self, request): for uri in request["post_logout_redirect_uris"]: part = urlparse(uri) if part.fragment: - raise InvalidPostLogoutUri( - "post_logout_redirect_uris contains fragment" - ) + raise InvalidPostLogoutUri("post_logout_redirect_uris contains fragment") query = part.query if part.query else None base = part._replace(query="").geturl() if query: @@ -1387,17 +1335,13 @@ def _verify_sector_identifier(self, request): try: si_redirects = json.loads(res.text) except ValueError: - raise InvalidSectorIdentifier( - "Error deserializing sector_identifier_uri content" - ) + raise InvalidSectorIdentifier("Error deserializing sector_identifier_uri content") if "redirect_uris" in request: logger.debug("redirect_uris: %s", request["redirect_uris"]) for uri in request["redirect_uris"]: if uri not in si_redirects: - raise InvalidSectorIdentifier( - "redirect_uri missing from sector_identifiers" - ) + raise InvalidSectorIdentifier("redirect_uri missing from sector_identifiers") return si_redirects, si_url @@ -1410,9 +1354,7 @@ def comb_uri(args): val = [] for base, query_dict in args[param]: if query_dict: - query_string = urlencode( - [(key, v) for key in query_dict for v in query_dict[key]] - ) + query_string = urlencode([(key, v) for key in query_dict for v in query_dict[key]]) val.append("%s?%s" % (base, query_string)) else: val.append(base) @@ -1422,9 +1364,7 @@ def comb_uri(args): def create_registration(self, authn=None, request=None, **kwargs): logger.debug("@registration_endpoint: <<%s>>" % sanitize(request)) - request_cls = self.server.message_factory.get_request_type( - "registration_endpoint" - ) + request_cls = self.server.message_factory.get_request_type("registration_endpoint") try: request = request_cls().deserialize(request, "json") except MessageException: @@ -1458,17 +1398,13 @@ def client_registration_setup(self, request): if "type" not in request: return error_response("invalid_type", descr="%s" % err) else: - return error_response( - "invalid_configuration_parameter", descr="%s" % err - ) + return error_response("invalid_configuration_parameter", descr="%s" % err) request.rm_blanks() try: self.match_client_request(request) except CapabilitiesMisMatch as err: - return error_response( - "invalid_request", descr="Don't support proposed %s" % err - ) + return error_response("invalid_request", descr="Don't support proposed %s" % err) # create new id och secret client_id = rndstr(12) @@ -1502,9 +1438,7 @@ def client_registration_setup(self, request): if isinstance(_cinfo, Response): return _cinfo - response_cls = self.server.message_factory.get_response_type( - "registration_endpoint" - ) + response_cls = self.server.message_factory.get_response_type("registration_endpoint") args = dict([(k, v) for k, v in _cinfo.items() if k in response_cls.c_param]) self.comb_uri(args) @@ -1571,16 +1505,8 @@ def read_registration(self, authn, request, **kwargs): return Unauthorized() logger.debug("Client '%s' reads client info" % client_id) - response_cls = self.server.message_factory.get_response_type( - "registration_endpoint" - ) - args = dict( - [ - (k, v) - for k, v in self.cdb[client_id].items() - if k in response_cls.c_param - ] - ) + response_cls = self.server.message_factory.get_response_type("registration_endpoint") + args = dict([(k, v) for k, v in self.cdb[client_id].items() if k in response_cls.c_param]) self.comb_uri(args) response = response_cls(**args) @@ -1677,9 +1603,9 @@ def discovery_endpoint(self, request, handle=None, **kwargs): _log_debug("@discovery_endpoint") - request = self.server.message_factory.get_request_type( - "discovery_endpoint" - )().deserialize(request, "urlencoded") + request = self.server.message_factory.get_request_type("discovery_endpoint")().deserialize( + request, "urlencoded" + ) _log_debug("discovery_request:%s" % (sanitize(request.to_dict()),)) if request["service"] != SWD_ISSUER: @@ -1687,9 +1613,7 @@ def discovery_endpoint(self, request, handle=None, **kwargs): # verify that the principal is one of mine - _response = self.server.message_factory.get_response_type("discovery_endpoint")( - locations=[self.baseurl] - ) + _response = self.server.message_factory.get_response_type("discovery_endpoint")(locations=[self.baseurl]) _log_debug("discovery_response:%s" % (sanitize(_response.to_dict()),)) @@ -1699,9 +1623,7 @@ def discovery_endpoint(self, request, handle=None, **kwargs): cookie = self.cookie_func(key, self.cookie_name, "disc", self.sso_ttl) headers.append(cookie) - return Response( - _response.to_json(), content="application/json", headers=headers - ) + return Response(_response.to_json(), content="application/json", headers=headers) def aresp_check(self, aresp, areq): # Use of the nonce is REQUIRED for all requests where an ID Token is @@ -1718,16 +1640,12 @@ def response_mode(self, areq, fragment_enc, **kwargs): "action": kwargs["redirect_uri"], "inputs": kwargs["aresp"].to_dict(), } - return Response( - self.template_renderer("form_post", context), headers=kwargs["headers"] - ) + return Response(self.template_renderer("form_post", context), headers=kwargs["headers"]) return None def create_authn_response(self, areq, sid): # create the response - aresp = self.server.message_factory.get_response_type( - "authorization_endpoint" - )() + aresp = self.server.message_factory.get_response_type("authorization_endpoint")() try: aresp["state"] = areq["state"] except KeyError: @@ -1794,14 +1712,10 @@ def create_authn_response(self, areq, sid): # or 'code id_token' try: - id_token = self.sign_encrypt_id_token( - _sinfo, client_info, areq, user_info=user_info, **hargs - ) + id_token = self.sign_encrypt_id_token(_sinfo, client_info, areq, user_info=user_info, **hargs) except (JWEException, NoSuitableSigningKeys) as err: logger.warning("%s", err) - return error_response( - "invalid_request", descr="Could not sign/encrypt id_token" - ) + return error_response("invalid_request", descr="Could not sign/encrypt id_token") aresp["id_token"] = id_token _sinfo["id_token"] = id_token @@ -1913,9 +1827,7 @@ def get_by_sub_and_(self, sub: str, key: str, val: Any) -> Optional[str]: # Below are LOGOUT related methods - def verify_post_logout_redirect_uri( - self, esreq: Message, client_id: str - ) -> Optional[str]: + def verify_post_logout_redirect_uri(self, esreq: Message, client_id: str) -> Optional[str]: """ Verify a post logout URI. @@ -1978,9 +1890,7 @@ def let_user_verify_logout( "redirect": redirect, "action": "/" + EndSessionEndpoint("").etype, } - return Response( - self.template_renderer("verify_logout", context), headers=headers - ) + return Response(self.template_renderer("verify_logout", context), headers=headers) def _get_uid_from_cookie( self, cookie: Optional[Union[str, SimpleCookie]] @@ -2005,9 +1915,7 @@ def _get_uid_from_cookie( return cookie_dealer, client_id, uid - def do_back_channel_logout( - self, cinfo: dict, sub: str, sid: str - ) -> Optional[Tuple[str, str]]: + def do_back_channel_logout(self, cinfo: dict, sub: str, sid: str) -> Optional[Tuple[str, str]]: """ Prepare information to be used to do a back-channel logout. @@ -2052,9 +1960,7 @@ def clean_sessions(self, usids: List[str]): for sid in usids: del _sdb[sid] - def logout_info_for_all_clients( - self, uid: Optional[str] = "", sid: Optional[str] = "" - ) -> Dict: + def logout_info_for_all_clients(self, uid: Optional[str] = "", sid: Optional[str] = "") -> Dict: """ Collect information necessary to logout one user from all clients he/she has been using. @@ -2089,9 +1995,7 @@ def logout_info_for_all_clients( bc_logouts[_cid] = self.do_back_channel_logout(_cdb[_cid], _sub, _csid) if "frontchannel_logout_uri" in _cdb[_cid]: # Construct an IFrame - fc_iframes[_cid] = self.do_front_channel_logout_iframe( - _cdb[_cid], _iss, _csid - ) + fc_iframes[_cid] = self.do_front_channel_logout_iframe(_cdb[_cid], _iss, _csid) return {"back_channel": bc_logouts, "front_channel": fc_iframes} @@ -2113,15 +2017,11 @@ def logout_info_for_one_client(self, session_id: str, client_id: str) -> Dict: if "backchannel_logout_uri" in self.cdb[client_id]: _subject_id = self.sdb[session_id]["sub"] logout_spec["back_channel"] = { - client_id: self.do_back_channel_logout( - self.cdb[client_id], _subject_id, session_id - ) + client_id: self.do_back_channel_logout(self.cdb[client_id], _subject_id, session_id) } elif "frontchannel_logout_uri" in self.cdb[client_id]: # Construct an IFrame - _iframe = self.do_front_channel_logout_iframe( - self.cdb[client_id], self.name, session_id - ) + _iframe = self.do_front_channel_logout_iframe(self.cdb[client_id], self.name, session_id) logout_spec["front_channel"] = {client_id: _iframe} return logout_spec @@ -2161,9 +2061,7 @@ def end_session_endpoint( sid = "" if "id_token_hint" in esr: - id_token_hint = IdToken().from_jwt( - esr["id_token_hint"], keyjar=self.keyjar, verify=True - ) + id_token_hint = IdToken().from_jwt(esr["id_token_hint"], keyjar=self.keyjar, verify=True) far_away = 86400 * 30 # 30 days if client_id: @@ -2172,9 +2070,7 @@ def end_session_endpoint( args = {} try: - id_token_hint.verify( - iss=self.baseurl, skew=far_away, nonce_storage_time=far_away, **args - ) + id_token_hint.verify(iss=self.baseurl, skew=far_away, nonce_storage_time=far_away, **args) except (VerificationError, NotForMe) as err: logger.warning("Verification error on id_token_hint: %s", err) return error_response("invalid_request", "Bad Id Token hint") @@ -2203,9 +2099,7 @@ def end_session_endpoint( break if not matching_client_id: - return error_response( - "invalid_request", "Could not find a matching client ID" - ) + return error_response("invalid_request", "Could not find a matching client ID") if not client_id: return error_response("invalid_request", "Could not find client ID") @@ -2223,18 +2117,14 @@ def end_session_endpoint( except KeyError: if self.post_logout_page is None: logger.warning("No post logout page configured for %s", client_id) - return error_response( - "server_error", "Have no post logout page configured" - ) + return error_response("server_error", "Have no post logout page configured") else: redirect_uri = self.post_logout_page else: if len(_ruri) == 1: _base, _query = _ruri[0] if _query: - query_string = urlencode( - [(key, v) for key in _query for v in _query[key]] - ) + query_string = urlencode([(key, v) for key in _query for v in _query[key]]) redirect_uri = "%s?%s" % (_base, query_string) else: redirect_uri = _base @@ -2258,9 +2148,7 @@ def end_session_endpoint( self.events.store("object args", "{}".format(payload)) # From me to me - _jws = JWT( - self.keyjar, iss=self.name, lifetime=86400, sign_alg=self.signing_alg - ) + _jws = JWT(self.keyjar, iss=self.name, lifetime=86400, sign_alg=self.signing_alg) sjwt = _jws.pack(aud=[self.name], **payload) location = "{}?{}".format(self.logout_verify_url, urlencode({"sjwt": sjwt})) @@ -2292,9 +2180,7 @@ def do_verified_logout( # Find all the session IDs this user has gotten sids = session_get(self.sdb, "uid", uid) else: - logout_spec = self.logout_info_for_one_client( - session_id=sid, client_id=client_id - ) + logout_spec = self.logout_info_for_one_client(session_id=sid, client_id=client_id) sids = [sid] if self.events: @@ -2302,12 +2188,8 @@ def do_verified_logout( if not logout_spec["back_channel"] and not logout_spec["front_channel"]: # kill cookies - kaka1 = self.write_session_cookie( - "removed", http_only=False, same_site="None" - ) - kaka2 = self.cookie_func( - "", typ="sso", cookie_name=self.sso_cookie_name, kill=True - ) + kaka1 = self.write_session_cookie("removed", http_only=False, same_site="None") + kaka2 = self.cookie_func("", typ="sso", cookie_name=self.sso_cookie_name, kill=True) return {"cookie": [kaka1, kaka2]} # take care of Back channel logout first @@ -2349,9 +2231,7 @@ def do_verified_logout( # kill cookies kaka1 = self.write_session_cookie("removed", http_only=False, same_site="None") - kaka2 = self.cookie_func( - "", typ="sso", cookie_name=self.sso_cookie_name, kill=True - ) + kaka2 = self.cookie_func("", typ="sso", cookie_name=self.sso_cookie_name, kill=True) res = {"cookie": [kaka1, kaka2]} if logout_spec["front_channel"]: @@ -2365,9 +2245,7 @@ def do_verified_logout( return res @staticmethod - def do_front_channel_logout_iframe( - client_info: Dict, issuer: str, session_id: str - ) -> Optional[str]: + def do_front_channel_logout_iframe(client_info: Dict, issuer: str, session_id: str) -> Optional[str]: """ Construct a front channel logout IFrame. @@ -2396,9 +2274,7 @@ def do_front_channel_logout_iframe( _np = p._replace(query="") frontchannel_logout_uri = _np.geturl() - _iframe = '