Skip to content

Commit

Permalink
Populate zone in Redis node type and use pod zone for replica selection
Browse files Browse the repository at this point in the history
- Populate  field in Redis  type
- Use  information to decide which node to hit, including replicas
- Ensure zone-aware routing for better read distribution
  • Loading branch information
vijaygupta18 committed Jan 23, 2025
1 parent 5547eb3 commit 68245b5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 30 deletions.
64 changes: 46 additions & 18 deletions src/Database/Redis/Cluster.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ module Database.Redis.Cluster
, requestMasterNodes
, nodes
, createNodePool
, getZoneInfoFromSubnet
) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import qualified Data.IORef as IOR
import Data.Maybe(mapMaybe, fromMaybe)
import Data.Maybe(mapMaybe, fromMaybe, isJust)
import Data.List(sortBy, find)
import Data.List.Extra (nubOrd)
import Data.Map(fromListWith, assocs)
import Data.Function(on)
import Control.Exception(Exception, SomeException, throwIO, BlockedIndefinitelyOnMVar(..), catches, Handler(..), try, fromException)
import Data.Pool(Pool, createPool, withResource, destroyAllResources)
import System.Random (randomRIO)
import Control.Concurrent.MVar(MVar, newMVar, readMVar, modifyMVar)
import Control.Monad(zipWithM, replicateM)
import Database.Redis.Cluster.HashSlot(HashSlot, keyToSlot)
Expand Down Expand Up @@ -103,10 +105,11 @@ data NodeRole = Master | Slave deriving (Show, Eq, Ord)
type Host = String
type Port = Int
type NodeID = B.ByteString
type Zone = Maybe String
-- Represents a single node, note that this type does not include the
-- connection to the node because the shard map can be shared amongst multiple
-- connections
data Node = Node NodeID NodeRole Host Port deriving (Show, Eq, Ord)
data Node = Node NodeID NodeRole Host Port Zone deriving (Show, Eq, Ord)

type MasterNode = Node
type SlaveNode = Node
Expand Down Expand Up @@ -142,6 +145,9 @@ instance Exception NoNodeException
data TimeoutException = TimeoutException String deriving (Show, Typeable)
instance Exception TimeoutException

getZoneInfoFromSubnet :: HM.HashMap String String -> Host -> Zone
getZoneInfoFromSubnet _ _ = Nothing

createClusterConnectionPools :: (Host -> CC.PortID -> IO CC.ConnectionContext) -> Int -> Time.NominalDiffTime -> [CMD.CommandInfo] -> ShardMap -> IO Connection
createClusterConnectionPools withAuth maxResources idleTime commandInfos shardMap = do
nodeConns <- nodeConnections
Expand All @@ -157,7 +163,7 @@ createClusterConnectionPools withAuth maxResources idleTime commandInfos shardMa
return $ HM.fromList nodeConnectionsList

