diff --git a/Router/item.py b/Router/item.py index c913ba4..8529f2f 100644 --- a/Router/item.py +++ b/Router/item.py @@ -2,8 +2,6 @@ from fastapi.responses import JSONResponse -from Data.user import LoginUser, UserCreate -from Database.database import get_db from fastapi import APIRouter from starlette.status import * from fastapi import Query, Request @@ -13,6 +11,7 @@ from Service.embedding_service import EmbeddingService from Service.item_service import ItemService from Service.purchase_service import PurchaseService +from Service.user_service import UserService from Service.useritem_service import UserItemService import numpy as np @@ -28,11 +27,15 @@ 500: {"description": "실패"} }) def add_items(request: Request, itemadd: ItemAdd): - UserItemService.add_userItems(itemadd) - purchase_history_list = PurchaseService.purchase_history_list_db(itemadd) - PurchaseService.purchase_history_list_save(purchase_history_list) - return JSONResponse(status_code=HTTP_200_OK, content={"message": "Item added successfully"}) - + try: + # 존재하지 않는 유저 아이디 들어오면 새로 생성 + user = UserService.get_user_or_create(itemadd.items[0].user_id) + UserItemService.add_userItems(itemadd) + purchase_history_list = PurchaseService.purchase_history_list_db(itemadd) + PurchaseService.purchase_history_list_save(purchase_history_list) + return JSONResponse(status_code=HTTP_200_OK, content={"message": "Item added successfully"}) + except Exception as e: + return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(e)}) """ 유저가 보유한 모든 아이템을 가져온다. @@ -42,6 +45,9 @@ def add_items(request: Request, itemadd: ItemAdd): 500: {"description": "실패"} }, response_model=List[ItemRead]) def get_userItem_all(request : Request, user_id: str): + user = UserService.get_user_by_id(user_id) + if user is None: + return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": "User not found"}) user_item_list = UserItemService.get_all_userItem(user_id) user_item_list_dict = UserItemService.to_userItem_dict(user_item_list) return JSONResponse(status_code=HTTP_200_OK, content={"items": user_item_list_dict}) @@ -75,13 +81,16 @@ def consume_item(request : Request, user_item_consume: UserItemConsume): }) def add_item(request : Request, userItemAdd: UserItemAdd): try: + user = UserService.get_user_or_create(userItemAdd.user_id) + userItemAdd.user_id = user.user_id UserItemService.add_userItem(userItemAdd) purchaseHistory = PurchaseService.purchase_history_db(userItemAdd) PurchaseService.purchase_history_save(purchaseHistory) return JSONResponse(status_code=HTTP_200_OK, content={"message": "Item added successfully"}) except OverflowError as e: return JSONResponse(status_code=HTTP_400_BAD_REQUEST, content={"message": "Too many items added"}) - + except Exception as e: + return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(e)}) @router.post("/add/{item_name}/{base_consume_expectation}/{base_price}", summary="아이템 추가") def add_item(request : Request, item_name: str, base_consume_expectation: int, base_price: int): diff --git a/Router/login.py b/Router/login.py index 85b8233..0b047f3 100644 --- a/Router/login.py +++ b/Router/login.py @@ -6,6 +6,7 @@ from starlette.status import * from fastapi import Query, Request from Database.models import User +from Service.user_service import UserService from Service.useritem_service import UserItemService from Utils.swagger import signin_response_example @@ -17,17 +18,8 @@ 500: {"description": "실패"} }) def signup(request: Request, userCreate: UserCreate): - user = User( - name=userCreate.name, - password=userCreate.password - ) - with get_db() as db: - db.add(user) - db.commit() - db.refresh(user) - - UserItemService.init_userItem(user.user_id) - return JSONResponse(status_code=HTTP_200_OK, content={"message": "User added successfully"}) + user = UserService.get_user_or_create("", userCreate.name, userCreate.password) + return JSONResponse(status_code=HTTP_200_OK, content={"message": "User added successfully", "user_id": user.user_id}) @router.post("/signin", summary="로그인", description="로그인 패스워드 확인안함", responses={ diff --git a/Router/price.py b/Router/price.py index 2e0c3db..864b9de 100644 --- a/Router/price.py +++ b/Router/price.py @@ -1,11 +1,8 @@ from fastapi.responses import JSONResponse -from Data.user import LoginUser, UserCreate -from Database.database import get_db from fastapi import APIRouter from starlette.status import * from fastapi import Query, Request -from Database.models import User from Service.item_service import ItemService from Service.purchase_service import PurchaseService from Service.useritem_service import UserItemService diff --git a/Service/user_service.py b/Service/user_service.py new file mode 100644 index 0000000..d5160cd --- /dev/null +++ b/Service/user_service.py @@ -0,0 +1,22 @@ +from Database.database import get_db +from Database.models import User +from Service.useritem_service import UserItemService + + +class UserService: + @staticmethod + def get_user_by_id(user_id: str): + with get_db() as db: + return db.query(User).filter(User.user_id == user_id).first() + + @staticmethod + def get_user_or_create(user_id: str, name=None, password="1234"): + user = UserService.get_user_by_id(user_id) + if user is None: + with get_db() as db: + user = User(name=user_id[:8] if name is None else name, password=password) + db.add(user) + db.commit() + db.refresh(user) + UserItemService.init_userItem(user.user_id) + return user \ No newline at end of file diff --git a/Service/useritem_service.py b/Service/useritem_service.py index 53b1955..e4ab3d6 100644 --- a/Service/useritem_service.py +++ b/Service/useritem_service.py @@ -2,7 +2,6 @@ from select import select from typing import Optional, List, Annotated -from sqlalchemy.orm import Session from sqlalchemy.sql.expression import extract from Data.item import UserItemAdd, ItemAdd, UserItemConsume @@ -53,6 +52,8 @@ def to_userItem_dict(userItemList: List[UserItem]): def add_userItem(itemAdd: UserItemAdd): with get_db() as db: useritem = db.query(UserItem).filter(UserItem.user_id == itemAdd.user_id, UserItem.item_name == itemAdd.item_name).first() + if useritem is None: + raise Exception("UserItem not found") if useritem.consume_date is None or useritem.count == 0: useritem.consume_date = itemAdd.purchase_date @@ -62,10 +63,11 @@ def add_userItem(itemAdd: UserItemAdd): db.commit() return True + @staticmethod def add_userItems(itemAdd : ItemAdd): - for addItem in itemAdd.items: - with get_db() as db: + with get_db() as db: + for addItem in itemAdd.items: useritem = db.query(UserItem).filter(UserItem.user_id == addItem.user_id, UserItem.item_name == addItem.item_name).first() if useritem.consume_date is None or useritem.count == 0: @@ -73,7 +75,7 @@ def add_userItems(itemAdd : ItemAdd): useritem.count += addItem.count useritem.consume_date += timedelta(days=useritem.consume_expectation * addItem.count) - db.commit() + db.commit() return True def consume_userItem(userItemConsume: UserItemConsume):