diff --git a/sense_energy/asyncsenseable.py b/sense_energy/asyncsenseable.py index fb781ce..7825014 100644 --- a/sense_energy/asyncsenseable.py +++ b/sense_energy/asyncsenseable.py @@ -1,6 +1,7 @@ import asyncio import ssl import sys +from functools import lru_cache from time import time import aiohttp @@ -15,6 +16,21 @@ else: from asyncio import timeout as asyncio_timeout + +@lru_cache(maxsize=None) +def get_ssl_context(ssl_verify: bool, ssl_cafile: str) -> ssl.SSLContext: + """Create or set the SSL context. Use custom ssl verification, if specified.""" + if not ssl_verify: + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif ssl_cafile: + ssl_context = ssl.create_default_context(cafile=ssl_cafile) + else: + ssl_context = ssl.create_default_context() + return ssl_context + + class ASyncSenseable(SenseableBase): def __init__( self, @@ -42,14 +58,7 @@ def __init__( def set_ssl_context(self, ssl_verify, ssl_cafile): """Create or set the SSL context. Use custom ssl verification, if specified.""" - if not ssl_verify: - self.ssl_context = ssl.create_default_context() - self.ssl_context.check_hostname = False - self.ssl_context.verify_mode = ssl.CERT_NONE - elif ssl_cafile: - self.ssl_context = ssl.create_default_context(cafile=ssl_cafile) - else: - self.ssl_context = ssl.create_default_context() + self.ssl_context = get_ssl_context(ssl_verify, ssl_cafile) async def authenticate(self, username, password, ssl_verify=True, ssl_cafile=""): """Authenticate with username (email) and password. Optionally set SSL context as well. @@ -59,9 +68,11 @@ async def authenticate(self, username, password, ssl_verify=True, ssl_cafile="") # Get auth token async with self._client_session.post( - API_URL + "authenticate", headers=self.headers, timeout=self.api_timeout, data=auth_data + API_URL + "authenticate", + headers=self.headers, + timeout=self.api_timeout, + data=auth_data, ) as resp: - # check MFA code required if resp.status == 401: data = await resp.json() @@ -91,9 +102,11 @@ async def validate_mfa(self, code): # Get auth token async with self._client_session.post( - API_URL + "authenticate/mfa", headers=self.headers, timeout=self.api_timeout, data=mfa_data + API_URL + "authenticate/mfa", + headers=self.headers, + timeout=self.api_timeout, + data=mfa_data, ) as resp: - # check for 200 return if resp.status != 200: raise SenseAuthenticationException(f"API Return Code: {resp.status}") @@ -111,9 +124,11 @@ async def renew_auth(self): # Get auth token async with self._client_session.post( - API_URL + "renew", headers=self.headers, timeout=self.api_timeout, data=renew_data + API_URL + "renew", + headers=self.headers, + timeout=self.api_timeout, + data=renew_data, ) as resp: - # check for 200 return if resp.status != 200: raise SenseAuthenticationException(f"API Return Code: {resp.status}") @@ -145,7 +160,8 @@ async def update_realtime(self, retry=True): async def async_realtime_stream(self, callback=None, single=False): """Reads realtime data from websocket. Data is passed to callback if available. - Continues reading realtime stream data forever unless 'single' is set to True.""" + Continues reading realtime stream data forever unless 'single' is set to True. + """ url = WS_URL % (self.sense_monitor_id, self.sense_access_token) # hello, features, [updates,] data async with websockets.connect(url, ssl=self.ssl_context) as ws: