Skip to content

Commit

Permalink
Merge pull request #56 from diogob/jwt-validate-exp
Browse files Browse the repository at this point in the history
Jwt validate exp
  • Loading branch information
diogob authored May 31, 2020
2 parents 3fa52d4 + 4744615 commit a4c0e55
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 51 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# CHANGELOG

## Next release

- Send close connection once the JWT token expires (if channel is open with a token using the `exp` claim).

## 0.6.1.0

- Add capability to unset `PGWS_ROOT_PATH` to disable static file serving.
10 changes: 9 additions & 1 deletion app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Hasql.Pool as P
import Network.Wai.Application.Static
import Data.Time.Clock (UTCTime, getCurrentTime)
import Control.AutoUpdate ( defaultUpdateSettings
, mkAutoUpdate
, updateAction
)

import Network.Wai (Application, responseLBS)
import Network.HTTP.Types (status200)
Expand Down Expand Up @@ -61,11 +66,14 @@ main = do

pool <- P.acquire (configPool conf, 10, pgSettings)
multi <- newHasqlBroadcaster listenChannel pgSettings
getTime <- mkGetTime

runSettings appSettings $
postgresWsMiddleware listenChannel (configJwtSecret conf) pool multi $
postgresWsMiddleware getTime listenChannel (configJwtSecret conf) pool multi $
logStdout $ maybe dummyApp staticApp' (configPath conf)
where
mkGetTime :: IO (IO UTCTime)
mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime}
staticApp' :: Text -> Application
staticApp' = staticApp . defaultFileServerSettings . toS
dummyApp :: Application
Expand Down
2 changes: 1 addition & 1 deletion client-example/screen.css
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ h2 {
}

div#main {
width: 600px;
width: 50%;
margin: 0px auto 0px auto;
padding: 0px;
background-color: #fff;
Expand Down
24 changes: 13 additions & 11 deletions postgres-websockets.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ library
, stringsearch >= 0.3.6.6 && < 0.4
, time >= 1.8.0.2 && < 1.9
, contravariant >= 1.5.2 && < 1.6
, alarmclock >= 0.7.0.2 && < 0.8
default-language: Haskell2010
default-extensions: OverloadedStrings, NoImplicitPrelude, LambdaCase

Expand All @@ -67,8 +68,9 @@ executable postgres-websockets
, wai >= 3.2 && < 4
, wai-extra >= 3.0.29 && < 3.1
, wai-app-static >= 3.1.7.1 && < 3.2
, http-types
, http-types >= 0.9
, envparse >= 0.4.1
, auto-update >= 0.1.6 && < 0.2
default-language: Haskell2010
default-extensions: OverloadedStrings, NoImplicitPrelude, QuasiQuotes

