Skip to content

Commit

Permalink
refactor(query): refactor license manager (#16492)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhang2014 authored Sep 23, 2024
1 parent 4be656a commit 3b49735
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 73 deletions.
48 changes: 24 additions & 24 deletions src/binaries/query/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,30 @@ pub async fn start_services(conf: &InnerConfig) -> Result<(), MainError> {

info!("Databend Query start with config: {:?}", conf);

// Cluster register.
{
ClusterDiscovery::instance()
.register_to_metastore(conf)
.await
.with_context(make_error)?;
info!(
"Databend query has been registered:{:?} to metasrv:{:?}.",
conf.query.cluster_id, conf.meta.endpoints
);
}

// RPC API service.
{
let address = conf.query.flight_api_address.clone();
let mut srv = FlightService::create(conf.clone()).with_context(make_error)?;
let listening = srv
.start(address.parse().with_context(make_error)?)
.await
.with_context(make_error)?;
shutdown_handle.add_service("RPCService", srv);
info!("Listening for RPC API (interserver): {}", listening);
}

// MySQL handler.
{
let hostname = conf.query.mysql_handler_host.clone();
Expand Down Expand Up @@ -229,30 +253,6 @@ pub async fn start_services(conf: &InnerConfig) -> Result<(), MainError> {
info!("Listening for FlightSQL API: {}", listening);
}

// RPC API service.
{
let address = conf.query.flight_api_address.clone();
let mut srv = FlightService::create(conf.clone()).with_context(make_error)?;
let listening = srv
.start(address.parse().with_context(make_error)?)
.await
.with_context(make_error)?;
shutdown_handle.add_service("RPCService", srv);
info!("Listening for RPC API (interserver): {}", listening);
}

// Cluster register.
{
ClusterDiscovery::instance()
.register_to_metastore(conf)
.await
.with_context(make_error)?;
info!(
"Databend query has been registered:{:?} to metasrv:{:?}.",
conf.query.cluster_id, conf.meta.endpoints
);
}

// Print information to users.
println!("Databend Query");

Expand Down
1 change: 1 addition & 0 deletions src/common/exception/src/exception_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ build_exceptions! {
/// For example: license key is expired
LicenseKeyInvalid(1402),
EnterpriseFeatureNotEnable(1403),
LicenseKeyExpired(1404),

BackgroundJobAlreadyExists(1501),
UnknownBackgroundJob(1502),
Expand Down
21 changes: 13 additions & 8 deletions src/common/license/src/license.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::fmt;

use databend_common_base::display::display_option::DisplayOptionExt;
use databend_common_base::display::display_slice::DisplaySliceExt;
use databend_common_exception::ErrorCode;
use serde::Deserialize;
use serde::Serialize;

Expand Down Expand Up @@ -124,31 +125,35 @@ impl fmt::Display for Feature {
}

impl Feature {
pub fn verify(&self, feature: &Feature) -> bool {
pub fn verify_default(&self, message: impl Into<String>) -> Result<(), ErrorCode> {
Err(ErrorCode::LicenseKeyInvalid(message.into()))
}

pub fn verify(&self, feature: &Feature) -> Result<bool, ErrorCode> {
match (self, feature) {
(Feature::ComputeQuota(c), Feature::ComputeQuota(v)) => {
if let Some(thread_num) = c.threads_num {
if thread_num <= v.threads_num.unwrap_or(usize::MAX) {
return false;
return Ok(false);
}
}

if let Some(max_memory_usage) = c.memory_usage {
if max_memory_usage <= v.memory_usage.unwrap_or(usize::MAX) {
return false;
return Ok(false);
}
}

true
Ok(true)
}
(Feature::StorageQuota(c), Feature::StorageQuota(v)) => {
if let Some(max_storage_usage) = c.storage_usage {
if max_storage_usage <= v.storage_usage.unwrap_or(usize::MAX) {
return false;
return Ok(false);
}
}

true
Ok(true)
}
(Feature::Test, Feature::Test)
| (Feature::AggregateIndex, Feature::AggregateIndex)
Expand All @@ -161,8 +166,8 @@ impl Feature {
| (Feature::InvertedIndex, Feature::InvertedIndex)
| (Feature::VirtualColumn, Feature::VirtualColumn)
| (Feature::AttacheTable, Feature::AttacheTable)
| (Feature::StorageEncryption, Feature::StorageEncryption) => true,
(_, _) => false,
| (Feature::StorageEncryption, Feature::StorageEncryption) => Ok(true),
(_, _) => Ok(false),
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/common/license/src/license_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@ impl LicenseManager for OssLicenseManager {
GlobalInstance::get()
}

fn check_enterprise_enabled(&self, _license_key: String, _feature: Feature) -> Result<()> {
Err(ErrorCode::LicenseKeyInvalid(
"Need Commercial License".to_string(),
))
fn check_enterprise_enabled(&self, _license_key: String, feature: Feature) -> Result<()> {
// oss ignore license key.
feature.verify_default("Need Commercial License".to_string())
}

fn parse_license(&self, _raw: &str) -> Result<JWTClaims<LicenseInfo>> {
Expand Down
94 changes: 57 additions & 37 deletions src/query/ee/src/license/license_mgr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use jwt_simple::algorithms::ES256PublicKey;
use jwt_simple::claims::JWTClaims;
use jwt_simple::prelude::Clock;
use jwt_simple::prelude::ECDSAP256PublicKeyLike;
use jwt_simple::JWTError;

const LICENSE_PUBLIC_KEY: &str = r#"-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEGsKCbhXU7j56VKZ7piDlLXGhud0a
Expand Down Expand Up @@ -60,34 +61,41 @@ impl LicenseManager for RealLicenseManager {

fn check_enterprise_enabled(&self, license_key: String, feature: Feature) -> Result<()> {
if license_key.is_empty() {
return Err(ErrorCode::LicenseKeyInvalid(format!(
"use of {feature} requires an enterprise license. license key is not found for {}",
return feature.verify_default(format!(
"The use of this feature requires a Databend Enterprise Edition license. No license key found for tenant: {}. To unlock enterprise features, please contact Databend to obtain a license. Learn more at https://docs.databend.com/guides/overview/editions/dee/",
self.tenant
)));
));
}

if let Some(v) = self.cache.get(&license_key) {
return Self::verify_feature(v.value(), feature);
return self.verify_feature(v.value(), feature);
}

let license = self.parse_license(&license_key).map_err_to_code(
ErrorCode::LicenseKeyInvalid,
|| format!("use of {feature} requires an enterprise license. current license is invalid for {}", self.tenant),
)?;
Self::verify_feature(&license, feature)?;
self.cache.insert(license_key, license);
Ok(())
match self.parse_license(&license_key) {
Ok(license) => {
self.verify_feature(&license, feature)?;
self.cache.insert(license_key, license);
Ok(())
}
Err(e) => match e.code() == ErrorCode::LICENSE_KEY_EXPIRED {
true => self.verify_if_expired(feature),
false => Err(e),
},
}
}

fn parse_license(&self, raw: &str) -> Result<JWTClaims<LicenseInfo>> {
let public_key = ES256PublicKey::from_pem(self.public_key.as_str())
.map_err_to_code(ErrorCode::LicenseKeyParseError, || "public key load failed")?;
public_key
.verify_token::<LicenseInfo>(raw, None)
.map_err_to_code(
ErrorCode::LicenseKeyParseError,
|| "jwt claim decode failed",
)
match public_key.verify_token::<LicenseInfo>(raw, None) {
Ok(v) => Ok(v),
Err(cause) => match cause.downcast_ref::<JWTError>() {
Some(JWTError::TokenHasExpired) => {
Err(ErrorCode::LicenseKeyExpired("license key is expired."))
}
_ => Err(ErrorCode::LicenseKeyParseError("jwt claim decode failed")),
},
}
}

fn get_storage_quota(&self, license_key: String) -> Result<StorageQuota> {
Expand All @@ -96,15 +104,26 @@ impl LicenseManager for RealLicenseManager {
}

if let Some(v) = self.cache.get(&license_key) {
Self::verify_license(v.value())?;
if Self::verify_license_expired(v.value())? {
return Err(ErrorCode::LicenseKeyExpired(format!(
"license key expired in {:?}",
v.value().expires_at,
)));
}
return Ok(v.custom.get_storage_quota());
}

let license = self.parse_license(&license_key).map_err_to_code(
ErrorCode::LicenseKeyInvalid,
|| format!("use of storage requires an enterprise license. current license is invalid for {}", self.tenant),
)?;
Self::verify_license(&license)?;

if Self::verify_license_expired(&license)? {
return Err(ErrorCode::LicenseKeyExpired(format!(
"license key expired in {:?}",
license.expires_at,
)));
}

let quota = license.custom.get_storage_quota();
self.cache.insert(license_key, license);
Expand All @@ -123,36 +142,28 @@ impl RealLicenseManager {
}
}

fn verify_license(l: &JWTClaims<LicenseInfo>) -> Result<()> {
fn verify_license_expired(l: &JWTClaims<LicenseInfo>) -> Result<bool> {
let now = Clock::now_since_epoch();
match l.expires_at {
Some(expire_at) => {
if now > expire_at {
return Err(ErrorCode::LicenseKeyInvalid(format!(
"license key expired in {:?}",
expire_at
)));
}
}
None => {
return Err(ErrorCode::LicenseKeyInvalid(
"cannot find valid expire time",
));
}
Some(expire_at) => Ok(now > expire_at),
None => Err(ErrorCode::LicenseKeyInvalid(
"cannot find valid expire time",
)),
}
Ok(())
}

fn verify_feature(l: &JWTClaims<LicenseInfo>, feature: Feature) -> Result<()> {
Self::verify_license(l)?;
fn verify_feature(&self, l: &JWTClaims<LicenseInfo>, feature: Feature) -> Result<()> {
if Self::verify_license_expired(l)? {
return self.verify_if_expired(feature);
}

if l.custom.features.is_none() {
return Ok(());
}

let verify_features = l.custom.features.as_ref().unwrap();
for verify_feature in verify_features {
if verify_feature.verify(&feature) {
if verify_feature.verify(&feature)? {
return Ok(());
}
}
Expand All @@ -163,4 +174,13 @@ impl RealLicenseManager {
l.custom.display_features()
)))
}

fn verify_if_expired(&self, feature: Feature) -> Result<()> {
feature.verify_default("").map_err(|_|
ErrorCode::LicenseKeyExpired(format!(
"The use of this feature requires a Databend Enterprise Edition license. License key has expired for tenant: {}. To unlock enterprise features, please contact Databend to obtain a license. Learn more at https://docs.databend.com/guides/overview/editions/dee/",
self.tenant
))
)
}
}

0 comments on commit 3b49735

Please sign in to comment.