Skip to content

Commit

Permalink
fix #64 兼容表结构主键列名不为id时引发的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
zy7y committed May 17, 2024
1 parent 972e306 commit 198ee90
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 191 deletions.
9 changes: 7 additions & 2 deletions dfs_generate/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, table_name, columns, uri):
self.table_name = table_name
self.columns = columns
self.uri = uri
self.pk = "{'id': id}"

@property
def table(self):
Expand Down Expand Up @@ -193,6 +194,8 @@ def model(self):
fields = []
for column in self.columns:
field = _sqlmodel_field_repr(column, imports)
if column["COLUMN_KEY"] == "PRI":
self.pk = "{" + "'" + column["COLUMN_NAME"] + "': id}"
if " " + field not in fields:
fields.append(" " + field)
return "\n".join(imports) + "\n\n" + head + "\n" + "\n".join(fields)
Expand All @@ -204,7 +207,7 @@ def dao(self):
"import model",
"import schema",
}
content = SQLMODEL_DAO.format(table=self.table)
content = SQLMODEL_DAO.format(table=self.table, pk=self.pk)
return "\n".join(imports) + "\n\n" + content

def router(self):
Expand Down Expand Up @@ -273,6 +276,8 @@ def model(self):
fields = []
for column in self.columns:
field = _tortoise_field_repr(column)
if column["COLUMN_KEY"] == "PRI":
self.pk = f'{column["COLUMN_NAME"]}=id'
if " " + field not in fields:
fields.append(" " + field)
return (
Expand All @@ -286,7 +291,7 @@ def model(self):

def dao(self):
imports = {"from typing import List, Optional", "import model", "import schema"}
content = TORTOISE_DAO.format(table=self.table)
content = TORTOISE_DAO.format(table=self.table, pk=self.pk)
return "\n".join(imports) + "\n\n" + content

def main(self):
Expand Down
4 changes: 2 additions & 2 deletions dfs_generate/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create(session: Session, obj_in: schema.{table}) -> model.{table}:
return obj
def query_by_id(session: Session, id: int) -> Optional[model.{table}]:
return session.get(model.{table}, id)
return session.get(model.{table}, {pk})
def update(session: Session, id: int, obj_in: schema.{table}) -> Optional[model.{table}]:
obj = query_by_id(session, id)
Expand Down Expand Up @@ -168,7 +168,7 @@ async def create(obj_in: schema.{table}) -> model.{table}:
return obj
async def query_by_id(id: int) -> Optional[model.{table}]:
return await model.{table}.get_or_none(id=id)
return await model.{table}.get_or_none({pk})
async def update(id: int, obj_in: schema.{table}) -> Optional[model.{table}]:
obj= await query_by_id(id)
Expand Down
37 changes: 15 additions & 22 deletions docs/sqlmodel/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@
from sqlmodel import Session, func, select


def create(session: Session, obj_in: schema.SysMenu) -> model.SysMenu:
obj = model.SysMenu(**obj_in.model_dump(exclude_unset=True))
def create(session: Session, obj_in: schema.Aerich) -> model.Aerich:
obj = model.Aerich(**obj_in.model_dump(exclude_unset=True))
session.add(obj)
session.commit()
session.refresh(obj)
return obj


def query_by_id(session: Session, id: int) -> Optional[model.SysMenu]:
return session.get(model.SysMenu, id)
def query_by_id(session: Session, id: int) -> Optional[model.Aerich]:
return session.get(model.Aerich, {'aerich_id': id})


def update(
session: Session, id: int, obj_in: schema.SysMenu
) -> Optional[model.SysMenu]:
def update(session: Session, id: int,
obj_in: schema.Aerich) -> Optional[model.Aerich]:
obj = query_by_id(session, id)
if obj:
for field, value in obj_in.model_dump(exclude_unset=True).items():
Expand All @@ -30,8 +29,8 @@ def update(
return obj


def delete_by_id(session: Session, id: int) -> Optional[model.SysMenu]:
obj = session.get(model.SysMenu, id)
def delete_by_id(session: Session, id: int) -> Optional[model.Aerich]:
obj = session.get(model.Aerich, id)
if obj:
session.delete(obj)
session.commit()
Expand All @@ -40,17 +39,11 @@ def delete_by_id(session: Session, id: int) -> Optional[model.SysMenu]:

def count(session: Session, **kwargs) -> int:
return session.scalar(
select(func.count()).select_from(model.SysMenu).filter_by(**kwargs)
)


def query_all_by_limit(
session: Session, page_number: int, page_size: int, **kwargs
) -> List[model.SysMenu]:
stmt = (
select(model.SysMenu)
.filter_by(**kwargs)
.offset((page_number - 1) * page_size)
.limit(page_size)
)
select(func.count()).select_from(model.Aerich).filter_by(**kwargs))


def query_all_by_limit(session: Session, page_number: int, page_size: int,
**kwargs) -> List[model.Aerich]:
stmt = select(model.Aerich).filter_by(**kwargs).offset(
(page_number - 1) * page_size).limit(page_size)
return session.exec(stmt).all()
15 changes: 6 additions & 9 deletions docs/sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from fastapi import FastAPI
from router import sys_menu
from router import aerich
from starlette.middleware.cors import CORSMiddleware

app = FastAPI(
title="DFS - FastAPI SQLModel CRUD",
description="""
app = FastAPI(title="DFS - FastAPI SQLModel CRUD",
description='''
[![](https://img.shields.io/github/stars/zy7y/dfs-generate)](https://github.com/zy7y/dfs-generate)
[![](https://img.shields.io/github/forks/zy7y/dfs-generate)](https://github.com/zy7y/dfs-generate)
[![](https://img.shields.io/github/repo-size/zy7y/dfs-generate?style=social)](https://github.com/zy7y/dfs-generate)
Expand All @@ -13,8 +12,7 @@
支持ORM:[SQLModel](https://sqlmodel.tiangolo.com/)、[Tortoise ORM](https://tortoise.github.io/)
支持前端: [Vue](https://cn.vuejs.org/)
""",
)
''')

app.add_middleware(
CORSMiddleware,
Expand All @@ -24,9 +22,8 @@
allow_headers=["*"],
)

app.include_router(sys_menu)
app.include_router(aerich)

if __name__ == "__main__":
if __name__ == '__main__':
import uvicorn

uvicorn.run("main:app", reload=True, port=5000)
52 changes: 7 additions & 45 deletions docs/sqlmodel/model.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,11 @@
from datetime import datetime
from typing import Optional

from sqlmodel import Column, DateTime, Field, SQLModel, func
from sqlmodel import JSON, Field, SQLModel


class SysMenu(SQLModel, table=True):
__tablename__ = 'sys_menu'
id: Optional[int] = Field(default=None, primary_key=True, description="主键")
status: int = Field(default=1, description="状态 1有效 9 删除 5选中")
created: datetime = Field(nullable=True,
description="创建时间",
default_factory=datetime.utcnow)
modified: datetime = Field(default=None,
description="更新时间",
sa_column=Column(DateTime(),
onupdate=func.now()))
name: Optional[str] = Field(default=None,
max_length=20,
nullable=True,
description="名称")
icon: Optional[str] = Field(default=None,
max_length=100,
nullable=True,
description="菜单图标")
path: Optional[str] = Field(default=None,
max_length=128,
nullable=True,
description="菜单url")
type: int = Field(default=...,
description="菜单类型 0目录 1组件 2按钮 3数据",
index=True)
component: Optional[str] = Field(default=None,
max_length=128,
nullable=True,
description="组件地址")
pid: Optional[int] = Field(default=None, nullable=True, description="父id")
identifier: Optional[str] = Field(default=None,
max_length=30,
nullable=True,
description="权限标识 user:add")
api: Optional[str] = Field(default=None,
max_length=128,
nullable=True,
description="接口地址")
method: Optional[str] = Field(default=None,
max_length=10,
nullable=True,
description="接口请求方式")
class Aerich(SQLModel, table=True):
__tablename__ = 'aerich'
aerich_id: Optional[int] = Field(default=None, primary_key=True)
version: str = Field(default=..., max_length=255)
app: str = Field(default=..., max_length=100)
content: dict = Field(default=..., sa_type=JSON)
39 changes: 18 additions & 21 deletions docs/sqlmodel/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,42 @@
from fastapi import APIRouter, Depends
from sqlmodel import Session

sys_menu = APIRouter(prefix="/SysMenu", tags=["SysMenu"])
aerich = APIRouter(prefix="/Aerich", tags=["Aerich"])


@sys_menu.get("/{id}", summary="通过ID查询详情")
def query_sys_menu_by_id(id: int) -> schema.Result[schema.SysMenu]:
@aerich.get("/{id}", summary="通过ID查询详情")
def query_aerich_by_id(id: int) -> schema.Result[schema.Aerich]:
with Session(engine) as session:
return schema.Result.ok(dao.query_by_id(session, id))


@sys_menu.get("", summary="分页条件查询")
def query_sys_menu_all_by_limit(
query: schema.SysMenu = Depends(), page: schema.PageParam = Depends()
) -> schema.PageResult[schema.SysMenu]:
@aerich.get("", summary="分页条件查询")
def query_aerich_all_by_limit(query: schema.Aerich = Depends(),
page: schema.PageParam = Depends()
) -> schema.PageResult[schema.Aerich]:
with Session(engine) as session:
total = dao.count(session, **query.model_dump(exclude_none=True))
data = dao.query_all_by_limit(
session,
**query.model_dump(exclude_none=True),
page_number=page.page_number,
page_size=page.page_size,
)
data = dao.query_all_by_limit(session,
**query.model_dump(exclude_none=True),
page_number=page.page_number,
page_size=page.page_size)
return schema.PageResult.ok(data=data, total=total)


@sys_menu.post("", summary="新增数据")
def create_sys_menu(instance: schema.SysMenu) -> schema.Result[schema.SysMenu]:
@aerich.post("", summary="新增数据")
def create_aerich(instance: schema.Aerich) -> schema.Result[schema.Aerich]:
with Session(engine) as session:
return schema.Result.ok(dao.create(session, instance))


@sys_menu.patch("/{id}", summary="更新数据")
def update_sys_menu_by_id(
id: int, instance: schema.SysMenu
) -> schema.Result[schema.SysMenu]:
@aerich.patch("/{id}", summary="更新数据")
def update_aerich_by_id(
id: int, instance: schema.Aerich) -> schema.Result[schema.Aerich]:
with Session(engine) as session:
return schema.Result.ok(dao.update(session, id, instance))


@sys_menu.delete("/{id}", summary="删除数据")
def delete_sys_menu_by_id(id: int) -> schema.Result[schema.SysMenu]:
@aerich.delete("/{id}", summary="删除数据")
def delete_aerich_by_id(id: int) -> schema.Result[schema.Aerich]:
with Session(engine) as session:
return schema.Result.ok(dao.delete_by_id(session, id))
22 changes: 6 additions & 16 deletions docs/sqlmodel/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from datetime import datetime
from typing import Generic, List, Optional, TypeVar

from pydantic import BaseModel, Field
from pydantic.alias_generators import to_camel

T = TypeVar("T")
T = TypeVar('T')


class Result(BaseModel, Generic[T]):
Expand Down Expand Up @@ -37,18 +36,9 @@ class PageParam(BaseModel):
model_config = {"alias_generator": to_camel, "populate_by_name": True}


class SysMenu(BaseModel):
id: Optional[int] = Field(None, description="主键")
status: Optional[int] = Field(None, description="状态 1有效 9 删除 5选中")
created: Optional[datetime] = Field(None, description="创建时间")
modified: Optional[datetime] = Field(None, description="更新时间")
name: Optional[str] = Field(None, description="名称")
icon: Optional[str] = Field(None, description="菜单图标")
path: Optional[str] = Field(None, description="菜单url")
type: Optional[int] = Field(None, description="菜单类型 0目录 1组件 2按钮 3数据")
component: Optional[str] = Field(None, description="组件地址")
pid: Optional[int] = Field(None, description="父id")
identifier: Optional[str] = Field(None, description="权限标识 user:add")
api: Optional[str] = Field(None, description="接口地址")
method: Optional[str] = Field(None, description="接口请求方式")
class Aerich(BaseModel):
aerich_id: Optional[int] = None
version: Optional[str] = None
app: Optional[str] = None
content: Optional[dict] = None
model_config = {"alias_generator": to_camel, "populate_by_name": True}
22 changes: 11 additions & 11 deletions docs/tortoise-orm/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
import schema


async def create(obj_in: schema.SysMenu) -> model.SysMenu:
obj = model.SysMenu(**obj_in.model_dump(exclude_unset=True))
async def create(obj_in: schema.Aerich) -> model.Aerich:
obj = model.Aerich(**obj_in.model_dump(exclude_unset=True))
await obj.save()
return obj


async def query_by_id(id: int) -> Optional[model.SysMenu]:
return await model.SysMenu.get_or_none(id=id)
async def query_by_id(id: int) -> Optional[model.Aerich]:
return await model.Aerich.get_or_none(aerich_id=id)


async def update(id: int, obj_in: schema.SysMenu) -> Optional[model.SysMenu]:
async def update(id: int, obj_in: schema.Aerich) -> Optional[model.Aerich]:
obj = await query_by_id(id)
if obj:
for field, value in obj_in.model_dump(exclude_unset=True).items():
Expand All @@ -23,20 +23,20 @@ async def update(id: int, obj_in: schema.SysMenu) -> Optional[model.SysMenu]:
return obj


async def delete_by_id(id: int) -> Optional[model.SysMenu]:
async def delete_by_id(id: int) -> Optional[model.Aerich]:
obj = await query_by_id(id)
if obj:
await obj.delete()
return obj


async def count(**kwargs) -> int:
return await model.SysMenu.filter(**kwargs).count()
return await model.Aerich.filter(**kwargs).count()


async def query_all_by_limit(
page_number: int, page_size: int, **kwargs
) -> List[model.SysMenu]:
async def query_all_by_limit(page_number: int, page_size: int,
**kwargs) -> List[model.Aerich]:
offset = (page_number - 1) * page_size
limit = page_size
return await model.SysMenu.filter(**kwargs).offset(offset).limit(limit).all()
return await model.Aerich.filter(**kwargs
).offset(offset).limit(limit).all()
Loading

0 comments on commit 198ee90

Please sign in to comment.