diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d09373b..abb1fe1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -57,7 +57,11 @@ jobs: ls -lah whoami env + echo "apk add:" + apk add python3-dev + echo "apk update:" apk update + echo "create venv:" python3 -m venv .env . .env/bin/activate && pip install -r requirements.txt . .env/bin/activate && pip install patchelf diff --git a/Cargo.toml b/Cargo.toml index e05b92f..3f2e24f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ log = "0.4" # pin mio until all dependencies are also on windows-sys 0.48 # https://github.com/microsoft/windows-rs/issues/2410#issuecomment-1490802715 mio = { version = "=0.8.6" } -ngrok = { version = "=0.14.0-pre.12" } +ngrok = { version = "=0.14.0-pre.13" } pyo3 = { version = "0.18.1", features = ["abi3", "abi3-py37", "extension-module", "multiple-pymethods"]} pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"] } pyo3-log = { version = "0.8.1" } diff --git a/README.md b/README.md index b5d10a5..7b1cd83 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,10 @@ listener = ngrok.forward( authtoken_from_env=True, app_protocol="http2", session_metadata="Online in One Line", + # advanced session connection configuration + server_addr="example.com:443", + root_cas="trusted", + session_ca_cert=load_file("ca.pem"), # listener configuration metadata="example listener metadata from python", domain="", diff --git a/requirements.txt b/requirements.txt index 98c3ffb..e94ea06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp==3.8.4 +aiohttp==3.9.5 black==23.3.0 furo==2022.12.7 maturin==0.14.16 diff --git a/src/connect.rs b/src/connect.rs index de30a07..a901576 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -252,14 +252,24 @@ fn configure_session(options: &Py) -> Result { plumb!(B, s_builder, cfg, authtoken); plumb_bool!(B, s_builder, cfg, authtoken_from_env); plumb!(B, s_builder, cfg, metadata, session_metadata); + plumb_vec!(B, s_builder, cfg, ca_cert, session_ca_cert, vecu8); + plumb!(B, s_builder, cfg, root_cas, root_cas); + plumb!(B, s_builder, cfg, server_addr, server_addr); Ok(s_builder.replace(SessionBuilder::new())) }) } async fn do_connect(options: Py) -> PyResult { + let force_new_session = Python::with_gil(|py| -> PyResult { + if let Some(v) = options.as_ref(py).get_item("force_new_session") { + return get_bool(v); + } + Ok(false) + })?; + // Using a singleton session for connect use cases let mut opt = SESSION.lock().await; - if opt.is_none() { + if opt.is_none() || force_new_session { opt.replace(configure_session(&options)?.async_connect().await?); } let session = opt.as_ref().unwrap(); diff --git a/src/session.rs b/src/session.rs index d661e61..272a544 100644 --- a/src/session.rs +++ b/src/session.rs @@ -272,7 +272,24 @@ impl SessionBuilder { /// .. _server_addr parameter in the ngrok docs: https://ngrok.com/docs/ngrok-agent/config#server_addr pub fn server_addr(self_: PyRefMut, addr: String) -> PyRefMut { self_.set(|b| { - b.server_addr(addr).expect("fixme"); + b.server_addr(&addr) + .unwrap_or_else(|_| panic!("failed to parse addr: {addr}")); + }); + self_ + } + + /// Sets the file path to a default certificate in PEM format to validate ngrok Session TLS connections. + /// Setting to "trusted" is the default, using the ngrok CA certificate. + /// Setting to "host" will verify using the certificates on the host operating system. + /// A client config set via tls_config after calling root_cas will override this value. + /// + /// Corresponds to the `root_cas parameter in the ngrok docs`_ + /// + /// .. _root_cas parameter in the ngrok docs: https://ngrok.com/docs/ngrok-agent/config#root_cas + pub fn root_cas(self_: PyRefMut, root_cas: String) -> PyRefMut { + self_.set(|b| { + b.root_cas(&root_cas) + .unwrap_or_else(|_| panic!("failed to invoke root_cas: {root_cas}")); }); self_ } diff --git a/test/test_connect.py b/test/test_connect.py index 5449aee..b4f0858 100644 --- a/test/test_connect.py +++ b/test/test_connect.py @@ -242,6 +242,38 @@ async def test_invalid_connect_policy(self): error = err self.assertIsInstance(error, ValueError) self.assertTrue("parse policy" in f"{error}") + shutdown(None, http_server) + + def test_root_cas(self): + http_server = test.make_http() + error = None + # tls error connecting to marketing site + try: + listener = ngrok.connect( + http_server.listen_to, + authtoken_from_env=True, + force_new_session=True, + root_cas="trusted", + server_addr="ngrok.com:443", + ) + except ValueError as err: + error = err + self.assertIsInstance(error, ValueError) + self.assertTrue("tls handshake" in f"{error}", error) + + # non-tls error connecting to marketing site with "host" root_cas + try: + listener = ngrok.connect( + http_server.listen_to, + authtoken_from_env=True, + force_new_session=True, + root_cas="host", + server_addr="ngrok.com:443", + ) + except ValueError as err: + error = err + self.assertIsInstance(error, ValueError) + self.assertFalse("tls handshake" in f"{error}", error) if __name__ == "__main__":