Skip to content

Commit

Permalink
Support float32 and float64 data type
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw committed Feb 19, 2024
1 parent 2eede65 commit bf1a0cb
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 66 deletions.
35 changes: 14 additions & 21 deletions bustubx/src/buffer/buffer_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,15 @@ impl BufferPoolManager {
#[cfg(test)]
mod tests {
use crate::{buffer::BufferPoolManager, storage::DiskManager};
use std::{fs::remove_file, sync::Arc};
use std::sync::Arc;
use tempfile::TempDir;

#[test]
pub fn test_buffer_pool_manager_new_page() {
let db_path = "./test_buffer_pool_manager_new_page.db";
let _ = remove_file(db_path);
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().join("test.db");

let disk_manager = DiskManager::try_new(&db_path).unwrap();
let disk_manager = DiskManager::try_new(&temp_path).unwrap();
let mut buffer_pool_manager = BufferPoolManager::new(3, Arc::new(disk_manager));
let page = buffer_pool_manager.new_page().unwrap().clone();
assert_eq!(page.read().unwrap().page_id, 1);
Expand All @@ -237,16 +238,14 @@ mod tests {
buffer_pool_manager.unpin_page(1, false).unwrap();
let page = buffer_pool_manager.new_page().unwrap();
assert_eq!(page.read().unwrap().page_id, 4);

let _ = remove_file(db_path);
}

#[test]
pub fn test_buffer_pool_manager_unpin_page() {
let db_path = "./test_buffer_pool_manager_unpin_page.db";
let _ = remove_file(db_path);
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().join("test.db");

let disk_manager = DiskManager::try_new(&db_path).unwrap();
let disk_manager = DiskManager::try_new(&temp_path).unwrap();
let mut buffer_pool_manager = BufferPoolManager::new(3, Arc::new(disk_manager));

let page = buffer_pool_manager.new_page().unwrap();
Expand All @@ -258,16 +257,14 @@ mod tests {
buffer_pool_manager.unpin_page(1, true).unwrap();
let page = buffer_pool_manager.new_page().unwrap();
assert_eq!(page.read().unwrap().page_id, 4);

let _ = remove_file(db_path);
}

#[test]
pub fn test_buffer_pool_manager_fetch_page() {
let db_path = "./test_buffer_pool_manager_fetch_page.db";
let _ = remove_file(db_path);
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().join("test.db");

let disk_manager = DiskManager::try_new(&db_path).unwrap();
let disk_manager = DiskManager::try_new(&temp_path).unwrap();
let mut buffer_pool_manager = BufferPoolManager::new(3, Arc::new(disk_manager));

let page1 = buffer_pool_manager.new_page().unwrap();
Expand All @@ -291,16 +288,14 @@ mod tests {
buffer_pool_manager.unpin_page(page2_id, false).unwrap();

assert_eq!(buffer_pool_manager.replacer.size(), 3);

let _ = remove_file(db_path);
}

#[test]
pub fn test_buffer_pool_manager_delete_page() {
let db_path = "./test_buffer_pool_manager_delete_page.db";
let _ = remove_file(db_path);
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().join("test.db");

let disk_manager = DiskManager::try_new(&db_path).unwrap();
let disk_manager = DiskManager::try_new(&temp_path).unwrap();
let mut buffer_pool_manager = BufferPoolManager::new(3, Arc::new(disk_manager));

let page1 = buffer_pool_manager.new_page().unwrap();
Expand All @@ -324,7 +319,5 @@ mod tests {

let page = buffer_pool_manager.fetch_page(page1_id).unwrap();
assert_eq!(page.read().unwrap().page_id, page1_id);

let _ = remove_file(db_path);
}
}
2 changes: 1 addition & 1 deletion bustubx/src/buffer/replacer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct LRUKReplacer {
// 可置换的frame数上限
replacer_size: usize,
k: usize,
pub node_store: HashMap<FrameId, LRUKNode>,
node_store: HashMap<FrameId, LRUKNode>,
// 当前时间戳(从0递增)
current_timestamp: u64,
}
Expand Down
19 changes: 8 additions & 11 deletions bustubx/src/catalog/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ impl Catalog {

#[cfg(test)]
mod tests {
use std::{fs::remove_file, sync::Arc};
use std::sync::Arc;
use tempfile::TempDir;

use crate::common::TableReference;
use crate::{
Expand All @@ -145,10 +146,10 @@ mod tests {

#[test]
pub fn test_catalog_create_table() {
let db_path = "./test_catalog_create_table.db";
let _ = remove_file(db_path);
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().join("test.db");

let disk_manager = DiskManager::try_new(&db_path).unwrap();
let disk_manager = DiskManager::try_new(&temp_path).unwrap();
let buffer_pool_manager = BufferPoolManager::new(1000, Arc::new(disk_manager));
let mut catalog = super::Catalog::new(buffer_pool_manager);

Expand Down Expand Up @@ -183,16 +184,14 @@ mod tests {
let table_info = catalog.table(&table_ref2).unwrap();
assert_eq!(table_info.name, table_ref2.table());
assert_eq!(table_info.schema.column_count(), 3);

let _ = remove_file(db_path);
}

#[test]
pub fn test_catalog_create_index() {
let db_path = "./test_catalog_create_index.db";
let _ = remove_file(db_path);
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().join("test.db");

let disk_manager = DiskManager::try_new(&db_path).unwrap();
let disk_manager = DiskManager::try_new(&temp_path).unwrap();
let buffer_pool_manager = BufferPoolManager::new(1000, Arc::new(disk_manager));
let mut catalog = super::Catalog::new(buffer_pool_manager);

Expand Down Expand Up @@ -262,7 +261,5 @@ mod tests {
assert!(index_info.is_some());
let index_info = index_info.unwrap();
assert_eq!(index_info.name, index_name1);

let _ = remove_file(db_path);
}
}
17 changes: 4 additions & 13 deletions bustubx/src/catalog/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,8 @@ pub enum DataType {
Int32,
Int64,
UInt64,
}

impl DataType {
pub fn type_size(&self) -> usize {
match self {
DataType::Boolean => 1,
DataType::Int8 => 1,
DataType::Int16 => 2,
DataType::Int32 => 4,
DataType::Int64 => 8,
DataType::UInt64 => 8,
}
}
Float32,
Float64,
}

impl TryFrom<&sqlparser::ast::DataType> for DataType {
Expand All @@ -34,6 +23,8 @@ impl TryFrom<&sqlparser::ast::DataType> for DataType {
sqlparser::ast::DataType::Int(_) => Ok(DataType::Int32),
sqlparser::ast::DataType::BigInt(_) => Ok(DataType::Int64),
sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(DataType::UInt64),
sqlparser::ast::DataType::Float(_) => Ok(DataType::Float32),
sqlparser::ast::DataType::Double => Ok(DataType::Float32),
_ => Err(BustubxError::NotSupport(format!(
"Not support datatype {}",
value
Expand Down
4 changes: 0 additions & 4 deletions bustubx/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ impl Schema {
Ok(idx)
}

pub fn fixed_len(&self) -> usize {
self.columns.iter().map(|c| c.data_type.type_size()).sum()
}

pub fn column_count(&self) -> usize {
self.columns.len()
}
Expand Down
46 changes: 45 additions & 1 deletion bustubx/src/common/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub enum ScalarValue {
Int32(Option<i32>),
Int64(Option<i64>),
UInt64(Option<u64>),
Float32(Option<f32>),
Float64(Option<f64>),
}

impl ScalarValue {
Expand All @@ -21,6 +23,8 @@ impl ScalarValue {
DataType::Int32 => Self::Int32(None),
DataType::Int64 => Self::Int64(None),
DataType::UInt64 => Self::UInt64(None),
DataType::Float32 => Self::Float32(None),
DataType::Float64 => Self::Float64(None),
}
}

Expand All @@ -32,6 +36,8 @@ impl ScalarValue {
ScalarValue::Int32(_) => DataType::Int32,
ScalarValue::Int64(_) => DataType::Int64,
ScalarValue::UInt64(_) => DataType::UInt64,
ScalarValue::Float32(_) => DataType::Float32,
ScalarValue::Float64(_) => DataType::Float64,
}
}

Expand All @@ -43,11 +49,14 @@ impl ScalarValue {
ScalarValue::Int32(v) => v.is_none(),
ScalarValue::Int64(v) => v.is_none(),
ScalarValue::UInt64(v) => v.is_none(),
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
}
}

/// Try to cast this value to a ScalarValue of type `data_type`
pub fn cast_to(&self, data_type: &DataType) -> BustubxResult<Self> {
// TODO use macro
match data_type {
DataType::Boolean => match self {
ScalarValue::Boolean(v) => Ok(ScalarValue::Boolean(v.clone())),
Expand All @@ -64,6 +73,15 @@ impl ScalarValue {
self, data_type
))),
},
DataType::Float32 => match self {
ScalarValue::Int8(v) => Ok(ScalarValue::Float32(v.map(|v| v as f32))),
ScalarValue::Int64(v) => Ok(ScalarValue::Float32(v.map(|v| v as f32))),
ScalarValue::Float64(v) => Ok(ScalarValue::Float32(v.map(|v| v as f32))),
_ => Err(BustubxError::NotSupport(format!(
"Failed to cast {} to {} type",
self, data_type
))),
},
_ => Err(BustubxError::NotSupport(format!(
"Not support cast to {} type",
data_type
Expand Down Expand Up @@ -96,6 +114,16 @@ impl PartialEq for ScalarValue {
(Int64(_), _) => false,
(UInt64(v1), UInt64(v2)) => v1.eq(v2),
(UInt64(_), _) => false,
(Float32(v1), Float32(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float32(_), _) => false,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float64(_), _) => false,
}
}
}
Expand All @@ -118,6 +146,16 @@ impl PartialOrd for ScalarValue {
(Int64(_), _) => None,
(UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2),
(UInt64(_), _) => None,
(Float32(v1), Float32(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float32(_), _) => None,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float64(_), _) => None,
}
}
}
Expand All @@ -137,6 +175,10 @@ impl std::fmt::Display for ScalarValue {
ScalarValue::Int64(Some(v)) => write!(f, "{v}"),
ScalarValue::UInt64(None) => write!(f, "NULL"),
ScalarValue::UInt64(Some(v)) => write!(f, "{v}"),
ScalarValue::Float32(None) => write!(f, "NULL"),
ScalarValue::Float32(Some(v)) => write!(f, "{v}"),
ScalarValue::Float64(None) => write!(f, "NULL"),
ScalarValue::Float64(Some(v)) => write!(f, "{v}"),
}
}
}
Expand All @@ -157,9 +199,11 @@ macro_rules! impl_from_for_scalar {
};
}

impl_from_for_scalar!(bool, Boolean);
impl_from_for_scalar!(i8, Int8);
impl_from_for_scalar!(i16, Int16);
impl_from_for_scalar!(i32, Int32);
impl_from_for_scalar!(i64, Int64);
impl_from_for_scalar!(u64, UInt64);
impl_from_for_scalar!(bool, Boolean);
impl_from_for_scalar!(f32, Float32);
impl_from_for_scalar!(f64, Float64);
13 changes: 9 additions & 4 deletions bustubx/src/planner/logical_planner/bind_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ impl LogicalPlanner<'_> {
pub fn bind_value(&self, value: &sqlparser::ast::Value) -> BustubxResult<Expr> {
match value {
sqlparser::ast::Value::Number(s, _) => {
let num: i64 = s.parse::<i64>().map_err(|e| {
BustubxError::Internal("Failed to parse literal as i64".to_string())
})?;
Ok(Expr::Literal(Literal { value: num.into() }))
if let Ok(num) = s.parse::<i64>() {
return Ok(Expr::Literal(Literal { value: num.into() }));
}
if let Ok(num) = s.parse::<f64>() {
return Ok(Expr::Literal(Literal { value: num.into() }));
}
Err(BustubxError::Internal(
"Failed to parse sql number value".to_string(),
))
}
sqlparser::ast::Value::Boolean(b) => Ok(Expr::Literal(Literal { value: (*b).into() })),
sqlparser::ast::Value::Null => Ok(Expr::Literal(Literal {
Expand Down
48 changes: 48 additions & 0 deletions bustubx/src/storage/codec/common.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::storage::codec::DecodedData;
use crate::{BustubxError, BustubxResult};
use core::f32;
use std::f64;

pub struct CommonCodec;

Expand Down Expand Up @@ -152,6 +154,40 @@ impl CommonCodec {
];
Ok((i64::from_be_bytes(data), 8))
}

pub fn encode_f32(data: f32) -> Vec<u8> {
data.to_be_bytes().to_vec()
}

pub fn decode_f32(bytes: &[u8]) -> BustubxResult<DecodedData<f32>> {
if bytes.len() < 4 {
return Err(BustubxError::Storage(format!(
"bytes length {} is less than {}",
bytes.len(),
4
)));
}
let data = [bytes[0], bytes[1], bytes[2], bytes[3]];
Ok((f32::from_be_bytes(data), 4))
}

pub fn encode_f64(data: f64) -> Vec<u8> {
data.to_be_bytes().to_vec()
}

pub fn decode_f64(bytes: &[u8]) -> BustubxResult<DecodedData<f64>> {
if bytes.len() < 8 {
return Err(BustubxError::Storage(format!(
"bytes length {} is less than {}",
bytes.len(),
8
)));
}
let data = [
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
];
Ok((f64::from_be_bytes(data), 8))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -221,5 +257,17 @@ mod tests {
.unwrap()
.0
);
assert_eq!(
5.0f32,
CommonCodec::decode_f32(&CommonCodec::encode_f32(5.0f32))
.unwrap()
.0
);
assert_eq!(
5.0f64,
CommonCodec::decode_f64(&CommonCodec::encode_f64(5.0f64))
.unwrap()
.0
);
}
}
Loading

0 comments on commit bf1a0cb

Please sign in to comment.