diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5575591..46cbbc8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: OCEANBASE_USER: 'root@test' OCEANBASE_PASS: '' run: | - docker exec ob433 obclient -h $OCEANBASE_HOST -P $OCEANBASE_PORT -u $OCEANBASE_USER -p$OCEANBASE_PASS -e "ALTER SYSTEM ob_vector_memory_limit_percentage = 30" + docker exec ob433 obclient -h $OCEANBASE_HOST -P $OCEANBASE_PORT -u $OCEANBASE_USER -p$OCEANBASE_PASS -e "ALTER SYSTEM ob_vector_memory_limit_percentage = 30; create user 'jtuser'@'%'; GRANT SELECT, INSERT, UPDATE, DELETE ON test.* TO 'jtuser'@'%'; FLUSH PRIVILEGES;" - name: Run tests run: | diff --git a/pyobvector/__init__.py b/pyobvector/__init__.py index c4b3664..4477273 100644 --- a/pyobvector/__init__.py +++ b/pyobvector/__init__.py @@ -51,10 +51,12 @@ st_dwithin, st_astext, ) +from .json_table import OceanBase __all__ = [ "ObVecClient", "MilvusLikeClient", + "ObVecJsonTableClient", "VecIndexType", "IndexParam", "IndexParams", @@ -85,4 +87,5 @@ "st_distance", "st_dwithin", "st_astext", + "OceanBase", ] diff --git a/pyobvector/client/__init__.py b/pyobvector/client/__init__.py index 21cf73c..64d5dab 100644 --- a/pyobvector/client/__init__.py +++ b/pyobvector/client/__init__.py @@ -30,6 +30,7 @@ """ from .ob_vec_client import ObVecClient from .milvus_like_client import MilvusLikeClient +from .ob_vec_json_table_client import ObVecJsonTableClient from .index_param import VecIndexType, IndexParam, IndexParams from .schema_type import DataType from .collection_schema import FieldSchema, CollectionSchema @@ -38,6 +39,7 @@ __all__ = [ "ObVecClient", "MilvusLikeClient", + "ObVecJsonTableClient", "VecIndexType", "IndexParam", "IndexParams", diff --git a/pyobvector/client/ob_vec_json_table_client.py b/pyobvector/client/ob_vec_json_table_client.py new file mode 100644 index 0000000..17defb2 --- /dev/null +++ b/pyobvector/client/ob_vec_json_table_client.py @@ -0,0 +1,809 @@ +import json +import logging +import re +from typing import Dict, List, Optional, Any + +from sqlalchemy import Column, Integer, String, JSON, Engine, select, text, func, CursorResult +from sqlalchemy.dialects.mysql import TINYINT +from sqlalchemy.orm import declarative_base, sessionmaker, Session +from sqlglot import parse_one, exp, Expression + +from .ob_vec_client import ObVecClient +from ..json_table import ( + OceanBase, + ChangeColumn, + JsonTableBool, + JsonTableTimestamp, + JsonTableVarcharFactory, + JsonTableDecimalFactory, + JsonTableInt, + val2json, + json_value +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +JSON_TABLE_META_TABLE_NAME = "_meta_json_t" +JSON_TABLE_DATA_TABLE_NAME = "_data_json_t" + +class ObVecJsonTableClient(ObVecClient): + """OceanBase Vector Store Client with JSON Table.""" + + Base = declarative_base() + + class JsonTableMetaTBL(Base): + __tablename__ = JSON_TABLE_META_TABLE_NAME + + user_id = Column(Integer, primary_key=True) + jtable_name = Column(String(512), primary_key=True) + jcol_id = Column(Integer, primary_key=True) + jcol_name = Column(String(512), primary_key=True) + jcol_type = Column(String(128), nullable=False) + jcol_nullable = Column(TINYINT, nullable=False) + jcol_has_default = Column(TINYINT, nullable=False) + jcol_default = Column(JSON) + + class JsonTableDataTBL(Base): + __tablename__ = JSON_TABLE_DATA_TABLE_NAME + + user_id = Column(Integer, primary_key=True) + jtable_name = Column(String(512), primary_key=True) + jdata_id = Column(Integer, primary_key=True, autoincrement=True, nullable=False) + jdata = Column(JSON) + + class JsonTableMetadata: + def __init__(self, user_id: int): + self.user_id = user_id + self.meta_cache: Dict[str, List] = {} + + @classmethod + def _parse_col_type(cls, col_type: str): + if col_type.startswith('TINYINT'): + return JsonTableBool + elif col_type.startswith('TIMESTAMP'): + return JsonTableTimestamp + elif col_type.startswith('INT'): + return JsonTableInt + elif col_type.startswith('VARCHAR'): + if col_type == 'VARCHAR': + factory = JsonTableVarcharFactory(255) + else: + varchar_pattern = r'VARCHAR\((\d+)\)' + varchar_matches = re.findall(varchar_pattern, col_type) + factory = JsonTableVarcharFactory(int(varchar_matches[0])) + model = factory.get_json_table_varchar_type() + return model + elif col_type.startswith('DECIMAL'): + if col_type == 'DECIMAL': + factory = JsonTableDecimalFactory(10, 0) + else: + decimal_pattern = r'DECIMAL\((\d+),\s*(\d+)\)' + decimal_matches = re.findall(decimal_pattern, col_type) + x, y = decimal_matches[0] + factory = JsonTableDecimalFactory(int(x), int(y)) + model = factory.get_json_table_decimal_type() + return model + raise ValueError(f"Invalid column type string: {col_type}") + + def reflect(self, engine: Engine): + self.meta_cache = {} + with engine.connect() as conn: + with conn.begin(): + stmt = select(ObVecJsonTableClient.JsonTableMetaTBL).filter( + ObVecJsonTableClient.JsonTableMetaTBL.user_id == self.user_id + ) + res = conn.execute(stmt) + for r in res: + if r[1] not in self.meta_cache: + self.meta_cache[r[1]] = [] + self.meta_cache[r[1]].append({ + 'jcol_id': r[2], + 'jcol_name': r[3], + 'jcol_type': r[4], + 'jcol_nullable': bool(r[5]), + 'jcol_has_default': bool(r[6]), + 'jcol_default': ( + r[7]['default'] + if isinstance(r[7], dict) else + json.loads(r[7])['default'] + ), + 'jcol_model': ObVecJsonTableClient.JsonTableMetadata._parse_col_type(r[4]) + }) + for k, _ in self.meta_cache.items(): + self.meta_cache[k].sort(key=lambda x: x['jcol_id']) + + for k, v in self.meta_cache.items(): + logger.debug(f"LOAD TABLE --- {k}: {v}") + + + def __init__( + self, + user_id: int, + uri: str = "127.0.0.1:2881", + user: str = "root@test", + password: str = "", + db_name: str = "test", + **kwargs, + ): + super().__init__(uri, user, password, db_name, **kwargs) + self.Base.metadata.create_all(self.engine) + self.session = sessionmaker(bind=self.engine) + self.user_id = user_id + self.jmetadata = ObVecJsonTableClient.JsonTableMetadata(self.user_id) + self.jmetadata.reflect(self.engine) + + def _reset(self): + # Only for test + self.perform_raw_text_sql(f"TRUNCATE TABLE {JSON_TABLE_DATA_TABLE_NAME}") + self.perform_raw_text_sql(f"TRUNCATE TABLE {JSON_TABLE_META_TABLE_NAME}") + self.jmetadata = ObVecJsonTableClient.JsonTableMetadata(self.user_id) + + def refresh_metadata(self) -> None: + self.jmetadata.reflect(self.engine) + + def perform_json_table_sql(self, sql: str) -> Optional[CursorResult]: + """Perform common SQL that operates on JSON Table.""" + ast = parse_one(sql, dialect="oceanbase") + if isinstance(ast, exp.Create): + if ast.kind and ast.kind == 'TABLE': + self._handle_create_json_table(ast) + else: + raise ValueError(f"Create {ast.kind} is not supported") + return None + elif isinstance(ast, exp.Alter): + self._handle_alter_json_table(ast) + return None + elif isinstance(ast, exp.Insert): + self._handle_jtable_dml_insert(ast) + return None + elif isinstance(ast, exp.Update): + self._handle_jtable_dml_update(ast) + return None + elif isinstance(ast, exp.Delete): + self._handle_jtable_dml_delete(ast) + return None + elif isinstance(ast, exp.Select): + return self._handle_jtable_dml_select(ast) + else: + raise ValueError(f"{type(ast)} not supported") + + def _parse_datatype_to_str(self, datatype): + if datatype == exp.DataType.Type.INT: + return "INT" + if datatype == exp.DataType.Type.TINYINT: + return "TINYINT" + if datatype == exp.DataType.Type.TIMESTAMP: + return "TIMESTAMP" + if datatype == exp.DataType.Type.VARCHAR: + return "VARCHAR" + if datatype == exp.DataType.Type.DECIMAL: + return "DECIMAL" + raise ValueError(f"{datatype} not supported") + + def _calc_default_value(self, default_val): + if default_val is None: + return None + with self.engine.connect() as conn: + res = conn.execute(text(f"SELECT {default_val}")) + for r in res: + logger.debug(f"============== Calculate default value: {r[0]}") + return r[0] + + def _handle_create_json_table(self, ast: Expression): + logger.debug("HANDLE CREATE JSON TABLE") + + if not isinstance(ast.this, exp.Schema): + raise ValueError("Invalid create table statement") + schema = ast.this + if not isinstance(schema.this, exp.Table): + raise ValueError("Invalid create table statement") + jtable = schema.this + if not isinstance(jtable.this, exp.Identifier): + raise ValueError("Invalid create table statement") + jtable_name = jtable.this.this + + if jtable_name == JSON_TABLE_META_TABLE_NAME or jtable_name == JSON_TABLE_DATA_TABLE_NAME: + raise ValueError(f"Invalid table name: {jtable_name}") + if jtable_name in self.jmetadata.meta_cache: + raise ValueError("Table name duplicated") + + session = self.session() + new_meta_cache_items = [] + col_id = 16 + for col_def in ast.find_all(exp.ColumnDef): + col_name = col_def.this.this + col_type_str = self._parse_datatype_to_str(col_def.kind.this) + col_type_params = col_def.kind.expressions + col_type_params_list = [] + col_nullable = True + col_has_default = False + col_default_val = None + for param in col_type_params: + if param.is_string: + col_type_params_list.append(f"'{param.this}'") + else: + col_type_params_list.append(f"{param.this}") + if len(col_type_params_list) > 0: + col_type_str += '(' + ','.join(col_type_params_list) + ')' + col_type_model = ObVecJsonTableClient.JsonTableMetadata._parse_col_type(col_type_str) + + for cons in col_def.constraints: + if isinstance(cons.kind, exp.DefaultColumnConstraint): + col_has_default = True + logger.debug(f"############ create jtable ########### {str(cons.kind.this)}") + col_default_val = str(cons.kind.this) + if col_default_val.upper() == "NULL": + col_default_val = None + elif isinstance(cons.kind, exp.NotNullColumnConstraint): + col_nullable = False + else: + raise ValueError(f"{cons.kind} constriaint is not supported.") + + if col_has_default and (col_default_val is not None): + # check default value is valid + col_type_model(val=self._calc_default_value(col_default_val)) + + if (not col_nullable) and col_has_default and (col_default_val is None): + raise ValueError(f"Invalid default value for '{col_name}'") + + logger.debug( + f"col_name={col_name}, col_id={col_id}, " + f"col_type_str={col_type_str}, col_nullable={col_nullable}, " + f"col_has_default={col_has_default}, col_default_val={col_default_val}" + ) + new_meta_cache_items.append({ + 'jcol_id': col_id, + 'jcol_name': col_name, + 'jcol_type': col_type_str, + 'jcol_nullable': col_nullable, + 'jcol_has_default': col_has_default, + 'jcol_default': col_default_val, + 'jcol_model': col_type_model, + }) + session.add(ObVecJsonTableClient.JsonTableMetaTBL( + user_id = self.user_id, + jtable_name = jtable_name, + jcol_id = col_id, + jcol_name = col_name, + jcol_type = col_type_str, + jcol_nullable = col_nullable, + jcol_has_default = col_has_default, + jcol_default = { + 'default': col_default_val, + } + )) + + col_id += 1 + + try: + session.commit() + self.jmetadata.meta_cache[jtable_name] = new_meta_cache_items + logger.debug(f"ADD METADATA CACHE ---- {jtable_name}: {new_meta_cache_items}") + except Exception as e: + session.rollback() + logger.error(f"Error occurred: {e}") + finally: + session.close() + + def _check_table_exists(self, jtable_name: str) -> bool: + return jtable_name in self.jmetadata.meta_cache + + def _check_col_exists(self, jtable_name: str, col_name: str) -> Optional[Dict]: + if not self._check_table_exists(jtable_name): + return None + for col_meta in self.jmetadata.meta_cache[jtable_name]: + if col_meta['jcol_name'] == col_name: + return col_meta + return None + + def _parse_col_datatype(self, expr: Expression) -> str: + col_type_str = self._parse_datatype_to_str(expr.this) + col_type_params_list = [] + for param in expr.expressions: + if param.is_string: + col_type_params_list.append(f"'{param.this}'") + else: + col_type_params_list.append(f"{param.this}") + if len(col_type_params_list) > 0: + col_type_str += '(' + ','.join(col_type_params_list) + ')' + return col_type_str + + def _parse_col_constraints(self, expr: Expression) -> Dict: + col_has_default = False + col_nullable = True + for cons in expr: + if isinstance(cons.kind, exp.DefaultColumnConstraint): + col_has_default = True + logger.debug(f"############ column constraints ########### {str(cons.kind.this)}") + col_default_val = str(cons.kind.this) + if col_default_val.upper() == "NULL": + col_default_val = None + elif isinstance(cons.kind, exp.NotNullColumnConstraint): + col_nullable = False + else: + raise ValueError(f"{cons.kind} constriaint is not supported.") + return { + 'jcol_nullable': col_nullable, + 'jcol_has_default': col_has_default, + 'jcol_default': col_default_val, + } + + def _handle_alter_jtable_change_column( + self, + session: Session, + jtable_name: str, + change_col: Expression, + ): + logger.debug("HANDLE ALTER CHANGE COLUMN") + origin_col_name = change_col.origin_col_name.this + if not self._check_col_exists(jtable_name, origin_col_name): + raise ValueError(f"{origin_col_name} not exists in {jtable_name}") + + new_col_name = change_col.this + if self._check_col_exists(jtable_name, new_col_name): + raise ValueError(f"Column {new_col_name} exists!") + + col_type_str = self._parse_col_datatype(change_col.dtype) + + session.query(ObVecJsonTableClient.JsonTableMetaTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + jcol_name=origin_col_name + ).update({ + ObVecJsonTableClient.JsonTableMetaTBL.jcol_name: new_col_name, + ObVecJsonTableClient.JsonTableMetaTBL.jcol_type: col_type_str, + ObVecJsonTableClient.JsonTableMetaTBL.jcol_nullable: True, + ObVecJsonTableClient.JsonTableMetaTBL.jcol_has_default: True, + ObVecJsonTableClient.JsonTableMetaTBL.jcol_default: { + 'default': None + }, + }) + + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_insert( + func.json_remove( + ObVecJsonTableClient.JsonTableDataTBL.jdata, f'$.{origin_col_name}' + ), + f'$.{new_col_name}', + func.json_value( + ObVecJsonTableClient.JsonTableDataTBL.jdata, f'$.{origin_col_name}', + ), + ) + }) + + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_replace( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{new_col_name}', + json_value( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{new_col_name}', + col_type_str, + ), + ) + }) + + def _handle_alter_jtable_drop_column( + self, + session: Session, + jtable_name: str, + drop_col: Expression, + ): + logger.debug("HANDLE ALTER DROP COLUMN") + if not isinstance(drop_col.this, exp.Column): + raise ValueError(f"Drop {drop_col.kind} is not supported") + col_name = drop_col.this.this.this + if not self._check_col_exists(jtable_name, col_name): + raise ValueError(f"{col_name} not exists in {jtable_name}") + + session.query(ObVecJsonTableClient.JsonTableMetaTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + jcol_name=col_name + ).delete() + + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_remove( + ObVecJsonTableClient.JsonTableDataTBL.jdata, f'$.{col_name}' + ) + }) + + def _handle_alter_jtable_add_column( + self, + session: Session, + jtable_name: str, + add_col: Expression, + ): + logger.debug("HANDLE ALTER ADD COLUMN") + new_col_name = add_col.this.this + if self._check_col_exists(jtable_name, new_col_name): + raise ValueError(f"{new_col_name} exists!") + + col_type_str = self._parse_col_datatype(add_col.kind) + model = ObVecJsonTableClient.JsonTableMetadata._parse_col_type(col_type_str) + constraints = self._parse_col_constraints(add_col.constraints) + if (not constraints['jcol_nullable']) and constraints['jcol_has_default'] and (constraints['jcol_default'] is None): + raise ValueError(f"Invalid default value for '{new_col_name}'") + if constraints['jcol_has_default'] and (constraints['jcol_default'] is not None): + model(val=self._calc_default_value(constraints['jcol_default'])) + cur_col_id = max([meta['jcol_id'] for meta in self.jmetadata.meta_cache[jtable_name]]) + 1 + + session.add(ObVecJsonTableClient.JsonTableMetaTBL( + user_id = self.user_id, + jtable_name = jtable_name, + jcol_id = cur_col_id, + jcol_name = new_col_name, + jcol_type = col_type_str, + jcol_nullable = constraints['jcol_nullable'], + jcol_has_default = constraints['jcol_has_default'], + jcol_default = { + 'default': constraints['jcol_default'], + } + )) + + if constraints['jcol_default'] is None: + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_insert( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{new_col_name}', + None, + ) + }) + else: + model = ObVecJsonTableClient.JsonTableMetadata._parse_col_type(col_type_str) + datum = model(val=self._calc_default_value(constraints['jcol_default'])) + json_val = val2json(datum.val) + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_insert( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{new_col_name}', + json_val, + ) + }) + + def _handle_alter_jtable_modify_column( + self, + session: Session, + jtable_name: str, + modify_col: Expression, + ): + logger.debug("HANDLE ALTER MODIFY COLUMN") + col_def = modify_col.this + col_name = col_def.this.this + if not self._check_col_exists(jtable_name, col_name): + raise ValueError(f"{col_name} not exists in {jtable_name}") + + col_type_str = self._parse_col_datatype(col_def.kind) + model = ObVecJsonTableClient.JsonTableMetadata._parse_col_type(col_type_str) + constraints = self._parse_col_constraints(col_def.constraints) + if (not constraints['jcol_nullable']) and constraints['jcol_has_default'] and (constraints['jcol_default'] is None): + raise ValueError(f"Invalid default value for '{col_name}'") + if constraints['jcol_has_default'] and (constraints['jcol_default'] is not None): + model(val=self._calc_default_value(constraints['jcol_default'])) + + session.query(ObVecJsonTableClient.JsonTableMetaTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + jcol_name=col_name + ).update({ + ObVecJsonTableClient.JsonTableMetaTBL.jcol_name: col_name, + ObVecJsonTableClient.JsonTableMetaTBL.jcol_type: col_type_str, + ObVecJsonTableClient.JsonTableMetaTBL.jcol_nullable: constraints['jcol_nullable'], + ObVecJsonTableClient.JsonTableMetaTBL.jcol_has_default: constraints['jcol_has_default'], + ObVecJsonTableClient.JsonTableMetaTBL.jcol_default: { + 'default': constraints['jcol_default'] + }, + }) + + if constraints['jcol_default'] is None: + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_replace( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{col_name}', + json_value( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{col_name}', + col_type_str, + ), + ) + }) + else: + model = ObVecJsonTableClient.JsonTableMetadata._parse_col_type(col_type_str) + datum = model(val=self._calc_default_value(constraints['jcol_default'])) + json_val = val2json(datum.val) + session.query(ObVecJsonTableClient.JsonTableDataTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableDataTBL.jdata: func.json_replace( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{col_name}', + func.ifnull( + json_value( + ObVecJsonTableClient.JsonTableDataTBL.jdata, + f'$.{col_name}', + col_type_str, + ), + json_val, + ), + ) + }) + + def _handle_alter_jtable_rename_table( + self, + session: Session, + jtable_name: str, + rename: Expression, + ): + if not self._check_table_exists(jtable_name): + raise ValueError(f"Table {jtable_name} does not exists") + + new_table_name = rename.this.this.this + if self.check_table_exists(new_table_name): + raise ValueError(f"Table {new_table_name} exists!") + + session.query(ObVecJsonTableClient.JsonTableMetaTBL).filter_by( + user_id=self.user_id, + jtable_name=jtable_name, + ).update({ + ObVecJsonTableClient.JsonTableMetaTBL.jtable_name: new_table_name, + }) + + def _handle_alter_json_table(self, ast: Expression): + if not isinstance(ast.this, exp.Table): + raise ValueError("Invalid alter table statement") + if not isinstance(ast.this.this, exp.Identifier): + raise ValueError("Invalid create table statement") + jtable_name = ast.this.this.this + if not self._check_table_exists(jtable_name): + raise ValueError(f"Table {jtable_name} does not exists") + + session = self.session() + for action in ast.actions: + if isinstance(action, ChangeColumn): + self._handle_alter_jtable_change_column( + session, + jtable_name, + action, + ) + if isinstance(action, exp.Drop): + self._handle_alter_jtable_drop_column( + session, + jtable_name, + action, + ) + if isinstance(action, exp.AlterColumn): + self._handle_alter_jtable_modify_column( + session, + jtable_name, + action, + ) + if isinstance(action, exp.ColumnDef): + self._handle_alter_jtable_add_column( + session, + jtable_name, + action, + ) + if isinstance(action, exp.AlterRename): + self._handle_alter_jtable_rename_table( + session, + jtable_name, + action, + ) + + try: + session.commit() + self.jmetadata.reflect(self.engine) + except Exception as e: + session.rollback() + logger.error(f"Error occurred: {e}") + finally: + session.close() + + def _handle_jtable_dml_insert(self, ast: Expression): + if isinstance(ast.this, exp.Schema): + table_name = ast.this.this.this.this + else: + table_name = ast.this.this.this + if not self._check_table_exists(table_name): + raise ValueError(f"Table {table_name} does not exists") + + table_col_names = [meta['jcol_name'] for meta in self.jmetadata.meta_cache[table_name]] + cols = { + meta['jcol_name']: meta + for meta in self.jmetadata.meta_cache[table_name] + } + if isinstance(ast.this, exp.Schema): + insert_col_names = [expr.this for expr in ast.this.expressions] + for col_name in insert_col_names: + if col_name not in table_col_names: + raise ValueError(f"Unknown column {col_name} in field list") + for meta in self.jmetadata.meta_cache[table_name]: + if ((meta['jcol_name'] not in insert_col_names) and + (not meta['jcol_nullable']) and (not meta['jcol_has_default'])): + raise ValueError(f"Field {meta['jcol_name']} does not have a default value") + elif isinstance(ast.this, exp.Table): + insert_col_names = table_col_names + else: + raise ValueError(f"Invalid ast type {ast.this}") + + session = self.session() + for tuple in ast.expression.expressions: + expr_list = tuple.expressions + if len(expr_list) != len(insert_col_names): + raise ValueError(f"Values Tuple length does not match with the length of insert columns") + kv = {} + for col_name, expr in zip(insert_col_names, expr_list): + model = cols[col_name]['jcol_model'] + datum = model(val=self._calc_default_value(str(expr))) + kv[col_name] = val2json(datum.val) + for col_name in table_col_names: + if col_name not in insert_col_names: + model = cols[col_name]['jcol_model'] + datum = model(val=self._calc_default_value(cols[col_name]['jcol_default'])) + kv[col_name] = val2json(datum.val) + + logger.debug(f"================= [INSERT] =============== {kv}") + + session.add(ObVecJsonTableClient.JsonTableDataTBL( + user_id = self.user_id, + jtable_name = table_name, + jdata = kv, + )) + + try: + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Error occurred: {e}") + finally: + session.close() + + def _handle_jtable_dml_update(self, ast: Expression): + table_name = ast.this.this.this + if not self._check_table_exists(table_name): + raise ValueError(f"Table {table_name} does not exists") + + path_settings = [] + for expr in ast.expressions: + col_name = expr.this.this.this + if not self._check_col_exists(table_name, col_name): + raise ValueError(f"Column {col_name} does not exists") + col_expr = expr.expression + path_settings.append(f"'$.{col_name}', {str(col_expr)}") + + where_clause = None + if 'where' in ast.args.keys(): + for column in ast.args['where'].find_all(exp.Column): + where_col_name = column.this.this + if not self._check_col_exists(table_name, where_col_name): + raise ValueError(f"Column {where_col_name} does not exists") + column.parent.args['this'] = parse_one( + f"JSON_VALUE({JSON_TABLE_DATA_TABLE_NAME}.jdata, '$.{where_col_name}')" + ) + where_clause = f"{JSON_TABLE_DATA_TABLE_NAME}.user_id = {self.user_id} AND {JSON_TABLE_DATA_TABLE_NAME}.jtable_name = '{table_name}' AND ({str(ast.args['where'].this)})" + + if where_clause: + update_sql = f"UPDATE {JSON_TABLE_DATA_TABLE_NAME} SET jdata = JSON_REPLACE({JSON_TABLE_DATA_TABLE_NAME}.jdata, {', '.join(path_settings)}) WHERE {where_clause}" + else: + update_sql = f"UPDATE {JSON_TABLE_DATA_TABLE_NAME} SET jdata = JSON_REPLACE({JSON_TABLE_DATA_TABLE_NAME}.jdata, {', '.join(path_settings)})" + + logger.debug(f"===================== do update: {update_sql}") + self.perform_raw_text_sql(update_sql) + + def _handle_jtable_dml_delete(self, ast: Expression): + table_name = ast.this.this.this + if not self._check_table_exists(table_name): + raise ValueError(f"Table {table_name} does not exists") + + where_clause = None + if 'where' in ast.args.keys(): + for column in ast.args['where'].find_all(exp.Column): + where_col_name = column.this.this + if not self._check_col_exists(table_name, where_col_name): + raise ValueError(f"Column {where_col_name} does not exists") + column.parent.args['this'] = parse_one( + f"JSON_VALUE({JSON_TABLE_DATA_TABLE_NAME}.jdata, '$.{where_col_name}')" + ) + where_clause = f"{JSON_TABLE_DATA_TABLE_NAME}.user_id = {self.user_id} AND {JSON_TABLE_DATA_TABLE_NAME}.jtable_name = '{table_name}' AND ({str(ast.args['where'].this)})" + + if where_clause: + delete_sql = f"DELETE FROM {JSON_TABLE_DATA_TABLE_NAME} WHERE {where_clause}" + else: + delete_sql = f"DELETE FROM {JSON_TABLE_DATA_TABLE_NAME}" + + logger.debug(f"===================== do delete: {delete_sql}") + self.perform_raw_text_sql(delete_sql) + + def _get_full_datatype(self, jdata_type: str): + if jdata_type.upper() == "VARCHAR": + return "VARCHAR(255)" + if jdata_type.upper() == "DECIMAL": + return "DECIMAL(10, 0)" + return jdata_type + + def _handle_jtable_dml_select(self, ast: Expression): + table_name = ast.args['from'].this.this.this + if not self._check_table_exists(table_name): + raise ValueError(f"Table {table_name} does not exists") + + ast.args['from'].args['this'].args['this'].args['this'] = JSON_TABLE_DATA_TABLE_NAME + + col_meta = self.jmetadata.meta_cache[table_name] + json_table_meta_str = [] + all_jcol_names = [] + for meta in col_meta: + json_table_meta_str.append( + f"{meta['jcol_name']} {self._get_full_datatype(meta['jcol_type'])} " + f"PATH '$.{meta['jcol_name']}'" + ) + all_jcol_names.append(meta['jcol_name']) + + need_replace_select_exprs = False + new_select_exprs = [] + for select_expr in ast.args['expressions']: + if isinstance(select_expr, exp.Star): + need_replace_select_exprs = True + for jcol_name in all_jcol_names: + col_expr = exp.Column() + identifier = exp.Identifier() + identifier.args['this'] = jcol_name + identifier.args['quoted'] = False + col_expr.args['this'] = identifier + new_select_exprs.append(col_expr) + else: + new_select_exprs.append(select_expr) + if need_replace_select_exprs: + ast.args['expressions'] = new_select_exprs + + tmp_table_name = "__tmp" + json_table_str = f"json_table({JSON_TABLE_DATA_TABLE_NAME}.jdata, '$' COLUMNS ({', '.join(json_table_meta_str)})) {tmp_table_name}" + + for col in ast.find_all(exp.Column): + if 'table' in col.args.keys(): + col.args['table'].args['this'] = tmp_table_name + else: + identifier = exp.Identifier() + identifier.args['this'] = tmp_table_name + identifier.args['quoted'] = False + col.args['table'] = identifier + + join_clause = parse_one(f"from t1, {json_table_str}") + join_node = join_clause.args['joins'][0] + if 'joins' in ast.args.keys(): + ast.args['joins'].append(join_node) + else: + ast.args['joins'] = [join_node] + + extra_filter_str = f"{JSON_TABLE_DATA_TABLE_NAME}.user_id = {self.user_id} AND {JSON_TABLE_DATA_TABLE_NAME}.jtable_name = '{table_name}'" + if 'where' in ast.args.keys(): + filter_str = str(ast.args['where'].args['this']) + new_filter_str = f"{extra_filter_str} AND ({filter_str})" + ast.args['where'].args['this'] = parse_one(new_filter_str) + else: + where_clause = exp.Where() + where_clause.args['this'] = parse_one(extra_filter_str) + ast.args['where'] = where_clause + + select_sql = str(ast) + logger.debug(f"===================== do select: {select_sql}") + return self.perform_raw_text_sql(select_sql) diff --git a/pyobvector/json_table/__init__.py b/pyobvector/json_table/__init__.py new file mode 100644 index 0000000..5579d62 --- /dev/null +++ b/pyobvector/json_table/__init__.py @@ -0,0 +1,25 @@ +from .oceanbase_dialect import OceanBase, ChangeColumn +from .virtual_data_type import ( + JType, + JsonTableDataType, + JsonTableBool, + JsonTableTimestamp, + JsonTableVarcharFactory, + JsonTableDecimalFactory, + JsonTableInt, + val2json, +) +from .json_value_returning_func import json_value + +__all__ = [ + "OceanBase", "ChangeColumn", + "JType", + "JsonTableDataType", + "JsonTableBool", + "JsonTableTimestamp", + "JsonTableVarcharFactory", + "JsonTableDecimalFactory", + "JsonTableInt", + "val2json", + "json_value" +] \ No newline at end of file diff --git a/pyobvector/json_table/json_value_returning_func.py b/pyobvector/json_table/json_value_returning_func.py new file mode 100644 index 0000000..18f7f9f --- /dev/null +++ b/pyobvector/json_table/json_value_returning_func.py @@ -0,0 +1,51 @@ +import logging +import re +from typing import Tuple + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import FunctionElement +from sqlalchemy import BINARY, Float, Boolean, Text + +logger = logging.getLogger(__name__) + +class json_value(FunctionElement): + type = Text() + inherit_cache = True + + def __init__(self, *args): + super().__init__() + self.args = args + +@compiles(json_value) +def compile_json_value(element, compiler, **kwargs): + args = [] + if len(element.args) != 3: + raise ValueError("Number of args for json_value should be 3") + args.append(compiler.process(element.args[0])) + if not (isinstance(element.args[1], str) and isinstance(element.args[2], str)): + raise ValueError("Invalid args for json_value") + + if element.args[2].startswith('TINYINT'): + returning_type = "SIGNED" + elif element.args[2].startswith('TIMESTAMP'): + returning_type = "DATETIME" + elif element.args[2].startswith('INT'): + returning_type = "SIGNED" + elif element.args[2].startswith('VARCHAR'): + if element.args[2] == 'VARCHAR': + returning_type = "CHAR(255)" + else: + varchar_pattern = r'VARCHAR\((\d+)\)' + varchar_matches = re.findall(varchar_pattern, element.args[2]) + returning_type = f"CHAR({int(varchar_matches[0])})" + elif element.args[2].startswith('DECIMAL'): + if element.args[2] == 'DECIMAL': + returning_type = "DECIMAL(10, 0)" + else: + decimal_pattern = r'DECIMAL\((\d+),\s*(\d+)\)' + decimal_matches = re.findall(decimal_pattern, element.args[2]) + x, y = decimal_matches[0] + returning_type = f"DECIMAL({x}, {y})" + args.append(f"'{element.args[1]}' RETURNING {returning_type}") + args = ", ".join(args) + return f"json_value({args})" diff --git a/pyobvector/json_table/oceanbase_dialect.py b/pyobvector/json_table/oceanbase_dialect.py new file mode 100644 index 0000000..45e7839 --- /dev/null +++ b/pyobvector/json_table/oceanbase_dialect.py @@ -0,0 +1,116 @@ +import typing as t +from sqlglot import parser, exp, Expression +from sqlglot.dialects.mysql import MySQL +from sqlglot.tokens import TokenType + +class ChangeColumn(Expression): + arg_types = { + "this": True, + "origin_col_name": True, + "dtype": True, + } + + @property + def origin_col_name(self) -> str: + origin_col_name = self.args.get("origin_col_name") + return origin_col_name + + @property + def dtype(self) -> Expression: + dtype = self.args.get("dtype") + return dtype + +class OceanBase(MySQL): + class Parser(MySQL.Parser): + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, + "MODIFY": lambda self: self._parse_alter_table_alter(), + "CHANGE": lambda self: self._parse_change_table_column(), + } + + def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]: + if self._match_texts(self.ALTER_ALTER_PARSERS): + return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self) + + self._match(TokenType.COLUMN) + column = self._parse_field_def() + + if self._match_pair(TokenType.DROP, TokenType.DEFAULT): + return self.expression(exp.AlterColumn, this=column, drop=True) + if self._match_pair(TokenType.SET, TokenType.DEFAULT): + return self.expression(exp.AlterColumn, this=column, default=self._parse_assignment()) + if self._match(TokenType.COMMENT): + return self.expression(exp.AlterColumn, this=column, comment=self._parse_string()) + if self._match_text_seq("DROP", "NOT", "NULL"): + return self.expression( + exp.AlterColumn, + this=column, + drop=True, + allow_null=True, + ) + if self._match_text_seq("SET", "NOT", "NULL"): + return self.expression( + exp.AlterColumn, + this=column, + allow_null=False, + ) + self._match_text_seq("SET", "DATA") + self._match_text_seq("TYPE") + return self.expression( + exp.AlterColumn, + this=column, + dtype=self._parse_types(), + collate=self._match(TokenType.COLLATE) and self._parse_term(), + using=self._match(TokenType.USING) and self._parse_assignment(), + ) + + def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command: + temporary = self._match(TokenType.TEMPORARY) + materialized = self._match_text_seq("MATERIALIZED") + + kind = self._match_set(self.CREATABLES) and self._prev.text.upper() + if not kind: + kind = "COLUMN" + + concurrently = self._match_text_seq("CONCURRENTLY") + if_exists = exists or self._parse_exists() + + if kind == "COLUMN": + this = self._parse_column() + else: + this = self._parse_table_parts( + schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA + ) + + cluster = self._parse_on_property() if self._match(TokenType.ON) else None + + if self._match(TokenType.L_PAREN, advance=False): + expressions = self._parse_wrapped_csv(self._parse_types) + else: + expressions = None + + return self.expression( + exp.Drop, + exists=if_exists, + this=this, + expressions=expressions, + kind=self.dialect.CREATABLE_KIND_MAPPING.get(kind) or kind, + temporary=temporary, + materialized=materialized, + cascade=self._match_text_seq("CASCADE"), + constraints=self._match_text_seq("CONSTRAINTS"), + purge=self._match_text_seq("PURGE"), + cluster=cluster, + concurrently=concurrently, + ) + + def _parse_change_table_column(self) -> t.Optional[exp.Expression]: + self._match(TokenType.COLUMN) + origin_col = self._parse_field(any_token=True) + column = self._parse_field() + return self.expression( + ChangeColumn, + this=column, + origin_col_name=origin_col, + dtype=self._parse_types(), + ) \ No newline at end of file diff --git a/pyobvector/json_table/virtual_data_type.py b/pyobvector/json_table/virtual_data_type.py new file mode 100644 index 0000000..4745b14 --- /dev/null +++ b/pyobvector/json_table/virtual_data_type.py @@ -0,0 +1,114 @@ +from datetime import datetime +from decimal import Decimal, InvalidOperation, ROUND_DOWN +from enum import Enum +from typing import Optional +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, AfterValidator, create_model + + +class IntEnum(int, Enum): + """Int type enumerate definition.""" + +class JType(IntEnum): + J_BOOL = 1 + J_TIMESTAMP = 2 + J_VARCHAR = 3 + J_DECIMAL = 4 + J_INT = 5 + +class JsonTableDataType(BaseModel): + type: JType + +class JsonTableBool(JsonTableDataType): + type: JType = Field(default=JType.J_BOOL) + val: Optional[bool] + +class JsonTableTimestamp(JsonTableDataType): + type: JType = Field(default=JType.J_TIMESTAMP) + val: Optional[datetime] + +def check_varchar_len_with_length(length: int): + def check_varchar_len(x: Optional[str]): + if x is None: + return None + if len(x) > length: + raise ValueError(f'{x} is longer than {length}') + return x + + return check_varchar_len + +class JsonTableVarcharFactory: + def __init__(self, length: int): + self.length = length + + def get_json_table_varchar_type(self): + model_name = f"JsonTableVarchar{self.length}" + fields = { + 'type': (JType, JType.J_VARCHAR), + 'val': (Annotated[Optional[str], AfterValidator(check_varchar_len_with_length(self.length))], ...) + } + return create_model( + model_name, + __base__=JsonTableDataType, + **fields + ) + +def check_and_parse_decimal(x: int, y: int): + def check_float(v): + if v is None: + return None + try: + decimal_value = Decimal(v) + except InvalidOperation: + raise ValueError(f"Value {v} cannot be converted to Decimal.") + + decimal_str = str(decimal_value).strip() + + if '.' in decimal_str: + integer_part, decimal_part = decimal_str.split('.') + else: + integer_part, decimal_part = decimal_str, '' + + integer_count = len(integer_part.lstrip('-')) # 去掉负号的长度 + decimal_count = len(decimal_part) + + if integer_count + min(decimal_count, y) > x: + raise ValueError(f"'{v}' Range out of Decimal({x}, {y})") + + if decimal_count > y: + quantize_str = '1.' + '0' * y + decimal_value = decimal_value.quantize(Decimal(quantize_str), rounding=ROUND_DOWN) + return decimal_value + return check_float + +class JsonTableDecimalFactory: + def __init__(self, ndigits: int, decimal_p: int): + self.ndigits = ndigits + self.decimal_p = decimal_p + + def get_json_table_decimal_type(self): + model_name = f"JsonTableDecimal_{self.ndigits}_{self.decimal_p}" + fields = { + 'type': (JType, JType.J_DECIMAL), + 'val': (Annotated[Optional[float], AfterValidator(check_and_parse_decimal(self.ndigits, self.decimal_p))], ...) + } + return create_model( + model_name, + __base__=JsonTableDataType, + **fields + ) + +class JsonTableInt(JsonTableDataType): + type: JType = Field(default=JType.J_INT) + val: Optional[int] + +def val2json(val): + if val is None: + return None + if isinstance(val, int) or isinstance(val, bool) or isinstance(val, str): + return val + if isinstance(val, datetime): + return val.isoformat() + if isinstance(val, Decimal): + return float(val) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index de363d1..d401f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ numpy = ">=1.26.0,<2.0.0" sqlalchemy = ">=1.4,<2.0.36" pymysql = "^1.1.1" aiomysql = "^0.2.0" +sqlglot = "^26.0.1" +pydantic = "^2.10.4" [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" diff --git a/tests/test_json_table.py b/tests/test_json_table.py new file mode 100644 index 0000000..21c5e9a --- /dev/null +++ b/tests/test_json_table.py @@ -0,0 +1,366 @@ +import unittest +import datetime +from decimal import Decimal +from pyobvector import * +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +def sub_dict(d_list, keys): + new_metas = [] + for meta in d_list: + tmp = { + k: meta[k] for k in keys + } + new_metas.append(tmp) + return new_metas + +def get_all_rows(res): + rows = [] + for r in res: + rows.append(r) + return rows + +class ObVecJsonTableTest(unittest.TestCase): + def setUp(self) -> None: + self.root_client = ObVecJsonTableClient(user_id=0) + self.client = ObVecJsonTableClient(user_id=1, user="jtuser@test") + + def test_create_and_alter_jtable(self): + self.root_client._reset() + self.client.refresh_metadata() + keys_to_check = ['jcol_id', 'jcol_name', 'jcol_type', 'jcol_nullable', 'jcol_has_default', 'jcol_default'] + self.client.perform_json_table_sql( + "create table `t2` (c1 int NOT NULL DEFAULT 10, c2 varchar(30) DEFAULT 'ca', c3 varchar not null, c4 decimal(10, 2));" + ) + tmp_client = ObVecJsonTableClient(user_id=1) + self.assertEqual(sub_dict(tmp_client.jmetadata.meta_cache['t2'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '10'}, + {'jcol_id': 17, 'jcol_name': 'c2', 'jcol_type': 'VARCHAR(30)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'ca'"}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t2 CHANGE COLUMN c2 changed_col INT" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t2'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '10'}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t2 DROP c3" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t2'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '10'}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t2 ADD COLUMN email VARCHAR(100) default 'example@example.com'" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t2'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '10'}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 20, 'jcol_name': 'email', 'jcol_type': 'VARCHAR(100)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'example@example.com'"} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t2 MODIFY COLUMN c4 INT NOT NULL DEFAULT 100" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t2'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '10'}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '100'}, + {'jcol_id': 20, 'jcol_name': 'email', 'jcol_type': 'VARCHAR(100)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'example@example.com'"} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t2 RENAME TO alter_test" + ) + self.assertEqual(self.client.jmetadata.meta_cache.get('t2', []), []) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['alter_test'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '10'}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'INT', 'jcol_nullable': False, 'jcol_has_default': True, 'jcol_default': '100'}, + {'jcol_id': 20, 'jcol_name': 'email', 'jcol_type': 'VARCHAR(100)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'example@example.com'"} + ] + ) + + def test_create_and_alter_jtable_evil(self): + self.root_client._reset() + self.client.refresh_metadata() + keys_to_check = ['jcol_id', 'jcol_name', 'jcol_type', 'jcol_nullable', 'jcol_has_default', 'jcol_default'] + self.client.perform_json_table_sql( + "create table `t1` (c1 int DEFAULT NULL, c2 varchar(30) DEFAULT 'ca', c3 varchar not null, c4 decimal(10, 2), c5 TIMESTAMP DEFAULT CURRENT_TIMESTAMP);" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t1'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 17, 'jcol_name': 'c2', 'jcol_type': 'VARCHAR(30)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'ca'"}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 20, 'jcol_name': 'c5', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': 'CURRENT_TIMESTAMP()'} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t1 CHANGE COLUMN c2 changed_col DECIMAL" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t1'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'DECIMAL', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 20, 'jcol_name': 'c5', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': 'CURRENT_TIMESTAMP()'} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t1 ADD COLUMN date timestamp default current_timestamp" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t1'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'DECIMAL', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 20, 'jcol_name': 'c5', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': 'CURRENT_TIMESTAMP()'}, + {'jcol_id': 21, 'jcol_name': 'date', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': 'CURRENT_TIMESTAMP()'} + ] + ) + + self.client.perform_json_table_sql( + "ALTER TABLE t1 MODIFY COLUMN c4 INT DEFAULT NULL" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t1'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 17, 'jcol_name': 'changed_col', 'jcol_type': 'DECIMAL', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 20, 'jcol_name': 'c5', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': 'CURRENT_TIMESTAMP()'}, + {'jcol_id': 21, 'jcol_name': 'date', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': 'CURRENT_TIMESTAMP()'} + ] + ) + + def test_dml(self): + self.root_client._reset() + self.client.refresh_metadata() + keys_to_check = ['jcol_id', 'jcol_name', 'jcol_type', 'jcol_nullable', 'jcol_has_default', 'jcol_default'] + self.client.perform_json_table_sql( + "create table `t1` (c1 int DEFAULT NULL, c2 varchar(30) DEFAULT 'ca', c3 varchar not null, c4 decimal(10, 2), c5 TIMESTAMP DEFAULT '2024-12-30T03:35:30');" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t1'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'c1', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 17, 'jcol_name': 'c2', 'jcol_type': 'VARCHAR(30)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'ca'"}, + {'jcol_id': 18, 'jcol_name': 'c3', 'jcol_type': 'VARCHAR', 'jcol_nullable': False, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 19, 'jcol_name': 'c4', 'jcol_type': 'DECIMAL(10,2)', 'jcol_nullable': True, 'jcol_has_default': False, 'jcol_default': None}, + {'jcol_id': 20, 'jcol_name': 'c5', 'jcol_type': 'TIMESTAMP', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'2024-12-30T03:35:30'"} + ] + ) + + self.client.perform_json_table_sql( + "insert into t1 (c2, c3) values ('hello', 'foo'), ('world', 'bar')" + ) + self.client.perform_json_table_sql( + "insert into t1 values (1+2, 'baz', 'oceanbase', 12.3+45.6, '2024-12-30T06:56:00')" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (None, 'hello', 'foo', None, datetime.datetime(2024, 12, 30, 3, 35, 30)), + (None, 'world', 'bar', None, datetime.datetime(2024, 12, 30, 3, 35, 30)), + (3, 'baz', 'oceanbase', Decimal('57.89'), datetime.datetime(2024, 12, 30, 6, 56)), + ] + ) + + self.client.perform_json_table_sql( + "update t1 set c1=10+10, c2='updated' where c3='oceanbase'" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (None, 'hello', 'foo', None, datetime.datetime(2024, 12, 30, 3, 35, 30)), + (None, 'world', 'bar', None, datetime.datetime(2024, 12, 30, 3, 35, 30)), + (20, 'updated', 'oceanbase', Decimal('57.89'), datetime.datetime(2024, 12, 30, 6, 56)), + ] + ) + + self.client.perform_json_table_sql( + "delete from t1 where c1 is NULL" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (20, 'updated', 'oceanbase', Decimal('57.89'), datetime.datetime(2024, 12, 30, 6, 56)), + ] + ) + + self.client.perform_json_table_sql( + "select c1, c2, t1.c3 from t1 where c1 > 21" + ) + self.assertEqual(get_all_rows(res), []) + + self.client.perform_json_table_sql( + "alter table t1 drop column c3" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (20, 'updated', Decimal('57.89'), datetime.datetime(2024, 12, 30, 6, 56)), + ] + ) + + self.client.perform_json_table_sql( + "alter table t1 add column new_col TIMESTAMP default '2024-12-30T02:44:17'" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (20, 'updated', Decimal('57.89'), + datetime.datetime(2024, 12, 30, 6, 56), + datetime.datetime(2024, 12, 30, 2, 44, 17)), + ] + ) + + self.client.perform_json_table_sql( + "alter table t1 modify column c4 INT DEFAULT 10" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (20, 'updated', 58, + datetime.datetime(2024, 12, 30, 6, 56), + datetime.datetime(2024, 12, 30, 2, 44, 17)), + ] + ) + + self.client.perform_json_table_sql( + "alter table t1 change column c1 change_col DECIMAL(10,2)" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (20.00, 'updated', 58, + datetime.datetime(2024, 12, 30, 6, 56), + datetime.datetime(2024, 12, 30, 2, 44, 17)), + ] + ) + + self.client.perform_json_table_sql( + "insert into t1 (change_col, c2, c4) values (12, 'pyobvector is good', 50), (90, 'oceanbase is good', 20)" + ) + res = self.client.perform_json_table_sql( + "select sum(c4) as c4_sum from t1 where CHAR_LENGTH(c2) > 10" + ) + self.assertEqual( + get_all_rows(res), + [ + (Decimal('70'),), + ] + ) + + res = self.client.perform_json_table_sql( + "select * from t1 where CHAR_LENGTH(c2) > 10 or c4 > 50 order by c4" + ) + self.assertEqual( + get_all_rows(res), + [ + (Decimal('90.00'), 'oceanbase is good', 20, datetime.datetime(2024, 12, 30, 3, 35, 30), datetime.datetime(2024, 12, 30, 2, 44, 17)), + (Decimal('12.00'), 'pyobvector is good', 50, datetime.datetime(2024, 12, 30, 3, 35, 30), datetime.datetime(2024, 12, 30, 2, 44, 17)), + (Decimal('20.00'), 'updated', 58, datetime.datetime(2024, 12, 30, 6, 56), datetime.datetime(2024, 12, 30, 2, 44, 17)) + ] + ) + + def test_col_name_conflict(self): + self.root_client._reset() + self.client.refresh_metadata() + keys_to_check = ['jcol_id', 'jcol_name', 'jcol_type', 'jcol_nullable', 'jcol_has_default', 'jcol_default'] + self.client.perform_json_table_sql( + "create table `t1` (user_id int DEFAULT NULL, jtable_name varchar(30) DEFAULT 'jtable');" + ) + self.assertEqual(sub_dict(self.client.jmetadata.meta_cache['t1'], keys_to_check), + [ + {'jcol_id': 16, 'jcol_name': 'user_id', 'jcol_type': 'INT', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': None}, + {'jcol_id': 17, 'jcol_name': 'jtable_name', 'jcol_type': 'VARCHAR(30)', 'jcol_nullable': True, 'jcol_has_default': True, 'jcol_default': "'jtable'"}, + ] + ) + + self.client.perform_json_table_sql( + "insert into t1 values (1, 'alice'), (2, 'bob')" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (1, 'alice'), + (2, 'bob'), + ] + ) + + res = self.client.perform_json_table_sql( + "select * from t1 where user_id > 1" + ) + self.assertEqual( + get_all_rows(res), + [ + (2, 'bob'), + ] + ) + + res = self.client.perform_json_table_sql( + "update t1 set user_id=15 where jtable_name='alice'" + ) + res = self.client.perform_json_table_sql( + "select * from t1" + ) + self.assertEqual( + get_all_rows(res), + [ + (15, 'alice'), + (2, 'bob'), + ] + ) diff --git a/tests/test_oceanbase_dialect.py b/tests/test_oceanbase_dialect.py new file mode 100644 index 0000000..fd38306 --- /dev/null +++ b/tests/test_oceanbase_dialect.py @@ -0,0 +1,43 @@ +import unittest +from pyobvector import * +import logging + +from sqlglot import parse_one + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +class OceanBaseDialectTest(unittest.TestCase): + def setUp(self) -> None: + return super().setUp() + + def test_drop_column(self): + sql = "ALTER TABLE users DROP COLUMN age" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}") + + sql = "ALTER TABLE users DROP age" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}") + + def test_modify_column(self): + sql = "ALTER TABLE users MODIFY COLUMN email VARCHAR(100) NOT NULL" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}") + + sql = "ALTER TABLE users MODIFY email VARCHAR(100) NOT NULL" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}") + + sql = "ALTER TABLE users MODIFY COLUMN email VARCHAR(100) NOT NULL DEFAULT 'ca'" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}") + + def test_change_column(self): + sql = "ALTER TABLE users CHANGE COLUMN username user_name VARCHAR(50)" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}") + + sql = "ALTER TABLE users CHANGE username user_name VARCHAR(50)" + ob_ast = parse_one(sql, dialect="oceanbase") + logger.info(f"\n{repr(ob_ast)}")