Skip to content

Commit

Permalink
Merge pull request #63 from rayliverified/main
Browse files Browse the repository at this point in the history
Fix Incorrect Column Type Crash
  • Loading branch information
xnuinside authored Sep 29, 2024
2 parents d95c69c + b5d33bd commit 3f7b067
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
13 changes: 8 additions & 5 deletions omymodels/models/dataclass/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Dict, List, Optional
from typing import List, Optional

from table_meta import TableMeta
from table_meta.model import Column

import omymodels.types as t
from omymodels.helpers import create_class_name, datetime_now_check
Expand All @@ -24,7 +27,7 @@ def add_custom_type(self, _type: str) -> str:
column_type = column_type[0]
return _type

def generate_attr(self, column: Dict, defaults_off: bool) -> str:
def generate_attr(self, column: Column, defaults_off: bool) -> str:
column_str = dt.dataclass_attr

if "." in column.type:
Expand Down Expand Up @@ -57,19 +60,19 @@ def generate_attr(self, column: Dict, defaults_off: bool) -> str:
return column_str

@staticmethod
def add_column_default(column_str: str, column: Dict) -> str:
def add_column_default(column_str: str, column: Column) -> str:
if column.type.upper() in datetime_types:
if datetime_now_check(column.default.lower()):
# todo: need to add other popular PostgreSQL & MySQL functions
column.default = dt.field_datetime_now
elif "'" not in column.default:
column.default = f"'{column['default']}'"
column.default = f"'{column.default}'"
column_str += dt.dataclass_default_attr.format(default=column.default)
return column_str

def generate_model(
self,
table: Dict,
table: TableMeta,
singular: bool = True,
exceptions: Optional[List] = None,
defaults_off: Optional[bool] = False,
Expand Down
35 changes: 21 additions & 14 deletions omymodels/models/pydantic/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional
from typing import List, Optional

from table_meta.model import Column
from table_meta.model import Column, TableMeta

import omymodels.types as t
from omymodels.helpers import create_class_name, datetime_now_check
Expand All @@ -11,29 +11,28 @@

class ModelGenerator:
def __init__(self):
self.imports = set([pt.base_model])
self.imports = {pt.base_model}
self.types_for_import = ["Json"]
self.datetime_import = False
self.typing_imports = set()
self.custom_types = {}
self.uuid_import = False
self.prefix = ""

def add_custom_type(self, target_type):
def add_custom_type(self, target_type: str) -> Optional[str]:
column_type = self.custom_types.get(target_type, None)
_type = None
if isinstance(column_type, tuple):
_type = column_type[1]
return _type

def get_not_custom_type(self, column: Column):
def get_not_custom_type(self, column: Column) -> str:
_type = None
if "." in column.type:
_type = column.type.split(".")[1]
else:
_type = column.type.lower().split("[")[0]
if _type == _type:
_type = types_mapping.get(_type, _type)
_type = types_mapping.get(_type, _type)
if _type in self.types_for_import:
self.imports.add(_type)
elif "datetime" in _type:
Expand All @@ -45,40 +44,48 @@ def get_not_custom_type(self, column: Column):
self.uuid_import = True
return _type

def generate_attr(self, column: Dict, defaults_off: bool) -> str:
def generate_attr(self, column: Column, defaults_off: bool) -> str:
_type = None

if column.nullable:
self.typing_imports.add("Optional")
column_str = pt.pydantic_optional_attr
else:
column_str = pt.pydantic_attr

if self.custom_types:
_type = self.add_custom_type(column.type)
if not _type:
_type = self.get_not_custom_type(column)

column_str = column_str.format(arg_name=column.name, type=_type)

if column.default and defaults_off is False:
if column.default is not None and not defaults_off:
column_str = self.add_default_values(column_str, column)

return column_str

@staticmethod
def add_default_values(column_str: str, column: Dict) -> str:
def add_default_values(column_str: str, column: Column) -> str:
# Handle datetime default values
if column.type.upper() in datetime_types:
if datetime_now_check(column.default.lower()):
# todo: need to add other popular PostgreSQL & MySQL functions
# Handle functions like CURRENT_TIMESTAMP
column.default = "datetime.datetime.now()"
elif "'" not in column.default:
column.default = f"'{column['default']}'"
elif column.default.upper() != "NULL" and "'" not in column.default:
column.default = f"'{column.default}'"

# If the default is 'NULL', don't set a default in Pydantic (it already defaults to None)
if column.default.upper() == "NULL":
return column_str

# Append the default value if it's not None (e.g., explicit default values like '0' or CURRENT_TIMESTAMP)
column_str += pt.pydantic_default_attr.format(default=column.default)
return column_str

def generate_model(
self,
table: Dict,
table: TableMeta,
singular: bool = True,
exceptions: Optional[List] = None,
defaults_off: Optional[bool] = False,
Expand Down
2 changes: 1 addition & 1 deletion omymodels/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def process_types_after_models_parser(column_data: Column) -> Column:
return column_data


def prepare_column_data(column_data: Column) -> str:
def prepare_column_data(column_data: Column) -> Column:
if "." in column_data.type or "(":
column_data = process_types_after_models_parser(column_data)
return column_data
Expand Down

0 comments on commit 3f7b067

Please sign in to comment.