diff --git a/spec/lucky/rate_limit_spec.cr b/spec/lucky/rate_limit_spec.cr new file mode 100644 index 000000000..60eb72387 --- /dev/null +++ b/spec/lucky/rate_limit_spec.cr @@ -0,0 +1,43 @@ +require "../spec_helper" + +include ContextHelper + +class RateLimitRoutes::Index < TestAction + include Lucky::RateLimit + + get "/rate_limit" do + plain_text "hello" + end + + private def rate_limit : NamedTuple(to: Int32, within: Time::Span) + {to: 1, within: 1.minute} + end +end + +describe Lucky::RateLimit do + describe "RateLimit" do + it "when request count is less than the rate limit" do + headers = HTTP::Headers.new + headers["X_FORWARDED_FOR"] = "127.0.0.1" + request = HTTP::Request.new("GET", "/rate_limit", body: "", headers: headers) + context = build_context(request) + + route = RateLimitRoutes::Index.new(context, params).call + route.context.response.status.should eq(HTTP::Status::OK) + end + + it "when request count is over the rate limit" do + headers = HTTP::Headers.new + headers["X_FORWARDED_FOR"] = "127.0.0.1" + request = HTTP::Request.new("GET", "/rate_limit", body: "", headers: headers) + context = build_context(request) + + 10.times do + RateLimitRoutes::Index.new(context, params).call + end + + route = RateLimitRoutes::Index.new(context, params).call + route.context.response.status.should eq(HTTP::Status::TOO_MANY_REQUESTS) + end + end +end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index df67856c0..545b6e7f8 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -48,4 +48,8 @@ Lucky::ForceSSLHandler.configure do |settings| settings.enabled = true end +LuckyCache.configure do |settings| + settings.storage = LuckyCache::MemoryStore.new +end + Habitat.raise_if_missing_settings! diff --git a/src/lucky/rate_limit.cr b/src/lucky/rate_limit.cr new file mode 100644 index 000000000..0ceb544fb --- /dev/null +++ b/src/lucky/rate_limit.cr @@ -0,0 +1,41 @@ +module Lucky::RateLimit + macro included + before enforce_rate_limit + end + + abstract def rate_limit : NamedTuple(to: Int32, within: Time::Span) + + private def enforce_rate_limit + cache = LuckyCache.settings.storage + count = cache.fetch(rate_limit_key, as: Int32, expires_in: rate_limit["within"]) { 0 } + cache.write(rate_limit_key, expires_in: rate_limit["within"]) { count + 1 } + + if count > rate_limit["to"] + context.response.status = HTTP::Status::TOO_MANY_REQUESTS + context.response.headers["Retry-After"] = rate_limit["within"].to_s + plain_text("Rate limit exceeded") + else + continue + end + end + + private def rate_limit_key : String + klass = self.class.to_s.downcase.gsub("::", ":") + "ratelimit:#{klass}:#{rate_limit_identifier}" + end + + private def rate_limit_identifier : Socket::Address | Nil + request = context.request + + if x_forwarded = request.headers["X_FORWARDED_FOR"]?.try(&.split(',').first?).presence + begin + Socket::IPAddress.new(x_forwarded, 0) + rescue Socket::Error + # if the x_forwarded is not a valid ip address we fallback to request.remote_address + request.remote_address + end + else + request.remote_address + end + end +end