Skip to content

Commit

Permalink
wip: extension loader
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Feb 8, 2024
1 parent 8155d25 commit 4c527ef
Show file tree
Hide file tree
Showing 25 changed files with 648 additions and 132 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 24 additions & 1 deletion crates/catalog/src/session_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ impl SessionCatalog {
.schema_objects
.get(&self.resolve_conf.default_schema_oid)?;
let obj_id = obj.objects.get(name)?;

let ent = self
.state
.entries
Expand All @@ -233,6 +232,30 @@ impl SessionCatalog {
}
}

pub fn resolve_scalar_function(&self, name: &str) -> Option<FunctionEntry> {
let obj = self
.schema_objects
.get(&self.resolve_conf.default_schema_oid)?;

let obj_id = obj.objects.get(name)?;


let ent = self
.state
.entries
.get(obj_id)
.expect("object name points to invalid function");

match ent {
CatalogEntry::Function(function)
if function.meta.builtin && function.func_type == FunctionType::Scalar =>
{
Some(function.clone())
}
_ => None,
}
}

/// Resolve an entry by schema name and object name.
///
/// Note that this will never return a schema entry.
Expand Down
4 changes: 2 additions & 2 deletions crates/glaredb/src/highlighter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fn colorize_sql(query: &str, st: &mut StyledText, is_hint: bool) {
st.push((new_style().fg(Color::LightGreen), format!("{w}")))
}
// Functions
other if FUNCTION_REGISTRY.contains(other) => {
other if FUNCTION_REGISTRY.lock().contains(other) => {
st.push((colorize_function(), format!("{w}")));
}
_ => {
Expand All @@ -201,7 +201,7 @@ fn colorize_sql(query: &str, st: &mut StyledText, is_hint: bool) {
},
// TODO: add more keywords
_ => {
if FUNCTION_REGISTRY.contains(&w.value) {
if FUNCTION_REGISTRY.lock().contains(&w.value) {
st.push((colorize_function(), format!("{w}")));
} else {
st.push((new_style(), format!("{w}")))
Expand Down
89 changes: 75 additions & 14 deletions crates/metastore/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use sqlbuiltins::builtins::{
DEFAULT_SCHEMA,
FIRST_NON_STATIC_OID,
};
use sqlbuiltins::functions::{BuiltinFunction, FUNCTION_REGISTRY};
use sqlbuiltins::functions::{BuiltinFunction, FUNCTION_REGISTRY, SENTENCE_TRANSFORMER_EXTENSION};
use sqlbuiltins::validation::{
validate_database_tunnel_support,
validate_object_name,
Expand Down Expand Up @@ -342,6 +342,8 @@ struct State {
schema_names: HashMap<String, u32>,
/// Map schema IDs to objects in the schema.
schema_objects: HashMap<u32, SchemaObjects>,

extension_names: HashMap<String, u32>,
}

impl State {
Expand All @@ -358,7 +360,6 @@ impl State {
let mut credentials_names = HashMap::new();
let mut schema_names = HashMap::new();
let mut schema_objects = HashMap::new();

// Sanity check to ensure we didn't accidentally persist builtin
// objects.
for (oid, ent) in &state.entries {
Expand All @@ -375,6 +376,7 @@ impl State {
schema_names.extend(builtin.schema_names);
schema_objects.extend(builtin.schema_objects);


// Rebuild name maps for user objects.
//
// All non-database objects are checked to ensure they have non-zero
Expand Down Expand Up @@ -486,6 +488,9 @@ impl State {
credentials_names,
schema_names,
schema_objects,
// Should extensions be persisted?
// AFAICT, other databases make you load them every time, acting as temporary objects.
extension_names: HashMap::new(),
};

Ok(internal_state)
Expand Down Expand Up @@ -1025,6 +1030,66 @@ impl State {
// Update the new storage size
self.deployment.storage_size = update_deployment_storage.new_storage_size;
}
Mutation::LoadExtension(load_extension) => {
// We only have the one extension for now
if load_extension.extension != "sentence_transformers" {
return Err(MetastoreError::UnsupportedExtension(
load_extension.extension,
));
}
// Create new entry
let mut oid = self.next_oid();
let schema_id = self.get_schema_id(DEFAULT_SCHEMA)?;
let mut oid_gen = || {
let curr_oid = oid;
oid += 1;
curr_oid
};

if self.extension_names.contains_key(&load_extension.extension) {
println!("Extension already loaded");
return Ok(());
}

self.extension_names
.insert(load_extension.extension.clone(), oid_gen());


let entries = BuiltinCatalog::builtin_function_to_entries(
oid_gen,
schema_id,
SENTENCE_TRANSFORMER_EXTENSION.udfs.iter(),
);
// Ensures no duplicate OIDs.
let mut insert_entry = |oid: u32, ent: CatalogEntry| {
if self.entries.0.contains_key(&oid) {
let old = self.entries.0.remove(&oid).unwrap();
return Err(MetastoreError::BuiltinRepeatedOid {
oid,
ent1: old,
ent2: ent,
});
}

self.entries.0.insert(oid, ent);

Ok(())
};

for ent in entries {
let name = ent.meta.name.to_string();
let oid = ent.meta.id;
insert_entry(oid, CatalogEntry::Function(ent))?;
self.schema_objects
.get_mut(&schema_id)
.expect("default schema should exist")
.functions
.insert(name, oid);
}
FUNCTION_REGISTRY
.lock()
.register_extension(&SENTENCE_TRANSFORMER_EXTENSION);
}
};

Ok(())
Expand Down Expand Up @@ -1275,22 +1340,18 @@ impl BuiltinCatalog {
oid += 1;
curr_oid
};

let table_func_ents = Self::builtin_function_to_entries(
&mut oid_gen,
schema_id,
FUNCTION_REGISTRY.table_funcs_iter(),
);
let registry = FUNCTION_REGISTRY.lock();
let table_func_ents =
Self::builtin_function_to_entries(&mut oid_gen, schema_id, registry.table_funcs_iter());
let scalar_func_ents = Self::builtin_function_to_entries(
&mut oid_gen,
schema_id,
FUNCTION_REGISTRY.scalar_funcs_iter(),
);
let scalar_udf_ents = Self::builtin_function_to_entries(
&mut oid_gen,
schema_id,
FUNCTION_REGISTRY.scalar_udfs_iter(),
registry.scalar_funcs_iter(),
);
let scalar_udf_ents =
Self::builtin_function_to_entries(&mut oid_gen, schema_id, registry.scalar_udfs_iter());

drop(registry);

for func_ent in table_func_ents
.into_iter()
Expand Down
3 changes: 3 additions & 0 deletions crates/metastore/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ pub enum MetastoreError {

#[error(transparent)]
Io(#[from] std::io::Error),

#[error("Unsupported Extension: {0}")]
UnsupportedExtension(String),
}

pub type Result<T, E = MetastoreError> = std::result::Result<T, E>;
Expand Down
4 changes: 4 additions & 0 deletions crates/protogen/proto/metastore/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ message Mutation {
CreateCredentials create_credentials = 15;
DropCredentials drop_credentials = 16;
UpdateDeploymentStorage update_deployment_storage = 17;
LoadExtension load_extension = 18;
}
// next: 17
}
Expand Down Expand Up @@ -174,6 +175,9 @@ message DropCredentials {
message UpdateDeploymentStorage {
uint64 new_storage_size = 1;
}
message LoadExtension {
string extension = 1;
}

message MutateRequest {
// Mutate the catalog for this database.
Expand Down
25 changes: 25 additions & 0 deletions crates/protogen/src/metastore/types/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum Mutation {
DropCredentials(DropCredentials),
// Deployment metadata updates
UpdateDeploymentStorage(UpdateDeploymentStorage),
LoadExtension(LoadExtension),
}

impl TryFrom<service::Mutation> for Mutation {
Expand Down Expand Up @@ -71,6 +72,7 @@ impl TryFrom<service::mutation::Mutation> for Mutation {
service::mutation::Mutation::UpdateDeploymentStorage(v) => {
Mutation::UpdateDeploymentStorage(v.try_into()?)
}
service::mutation::Mutation::LoadExtension(v) => Mutation::LoadExtension(v.try_into()?),
})
}
}
Expand Down Expand Up @@ -105,6 +107,7 @@ impl TryFrom<Mutation> for service::mutation::Mutation {
Mutation::UpdateDeploymentStorage(v) => {
service::mutation::Mutation::UpdateDeploymentStorage(v.into())
}
Mutation::LoadExtension(v) => service::mutation::Mutation::LoadExtension(v.into()),
})
}
}
Expand Down Expand Up @@ -692,6 +695,28 @@ impl From<UpdateDeploymentStorage> for service::UpdateDeploymentStorage {
}
}

#[derive(Debug, Clone, Arbitrary, PartialEq, Eq)]
pub struct LoadExtension {
pub extension: String,
}

impl TryFrom<service::LoadExtension> for LoadExtension {
type Error = ProtoConvError;
fn try_from(value: service::LoadExtension) -> Result<Self, Self::Error> {
Ok(Self {
extension: value.extension,
})
}
}

impl From<LoadExtension> for service::LoadExtension {
fn from(value: LoadExtension) -> Self {
Self {
extension: value.extension,
}
}
}

#[cfg(test)]
mod tests {
use proptest::arbitrary::any;
Expand Down
13 changes: 12 additions & 1 deletion crates/protogen/src/sqlexec/physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,15 @@ pub struct DeleteExec {
pub where_expr: Option<LogicalExprNode>,
}

#[derive(Clone, PartialEq, Message)]
pub struct LoadExec {
#[prost(string, tag = "1")]
pub extension: String,
#[prost(uint64, tag = "2")]
pub catalog_version: u64,
}


#[derive(Clone, PartialEq, Message)]
pub struct InsertExec {
#[prost(bytes, tag = "1")]
Expand Down Expand Up @@ -341,7 +350,7 @@ pub struct AnalyzeExec {
pub struct ExecutionPlanExtension {
#[prost(
oneof = "ExecutionPlanExtensionType",
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31"
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32"
)]
pub inner: Option<ExecutionPlanExtensionType>,
}
Expand Down Expand Up @@ -415,4 +424,6 @@ pub enum ExecutionPlanExtensionType {
DataSourceMetricsExecAdapter(DataSourceMetricsExecAdapter),
#[prost(message, tag = "31")]
DescribeTable(DescribeTableExec),
#[prost(message, tag = "32")]
LoadExec(LoadExec),
}
2 changes: 1 addition & 1 deletion crates/rpcsrv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ thiserror = { workspace = true }
tokio = { workspace = true }
tonic = { workspace = true }
tracing = { workspace = true }

sqlbuiltins = { path = "../sqlbuiltins" }
datafusion_ext = { path = "../datafusion_ext" }
logutil = { path = "../logutil" }
protogen = { path = "../protogen" }
Expand Down
5 changes: 4 additions & 1 deletion crates/rpcsrv/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf::PhysicalPlanNode;
use protogen::metastore::types::catalog::CatalogState;
use protogen::rpcsrv::types::service::ResolvedTableReference;
use sqlbuiltins::functions::FUNCTION_REGISTRY;
use sqlexec::context::remote::RemoteSessionContext;
use sqlexec::remote::batch_stream::ExecutionBatchStream;
use uuid::Uuid;
Expand Down Expand Up @@ -57,12 +58,14 @@ impl RemoteSession {
let codec = self.session.extension_codec();
let plan = PhysicalPlanNode::try_decode(physical_plan.as_ref())?;

let registry = FUNCTION_REGISTRY.lock();
let plan = plan.try_into_physical_plan(
self.session.get_datafusion_context(),
&*registry,
self.session.get_datafusion_context().runtime_env().as_ref(),
&codec,
)?;


let stream = self.session.execute_physical(plan.clone())?;
Ok((plan, stream))
}
Expand Down
1 change: 1 addition & 0 deletions crates/sqlbuiltins/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ hf-hub = "0.3.2"
tokenizers = "0.15.1"
serde_json.workspace = true
lance-linalg = { git = "https://github.com/lancedb/lance", rev = "310d79eccf93f3c6a48c162c575918cdba13faec" }
parking_lot = "0.12.1"

Loading

0 comments on commit 4c527ef

Please sign in to comment.