-
Notifications
You must be signed in to change notification settings - Fork 1
/
flaskswaggertypes.py
211 lines (156 loc) · 8.55 KB
/
flaskswaggertypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from flask import Flask, request, make_response, Response, json
import marshmallow
from collections import namedtuple
import pkg_resources
from apispec.ext.marshmallow.swagger import fields2jsonschema, schema2jsonschema, schema2parameters, fields2parameters
from werkzeug.routing import parse_rule as werkzeurg_parse_rule
# NOT YET implemented:
# * Tags
# Not supported by design
# * OpenApi v3 (only version 2.0 is supported)
# * Fine grain security (this will suport global security definitions only)
class FlaskSwaggerTypes:
SpecMetadata = namedtuple('SpecMetadata', 'title description basePath version host securityDefinitions')
SpecMetadata.__new__.__defaults__ = (None,) * len(SpecMetadata._fields)
def __init__(self, flask_app, spec_metadata={}):
self.flask_app = flask_app
self.swagger_endpoints = dict()
self.swagger_definition_schemas = []
self.swagger_metadata = self.SpecMetadata(**spec_metadata)
self.flask_app.add_url_rule("/swagger_spec", "getSwaggerSpec" , self.getSwaggerSpec)
self.flask_app.add_url_rule("/swagger_ui", "getSwaggerUi" , self.getSwaggerUi)
def getSwaggerUi(self):
return pkg_resources.resource_string(__name__, 'swagger.html')
def getSwaggerSpec(self):
spec_text = self.generate_swagger_spec()
resp = Response(spec_text)
resp.headers['Access-Control-Allow-Origin'] = '*'
return resp
def _extract_path_schema_from_werkzeug_rule(self, rule, schema_name_prefix):
parsed_rule_gen = werkzeurg_parse_rule(rule)
path_params = [ path_fragment for path_fragment in parsed_rule_gen if path_fragment[0] is not None ]
types_map = {
'default' : marshmallow.fields.String,
'string' : marshmallow.fields.String,
'int' : marshmallow.fields.Integer
}
schema_fields = {}
for path_param in path_params:
#Customize the error message a little bit to make it more obvious, this would raise KeyError anyway
if path_param[0] not in types_map:
raise KeyError("Unknonwn path parameter type: '" + str(path_param[0]) + "'. Flask-swagger-types only suports 'int' and 'string' types in path")
schema_fields[path_param[2]] = types_map[path_param[0]](required=True)
PathSchema = type( schema_name_prefix + 'PathSchema', (marshmallow.Schema,), schema_fields)
return PathSchema
def _append_swagger_endpoint(self, rule, method, function_name, schemas, responses):
swagger_formated_rule = "".join([ path_fragment[2] if path_fragment[0] is None else '{' + path_fragment[2] + '}' for path_fragment in werkzeurg_parse_rule(rule) ])
val = {
'method': method.lower(),
'schemas': schemas,
'responses': responses,
'function_name' : function_name
}
if swagger_formated_rule in self.swagger_endpoints:
self.swagger_endpoints[swagger_formated_rule].append(val)
else:
self.swagger_endpoints[swagger_formated_rule] = [val]
if 'body' in schemas:
self.swagger_definition_schemas.append(schemas['body'])
# Duplicates are fine, this is uniquified later anyway as the class name is going to be used as a dict key
[ self.swagger_definition_schemas.append(response_schema[2]) for response_schema in responses if len(response_schema) > 2 ]
def _generate_swagger_definitions_tree(self):
definitions = {}
for schema in self.swagger_definition_schemas:
definitions[schema.__name__] = schema2parameters(schema)[0]['schema']
return definitions
def _response2swagger_response_node(self, response):
swagger_response_node = {}
swagger_response_node['description'] = response[1]
if len(response) > 2:
swagger_response_node['schema'] = {}
swagger_response_node['schema']['$ref'] = "#/definitions/" + response[2].__name__
return swagger_response_node
def _endpoint_schemas2swagger_parameters(self, endpoint_schemas):
parameters = []
if 'path' in endpoint_schemas:
parameters += schema2parameters(endpoint_schemas['path'], default_in="path")
if 'query' in endpoint_schemas:
parameters += schema2parameters(endpoint_schemas['query'], default_in="query")
if 'header' in endpoint_schemas:
parameters += schema2parameters(endpoint_schemas['header'], default_in="header")
if 'body' in endpoint_schemas:
parameters += [{
'in' : 'body',
'name': 'body',
'required': True,
'schema' : {'$ref': "#/definitions/" + endpoint_schemas['body'].__name__ }
}]
return parameters
def _endpoint2swagger_endpoint(self, endpoint):
swagger_path_node = {
'operationId' : endpoint['function_name'],
'responses' : { response[0] : self._response2swagger_response_node(response) for response in endpoint['responses'] },
'parameters' : self._endpoint_schemas2swagger_parameters(endpoint['schemas'])
}
return swagger_path_node
def _generate_swagger_paths_tree(self):
paths_tree = {}
for path in self.swagger_endpoints:
paths_tree[path] = { endpoint['method'] : self._endpoint2swagger_endpoint(endpoint) for endpoint in self.swagger_endpoints[path] }
return paths_tree
def generate_swagger_spec(self):
info = {
'description' : self.swagger_metadata.description,
'version': self.swagger_metadata.version,
'title': self.swagger_metadata.title
}
swagger_spec_tree ={
'swagger': "2.0",
'info': info,
'host': self.swagger_metadata.host,
'basePath': self.swagger_metadata.basePath,
'paths': self._generate_swagger_paths_tree(),
'definitions': self._generate_swagger_definitions_tree(),
'securityDefinitions' : self.swagger_metadata.securityDefinitions
}
swagger_spec_tree_without_empty_fields = {k: v for k, v in swagger_spec_tree.items() if v}
return json.dumps(swagger_spec_tree_without_empty_fields, indent=4)
def _validate_input_data_with_schemas(self, request, input_schemas):
fsa_data = {}
errors = []
if 'body' in input_schemas:
fsa_data['body'], errors = input_schemas['body']().load(request.get_json())
if errors:
return fsa_data, errors
if 'query' in input_schemas:
# request.args is a multidict, let's use the good old dict interface for simplicity
# we are not intersted in corner cases such as passing multiple values in the same key
# read more here:
# http://werkzeug.pocoo.org/docs/0.12/datastructures/#werkzeug.datastructures.MultiDict
fsa_data['query'], errors = input_schemas['query']().load(request.args)
if errors:
return fsa_data, errors
if 'path' in input_schemas:
# at this point, flask already validate this, so we can proceed with confidence
fsa_data['path'] = input_schemas['path']().load(request.view_args)
if 'header' in input_schemas:
fsa_data['header'], errors = input_schemas['header']().load(request.headers)
return fsa_data, errors
def Fstroute(self,rule, http_verb, input_schemas={}, responses={}, **options):
# we allow only one method per route and there is no default
# this is diferent than flask's philosopy which provides many defaults
if http_verb not in ["GET", "POST", "PUT", "DELETE"]:
exit("invalid method" + str(http_verb))
def decorator(f):
input_schemas['path'] = self._extract_path_schema_from_werkzeug_rule(rule, f.__name__)
self._append_swagger_endpoint(rule, http_verb, f.__name__, input_schemas, responses)
def modified_f(*a, **kw):
validatedData, errors = self._validate_input_data_with_schemas(request, input_schemas)
if errors:
error_message = json.dumps(errors)
return make_response(error_message, 400)
request.fst_data = validatedData
return f(*a,**kw)
self.flask_app.add_url_rule(rule, f.__name__ , modified_f , methods=[http_verb], **options)
return f
return decorator