Skip to content

Commit

Permalink
Fix issues with signup and signin (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiemh authored Feb 1, 2025
1 parent 67196f9 commit b8c097f
Show file tree
Hide file tree
Showing 17 changed files with 243 additions and 191 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: unit-tests
name: Unit tests

on:
push:
Expand Down
62 changes: 34 additions & 28 deletions src/surrealdb/connections/async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,30 @@ def set_token(self, token: str) -> None:
"""
self.token = token

async def authenticate(self) -> None:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return await self._send(message, "authenticating")

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
self.token = response["result"]
return response["result"]

async def signin(self, vars: dict) -> dict:
message = RequestMessage(
self.id,
Expand All @@ -112,9 +136,16 @@ async def signin(self, vars: dict) -> dict:
response = await self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
package = dict()
package["token"] = self.token
return package
return response["result"]

async def info(self) -> dict:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = await self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
Expand Down Expand Up @@ -187,15 +218,6 @@ async def delete(
self.check_response_for_result(response, "delete")
return response["result"]

async def info(self) -> dict:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = await self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

async def insert(
self, table: Union[str, Table], data: Union[List[dict], dict]
) -> Union[List[dict], dict]:
Expand All @@ -222,11 +244,6 @@ async def insert_relation(
self.check_response_for_result(response, "insert_relation")
return response["result"]

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None

async def let(self, key: str, value: Any) -> None:
self.vars[key] = value

Expand Down Expand Up @@ -306,17 +323,6 @@ async def upsert(
self.check_response_for_result(response, "upsert")
return response["result"]

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
self.token = response["result"]
return response["result"]

async def __aenter__(self) -> "AsyncHttpSurrealConnection":
"""
Asynchronous context manager entry.
Expand Down
38 changes: 19 additions & 19 deletions src/surrealdb/connections/async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ async def use(self, namespace: str, database: str) -> None:
"""
raise NotImplementedError(f"query not implemented for: {self}")

async def authenticate(self, token: str) -> None:
"""Authenticate the current connection with a JWT token.
Args:
token: The JWT authentication token.
Example:
await db.authenticate('insert token here')
"""
raise NotImplementedError(f"authenticate not implemented for: {self}")

async def invalidate(self) -> None:
"""Invalidate the authentication for the current connection.
Example:
await db.invalidate()
"""
raise NotImplementedError(f"invalidate not implemented for: {self}")

async def signup(self, vars: Dict) -> str:
"""Sign this connection up to a specific authentication scope.
[See the docs](https://surrealdb.com/docs/sdk/python/methods/signup)
Expand Down Expand Up @@ -77,25 +96,6 @@ async def signin(self, vars: Dict) -> str:
"""
raise NotImplementedError(f"query not implemented for: {self}")

async def invalidate(self) -> None:
"""Invalidate the authentication for the current connection.
Example:
await db.invalidate()
"""
raise NotImplementedError(f"invalidate not implemented for: {self}")

async def authenticate(self, token: str) -> None:
"""Authenticate the current connection with a JWT token.
Args:
token: The JWT authentication token.
Example:
await db.authenticate('insert token here')
"""
raise NotImplementedError(f"authenticate not implemented for: {self}")

async def let(self, key: str, value: Any) -> None:
"""Assign a value as a variable for this connection.
Expand Down
85 changes: 41 additions & 44 deletions src/surrealdb/connections/async_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,28 @@ async def connect(self, url: Optional[str] = None, max_size: Optional[int] = Non
subprotocols=[websockets.Subprotocol("cbor")]
)

# async def signup(self, vars: Dict[str, Any]) -> str:
async def authenticate(self, token: str) -> dict:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return await self._send(message, "authenticating")

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]

async def signin(self, vars: Dict[str, Any]) -> str:
message = RequestMessage(
Expand All @@ -96,9 +117,25 @@ async def signin(self, vars: Dict[str, Any]) -> str:
response = await self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
if response.get("id") is None:
raise Exception(f"no id signing in: {response}")
self.id = response["id"]
return response["result"]

async def info(self) -> Optional[dict]:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
outcome = await self._send(message, "getting database information")
self.check_response_for_result(outcome, "getting database information")
return outcome["result"]

async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
self.id,
RequestMethod.USE,
namespace=namespace,
database=database,
)
await self._send(message, "use")

async def query(self, query: str, params: Optional[dict] = None) -> dict:
if params is None:
Expand All @@ -125,24 +162,6 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict:
response = await self._send(message, "query", bypass=True)
return response

async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
self.id,
RequestMethod.USE,
namespace=namespace,
database=database,
)
await self._send(message, "use")

async def info(self) -> Optional[dict]:
message = RequestMessage(
self.id,
RequestMethod.INFO
)
outcome = await self._send(message, "getting database information")
self.check_response_for_result(outcome, "getting database information")
return outcome["result"]

async def version(self) -> str:
message = RequestMessage(
self.id,
Expand All @@ -152,18 +171,6 @@ async def version(self) -> str:
self.check_response_for_result(response, "getting database version")
return response["result"]

async def authenticate(self, token: str) -> dict:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return await self._send(message, "authenticating")

async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")

async def let(self, key: str, value: Any) -> None:
message = RequestMessage(
self.id,
Expand Down Expand Up @@ -331,16 +338,6 @@ async def kill(self, query_uuid: Union[str, UUID]) -> None:
)
await self._send(message, "kill")

async def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]

async def upsert(
self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None
) -> Union[List[dict], dict]:
Expand Down
50 changes: 33 additions & 17 deletions src/surrealdb/connections/blocking_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ def _send(self, message: RequestMessage, operation: str, bypass: bool = False) -
def set_token(self, token: str) -> None:
self.token = token

def authenticate(self, token: str) -> dict:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
)
return self._send(message, "authenticating")

def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
self._send(message, "invalidating")
self.token = None

def signup(self, vars: Dict) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_UP,
data=vars
)
response = self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]

def signin(self, vars: dict) -> dict:
message = RequestMessage(
self.id,
Expand All @@ -65,9 +88,16 @@ def signin(self, vars: dict) -> dict:
response = self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
package = dict()
package["token"] = self.token
return package
return response["result"]

def info(self):
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
Expand Down Expand Up @@ -140,15 +170,6 @@ def delete(
self.check_response_for_result(response, "delete")
return response["result"]

def info(self):
message = RequestMessage(
self.id,
RequestMethod.INFO
)
response = self._send(message, "getting database information")
self.check_response_for_result(response, "getting database information")
return response["result"]

def insert(
self, table: Union[str, Table], data: Union[List[dict], dict]
) -> Union[List[dict], dict]:
Expand All @@ -175,11 +196,6 @@ def insert_relation(
self.check_response_for_result(response, "insert_relation")
return response["result"]

def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
self._send(message, "invalidating")
self.token = None

def let(self, key: str, value: Any) -> None:
self.vars[key] = value

Expand Down
Loading

0 comments on commit b8c097f

Please sign in to comment.