Expand All @@ -82,18 +84,18 @@ test-suite postgres-websockets-test
build-depends: base
, protolude >= 0.2.3
, postgres-websockets
, containers
, hspec
, hspec-wai
, hspec-wai-json
, aeson
, hasql
, hasql-pool
, hspec >= 2.7.1 && < 2.8
, hspec-wai >= 0.9.2 && < 0.10
, hspec-wai-json >= 0.9.2 && < 0.10
, aeson >= 1.4.6.0 && < 1.5
, hasql >= 0.19
, hasql-pool >= 0.4
, hasql-notifications >= 0.1.0.0 && < 0.2
, http-types
, http-types >= 0.9
, time >= 1.8.0.2 && < 1.9
, unordered-containers >= 0.2
, wai-extra
, stm
, wai-extra >= 3.0.29 && < 3.1
, stm >= 2.5.0.0 && < 2.6
ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N
default-language: Haskell2010
default-extensions: OverloadedStrings, NoImplicitPrelude
Expand Down
45 changes: 28 additions & 17 deletions src/PostgresWebsockets.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as M
import qualified Data.Text.Encoding.Error as T
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Time.Clock (UTCTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
import PostgresWebsockets.Broadcast (Multiplexer, onMessage)
import qualified PostgresWebsockets.Broadcast as B
import PostgresWebsockets.Claims
Expand All @@ -38,19 +40,21 @@ data Message = Message
instance A.ToJSON Message

-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
postgresWsMiddleware :: Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgresWsMiddleware =
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
where
compose = (.) . (.) . (.) . (.)
compose = (.) . (.) . (.) . (.) . (.)

-- private functions
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode = 3001

-- when the websocket is closed a ConnectionClosed Exception is triggered
-- this kills all children and frees resources for us
wsApp :: Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp dbChannel secret pool multi pendingConn =
validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
wsApp :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp getTime dbChannel secret pool multi pendingConn =
getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
where
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
Expand All @@ -68,27 +72,34 @@ wsApp dbChannel secret pool multi pendingConn =
-- We should accept only after verifying JWT
conn <- WS.acceptRequest pendingConn
-- Fork a pinging thread to ensure browser connections stay alive
WS.forkPingThread conn 30
WS.withPingThread conn 30 (pure ()) $ do
case M.lookup "exp" validClaims of
Just (A.Number expClaim) -> do
connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString))
setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim)
Just _ -> pure ()
Nothing -> pure ()

when (hasRead mode) $
onMessage multi ch $ WS.sendTextData conn . B.payload
when (hasRead mode) $
onMessage multi ch $ WS.sendTextData conn . B.payload

when (hasWrite mode) $
let sendNotifications = void . H.notifyPool pool dbChannel . toS
in notifySession validClaims (toS ch) conn sendNotifications
when (hasWrite mode) $
let sendNotifications = void . H.notifyPool pool dbChannel . toS
in notifySession validClaims (toS ch) conn getTime sendNotifications

waitForever <- newEmptyMVar
void $ takeMVar waitForever
waitForever <- newEmptyMVar
void $ takeMVar waitForever

-- Having both channel and claims as parameters seem redundant
-- But it allows the function to ignore the claims structure and the source
-- of the channel, so all claims decoding can be coded in the caller
notifySession :: A.Object
-> Text
-> WS.Connection
-> IO UTCTime
-> (ByteString -> IO ())
-> IO ()
notifySession claimsToSend ch wsCon send =
notifySession claimsToSend ch wsCon getTime send =
withAsync (forever relayData) wait
where
relayData = jsonMsgWithTime >>= send
Expand All @@ -102,5 +113,5 @@ notifySession claimsToSend ch wsCon send =
claimsWithChannel = M.insert "channel" (A.String ch) claimsToSend
claimsWithTime :: IO (M.HashMap Text A.Value)
claimsWithTime = do
time <- getPOSIXTime
return $ M.insert "message_delivered_at" (A.Number $ fromRational $ toRational time) claimsWithChannel
time <- utcTimeToPOSIXSeconds <$> getTime
return $ M.insert "message_delivered_at" (A.Number $ realToFrac time) claimsWithChannel
36 changes: 20 additions & 16 deletions src/PostgresWebsockets/Claims.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@ module PostgresWebsockets.Claims
import Control.Lens
import qualified Crypto.JOSE.Types as JOSE.Types
import Crypto.JWT
import Data.Aeson (Value (..), decode, toJSON)
import qualified Data.HashMap.Strict as M
import Protolude
import Data.Time.Clock (UTCTime)
import Data.String (String, fromString)
import qualified Data.Aeson as JSON
import qualified Data.Aeson.Types as JSON


type Claims = M.HashMap Text Value
type Claims = M.HashMap Text JSON.Value
type ConnectionInfo = (ByteString, ByteString, Claims)

