diff --git a/tests/test_conversion.py b/tests/test_conversion.py index 8507d26..dd8138a 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -16,12 +16,23 @@ "IS_NULLABLE": "NO", "COLUMN_KEY": "PRI", "COLUMN_COMMENT": "主键ID", + "CHARACTER_MAXIMUM_LENGTH": "", + "NUMERIC_PRECISION": "", + "NUMERIC_SCALE": "", + "COLUMN_DEFAULT": "", + "EXTRA": "", }, { "COLUMN_NAME": "name", - "DATA_TYPE": "varchar(100)", + "DATA_TYPE": "varchar", "IS_NULLABLE": "YES", "COLUMN_COMMENT": "姓名", + "CHARACTER_MAXIMUM_LENGTH": "100", + "NUMERIC_PRECISION": "", + "NUMERIC_SCALE": "", + "COLUMN_DEFAULT": "", + "EXTRA": "", + "COLUMN_KEY": "", }, ] @@ -64,8 +75,11 @@ def test_tortoise_conversion_model(tortoise_conversion_fixture): """测试TortoiseConversion的model方法输出格式""" model_code = tortoise_conversion_fixture.model() assert "class Users(Model):" in model_code - assert "id = fields.Int(pk=True)" in model_code - assert "name = fields.CharField(max_length=100)" in model_code + assert 'id = fields.IntField(description="主键ID", pk=True)' in model_code + assert ( + 'name = fields.CharField(null=True, max_length=100, description="姓名")' + in model_code + ) def test_pydantic_field(): @@ -79,7 +93,10 @@ def test_sqlmodel_field_repr(): """测试_sqlmodel_field_repr函数的输出""" column = MOCK_COLUMNS[0] # 使用id字段作为测试 imports, field_code = set(), _sqlmodel_field_repr(column, set()) - assert "id: Optional[int] = Field(nullable=False)" in field_code + assert ( + 'id: Optional[int] = Field(default=None,primary_key=True,description="主键ID")' + == field_code + ) assert ( "from datetime import datetime" not in imports ) # id字段不应触发默认时间戳逻辑 @@ -89,4 +106,7 @@ def test_tortoise_field_repr(): """测试_tortoise_field_repr函数的输出""" column = MOCK_COLUMNS[1] field_code = _tortoise_field_repr(column) - assert "name = fields.CharField(max_length=100, description='姓名')" in field_code + assert ( + 'name = fields.CharField(null=True, max_length=100, description="姓名")' + == field_code + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index fb9b472..045bfc5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,9 +1,6 @@ -import pymysql import pytest from dfs_generate.tools import tran, to_pascal, to_snake -from dfs_generate.tools import MySQLConf, MySQLHelper -from unittest.mock import MagicMock -from pymysql.err import OperationalError +from dfs_generate.tools import MySQLConf # 测试 tran 函数 @@ -16,7 +13,7 @@ ], ) def test_tran(t, mode, expected): - assert tran(t, mode) == expected + assert tran(t, mode)["type"] == expected["type"] # 测试 to_pascal 函数 @@ -75,51 +72,3 @@ def test_mysqlconf_json(): "charset": "utf8mb4", } assert conf.json() == expected_json - - -@pytest.fixture -def mysql_helper_mock(monkeypatch): - """Fixture to create a mocked MySQLHelper instance.""" - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - monkeypatch.setattr("pymysql.connect", lambda *args, **kwargs: mock_conn) - helper = MySQLHelper( - MySQLConf(host="localhost", user="test", password="pwd", db="test_db") - ) - return helper, mock_conn, mock_cursor - - -def test_mysqlhelper_init(mysql_helper_mock): - helper, mock_conn, _ = mysql_helper_mock - mock_conn.assert_called_once() - assert helper.conn == mock_conn - assert helper.cursor == mock_conn.cursor.return_value - - -def test_mysqlhelper_set_conn(mysql_helper_mock): - helper, mock_conn, _ = mysql_helper_mock - new_conf = MySQLConf( - host="new_host", user="new_user", password="new_pwd", db="new_db" - ) - helper.set_conn(new_conf) - mock_conn.assert_called_with( - **new_conf.json(), cursorclass=pymysql.cursors.DictCursor - ) - - -def test_mysqlhelper_close(mysql_helper_mock): - _, mock_conn, mock_cursor = mysql_helper_mock - helper = MySQLHelper( - MySQLConf(host="localhost", user="test", password="pwd", db="test_db") - ) - helper.close() - mock_cursor.close.assert_called_once() - mock_conn.close.assert_called_once() - - -def test_mysqlhelper_get_tables_error(mysql_helper_mock): - helper, _, mock_cursor = mysql_helper_mock - mock_cursor.execute.side_effect = OperationalError - with pytest.raises(OperationalError): - helper.get_tables()