createNodePool :: (Host -> CC.PortID -> IO CC.ConnectionContext) -> Int -> Time.NominalDiffTime -> Node -> IO (NodeID, NodeConnection)
createNodePool withAuth maxResources idleTime (Node nodeid _ host port) = do
createNodePool withAuth maxResources idleTime (Node nodeid _ host port _zone) = do
connectionPool <- createPool (do
connectionContext <- withAuth host (CC.PortNumber $ toEnum port)
ref <- IOR.newIORef Nothing
Expand All @@ -174,15 +180,15 @@ destroyNodeResources (Connection shardNodeVar _ _) =
-- Add a request to the current pipeline for this connection. The pipeline will
-- be executed implicitly as soon as any result returned from this function is
-- evaluated.
requestPipelined :: (Maybe NodeConnection -> IO (ShardMap, NodeConnectionMap)) -> Connection -> [B.ByteString] -> MVar Pipeline -> IO Reply
requestPipelined refreshShardmapAction conn nextRequest pipelineVar = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
requestPipelined :: (Maybe NodeConnection -> IO (ShardMap, NodeConnectionMap)) -> Connection -> [B.ByteString] -> MVar Pipeline -> Maybe String -> IO Reply
requestPipelined refreshShardmapAction conn nextRequest pipelineVar podZone = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
(newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case
Pending requests | isMulti nextRequest -> do
replies <- evaluatePipeline refreshShardmapAction conn requests
replies <- evaluatePipeline refreshShardmapAction conn requests podZone
s' <- newMVar $ TransactionPending [nextRequest]
return (Executed replies, (s', 0))
Pending requests | length requests > 1000 -> do
replies <- evaluatePipeline refreshShardmapAction conn (nextRequest:requests)
replies <- evaluatePipeline refreshShardmapAction conn (nextRequest:requests) podZone
return (Executed replies, (stateVar, length requests))
Pending requests ->
return (Pending (nextRequest:requests), (stateVar, length requests))
Expand All @@ -204,7 +210,7 @@ requestPipelined refreshShardmapAction conn nextRequest pipelineVar = modifyMVar
Executed replies ->
return (Executed replies, replies)
Pending requests-> do
replies <- evaluatePipeline refreshShardmapAction conn requests
replies <- evaluatePipeline refreshShardmapAction conn requests podZone
return (Executed replies, replies)
TransactionPending requests-> do
replies <- evaluateTransactionPipeline refreshShardmapAction conn requests
Expand Down Expand Up @@ -243,8 +249,8 @@ rawResponse (CompletedRequest _ _ r) = r
-- step is not pipelined, there is a request per error. This is probably
-- acceptable in most cases as these errors should only occur in the case of
-- cluster reconfiguration events, which should be rare.
evaluatePipeline :: (Maybe NodeConnection -> IO (ShardMap, NodeConnectionMap)) -> Connection -> [[B.ByteString]] -> IO [Reply]
evaluatePipeline refreshShardmapAction conn@(Connection shardNodeVar infoMap _) requests = do
evaluatePipeline :: (Maybe NodeConnection -> IO (ShardMap, NodeConnectionMap)) -> Connection -> [[B.ByteString]] -> Maybe String -> IO [Reply]
evaluatePipeline refreshShardmapAction conn@(Connection shardNodeVar infoMap _) requests podZone = do
(shardMap, nodesConn) <- hasLocked $ readMVar shardNodeVar
erequestsByNode <- try $ getRequestsByNode shardMap nodesConn
requestsByNode <- case erequestsByNode of
Expand Down Expand Up @@ -293,7 +299,7 @@ evaluatePipeline refreshShardmapAction conn@(Connection shardNodeVar infoMap _)
return $ assocs $ fromListWith (++) (mconcat commandsWithNodes)
requestWithNodes :: ShardMap -> NodeConnectionMap -> Int -> [B.ByteString] -> IO [(NodeConnection, [PendingRequest])]
requestWithNodes shardMap nodeConnMap index request = do
nodeConns <- nodeConnectionForCommand shardMap nodeConnMap infoMap request
nodeConns <- nodeConnectionForCommand shardMap nodeConnMap infoMap request podZone
return $ (, [PendingRequest index request]) <$> nodeConns
executeRequests :: NodeConnection -> [PendingRequest] -> IO [CompletedRequest]
executeRequests nodeConn nodeRequests = do
Expand All @@ -303,12 +309,13 @@ evaluatePipeline refreshShardmapAction conn@(Connection shardNodeVar infoMap _)
refreshShardMapAndRetryRequest refreshedShardMapAndNodeConnsIORef refreshShardmap request = do
(newShardMap, newNodeConn) <- fromMaybeM (hasLocked refreshShardmap >>= (\new -> IOR.writeIORef refreshedShardMapAndNodeConnsIORef (Just new) >> return new )) $
IOR.readIORef refreshedShardMapAndNodeConnsIORef
nodeConns <- nodeConnectionForCommand newShardMap newNodeConn infoMap request
nodeConns <- nodeConnectionForCommand newShardMap newNodeConn infoMap request podZone
head <$> requestNode (head nodeConns) [request]

--fix multi exec
-- Like `evaluateOnPipeline`, except we expect to be able to run all commands
-- on a single shard. Failing to meet this expectation is an error.
-- Note: This function is not suitable for zone-aware replica node usage, as it may include both SET and GET commands within the same transaction.
evaluateTransactionPipeline :: (Maybe NodeConnection -> IO (ShardMap, NodeConnectionMap)) -> Connection -> [[B.ByteString]] -> IO [Reply]
evaluateTransactionPipeline refreshShardmapAction conn requests' = do
let requests = reverse requests'
Expand Down Expand Up @@ -431,8 +438,8 @@ nodeConnWithHostAndPort (shardMap, nodeConns) host port = do
Nothing -> return Nothing
Just node -> return (HM.lookup (nodeId node) nodeConns)

nodeConnectionForCommand :: ShardMap -> NodeConnectionMap -> CMD.InfoMap -> [B.ByteString] -> IO [NodeConnection]
nodeConnectionForCommand (ShardMap shardMap) nodeConns infoMap request =
nodeConnectionForCommand :: ShardMap -> NodeConnectionMap -> CMD.InfoMap -> [B.ByteString] -> Maybe String -> IO [NodeConnection]
nodeConnectionForCommand (ShardMap shardMap) nodeConns infoMap request podZone =
case request of
("FLUSHALL" : _) -> allNodes
("FLUSHDB" : _) -> allNodes
Expand All @@ -441,9 +448,10 @@ nodeConnectionForCommand (ShardMap shardMap) nodeConns infoMap request =
_ -> do
keys <- requestKeys infoMap request
hashSlot <- hashSlotForKeys (CrossSlotException [request]) keys
node <- case IntMap.lookup (fromEnum hashSlot) shardMap of
shardNode <- case IntMap.lookup (fromEnum hashSlot) shardMap of
Nothing -> throwIO $ MissingNodeException ("HashSlot lookup failed in nodeConnectionForCommand" : request)
Just (Shard master _) -> return master
Just shard -> return shard
node <- getMasterOrReplicaNode shardNode request podZone
maybe (throwIO $ MissingNodeException ("NodeId lookup failed in nodeConnectionForCommand" : request)) (return . return) (HM.lookup (nodeId node) nodeConns)
where
allNodes = do
Expand All @@ -452,6 +460,26 @@ nodeConnectionForCommand (ShardMap shardMap) nodeConns infoMap request =
Nothing -> throwIO $ MissingNodeException ("Master node lookup failed" : request)
Just allNodes' -> return allNodes'

redisReadCommands :: [B.ByteString]
redisReadCommands = ["GET","MGET"]

getMasterOrReplicaNode :: Shard -> [B.ByteString] -> Maybe String -> IO Node
getMasterOrReplicaNode (Shard master _) _ Nothing = return master
getMasterOrReplicaNode (Shard master replicas) (cmd:_) (Just podZone)
| allHasZoneInfo (master : replicas) && cmd `elem` redisReadCommands = getRandomNode master replicas podZone
getMasterOrReplicaNode (Shard master _) _ _ = return master

allHasZoneInfo :: [Node] -> Bool
allHasZoneInfo = all (\(Node _ _ _ _ zone) -> isJust zone)

getRandomNode :: Node -> [Node] -> String -> IO Node
getRandomNode master replicas podZone =
let matchedZoneReplicas = filter (\(Node _ _ _ _ zone) -> Just podZone == zone) replicas
in if null matchedZoneReplicas
then return master
else (matchedZoneReplicas !!) <$> randomRIO (0, length matchedZoneReplicas - 1)


allMasterNodes :: ShardMap -> NodeConnectionMap -> IO (Maybe [NodeConnection])
allMasterNodes (ShardMap shardMap) nodeConns = do
return $ mapM (flip HM.lookup nodeConns) onlyMasterNodeIds
Expand Down Expand Up @@ -495,10 +523,10 @@ nodes (ShardMap shardMap) = concatMap snd $ IntMap.toList $ fmap shardNodes shar


nodeWithHostAndPort :: ShardMap -> Host -> Port -> Maybe Node
nodeWithHostAndPort shardMap host port = find (\(Node _ _ nodeHost nodePort) -> port == nodePort && host == nodeHost) (nodes shardMap)
nodeWithHostAndPort shardMap host port = find (\(Node _ _ nodeHost nodePort _) -> port == nodePort && host == nodeHost) (nodes shardMap)

nodeId :: Node -> NodeID
nodeId (Node theId _ _ _) = theId
nodeId (Node theId _ _ _ _) = theId

hasLocked :: IO a -> IO a
hasLocked action =
Expand Down
30 changes: 20 additions & 10 deletions src/Database/Redis/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -258,22 +258,32 @@ connectWithAuth ConnInfo{connectTLSParams,connectAuth,connectReadOnly,connectTim
clusterConnectTimeoutinUs :: Time.NominalDiffTime -> Int
clusterConnectTimeoutinUs = round . (1000000 *)

testfn :: IO (HM.HashMap String String)
testfn = do
td <- lookupEnv "REDIS_SUBNET_MAP"
print td
subnetMap :: HM.HashMap String String <- fromMaybe HM.empty . (>>= readMaybe) <$> lookupEnv "REDIS_SUBNET_MAP"
return subnetMap

shardMapFromClusterSlotsResponse :: ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr mkShardMap (pure IntMap.empty) clusterSlotsResponseEntries where
mkShardMap :: ClusterSlotsResponseEntry -> IO (IntMap.IntMap Shard) -> IO (IntMap.IntMap Shard)
mkShardMap ClusterSlotsResponseEntry{..} accumulator = do
shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = do
subnetMap :: HM.HashMap String String <- fromMaybe HM.empty . (>>= readMaybe) <$> lookupEnv "REDIS_SUBNET_MAP"
ShardMap <$> foldr (mkShardMap subnetMap) (pure IntMap.empty) clusterSlotsResponseEntries where
mkShardMap :: HM.HashMap String String -> ClusterSlotsResponseEntry -> IO (IntMap.IntMap Shard) -> IO (IntMap.IntMap Shard)
mkShardMap subnetMap ClusterSlotsResponseEntry{..} accumulator = do
accumulated <- accumulator
let master = nodeFromClusterSlotNode True clusterSlotsResponseEntryMaster
-- let replicas = map (nodeFromClusterSlotNode False) clusterSlotsResponseEntryReplicas
let shard = Shard master []
let master = nodeFromClusterSlotNode True subnetMap clusterSlotsResponseEntryMaster
let replicas = map (nodeFromClusterSlotNode False subnetMap) clusterSlotsResponseEntryReplicas
let shard = Shard master replicas
let slotMap = IntMap.fromList $ map (, shard) [clusterSlotsResponseEntryStartSlot..clusterSlotsResponseEntryEndSlot]
return $ IntMap.union slotMap accumulated
nodeFromClusterSlotNode :: Bool -> ClusterSlotsNode -> Node
nodeFromClusterSlotNode isMaster ClusterSlotsNode{..} =
nodeFromClusterSlotNode :: Bool -> HM.HashMap String String -> ClusterSlotsNode -> Node
nodeFromClusterSlotNode isMaster subnetMap ClusterSlotsNode{..} =
let hostname = Char8.unpack clusterSlotsNodeIP
role = if isMaster then Cluster.Master else Cluster.Slave
zone = Cluster.getZoneInfoFromSubnet subnetMap hostname
in
Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort)
Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort) zone

refreshShardMap :: ConnectInfo -> Cluster.Connection -> Maybe Cluster.NodeConnection -> IO (ShardMap, NodeConnectionMap)
refreshShardMap connectInfo@ConnInfo{connectMaxConnections,connectMaxIdleTime} (Cluster.Connection shardNodeVar _ _) nodeConn = do
Expand All @@ -286,7 +296,7 @@ refreshShardMap connectInfo@ConnInfo{connectMaxConnections,connectMaxIdleTime} (
withAuth = connectWithAuth connectInfo
updateNodeConnections :: ShardMap -> HM.HashMap Cluster.NodeID Cluster.NodeConnection -> IO (HM.HashMap Cluster.NodeID Cluster.NodeConnection)
updateNodeConnections newShardMap oldNodeConnMap = do
foldM (\acc node@(Cluster.Node nodeid _ _ _) ->
foldM (\acc node@(Cluster.Node nodeid _ _ _ _zone) ->
case HM.lookup nodeid oldNodeConnMap of
Just nodeconn -> return $ HM.insert nodeid nodeconn acc
Nothing -> do
Expand Down
7 changes: 5 additions & 2 deletions src/Database/Redis/Core.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{-# LANGUAGE OverloadedStrings, GeneralizedNewtypeDeriving, RecordWildCards,
MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, CPP,
DeriveDataTypeable, StandaloneDeriving #-}
DeriveDataTypeable, StandaloneDeriving, ScopedTypeVariables #-}

module Database.Redis.Core (
Redis(), unRedis, reRedis,
Expand All @@ -25,6 +25,8 @@ import qualified Database.Redis.ProtocolPipelining as PP
import Database.Redis.Types
import Database.Redis.Cluster(ShardMap, NodeConnectionMap, NodeConnection)
import qualified Database.Redis.Cluster as Cluster
import System.Environment (lookupEnv)
import Text.Read (readMaybe)

--------------------------------------------------------------------------------
-- The Redis Monad
Expand Down Expand Up @@ -114,13 +116,14 @@ sendRequest :: (RedisCtx m f, RedisResult a)
sendRequest req = do
r' <- liftRedis $ Redis $ do
env <- ask
podZone :: Maybe String <- (>>= readMaybe) <$> (liftIO $ lookupEnv "POD_ZONE")
case env of
NonClusteredEnv{..} -> do
r <- liftIO $ PP.request envConn (renderRequest req)
setLastReply r
return r
ClusteredEnv{..} -> do
r <- liftIO $ Cluster.requestPipelined refreshAction connection req pipeline
r <- liftIO $ Cluster.requestPipelined refreshAction connection req pipeline podZone
lift (writeIORef clusteredLastReply r)
return r
returnDecode r'
Expand Down

0 comments on commit 68245b5

Please sign in to comment.