diff --git a/__tests__/authkit-provider.spec.tsx b/__tests__/authkit-provider.spec.tsx index 5175072..f26218c 100644 --- a/__tests__/authkit-provider.spec.tsx +++ b/__tests__/authkit-provider.spec.tsx @@ -165,6 +165,34 @@ describe('AuthKitProvider', () => { }); describe('useAuth', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should call getAuth when a user is not returned when ensureSignedIn is true', async () => { + // First and second calls return no user, second call returns a user + (getAuthAction as jest.Mock) + .mockResolvedValueOnce({ user: null, loading: true }) + .mockResolvedValueOnce({ user: { email: 'test@example.com' }, loading: false }); + + const TestComponent = () => { + const auth = useAuth({ ensureSignedIn: true }); + return
{auth.user?.email}
; + }; + + const { getByTestId } = render( + + + , + ); + + await waitFor(() => { + expect(getAuthAction).toHaveBeenCalledTimes(2); + expect(getAuthAction).toHaveBeenLastCalledWith(true); + expect(getByTestId('email')).toHaveTextContent('test@example.com'); + }); + }); + it('should throw error when used outside of AuthKitProvider', () => { const TestComponent = () => { const auth = useAuth(); diff --git a/src/components/authkit-provider.tsx b/src/components/authkit-provider.tsx index b61f7f4..af2171b 100644 --- a/src/components/authkit-provider.tsx +++ b/src/components/authkit-provider.tsx @@ -160,10 +160,18 @@ export const AuthKitProvider = ({ children, onSessionExpired }: AuthKitProviderP ); }; -export function useAuth() { +export function useAuth({ ensureSignedIn = false }: { ensureSignedIn?: boolean } = {}) { const context = useContext(AuthContext); + + useEffect(() => { + if (context && ensureSignedIn && !context.user && !context.loading) { + context.getAuth({ ensureSignedIn }); + } + }, [ensureSignedIn, context?.user, context?.loading, context?.getAuth]); + if (!context) { throw new Error('useAuth must be used within an AuthKitProvider'); } + return context; }