diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c2e09d..987297c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # CHANGELOG +## Next release + +- Send close connection once the JWT token expires (if channel is open with a token using the `exp` claim). + ## 0.6.1.0 - Add capability to unset `PGWS_ROOT_PATH` to disable static file serving. diff --git a/app/Main.hs b/app/Main.hs index 3caf39a..636b052 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -19,6 +19,11 @@ import qualified Hasql.Decoders as HD import qualified Hasql.Encoders as HE import qualified Hasql.Pool as P import Network.Wai.Application.Static +import Data.Time.Clock (UTCTime, getCurrentTime) +import Control.AutoUpdate ( defaultUpdateSettings + , mkAutoUpdate + , updateAction + ) import Network.Wai (Application, responseLBS) import Network.HTTP.Types (status200) @@ -61,11 +66,14 @@ main = do pool <- P.acquire (configPool conf, 10, pgSettings) multi <- newHasqlBroadcaster listenChannel pgSettings + getTime <- mkGetTime runSettings appSettings $ - postgresWsMiddleware listenChannel (configJwtSecret conf) pool multi $ + postgresWsMiddleware getTime listenChannel (configJwtSecret conf) pool multi $ logStdout $ maybe dummyApp staticApp' (configPath conf) where + mkGetTime :: IO (IO UTCTime) + mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime} staticApp' :: Text -> Application staticApp' = staticApp . defaultFileServerSettings . toS dummyApp :: Application diff --git a/client-example/screen.css b/client-example/screen.css index 373ba74..3c25296 100644 --- a/client-example/screen.css +++ b/client-example/screen.css @@ -21,7 +21,7 @@ h2 { } div#main { - width: 600px; + width: 50%; margin: 0px auto 0px auto; padding: 0px; background-color: #fff; diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 212e144..ecf9541 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -44,6 +44,7 @@ library , stringsearch >= 0.3.6.6 && < 0.4 , time >= 1.8.0.2 && < 1.9 , contravariant >= 1.5.2 && < 1.6 + , alarmclock >= 0.7.0.2 && < 0.8 default-language: Haskell2010 default-extensions: OverloadedStrings, NoImplicitPrelude, LambdaCase @@ -67,8 +68,9 @@ executable postgres-websockets , wai >= 3.2 && < 4 , wai-extra >= 3.0.29 && < 3.1 , wai-app-static >= 3.1.7.1 && < 3.2 - , http-types + , http-types >= 0.9 , envparse >= 0.4.1 + , auto-update >= 0.1.6 && < 0.2 default-language: Haskell2010 default-extensions: OverloadedStrings, NoImplicitPrelude, QuasiQuotes @@ -82,18 +84,18 @@ test-suite postgres-websockets-test build-depends: base , protolude >= 0.2.3 , postgres-websockets - , containers - , hspec - , hspec-wai - , hspec-wai-json - , aeson - , hasql - , hasql-pool + , hspec >= 2.7.1 && < 2.8 + , hspec-wai >= 0.9.2 && < 0.10 + , hspec-wai-json >= 0.9.2 && < 0.10 + , aeson >= 1.4.6.0 && < 1.5 + , hasql >= 0.19 + , hasql-pool >= 0.4 , hasql-notifications >= 0.1.0.0 && < 0.2 - , http-types + , http-types >= 0.9 + , time >= 1.8.0.2 && < 1.9 , unordered-containers >= 0.2 - , wai-extra - , stm + , wai-extra >= 3.0.29 && < 3.1 + , stm >= 2.5.0.0 && < 2.6 ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N default-language: Haskell2010 default-extensions: OverloadedStrings, NoImplicitPrelude diff --git a/src/PostgresWebsockets.hs b/src/PostgresWebsockets.hs index 649393d..9ee2f2a 100644 --- a/src/PostgresWebsockets.hs +++ b/src/PostgresWebsockets.hs @@ -22,7 +22,9 @@ import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy as BL import qualified Data.HashMap.Strict as M import qualified Data.Text.Encoding.Error as T -import Data.Time.Clock.POSIX (getPOSIXTime) +import Data.Time.Clock (UTCTime) +import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime) +import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm) import PostgresWebsockets.Broadcast (Multiplexer, onMessage) import qualified PostgresWebsockets.Broadcast as B import PostgresWebsockets.Claims @@ -38,19 +40,21 @@ data Message = Message instance A.ToJSON Message -- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware. -postgresWsMiddleware :: Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application +postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application postgresWsMiddleware = WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp where - compose = (.) . (.) . (.) . (.) + compose = (.) . (.) . (.) . (.) . (.) -- private functions +jwtExpirationStatusCode :: Word16 +jwtExpirationStatusCode = 3001 -- when the websocket is closed a ConnectionClosed Exception is triggered -- this kills all children and frees resources for us -wsApp :: Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp -wsApp dbChannel secret pool multi pendingConn = - validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions +wsApp :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp +wsApp getTime dbChannel secret pool multi pendingConn = + getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions where hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString) hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString) @@ -68,17 +72,23 @@ wsApp dbChannel secret pool multi pendingConn = -- We should accept only after verifying JWT conn <- WS.acceptRequest pendingConn -- Fork a pinging thread to ensure browser connections stay alive - WS.forkPingThread conn 30 + WS.withPingThread conn 30 (pure ()) $ do + case M.lookup "exp" validClaims of + Just (A.Number expClaim) -> do + connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString)) + setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim) + Just _ -> pure () + Nothing -> pure () - when (hasRead mode) $ - onMessage multi ch $ WS.sendTextData conn . B.payload + when (hasRead mode) $ + onMessage multi ch $ WS.sendTextData conn . B.payload - when (hasWrite mode) $ - let sendNotifications = void . H.notifyPool pool dbChannel . toS - in notifySession validClaims (toS ch) conn sendNotifications + when (hasWrite mode) $ + let sendNotifications = void . H.notifyPool pool dbChannel . toS + in notifySession validClaims (toS ch) conn getTime sendNotifications - waitForever <- newEmptyMVar - void $ takeMVar waitForever + waitForever <- newEmptyMVar + void $ takeMVar waitForever -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source @@ -86,9 +96,10 @@ wsApp dbChannel secret pool multi pendingConn = notifySession :: A.Object -> Text -> WS.Connection + -> IO UTCTime -> (ByteString -> IO ()) -> IO () -notifySession claimsToSend ch wsCon send = +notifySession claimsToSend ch wsCon getTime send = withAsync (forever relayData) wait where relayData = jsonMsgWithTime >>= send @@ -102,5 +113,5 @@ notifySession claimsToSend ch wsCon send = claimsWithChannel = M.insert "channel" (A.String ch) claimsToSend claimsWithTime :: IO (M.HashMap Text A.Value) claimsWithTime = do - time <- getPOSIXTime - return $ M.insert "message_delivered_at" (A.Number $ fromRational $ toRational time) claimsWithChannel + time <- utcTimeToPOSIXSeconds <$> getTime + return $ M.insert "message_delivered_at" (A.Number $ realToFrac time) claimsWithChannel diff --git a/src/PostgresWebsockets/Claims.hs b/src/PostgresWebsockets/Claims.hs index dce6d91..d889dd8 100644 --- a/src/PostgresWebsockets/Claims.hs +++ b/src/PostgresWebsockets/Claims.hs @@ -10,23 +10,27 @@ module PostgresWebsockets.Claims import Control.Lens import qualified Crypto.JOSE.Types as JOSE.Types import Crypto.JWT -import Data.Aeson (Value (..), decode, toJSON) import qualified Data.HashMap.Strict as M import Protolude +import Data.Time.Clock (UTCTime) +import Data.String (String, fromString) +import qualified Data.Aeson as JSON +import qualified Data.Aeson.Types as JSON -type Claims = M.HashMap Text Value +type Claims = M.HashMap Text JSON.Value type ConnectionInfo = (ByteString, ByteString, Claims) {-| Given a secret, a token and a timestamp it validates the claims and returns either an error message or a triple containing channel, mode and claims hashmap. -} -validateClaims :: Maybe ByteString -> ByteString -> LByteString -> IO (Either Text ConnectionInfo) -validateClaims requestChannel secret jwtToken = +validateClaims :: Maybe ByteString -> ByteString -> LByteString -> UTCTime -> IO (Either Text ConnectionInfo) +validateClaims requestChannel secret jwtToken time = runExceptT $ do - cl <- liftIO $ jwtClaims (parseJWK secret) jwtToken + cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken cl' <- case cl of JWTClaims c -> pure c + JWTInvalid JWTExpired -> throwError "Token expired" _ -> throwError "Error" channel <- claimAsJSON requestChannel "channel" cl' mode <- claimAsJSON Nothing "mode" cl' @@ -35,7 +39,7 @@ validateClaims requestChannel secret jwtToken = where claimAsJSON :: Maybe ByteString -> Text -> Claims -> ExceptT Text IO ByteString claimAsJSON defaultVal name cl = case M.lookup name cl of - Just (String s) -> pure $ encodeUtf8 s + Just (JSON.String s) -> pure $ encodeUtf8 s Just _ -> throwError "claim is not string value" Nothing -> nonExistingClaim defaultVal name @@ -53,20 +57,20 @@ validateClaims requestChannel secret jwtToken = -} data JWTAttempt = JWTInvalid JWTError | JWTMissingSecret - | JWTClaims (M.HashMap Text Value) + | JWTClaims (M.HashMap Text JSON.Value) deriving Eq {-| Receives the JWT secret (from config) and a JWT and returns a map of JWT claims. -} -jwtClaims :: JWK -> LByteString -> IO JWTAttempt -jwtClaims _ "" = return $ JWTClaims M.empty -jwtClaims secret payload = do - let validation = defaultJWTValidationSettings (const True) +jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt +jwtClaims _ _ "" = return $ JWTClaims M.empty +jwtClaims time jwk payload = do + let config = defaultJWTValidationSettings (const True) eJwt <- runExceptT $ do jwt <- decodeCompact payload - verifyClaims validation secret jwt + verifyClaimsAt config jwk time jwt return $ case eJwt of Left e -> JWTInvalid e Right jwt -> JWTClaims . claims2map $ jwt @@ -75,10 +79,10 @@ jwtClaims secret payload = do Internal helper used to turn JWT ClaimSet into something easier to work with -} -claims2map :: ClaimsSet -> M.HashMap Text Value -claims2map = val2map . toJSON +claims2map :: ClaimsSet -> M.HashMap Text JSON.Value +claims2map = val2map . JSON.toJSON where - val2map (Object o) = o + val2map (JSON.Object o) = o val2map _ = M.empty {-| @@ -96,4 +100,4 @@ hs256jwk key = parseJWK :: ByteString -> JWK parseJWK str = - fromMaybe (hs256jwk str) (decode (toS str) :: Maybe JWK) + fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK) diff --git a/test/ClaimsSpec.hs b/test/ClaimsSpec.hs index 2fbb943..3d019c4 100644 --- a/test/ClaimsSpec.hs +++ b/test/ClaimsSpec.hs @@ -5,13 +5,19 @@ import Protolude import qualified Data.HashMap.Strict as M import Test.Hspec import Data.Aeson (Value (..) ) - +import Data.Time.Clock import PostgresWebsockets.Claims spec :: Spec spec = - describe "validate claims" - $ it "should succeed using a matching token" - $ validateClaims Nothing "reallyreallyreallyreallyverysafe" - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" + describe "validate claims" $ do + it "should invalidate an expired token" $ do + time <- getCurrentTime + validateClaims Nothing "reallyreallyreallyreallyverysafe" + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" time + `shouldReturn` Left "Token expired" + it "should succeed using a matching token" $ do + time <- getCurrentTime + validateClaims Nothing "reallyreallyreallyreallyverysafe" + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" time `shouldReturn` Right ("test", "r", M.fromList[("mode",String "r"),("channel",String "test")])