diff --git a/omymodels/models/dataclass/core.py b/omymodels/models/dataclass/core.py index 15bb499..cf6dcb0 100644 --- a/omymodels/models/dataclass/core.py +++ b/omymodels/models/dataclass/core.py @@ -1,4 +1,7 @@ -from typing import Dict, List, Optional +from typing import List, Optional + +from table_meta import TableMeta +from table_meta.model import Column import omymodels.types as t from omymodels.helpers import create_class_name, datetime_now_check @@ -24,7 +27,7 @@ def add_custom_type(self, _type: str) -> str: column_type = column_type[0] return _type - def generate_attr(self, column: Dict, defaults_off: bool) -> str: + def generate_attr(self, column: Column, defaults_off: bool) -> str: column_str = dt.dataclass_attr if "." in column.type: @@ -57,19 +60,19 @@ def generate_attr(self, column: Dict, defaults_off: bool) -> str: return column_str @staticmethod - def add_column_default(column_str: str, column: Dict) -> str: + def add_column_default(column_str: str, column: Column) -> str: if column.type.upper() in datetime_types: if datetime_now_check(column.default.lower()): # todo: need to add other popular PostgreSQL & MySQL functions column.default = dt.field_datetime_now elif "'" not in column.default: - column.default = f"'{column['default']}'" + column.default = f"'{column.default}'" column_str += dt.dataclass_default_attr.format(default=column.default) return column_str def generate_model( self, - table: Dict, + table: TableMeta, singular: bool = True, exceptions: Optional[List] = None, defaults_off: Optional[bool] = False, diff --git a/omymodels/models/pydantic/core.py b/omymodels/models/pydantic/core.py index e719a8f..a1ea0ae 100644 --- a/omymodels/models/pydantic/core.py +++ b/omymodels/models/pydantic/core.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional +from typing import List, Optional -from table_meta.model import Column +from table_meta.model import Column, TableMeta import omymodels.types as t from omymodels.helpers import create_class_name, datetime_now_check @@ -11,7 +11,7 @@ class ModelGenerator: def __init__(self): - self.imports = set([pt.base_model]) + self.imports = {pt.base_model} self.types_for_import = ["Json"] self.datetime_import = False self.typing_imports = set() @@ -19,21 +19,20 @@ def __init__(self): self.uuid_import = False self.prefix = "" - def add_custom_type(self, target_type): + def add_custom_type(self, target_type: str) -> Optional[str]: column_type = self.custom_types.get(target_type, None) _type = None if isinstance(column_type, tuple): _type = column_type[1] return _type - def get_not_custom_type(self, column: Column): + def get_not_custom_type(self, column: Column) -> str: _type = None if "." in column.type: _type = column.type.split(".")[1] else: _type = column.type.lower().split("[")[0] - if _type == _type: - _type = types_mapping.get(_type, _type) + _type = types_mapping.get(_type, _type) if _type in self.types_for_import: self.imports.add(_type) elif "datetime" in _type: @@ -45,7 +44,7 @@ def get_not_custom_type(self, column: Column): self.uuid_import = True return _type - def generate_attr(self, column: Dict, defaults_off: bool) -> str: + def generate_attr(self, column: Column, defaults_off: bool) -> str: _type = None if column.nullable: @@ -53,6 +52,7 @@ def generate_attr(self, column: Dict, defaults_off: bool) -> str: column_str = pt.pydantic_optional_attr else: column_str = pt.pydantic_attr + if self.custom_types: _type = self.add_custom_type(column.type) if not _type: @@ -60,25 +60,32 @@ def generate_attr(self, column: Dict, defaults_off: bool) -> str: column_str = column_str.format(arg_name=column.name, type=_type) - if column.default and defaults_off is False: + if column.default is not None and not defaults_off: column_str = self.add_default_values(column_str, column) return column_str @staticmethod - def add_default_values(column_str: str, column: Dict) -> str: + def add_default_values(column_str: str, column: Column) -> str: + # Handle datetime default values if column.type.upper() in datetime_types: if datetime_now_check(column.default.lower()): - # todo: need to add other popular PostgreSQL & MySQL functions + # Handle functions like CURRENT_TIMESTAMP column.default = "datetime.datetime.now()" - elif "'" not in column.default: - column.default = f"'{column['default']}'" + elif column.default.upper() != "NULL" and "'" not in column.default: + column.default = f"'{column.default}'" + + # If the default is 'NULL', don't set a default in Pydantic (it already defaults to None) + if column.default.upper() == "NULL": + return column_str + + # Append the default value if it's not None (e.g., explicit default values like '0' or CURRENT_TIMESTAMP) column_str += pt.pydantic_default_attr.format(default=column.default) return column_str def generate_model( self, - table: Dict, + table: TableMeta, singular: bool = True, exceptions: Optional[List] = None, defaults_off: Optional[bool] = False, diff --git a/omymodels/types.py b/omymodels/types.py index 8eac93f..1a30870 100644 --- a/omymodels/types.py +++ b/omymodels/types.py @@ -116,7 +116,7 @@ def process_types_after_models_parser(column_data: Column) -> Column: return column_data -def prepare_column_data(column_data: Column) -> str: +def prepare_column_data(column_data: Column) -> Column: if "." in column_data.type or "(": column_data = process_types_after_models_parser(column_data) return column_data