From c76d635559bfe53dd46227ea2f1497ff737ca5ad Mon Sep 17 00:00:00 2001 From: Alexander Malev Date: Mon, 13 Nov 2023 00:44:59 +0300 Subject: [PATCH] Auth username password (#34) * Auth username password --- redis_rs/__init__.py | 6 +++ redis_rs/__init__.pyi | 3 ++ src/client.rs | 5 +++ src/cluster_async.rs | 9 +++-- src/cluster_bb8.rs | 53 +++++++++++++++++++++++--- src/cluster_deadpool.rs | 16 ++++++-- src/lib.rs | 51 ++++++++++++++++++------- src/pool_manager.rs | 83 +++++++++++++++++++++-------------------- src/shards_async.rs | 9 +++-- tests/conftest.py | 12 ++++-- tests/test_client.py | 18 +++++++++ 11 files changed, 193 insertions(+), 72 deletions(-) diff --git a/redis_rs/__init__.py b/redis_rs/__init__.py index 04a536d..bee3b4e 100644 --- a/redis_rs/__init__.py +++ b/redis_rs/__init__.py @@ -18,6 +18,9 @@ def create_client( *args: str, max_size: Optional[int] = None, cluster: Optional[bool] = None, + username: Optional[str] = None, + password: Optional[str] = None, + db: Optional[int] = None, client_id: Optional[str] = None, features: Optional[List[str]] = None, ) -> Client: @@ -27,6 +30,9 @@ def create_client( *args, max_size=max_size, cluster=cluster, + username=username, + password=password, + db=db, client_id=client_id, features=features, ) diff --git a/redis_rs/__init__.pyi b/redis_rs/__init__.pyi index e6f1d9a..5655804 100644 --- a/redis_rs/__init__.pyi +++ b/redis_rs/__init__.pyi @@ -22,6 +22,9 @@ def create_client( *args: str, max_size: Optional[int] = None, cluster: Optional[bool] = None, + username: Optional[str] = None, + password: Optional[str] = None, + db: Optional[int] = None, client_id: Optional[str] = None, features: Optional[List[str]] = None, ) -> Client: ... diff --git a/src/client.rs b/src/client.rs index 791118d..4733e83 100644 --- a/src/client.rs +++ b/src/client.rs @@ -30,6 +30,7 @@ impl Client { let mut status = self.cr.status()?; let is_closed = status.remove("closed"); let is_cluster = status.remove("cluster"); + let is_auth = status.remove("auth"); let result = PyDict::new(py); for (k, v) in status.into_iter() { let value = types::to_object(py, v, "utf-8"); @@ -43,6 +44,10 @@ impl Client { let is_closed = c == 1; result.set_item("closed", is_closed.to_object(py))?; } + if let Some(redis::Value::Int(c)) = is_auth { + let is_auth = c == 1; + result.set_item("auth", is_auth.to_object(py))?; + } Ok(result.to_object(py)) } diff --git a/src/cluster_async.rs b/src/cluster_async.rs index 048f3bc..134774e 100644 --- a/src/cluster_async.rs +++ b/src/cluster_async.rs @@ -5,7 +5,7 @@ use crate::{ pool::{Connection, Pool}, }; use async_trait::async_trait; -use redis::{aio::ConnectionLike, cluster::ClusterClient, Cmd}; +use redis::{aio::ConnectionLike, cluster::ClusterClient, Cmd, IntoConnectionInfo}; use tokio::sync::Semaphore; pub struct Cluster { @@ -14,8 +14,11 @@ pub struct Cluster { } impl Cluster { - pub async fn new(initial_nodes: Vec, max_size: u32) -> Result { - let client = ClusterClient::new(initial_nodes).unwrap(); + pub async fn new(initial_nodes: Vec, max_size: u32) -> Result + where + T: IntoConnectionInfo, + { + let client = ClusterClient::new(initial_nodes)?; let semaphore = Semaphore::new(max_size as usize); let connection = client.get_async_connection().await?; Ok(Self { diff --git a/src/cluster_bb8.rs b/src/cluster_bb8.rs index 607fcdd..b78c97a 100644 --- a/src/cluster_bb8.rs +++ b/src/cluster_bb8.rs @@ -5,16 +5,59 @@ use crate::{ pool::{Connection, Pool}, }; use async_trait::async_trait; -use bb8_redis_cluster::RedisConnectionManager; -use redis::{aio::ConnectionLike, Cmd}; +use redis::{ + aio::ConnectionLike, cluster::ClusterClient, cluster_async::ClusterConnection, Cmd, ErrorKind, + IntoConnectionInfo, RedisError, +}; + +pub struct ClusterManager { + pub(crate) client: ClusterClient, +} + +impl ClusterManager { + pub fn new(initial_nodes: Vec) -> Result + where + T: IntoConnectionInfo, + { + let client = ClusterClient::new(initial_nodes)?; + Ok(Self { client }) + } +} + +#[async_trait] +impl bb8::ManageConnection for ClusterManager { + type Connection = ClusterConnection; + type Error = RedisError; + + async fn connect(&self) -> Result { + self.client.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong: String = redis::cmd("PING").query_async(conn).await?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err((ErrorKind::ResponseError, "ping request").into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} + +type Manager = ClusterManager; pub struct BB8Cluster { - pool: bb8::Pool, + pool: bb8::Pool, } impl BB8Cluster { - pub async fn new(initial_nodes: Vec, max_size: u32) -> Self { - let manager = RedisConnectionManager::new(initial_nodes).unwrap(); + pub async fn new(initial_nodes: Vec, max_size: u32) -> Self + where + T: IntoConnectionInfo, + { + let manager = Manager::new(initial_nodes).unwrap(); let pool = bb8::Pool::builder() .max_size(max_size) .build(manager) diff --git a/src/cluster_deadpool.rs b/src/cluster_deadpool.rs index 621021d..18efe55 100644 --- a/src/cluster_deadpool.rs +++ b/src/cluster_deadpool.rs @@ -4,7 +4,7 @@ use crate::{ }; use async_trait::async_trait; use deadpool_redis_cluster::{Config, PoolError, Runtime}; -use redis::{aio::ConnectionLike, Cmd}; +use redis::{aio::ConnectionLike, Cmd, IntoConnectionInfo}; use std::collections::HashMap; pub struct DeadPoolCluster { @@ -21,13 +21,21 @@ impl From for error::RedisError { } impl DeadPoolCluster { - pub fn new(initial_nodes: Vec, max_size: u32) -> Self { - let cfg = Config::from_urls(initial_nodes); + pub fn new(initial_nodes: Vec, max_size: u32) -> Result + where + T: IntoConnectionInfo, + { + let mut urls = vec![]; + for i in initial_nodes.into_iter() { + let url = i.into_connection_info()?; + urls.push(url.addr.to_string()); + } + let cfg = Config::from_urls(urls); let pool = cfg .create_pool(Some(Runtime::Tokio1)) .expect("Error with redis pool"); pool.resize(max_size as usize); - Self { pool } + Ok(Self { pool }) } } diff --git a/src/lib.rs b/src/lib.rs index 368cdd9..84b0e06 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ use pyo3::prelude::*; +use redis::IntoConnectionInfo; mod client; mod client_result; mod client_result_async; @@ -17,27 +18,49 @@ mod single_node; mod types; #[pyfunction] -#[pyo3(signature = (*initial_nodes, max_size=None, cluster=None, client_id=None, features=None))] +#[pyo3(signature = ( + *initial_nodes, + max_size=None, + cluster=None, + username = None, + password = None, + db = None, + client_id=None, + features=None, +))] +#[allow(clippy::too_many_arguments)] fn create_client( initial_nodes: Vec, max_size: Option, cluster: Option, + username: Option, + password: Option, + db: Option, client_id: Option, features: Option>, ) -> PyResult { - let is_cluster = match cluster { - None => initial_nodes.len() > 1, - Some(c) => c, - }; - let mut cm = if is_cluster { - pool_manager::PoolManager::new_cluster(initial_nodes) - } else { - let addr = initial_nodes - .get(0) - .map(String::as_str) - .unwrap_or("redis://localhost:6379"); - pool_manager::PoolManager::new(addr) - }; + let mut nodes = initial_nodes.clone(); + if nodes.is_empty() { + nodes.push("redis://localhost:6379".to_string()); + } + let mut infos = vec![]; + for i in nodes.into_iter() { + let mut info = i.into_connection_info().map_err(error::RedisError::from)?; + if password.is_some() { + info.redis.password = password.clone(); + } + if username.is_some() { + info.redis.username = username.clone(); + } + if let Some(db) = db { + info.redis.db = db; + } + infos.push(info); + } + + let mut cm = pool_manager::PoolManager::new(infos)?; + cm.is_cluster = cluster; + if let Some(features) = features { cm.features = features .into_iter() diff --git a/src/pool_manager.rs b/src/pool_manager.rs index d13b779..853697b 100644 --- a/src/pool_manager.rs +++ b/src/pool_manager.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use redis::{Cmd, FromRedisValue, IntoConnectionInfo}; +use redis::{Cmd, ConnectionInfo, FromRedisValue}; use crate::{ client::Client, @@ -31,8 +31,8 @@ impl From for Client { } pub struct PoolManager { - pub(crate) is_cluster: bool, - pub(crate) initial_nodes: Vec, + pub(crate) is_cluster: Option, + pub(crate) initial_nodes: Vec, pub(crate) max_size: u32, pub(crate) pool: Box, pub(crate) client_id: String, @@ -40,51 +40,46 @@ pub struct PoolManager { } impl PoolManager { - pub fn new_cluster(initial_nodes: Vec) -> Self { - Self { + pub fn new(initial_nodes: Vec) -> Result { + Ok(Self { initial_nodes, - is_cluster: true, - max_size: 10, - pool: Box::new(ClosedPool), - client_id: String::default(), - features: vec![], - } - } - - pub fn new(addr: &str) -> Self { - Self { - initial_nodes: vec![addr.to_string()], - is_cluster: false, + is_cluster: Some(false), max_size: 10, pool: Box::new(ClosedPool), client_id: String::default(), features: vec![], - } + }) } pub async fn init(&mut self) -> Result<&Self, error::RedisError> { - let nodes = self.initial_nodes.clone(); + let mut nodes = self.initial_nodes.clone(); let ms = self.max_size; - let is_cluster = self.is_cluster; - if is_cluster { - self.pool = match self.features.as_slice() { - [types::Feature::BB8, ..] => Box::new(BB8Cluster::new(nodes, ms).await), - [types::Feature::DeadPool, ..] => Box::new(DeadPoolCluster::new(nodes, ms)), - [types::Feature::Shards, ..] => { - Box::new(AsyncShards::new(nodes, ms, Some(true)).await?) - } - _ => Box::new(Cluster::new(nodes, ms).await?), - }; - } else { - let info = nodes.clone().remove(0).into_connection_info()?; - self.pool = match self.features.as_slice() { - [types::Feature::BB8, ..] => Box::new(BB8Pool::new(info, ms).await?), - [types::Feature::DeadPool, ..] => Box::new(DeadPool::new(info, ms).await?), - [types::Feature::Shards, ..] => { - Box::new(AsyncShards::new(nodes, ms, Some(false)).await?) - } - _ => Box::new(Node::new(info, ms).await?), - }; + match self.is_cluster { + None => { + self.pool = Box::new(AsyncShards::new(nodes, ms, self.is_cluster).await?); + } + Some(true) => { + self.pool = match self.features.as_slice() { + [types::Feature::BB8, ..] => Box::new(BB8Cluster::new(nodes, ms).await), + [types::Feature::DeadPool, ..] => Box::new(DeadPoolCluster::new(nodes, ms)?), + [types::Feature::Shards, ..] => { + Box::new(AsyncShards::new(nodes, ms, Some(true)).await?) + } + _ => Box::new(Cluster::new(nodes, ms).await?), + }; + } + Some(false) => { + self.pool = match self.features.as_slice() { + [types::Feature::BB8, ..] => Box::new(BB8Pool::new(nodes.remove(0), ms).await?), + [types::Feature::DeadPool, ..] => { + Box::new(DeadPool::new(nodes.remove(0), ms).await?) + } + [types::Feature::Shards, ..] => { + Box::new(AsyncShards::new(nodes, ms, Some(false)).await?) + } + _ => Box::new(Node::new(nodes.remove(0), ms).await?), + }; + } }; Ok(self) } @@ -99,7 +94,15 @@ impl PoolManager { let initial_nodes = self .initial_nodes .iter() - .map(|s| redis::Value::Data(s.as_bytes().to_vec())) + .map(|s| { + if let Some(username) = s.redis.username.clone() { + result.insert("username", redis::Value::Data(username.as_bytes().to_vec())); + } + if s.redis.password.is_some() { + result.insert("auth", redis::Value::Int(1)); + } + redis::Value::Data(s.addr.to_string().as_bytes().to_vec()) + }) .collect(); result.insert("initial_nodes", redis::Value::Bulk(initial_nodes)); result.insert("max_size", redis::Value::Int(self.max_size as i64)); diff --git a/src/shards_async.rs b/src/shards_async.rs index d2852e1..e9bbea1 100644 --- a/src/shards_async.rs +++ b/src/shards_async.rs @@ -24,11 +24,14 @@ pub struct AsyncShards { } impl AsyncShards { - pub async fn new( - nodes: Vec, + pub async fn new( + nodes: Vec, max_size: u32, is_cluster: Option, - ) -> RedisResult { + ) -> RedisResult + where + T: IntoConnectionInfo, + { let mut result = Self { max_size, ..Default::default() diff --git a/tests/conftest.py b/tests/conftest.py index ebb601c..5d5e1b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,11 +35,17 @@ async def get_redis_version(nodes: list) -> str: @pytest.fixture -async def async_client(): - async with redis_rs.create_client( +def client_factory(): + return lambda **kwargs: redis_rs.create_client( *NODES, cluster=IS_CLUSTER, - ) as c: + **kwargs, + ) + + +@pytest.fixture +async def async_client(client_factory): + async with client_factory() as c: yield c diff --git a/tests/test_client.py b/tests/test_client.py index e85901b..5eb71f9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,23 @@ +import pytest + import redis_rs async def test_client_id(async_client: redis_rs.AsyncClient): assert async_client.client_id + + +@pytest.mark.redis(single=True) +async def test_password(async_client: redis_rs.AsyncClient, client_factory): + user = "test" + password = await async_client.fetch_str("ACL", "GENPASS") + assert await async_client.execute("ACL", "SETUSER", user, "nopass") + assert await async_client.execute("ACL", "SETUSER", user, "on", ">" + password, "+acl|whoami", "+cluster|slots") + assert await async_client.execute("AUTH", user, password) + assert user == await async_client.fetch_str("ACL", "WHOAMI") + + async with client_factory(username=user, password=password) as client: + assert user == await client.fetch_str("ACL", "WHOAMI") + status = client.status() + assert status.get("username") + assert status.get("auth")