{-| Given a secret, a token and a timestamp it validates the claims and returns
either an error message or a triple containing channel, mode and claims hashmap.
-}
validateClaims :: Maybe ByteString -> ByteString -> LByteString -> IO (Either Text ConnectionInfo)
validateClaims requestChannel secret jwtToken =
validateClaims :: Maybe ByteString -> ByteString -> LByteString -> UTCTime -> IO (Either Text ConnectionInfo)
validateClaims requestChannel secret jwtToken time =
runExceptT $ do
cl <- liftIO $ jwtClaims (parseJWK secret) jwtToken
cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken
cl' <- case cl of
JWTClaims c -> pure c
JWTInvalid JWTExpired -> throwError "Token expired"
_ -> throwError "Error"
channel <- claimAsJSON requestChannel "channel" cl'
mode <- claimAsJSON Nothing "mode" cl'
Expand All @@ -35,7 +39,7 @@ validateClaims requestChannel secret jwtToken =
where
claimAsJSON :: Maybe ByteString -> Text -> Claims -> ExceptT Text IO ByteString
claimAsJSON defaultVal name cl = case M.lookup name cl of
Just (String s) -> pure $ encodeUtf8 s
Just (JSON.String s) -> pure $ encodeUtf8 s
Just _ -> throwError "claim is not string value"
Nothing -> nonExistingClaim defaultVal name

Expand All @@ -53,20 +57,20 @@ validateClaims requestChannel secret jwtToken =
-}
data JWTAttempt = JWTInvalid JWTError
| JWTMissingSecret
| JWTClaims (M.HashMap Text Value)
| JWTClaims (M.HashMap Text JSON.Value)
deriving Eq

{-|
Receives the JWT secret (from config) and a JWT and returns a map
of JWT claims.
-}
jwtClaims :: JWK -> LByteString -> IO JWTAttempt
jwtClaims _ "" = return $ JWTClaims M.empty
jwtClaims secret payload = do
let validation = defaultJWTValidationSettings (const True)
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims _ _ "" = return $ JWTClaims M.empty
jwtClaims time jwk payload = do
let config = defaultJWTValidationSettings (const True)
eJwt <- runExceptT $ do
jwt <- decodeCompact payload
verifyClaims validation secret jwt
verifyClaimsAt config jwk time jwt
return $ case eJwt of
Left e -> JWTInvalid e
Right jwt -> JWTClaims . claims2map $ jwt
Expand All @@ -75,10 +79,10 @@ jwtClaims secret payload = do
Internal helper used to turn JWT ClaimSet into something
easier to work with
-}
claims2map :: ClaimsSet -> M.HashMap Text Value
claims2map = val2map . toJSON
claims2map :: ClaimsSet -> M.HashMap Text JSON.Value
claims2map = val2map . JSON.toJSON
where
val2map (Object o) = o
val2map (JSON.Object o) = o
val2map _ = M.empty

{-|
Expand All @@ -96,4 +100,4 @@ hs256jwk key =

parseJWK :: ByteString -> JWK
parseJWK str =
fromMaybe (hs256jwk str) (decode (toS str) :: Maybe JWK)
fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK)
16 changes: 11 additions & 5 deletions test/ClaimsSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ import Protolude
import qualified Data.HashMap.Strict as M
import Test.Hspec
import Data.Aeson (Value (..) )

import Data.Time.Clock
import PostgresWebsockets.Claims

spec :: Spec
spec =
describe "validate claims"
$ it "should succeed using a matching token"
$ validateClaims Nothing "reallyreallyreallyreallyverysafe"
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58"
describe "validate claims" $ do
it "should invalidate an expired token" $ do
time <- getCurrentTime
validateClaims Nothing "reallyreallyreallyreallyverysafe"
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" time
`shouldReturn` Left "Token expired"
it "should succeed using a matching token" $ do
time <- getCurrentTime
validateClaims Nothing "reallyreallyreallyreallyverysafe"
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" time
`shouldReturn` Right ("test", "r", M.fromList[("mode",String "r"),("channel",String "test")])

0 comments on commit a4c0e55

Please sign in to comment.