From 711a4d22f3c3637ab9b33db86694aeafd14d6a02 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Tue, 19 Mar 2024 17:36:34 +0000 Subject: [PATCH] add support for dumping motd on session creation This patch adds support for one of the two motd modes that I plan to support. In this basic "dump" mode, we only display the message of the day when a user first creates a session. Since we are doing it right at the beginning, we can be confident we won't be mangling the restore buffer. Directly injecting the message of the day into the output stream is much more fraught on reconnect. I was hoping that this would be a simple change, but it would up being quite a bit more involved than I had hoped. The main reason for this is the interaction between the prompt prefix injection and message of the day injection. We need to make sure that motd injection happens after the prompt prefix shell code has finished executing so that it does not get clobbered, but unfortunately this is not as easy as just doing one after the other. With the naive approach there is a race condition where first we write the prefix injection shell code to the shell process, then we write out the message of the day, then the shell finishes processing the shell code and issues the terminal reset code emitted by the `clear` command at the end of the prompt prefix shell code. To deal with this, I started scanning for the control code emitted by `clear` in the output stream. I was able to re-use the efficient trie I wrote for the keybindings engine to do this. This addresses the first part of #5, but the issue is not resolved yet. I also realized that two different config variables to control this behavior leaves too much room for weird states. In particular, I worried about doing a direct dump during reattach. I decided it is better to do everything via one variable, where the mode implies when we actually show the motd. --- .github/workflows/nightly.yml | 3 +- .github/workflows/presubmit.yml | 25 ++-- Cargo.lock | 119 +++++++++++++++++ deny.toml | 11 +- libshpool/Cargo.toml | 2 + libshpool/README.md | 11 ++ libshpool/src/config.rs | 40 ++++++ libshpool/src/daemon/control_codes.rs | 83 ++++++++++++ libshpool/src/daemon/keybindings.rs | 176 +++----------------------- libshpool/src/daemon/mod.rs | 5 +- libshpool/src/daemon/prompt.rs | 12 +- libshpool/src/daemon/server.rs | 91 +++++++++++-- libshpool/src/daemon/shell.rs | 66 +++++++++- libshpool/src/daemon/show_motd.rs | 93 ++++++++++++++ libshpool/src/daemon/trie.rs | 146 +++++++++++++++++++++ libshpool/src/protocol.rs | 2 +- shpool/Cargo.toml | 1 + shpool/src/main.rs | 2 + shpool/tests/attach.rs | 150 ++++++++++++++++------ shpool/tests/daemon.rs | 2 +- shpool/tests/data/motd_dump.toml.tmpl | 11 ++ shpool/tests/detach.rs | 2 +- shpool/tests/kill.rs | 4 +- shpool/tests/support/daemon.rs | 8 +- shpool/tests/support/line_matcher.rs | 45 +++++++ 25 files changed, 866 insertions(+), 244 deletions(-) create mode 100644 libshpool/src/daemon/control_codes.rs create mode 100644 libshpool/src/daemon/show_motd.rs create mode 100644 libshpool/src/daemon/trie.rs create mode 100644 shpool/tests/data/motd_dump.toml.tmpl diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7b9e0ee6..6b7ab5c4 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -9,10 +9,11 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: channel: '1.74.0' bins: cargo-deny + - run: sudo apt-get install libpam0g-dev - run: cargo deny --all-features check postsubmit: diff --git a/.github/workflows/presubmit.yml b/.github/workflows/presubmit.yml index 1e940eae..4018895a 100644 --- a/.github/workflows/presubmit.yml +++ b/.github/workflows/presubmit.yml @@ -1,5 +1,5 @@ name: presubmit -on: [push, pull_request, workflow_call, workflow_dispatch] +on: [pull_request, workflow_call, workflow_dispatch] jobs: test: @@ -7,18 +7,21 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: channel: '1.74.0' - - run: sudo apt-get install zsh fish - - run: cargo test --all-features - - uses: actions/upload-artifact@v4 + - run: sudo apt-get install zsh fish libpam0g-dev + - run: SHPOOL_LEAVE_TEST_LOGS=true cargo test --all-features + - name: Archive Logs + if: always() + uses: actions/upload-artifact@v4 + id: artifact-upload-step with: name: test-logs path: /tmp/shpool-test*/*.log # miri does not handle all the IO we do, disabled for now. - # + # # miri: # name: cargo +nightly miri test # runs-on: ubuntu-22.04 @@ -36,10 +39,11 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: components: rustfmt channel: nightly + - run: sudo apt-get install libpam0g-dev - run: cargo +nightly fmt -- --check cranky: @@ -47,12 +51,12 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: components: clippy bins: cargo-cranky@0.3.0 channel: nightly - - run: sudo apt-get install zsh fish + - run: sudo apt-get install zsh fish libpam0g-dev - run: cargo +nightly cranky --all-targets -- -D warnings deny: @@ -60,8 +64,9 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: channel: '1.74.0' bins: cargo-deny + - run: sudo apt-get install libpam0g-dev - run: cargo deny --all-features check licenses diff --git a/Cargo.lock b/Cargo.lock index c55650ee..ce53a1e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -218,6 +218,35 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "dlopen2" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1297103d2bbaea85724fcee6294c2d50b1081f9ad47d0f6f6f61eda65315a6" +dependencies = [ + "dlopen2_derive", + "libc", + "once_cell", + "winapi", +] + +[[package]] +name = "dlopen2_derive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b99bf03862d7f545ebc28ddd33a665b50865f4dfd84031a393823879bd4c54" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + [[package]] name = "equivalent" version = "1.0.1" @@ -273,6 +302,15 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -346,6 +384,7 @@ dependencies = [ "lazy_static", "libc", "log", + "motd", "nix", "ntest", "serde", @@ -354,6 +393,7 @@ dependencies = [ "shpool_pty", "shpool_vt100", "signal-hook", + "termini", "toml", "tracing", "tracing-subscriber", @@ -386,6 +426,25 @@ dependencies = [ "autocfg", ] +[[package]] +name = "motd" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8b1327b3888ed1ed4b1f0ba708f1e5c4f66b6140a5816792fe05f643c3efc9d" +dependencies = [ + "dlopen2", + "lazy_static", + "libc", + "log", + "pam-sys", + "serde", + "serde_derive", + "serde_json", + "tempfile", + "walkdir", + "which", +] + [[package]] name = "nix" version = "0.26.4" @@ -447,6 +506,15 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "pam-sys" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd4858311a097f01a0006ef7d0cd50bca81ec430c949d7bf95cbefd202282434" +dependencies = [ + "libc", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -534,6 +602,15 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "serde" version = "1.0.197" @@ -598,6 +675,7 @@ dependencies = [ "crossbeam-channel", "lazy_static", "libshpool", + "motd", "nix", "ntest", "regex", @@ -693,6 +771,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "termini" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ad441d87dd98bc5eeb31cf2fb7e4839968763006b478efb38668a3bf9da0d59" +dependencies = [ + "home", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -849,6 +936,16 @@ dependencies = [ "quote", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasm-bindgen" version = "0.2.92" @@ -903,6 +1000,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "which" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fa5e0c10bf77f44aac573e498d1a82d5fbd5e91f6fc0a99e7be4b38e85e101c" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -919,6 +1029,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/deny.toml b/deny.toml index ac9cc415..e7b72e26 100644 --- a/deny.toml +++ b/deny.toml @@ -1,20 +1,19 @@ +version = 2 + +[graph] +all-features = true + [advisories] db-path = "~/.cargo/advisory-db" db-urls = ["https://github.com/rustsec/advisory-db"] -vulnerability = "deny" -unmaintained = "deny" yanked = "deny" -notice = "deny" ignore = [ ] [licenses] -unlicensed = "deny" allow = [ "Apache-2.0", "MIT", "Unicode-DFS-2016", ] -copyleft = "deny" -default = "deny" confidence-threshold = 1.0 diff --git a/libshpool/Cargo.toml b/libshpool/Cargo.toml index 9c0d8c1d..04f0cd2e 100644 --- a/libshpool/Cargo.toml +++ b/libshpool/Cargo.toml @@ -37,6 +37,8 @@ tracing = "0.1" # logging and performance monitoring facade bincode = "1" # serialization for the control protocol shpool_vt100 = "0.1.2" # terminal emulation for the scrollback buffer shell-words = "1" # parsing the -c/--cmd argument +motd = "0.2.0" # getting the message-of-the-day +termini = "1.0.0" # terminfo database [dependencies.tracing-subscriber] version = "0.3" diff --git a/libshpool/README.md b/libshpool/README.md index 9d87764f..63ded92c 100644 --- a/libshpool/README.md +++ b/libshpool/README.md @@ -8,3 +8,14 @@ to an internal google version of the tool, but don't believe that telemetry belongs in an open-source tool. Other potential use-cases such as incorporating a shpool daemon into an IDE that hosts remote terminals could be imagined though. + +## Integrating + +In order to call libshpool, you must keep a few things in mind. +In spirit, you just need to call `libshpool::run(libshpoo::Args::parse())`, +but you need to take care of a few things manually. + +1. Handle the `version` subcommand. Since libshpool is a library, the output + will not be very good if the library handles the versioning. +2. Depend on the `motd` crate and call `motd::handle_reexec()` in your `main` + function. diff --git a/libshpool/src/config.rs b/libshpool/src/config.rs index dcc47259..81518af7 100644 --- a/libshpool/src/config.rs +++ b/libshpool/src/config.rs @@ -110,6 +110,21 @@ pub struct Config { /// verbatim except that the string '$SHPOOL_SESSION_NAME' will /// get replaced with the actual name of the shpool session. pub prompt_prefix: Option, + + /// Control when and how shpool will display the message of the day. + pub motd: Option, + + /// Override arguments to pass to pam_motd.so when resolving the + /// message of the day. Normally, you want to leave this blank + /// so that shpool will scrape the default arguments used in + /// `/etc/pam.d/{ssh,login}` which typically produces the expected + /// result, but in some cases you may need to override the argument + /// list. You can also use this to make a custom message of the + /// day that is only displayed when using shpool. + /// + /// See https://man7.org/linux/man-pages/man8/pam_motd.8.html + /// for more info. + pub motd_args: Option>, } #[derive(Deserialize, Debug, Clone)] @@ -140,6 +155,31 @@ pub enum SessionRestoreMode { Lines(u16), } +#[derive(Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "lowercase")] +pub enum MotdDisplayMode { + /// Never display the message of the day. + #[default] + Never, + + /// Display the message of the day using the given program + /// as the pager. The pager will be invoked like `pager /tmp/motd.txt`, + /// and normal connection will only proceed once the pager has + /// exited. + /// + /// Display the message of the day each time a user attaches + /// (wether to a new session or reattaching to an existing session). + /// + /// `less` by default. + // Pager(String), + + /// Just dump the message of the day directly to the screen. + /// Dumps are only performed when a new session is created. + /// There is no safe way to dump directly when reattaching, + /// so we don't attempt it. + Dump, +} + #[cfg(test)] mod test { use super::*; diff --git a/libshpool/src/daemon/control_codes.rs b/libshpool/src/daemon/control_codes.rs new file mode 100644 index 00000000..11c1016a --- /dev/null +++ b/libshpool/src/daemon/control_codes.rs @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! The escape codes module provides an online (trie based) matcher +//! to scan for escape codes we are interested in in the output of +//! the subshell. For the moment, we just use this to scan for +//! the ClearScreen code emitted by the prompt prefix injection shell +//! code. We need to scan for this to avoid a race that can lead to +//! the motd getting clobbered when in dump mode. + +use anyhow::{anyhow, Context}; + +use super::trie::{Trie, TrieCursor}; + +#[derive(Debug, Clone, Copy)] +pub enum Code { + ClearScreen, +} + +#[derive(Debug)] +pub struct Matcher { + codes: Trie>>, + codes_cursor: TrieCursor, +} + +impl Matcher { + pub fn new(term_db: &termini::TermInfo) -> anyhow::Result { + let clear_code_bytes = match term_db.raw_string_cap(termini::StringCapability::ClearScreen) + { + Some(code) => Vec::from(code), + None => { + // If we somehow have a wacky terminfo db with no clear code, we fall + // back on xterm clear since we still need something to scan for. + let xterm_db = + termini::TermInfo::from_name("xterm").context("building fallback xterm db")?; + let code = xterm_db + .raw_string_cap(termini::StringCapability::ClearScreen) + .ok_or(anyhow!("no fallback clear screen code"))?; + Vec::from(code) + } + }; + + let raw_bindings = vec![ + // We need to scan for the clear code that gets emitted by the prompt prefix + // shell injection code so that we can make sure that the message of the day + // won't get clobbered immediately. + (clear_code_bytes, Code::ClearScreen), + ]; + let mut codes = Trie::new(); + for (raw_bytes, code) in raw_bindings.into_iter() { + codes.insert(raw_bytes.into_iter(), code); + } + + Ok(Matcher { codes, codes_cursor: TrieCursor::Start }) + } + + pub fn transition(&mut self, byte: u8) -> Option { + self.codes_cursor = self.codes.advance(self.codes_cursor, byte); + match self.codes_cursor { + TrieCursor::NoMatch => { + self.codes_cursor = TrieCursor::Start; + None + } + TrieCursor::Match { is_partial, .. } if !is_partial => { + let code = self.codes.get(self.codes_cursor).copied(); + self.codes_cursor = TrieCursor::Start; + code + } + _ => None, + } + } +} diff --git a/libshpool/src/daemon/keybindings.rs b/libshpool/src/daemon/keybindings.rs index e61504a9..b106802f 100644 --- a/libshpool/src/daemon/keybindings.rs +++ b/libshpool/src/daemon/keybindings.rs @@ -48,11 +48,13 @@ //! be singletons besides 'Ctrl' or of the form 'Ctrl-x' where //! x is some non-'Ctrl' key. -use std::{collections::HashMap, fmt, hash}; +use std::{collections::HashMap, fmt}; use anyhow::{anyhow, Context}; use serde_derive::Deserialize; +use super::trie::{Trie, TrieCursor, TrieTab}; + // // Keybindings table // @@ -96,6 +98,20 @@ pub enum BindingResult { #[derive(Eq, PartialEq, Copy, Clone, Hash)] struct ChordAtom(u8); +impl TrieTab for Vec> { + fn new() -> Self { + vec![None; u8::MAX as usize] + } + + fn get(&self, index: ChordAtom) -> Option<&usize> { + self[index.0 as usize].as_ref() + } + + fn set(&mut self, index: ChordAtom, elem: usize) { + self[index.0 as usize] = Some(elem) + } +} + impl Bindings { /// new builds a bindings matching engine, parsing the given binding->action /// mapping and compiling it into the pair of tries that we use to perform @@ -396,164 +412,6 @@ impl Lexer { } } -// -// Trie (used in both the parser and the execution engine) -// - -#[derive(Debug)] -struct Trie { - // The nodes which form the tree. The first node is the root - // node, afterwards the order is undefined. - nodes: Vec>, -} - -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -enum TrieCursor { - /// A cursor to use to start a char-wise match - Start, - /// Represents a state in the middle or end of a match - Match { idx: usize, is_partial: bool }, - /// A terminal state indicating a failure to match - NoMatch, -} - -#[derive(Debug)] -struct TrieNode { - // We need to store a phantom symbol here so we can have the - // Sym type parameter available for the TrieTab trait constraint - // in the impl block. Apologies for the type tetris. - phantom: std::marker::PhantomData, - value: Option, - tab: TT, -} - -impl Trie -where - TT: TrieTab, - Sym: Copy, -{ - fn new() -> Self { - Trie { nodes: vec![TrieNode::new(None)] } - } - - fn insert>(&mut self, seq: Seq, value: V) { - let mut current_node = 0; - for sym in seq { - current_node = if let Some(next_node) = self.nodes[current_node].tab.get(sym) { - *next_node - } else { - let idx = self.nodes.len(); - self.nodes.push(TrieNode::new(None)); - self.nodes[current_node].tab.set(sym, idx); - idx - }; - } - self.nodes[current_node].value = Some(value); - } - - #[allow(dead_code)] - fn contains>(&self, seq: Seq) -> bool { - let mut match_state = TrieCursor::Start; - for sym in seq { - match_state = self.advance(match_state, sym); - if let TrieCursor::NoMatch = match_state { - return false; - } - } - if let TrieCursor::Start = match_state { - return self.nodes[0].value.is_some(); - } - - if let TrieCursor::Match { is_partial, .. } = match_state { !is_partial } else { false } - } - - fn advance(&self, cursor: TrieCursor, sym: Sym) -> TrieCursor { - let node = match cursor { - TrieCursor::Start => &self.nodes[0], - TrieCursor::Match { idx, .. } => &self.nodes[idx], - TrieCursor::NoMatch => return TrieCursor::NoMatch, - }; - - if let Some(idx) = node.tab.get(sym) { - TrieCursor::Match { idx: *idx, is_partial: self.nodes[*idx].value.is_none() } - } else { - TrieCursor::NoMatch - } - } - - fn get(&self, cursor: TrieCursor) -> Option<&V> { - if let TrieCursor::Match { idx, .. } = cursor { - self.nodes[idx].value.as_ref() - } else { - None - } - } -} - -impl TrieNode -where - TT: TrieTab, -{ - fn new(value: Option) -> Self { - TrieNode { phantom: std::marker::PhantomData, value, tab: TT::new() } - } -} - -/// The backing table the trie uses to associate symbols with state -/// indexes. This is basically std::ops::IndexMut plus a new function. -/// We can't just make this a sub-trait of IndexMut because u8 does -/// not implement IndexMut for vectors. -trait TrieTab { - fn new() -> Self; - fn get(&self, index: Idx) -> Option<&usize>; - fn set(&mut self, index: Idx, elem: usize); -} - -impl TrieTab for HashMap -where - Sym: hash::Hash + Eq + PartialEq, -{ - fn new() -> Self { - HashMap::new() - } - - fn get(&self, index: Sym) -> Option<&usize> { - self.get(&index) - } - - fn set(&mut self, index: Sym, elem: usize) { - self.insert(index, elem); - } -} - -impl TrieTab for Vec> { - fn new() -> Self { - vec![None; u8::MAX as usize] - } - - fn get(&self, index: u8) -> Option<&usize> { - self[index as usize].as_ref() - } - - fn set(&mut self, index: u8, elem: usize) { - self[index as usize] = Some(elem) - } -} - -impl TrieTab for Vec> { - fn new() -> Self { - vec![None; u8::MAX as usize] - } - - fn get(&self, index: ChordAtom) -> Option<&usize> { - self[index.0 as usize].as_ref() - } - - fn set(&mut self, index: ChordAtom, elem: usize) { - self[index.0 as usize] = Some(elem) - } -} - // // Data Tables // diff --git a/libshpool/src/daemon/mod.rs b/libshpool/src/daemon/mod.rs index 00591da4..e1081ba6 100644 --- a/libshpool/src/daemon/mod.rs +++ b/libshpool/src/daemon/mod.rs @@ -19,14 +19,17 @@ use tracing::{info, instrument}; use super::{config, hooks}; +mod control_codes; mod etc_environment; mod exit_notify; pub mod keybindings; mod prompt; mod server; mod shell; +mod show_motd; mod signals; mod systemd; +mod trie; mod ttl_reaper; #[instrument(skip_all)] @@ -39,7 +42,7 @@ pub fn run( info!("\n\n======================== STARTING DAEMON ============================\n\n"); let config = config::read_config(&config_file)?; - let server = server::Server::new(config, hooks, runtime_dir); + let server = server::Server::new(config, hooks, runtime_dir)?; let (cleanup_socket, listener) = match systemd::activation_socket() { Ok(l) => { diff --git a/libshpool/src/daemon/prompt.rs b/libshpool/src/daemon/prompt.rs index 794e6aaa..99f722b5 100644 --- a/libshpool/src/daemon/prompt.rs +++ b/libshpool/src/daemon/prompt.rs @@ -29,9 +29,10 @@ pub fn inject_prefix( shell: &str, prompt_prefix: &str, session_name: &str, + needs_default_term: bool, ) -> anyhow::Result<()> { let prompt_prefix = prompt_prefix.replace("$SHPOOL_SESSION_NAME", session_name); - let script = if shell.ends_with("bash") { + let mut script = if shell.ends_with("bash") { format!( r#" if [[ -z "${{PROMPT_COMMAND+x}}" ]]; then @@ -46,7 +47,6 @@ pub fn inject_prefix( }} PROMPT_COMMAND=__shpool__prompt_command fi - clear "# ) } else if shell.ends_with("zsh") { @@ -62,7 +62,6 @@ pub fn inject_prefix( PROMPT="{prompt_prefix}${{PROMPT}}" }} precmd_functions+=(__shpool__prompt_command) - clear "# ) } else if shell.ends_with("fish") { @@ -70,13 +69,18 @@ pub fn inject_prefix( r#" functions --copy fish_prompt shpool__old_prompt function fish_prompt; echo -n "{prompt_prefix}"; shpool__old_prompt; end - clear "# ) } else { return Err(anyhow!("don't know how to inject a prefix for shell '{}'", shell)); }; + if needs_default_term { + script.push_str("\nTERM=xterm clear\n"); + } else { + script.push_str("\nclear\n"); + } + let mut pty_master = pty_master.is_parent().context("expected parent")?; pty_master.write_all(script.as_bytes()).context("running prefix script")?; diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index af621b2a..5f62e024 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -14,7 +14,9 @@ use std::{ collections::HashMap, - env, fs, io, net, + env, fs, io, + io::Write, + net, ops::Add, os, os::unix::{ @@ -36,7 +38,7 @@ use tracing::{error, info, instrument, span, trace, warn, Level}; use super::{ super::{config, consts, protocol, test_hooks, tty, user}, - etc_environment, hooks, prompt, shell, ttl_reaper, + etc_environment, hooks, prompt, shell, show_motd, ttl_reaper, }; use crate::daemon::exit_notify::ExitNotifier; @@ -56,6 +58,7 @@ pub struct Server { runtime_dir: PathBuf, register_new_reapable_session: crossbeam_channel::Sender<(String, Instant)>, hooks: Box, + motd_shower: Arc, } impl Server { @@ -64,7 +67,7 @@ impl Server { config: config::Config, hooks: Box, runtime_dir: PathBuf, - ) -> Arc { + ) -> anyhow::Result> { let shells = Arc::new(Mutex::new(HashMap::new())); // buffered so that we are unlikely to block when setting up a // new session @@ -76,13 +79,18 @@ impl Server { } }); - Arc::new(Server { + let motd_shower = Arc::new(show_motd::DailyMessenger::new( + config.motd.clone().unwrap_or_default(), + config.motd_args.clone(), + )?); + Ok(Arc::new(Server { config, shells, runtime_dir, register_new_reapable_session: new_sess_tx, hooks, - }) + motd_shower, + })) } #[instrument(skip_all)] @@ -241,11 +249,19 @@ impl Server { } if matches!(status, protocol::AttachStatus::Created { .. }) { + use config::MotdDisplayMode; + info!("creating new subshell"); if let Err(err) = self.hooks.on_new_session(&header.name) { warn!("new_session hook: {:?}", err); } - let session = self.spawn_subshell(conn_id, stream, &header)?; + let motd = self.config.motd.clone().unwrap_or_default(); + let session = self.spawn_subshell( + conn_id, + stream, + &header, + matches!(motd, MotdDisplayMode::Dump), + )?; shells.insert(header.name.clone(), Box::new(session)); // fallthrough to bidi streaming @@ -280,7 +296,8 @@ impl Server { } }; - let reply_status = write_reply(client_stream, protocol::AttachReplyHeader { status }); + let reply_status = + write_reply(client_stream, protocol::AttachReplyHeader { status: status.clone() }); if let Err(e) = reply_status { error!("error writing reply status: {:?}", e); } @@ -523,6 +540,7 @@ impl Server { conn_id: usize, client_stream: UnixStream, header: &protocol::AttachHeader, + dump_motd_on_new_session: bool, ) -> anyhow::Result { let user_info = user::info()?; let shell = if let Some(s) = &self.config.shell { @@ -563,7 +581,26 @@ impl Server { // to avoid breakage and vars the user has asked us to inject. .env_clear(); - self.inject_env(&mut cmd, &user_info, header).context("setting up shell env")?; + let term = self.inject_env(&mut cmd, &user_info, header).context("setting up shell env")?; + let term_db = Arc::new(if let Some(term) = &term { + termini::TermInfo::from_name(term).context("resolving terminfo")? + } else { + warn!("no $TERM, using default terminfo"); + match termini::TermInfo::from_env() { + Ok(db) => db, + Err(err) => { + warn!("could not get terminfo from env: {:?}", err); + match termini::TermInfo::from_name("xterm") { + Ok(db) => db, + Err(err) => { + warn!("could not get xterm terminfo: {:?}", err); + let empty_db = io::Cursor::new(vec![]); + termini::TermInfo::parse(empty_db).context("getting terminfo db")? + } + } + } + } + }); let shell_basename = if header.cmd.is_none() { // spawn the shell as a login shell by setting @@ -625,15 +662,37 @@ impl Server { info!("reaped child shell: {:?}", waitable_child); }); + let has_clear_screen = + term_db.raw_string_cap(termini::StringCapability::ClearScreen).is_some(); + let needs_default_term = !has_clear_screen || term.is_none(); + // inject the prompt prefix, if any + info!("injecting prompt prefix"); let prompt_prefix = self.config.prompt_prefix.clone().unwrap_or(String::from("")); if let Some(shell_basename) = shell_basename { if !prompt_prefix.is_empty() { - if let Err(err) = - prompt::inject_prefix(&mut fork, shell_basename, &prompt_prefix, &header.name) - { + if let Err(err) = prompt::inject_prefix( + &mut fork, + shell_basename, + &prompt_prefix, + &header.name, + needs_default_term, + ) { warn!("issue injecting prefix: {:?}", err); } + } else { + // issue a clear even if we don't have a prompt to inject for consistency + // and to simplify motd handling + let script = if needs_default_term { + "clear\n" + } else { + // If we don't have a $TERM value or we have some wacky $TERM value for which + // there is no ClearScreen code, set TERM to xterm so that we won't get a + // warning and will generate a code we can scan for. + "TERM=xterm clear\n" + }; + let mut pty_master = fork.is_parent().context("expected parent")?; + pty_master.write_all(script.as_bytes()).context("running initial clear")?; } } @@ -655,6 +714,9 @@ impl Server { client_stream: Some(client_stream), config: self.config.clone(), reader_join_h: None, + term_db, + motd_shower: Arc::clone(&self.motd_shower), + needs_initial_motd_dump: dump_motd_on_new_session, }; let child_pid = session_inner.pty_master.child_pid().ok_or(anyhow!("no child pid"))?; session_inner.reader_join_h = Some(session_inner.spawn_reader(shell::ReaderArgs { @@ -691,13 +753,14 @@ impl Server { }) } + /// Set up the environment for the shell, returning the right TERM value. #[instrument(skip_all)] fn inject_env( &self, cmd: &mut process::Command, user_info: &user::Info, header: &protocol::AttachHeader, - ) -> anyhow::Result<()> { + ) -> anyhow::Result> { cmd.env("HOME", &user_info.home_dir) .env( "PATH", @@ -753,7 +816,7 @@ impl Server { } } info!("injecting TERM into shell {:?}", term); - if let Some(t) = term { + if let Some(t) = &term { cmd.env("TERM", t); } @@ -780,7 +843,7 @@ impl Server { } } - Ok(()) + Ok(term) } fn ssh_auth_sock_symlink(&self, session_name: PathBuf) -> PathBuf { diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index 7441f0f4..989321e7 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -32,7 +32,7 @@ use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use crate::{ consts, - daemon::{config, exit_notify::ExitNotifier, keybindings}, + daemon::{config, control_codes, exit_notify::ExitNotifier, keybindings, show_motd}, protocol, test_hooks, tty, }; @@ -101,6 +101,9 @@ pub struct SessionInner { pub pty_master: shpool_pty::fork::Fork, pub client_stream: Option, pub config: config::Config, + pub term_db: Arc, + pub motd_shower: Arc, + pub needs_initial_motd_dump: bool, /// The join handle for the always-on background reader thread. /// Only wrapped in an option so we can spawn the thread after @@ -189,6 +192,13 @@ impl SessionInner { ) -> anyhow::Result>> { use nix::poll; + let term_db = Arc::clone(&self.term_db); + let mut control_code_matcher = + control_codes::Matcher::new(&self.term_db).context("building control code matcher")?; + + let motd_shower = Arc::clone(&self.motd_shower); + let mut needs_initial_motd_dump = self.needs_initial_motd_dump; + let mut pty_master = self.pty_master.is_parent()?; let name = self.name.clone(); let mut closure = move || { @@ -342,7 +352,7 @@ impl SessionInner { .context("sending size change ack")?; } Err(err) => { - info!("size change: bailing due to: {:?}", err); + warn!("size change: bailing due to: {:?}", err); return Ok(()); } } @@ -452,7 +462,55 @@ impl SessionInner { } } + // scan for control codes we need to handle let mut reset_client_conn = false; + let mut snip_buf_to = None; + if needs_initial_motd_dump { + for (i, byte) in buf[..len].iter().enumerate() { + match control_code_matcher.transition(*byte) { + Some(control_codes::Code::ClearScreen) if needs_initial_motd_dump => { + debug!("detected initial ClearScreen code"); + if let ClientConnectionMsg::New(conn) = &client_conn { + let mut s = conn.sink.lock().unwrap(); + + // write the clear code ahead of time so we don't + // immediately clobber ourselves + let write_to = i + 1; + let chunk = protocol::Chunk { + kind: protocol::ChunkKind::Data, + buf: &buf[..write_to], + }; + let write_result = + chunk.write_to(&mut *s).and_then(|_| s.flush()); + if let Err(err) = write_result { + info!( + "while writing ClearScreen: client_stream write err, assuming hangup: {:?}", + err + ); + reset_client_conn = true; + } else { + test_hooks::emit("daemon-wrote-s2c-chunk"); + } + snip_buf_to = Some(write_to); + + if let Err(e) = motd_shower.dump(&mut *s, &term_db) { + warn!("Error handling clear: {:?}", e); + } + } + needs_initial_motd_dump = false; + } + _ => {} + } + } + } + if let Some(snip_to) = snip_buf_to { + if snip_to < buf.len() { + buf = Vec::from(&buf[snip_to..]); + } else { + buf.clear(); + } + } + if let ClientConnectionMsg::New(conn) = &client_conn { let chunk = protocol::Chunk { kind: protocol::ChunkKind::Data, buf: &buf[..len] }; @@ -718,7 +776,9 @@ impl SessionInner { NoMatch => { partial_keybinding.clear(); } - Partial => partial_keybinding.push(*byte), + Partial => { + partial_keybinding.push(*byte); + } Match(action) => { info!("{:?} keybinding action fired", action); let keybinding_len = partial_keybinding.len() + 1; diff --git a/libshpool/src/daemon/show_motd.rs b/libshpool/src/daemon/show_motd.rs new file mode 100644 index 00000000..f2d192d1 --- /dev/null +++ b/libshpool/src/daemon/show_motd.rs @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io; + +use anyhow::{anyhow, Context}; + +use super::super::{config, protocol}; + +/// Showers know how to show the message of the day. +#[derive(Debug, Clone)] +pub struct DailyMessenger { + motd_resolver: motd::Resolver, + mode: config::MotdDisplayMode, + args: Option>, +} + +impl DailyMessenger { + /// Make a new Shower. + pub fn new(mode: config::MotdDisplayMode, args: Option>) -> anyhow::Result { + Ok(DailyMessenger { + motd_resolver: motd::Resolver::new(motd::PamMotdResolutionStrategy::Auto) + .context("creating motd resolver")?, + mode, + args, + }) + } + + pub fn dump( + &self, + mut stream: W, + term_db: &termini::TermInfo, + ) -> anyhow::Result<()> { + assert!(matches!(self.mode, config::MotdDisplayMode::Dump)); + + let raw_motd_value = self.get_raw_motd_value(term_db)?; + + let chunk = + protocol::Chunk { kind: protocol::ChunkKind::Data, buf: raw_motd_value.as_slice() }; + + chunk.write_to(&mut stream).context("dumping motd") + } + + fn get_raw_motd_value(&self, term_db: &termini::TermInfo) -> anyhow::Result> { + let motd_value = self + .motd_resolver + .value(match &self.args { + Some(args) => { + let mut args = args.clone(); + // On debian based systems we need to set noupdate in order to get + // the motd from userspace. It should be ignored on non-debian systems. + if !args.iter().any(|a| a == "noupdate") { + args.push(String::from("noupdate")); + } + motd::ArgResolutionStrategy::Exact(args) + } + None => motd::ArgResolutionStrategy::Auto, + }) + .context("resolving motd")?; + Self::convert_to_raw(term_db, &motd_value) + } + + /// Convert the given motd into a byte buffer suitable to be written to the + /// terminal. The only real transformation we perform is injecting carrage + /// returns after newlines. + fn convert_to_raw(term_db: &termini::TermInfo, motd: &str) -> anyhow::Result> { + let carrage_return_code = term_db + .raw_string_cap(termini::StringCapability::CarriageReturn) + .ok_or(anyhow!("no carrage return code"))?; + + let mut buf: Vec = vec![]; + + let lines = motd.split('\n'); + for line in lines { + buf.extend(line.as_bytes()); + buf.push(b'\n'); + buf.extend(carrage_return_code); + } + + Ok(buf) + } +} diff --git a/libshpool/src/daemon/trie.rs b/libshpool/src/daemon/trie.rs new file mode 100644 index 00000000..faccf8f0 --- /dev/null +++ b/libshpool/src/daemon/trie.rs @@ -0,0 +1,146 @@ +use std::{collections::HashMap, hash}; + +#[derive(Debug)] +pub struct Trie { + // The nodes which form the tree. The first node is the root + // node, afterwards the order is undefined. + nodes: Vec>, +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum TrieCursor { + /// A cursor to use to start a char-wise match + Start, + /// Represents a state in the middle or end of a match + Match { idx: usize, is_partial: bool }, + /// A terminal state indicating a failure to match + NoMatch, +} + +#[derive(Debug)] +pub struct TrieNode { + // We need to store a phantom symbol here so we can have the + // Sym type parameter available for the TrieTab trait constraint + // in the impl block. Apologies for the type tetris. + phantom: std::marker::PhantomData, + value: Option, + tab: TT, +} + +impl Trie +where + TT: TrieTab, + Sym: Copy, +{ + pub fn new() -> Self { + Trie { nodes: vec![TrieNode::new(None)] } + } + + /// Insert a seq, value pair into the trie + pub fn insert>(&mut self, seq: Seq, value: V) { + let mut current_node = 0; + for sym in seq { + current_node = if let Some(next_node) = self.nodes[current_node].tab.get(sym) { + *next_node + } else { + let idx = self.nodes.len(); + self.nodes.push(TrieNode::new(None)); + self.nodes[current_node].tab.set(sym, idx); + idx + }; + } + self.nodes[current_node].value = Some(value); + } + + /// Check if the given sequence exists in the trie, used by tests. + #[allow(dead_code)] + pub fn contains>(&self, seq: Seq) -> bool { + let mut match_state = TrieCursor::Start; + for sym in seq { + match_state = self.advance(match_state, sym); + if let TrieCursor::NoMatch = match_state { + return false; + } + } + if let TrieCursor::Start = match_state { + return self.nodes[0].value.is_some(); + } + + if let TrieCursor::Match { is_partial, .. } = match_state { !is_partial } else { false } + } + + /// Process a single token of input, returning the current state. + /// To start a new match, use TrieCursor::Start. + pub fn advance(&self, cursor: TrieCursor, sym: Sym) -> TrieCursor { + let node = match cursor { + TrieCursor::Start => &self.nodes[0], + TrieCursor::Match { idx, .. } => &self.nodes[idx], + TrieCursor::NoMatch => return TrieCursor::NoMatch, + }; + + if let Some(idx) = node.tab.get(sym) { + TrieCursor::Match { idx: *idx, is_partial: self.nodes[*idx].value.is_none() } + } else { + TrieCursor::NoMatch + } + } + + /// Get the value for a match cursor. + pub fn get(&self, cursor: TrieCursor) -> Option<&V> { + if let TrieCursor::Match { idx, .. } = cursor { + self.nodes[idx].value.as_ref() + } else { + None + } + } +} + +impl TrieNode +where + TT: TrieTab, +{ + fn new(value: Option) -> Self { + TrieNode { phantom: std::marker::PhantomData, value, tab: TT::new() } + } +} + +/// The backing table the trie uses to associate symbols with state +/// indexes. This is basically std::ops::IndexMut plus a new function. +/// We can't just make this a sub-trait of IndexMut because u8 does +/// not implement IndexMut for vectors. +pub trait TrieTab { + fn new() -> Self; + fn get(&self, index: Idx) -> Option<&usize>; + fn set(&mut self, index: Idx, elem: usize); +} + +impl TrieTab for HashMap +where + Sym: hash::Hash + Eq + PartialEq, +{ + fn new() -> Self { + HashMap::new() + } + + fn get(&self, index: Sym) -> Option<&usize> { + self.get(&index) + } + + fn set(&mut self, index: Sym, elem: usize) { + self.insert(index, elem); + } +} + +impl TrieTab for Vec> { + fn new() -> Self { + vec![None; u8::MAX as usize] + } + + fn get(&self, index: u8) -> Option<&usize> { + self[index as usize].as_ref() + } + + fn set(&mut self, index: u8, elem: usize) { + self[index as usize] = Some(elem) + } +} diff --git a/libshpool/src/protocol.rs b/libshpool/src/protocol.rs index da8ec5a3..2fe03aba 100644 --- a/libshpool/src/protocol.rs +++ b/libshpool/src/protocol.rs @@ -218,7 +218,7 @@ impl fmt::Display for SessionStatus { } /// AttachStatus indicates what happened during an attach attempt. -#[derive(PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(PartialEq, Eq, Serialize, Deserialize, Debug, Clone)] pub enum AttachStatus { /// Attached indicates that there was an existing shell session with /// the given name, and `shpool attach` successfully connected to it. diff --git a/shpool/Cargo.toml b/shpool/Cargo.toml index 7f5c7ecc..f2a43533 100644 --- a/shpool/Cargo.toml +++ b/shpool/Cargo.toml @@ -19,6 +19,7 @@ rust-version = "1.74" clap = { version = "4", features = ["derive"] } # cli parsing anyhow = "1" # dynamic, unstructured errors libshpool = { version = "0.5.0", path = "../libshpool" } +motd = "0.2.0" # getting the message-of-the-day [dev-dependencies] lazy_static = "1" # globals diff --git a/shpool/src/main.rs b/shpool/src/main.rs index d2852d97..fbd306d6 100644 --- a/shpool/src/main.rs +++ b/shpool/src/main.rs @@ -20,6 +20,8 @@ use clap::Parser; const VERSION: &str = env!("CARGO_PKG_VERSION"); fn main() -> anyhow::Result<()> { + motd::handle_reexec(); + let args = libshpool::Args::parse(); if args.version() { diff --git a/shpool/tests/attach.rs b/shpool/tests/attach.rs index b5050d0e..a8d1cc3a 100644 --- a/shpool/tests/attach.rs +++ b/shpool/tests/attach.rs @@ -1,7 +1,8 @@ use std::{ - fs, + env, fs, io::BufRead, - io::Read, + io::{Read, Write}, + path::PathBuf, process::{Command, Stdio}, thread, time, }; @@ -32,7 +33,7 @@ fn happy_path() -> anyhow::Result<()> { daemon_proc.await_event("daemon-about-to-listen")?; attach_proc.run_cmd("echo hi")?; - line_matcher.match_re("hi$")?; + line_matcher.scan_until_re("hi$")?; attach_proc.run_cmd("echo ping")?; line_matcher.match_re("ping$")?; @@ -105,7 +106,7 @@ fn forward_env() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd(r#"echo "$FOO:$BAR:$BAZ" "#)?; - line_matcher.match_re("foo:bar:$")?; + line_matcher.scan_until_re("foo:bar:$")?; Ok(()) }) @@ -139,7 +140,8 @@ fn symlink_ssh_auth_sock() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt redraw attach_proc.run_cmd("ls -l $SSH_AUTH_SOCK")?; - line_matcher.match_re(r#".*sh1/ssh-auth-sock.socket ->.*ssh-auth-sock-target.fake$"#)?; + line_matcher + .scan_until_re(r#".*sh1/ssh-auth-sock.socket ->.*ssh-auth-sock-target.fake$"#)?; Ok(()) }) @@ -163,7 +165,7 @@ fn missing_ssh_auth_sock() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt re-draw attach_proc.run_cmd("ls -l $SSH_AUTH_SOCK")?; - line_matcher.match_re(r#".*No such file or directory$"#)?; + line_matcher.scan_until_re(r#".*No such file or directory$"#)?; Ok(()) }) @@ -220,7 +222,7 @@ fn config_disable_symlink_ssh_auth_sock() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt re-draw attach_proc.run_cmd("ls -l $SSH_AUTH_SOCK")?; - line_matcher.match_re(r#".*No such file or directory$"#)?; + line_matcher.scan_until_re(r#".*No such file or directory$"#)?; Ok(()) }) @@ -244,7 +246,7 @@ fn bounce() -> anyhow::Result<()> { attach_proc.run_cmd("export MYVAR=1")?; attach_proc.run_cmd("echo $MYVAR")?; - line_matcher.match_re("1$")?; + line_matcher.scan_until_re("1$")?; } // falling out of scope kills attach_proc // wait until the daemon has noticed that the connection @@ -281,10 +283,10 @@ fn two_at_once() -> anyhow::Result<()> { let mut line_matcher2 = attach_proc2.line_matcher()?; attach_proc1.run_cmd("echo proc1").context("proc1 echo")?; - line_matcher1.match_re("proc1$").context("proc1 match")?; + line_matcher1.scan_until_re("proc1$").context("proc1 match")?; attach_proc2.run_cmd("echo proc2").context("proc2 echo")?; - line_matcher2.match_re("proc2$").context("proc2 match")?; + line_matcher2.scan_until_re("proc2$").context("proc2 match")?; Ok(()) }) @@ -308,7 +310,7 @@ fn explicit_exit() -> anyhow::Result<()> { attach_proc.run_cmd("export MYVAR=first")?; attach_proc.run_cmd("echo $MYVAR")?; - line_matcher.match_re("first$")?; + line_matcher.scan_until_re("first$")?; attach_proc.run_cmd("exit")?; @@ -324,7 +326,7 @@ fn explicit_exit() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo ${MYVAR:-second}")?; - line_matcher.match_re("second$")?; + line_matcher.scan_until_re("second$")?; } Ok(()) @@ -426,7 +428,7 @@ fn force_attach() -> anyhow::Result<()> { tty1.run_cmd("echo $MYVAR")?; // read some output to make sure the var is set by the time // we force-attach - line_matcher1.match_re("set_from_tty1$")?; + line_matcher1.scan_until_re("set_from_tty1$")?; let mut tty2 = daemon_proc .attach("sh1", AttachArgs { force: true, ..Default::default() }) @@ -450,12 +452,12 @@ fn busy() -> anyhow::Result<()> { daemon_proc.attach("sh1", Default::default()).context("attaching from tty1")?; let mut line_matcher1 = tty1.line_matcher()?; tty1.run_cmd("echo foo")?; // make sure the shell is up and running - line_matcher1.match_re("foo$")?; + line_matcher1.scan_until_re("foo$")?; let mut tty2 = daemon_proc.attach("sh1", Default::default()).context("attaching from tty2")?; let mut line_matcher2 = tty2.stderr_line_matcher()?; - line_matcher2.match_re("already has a terminal attached$")?; + line_matcher2.scan_until_re("already has a terminal attached$")?; Ok(()) }) @@ -473,7 +475,7 @@ fn daemon_hangup() -> anyhow::Result<()> { // make sure the shell is up and running let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo foo")?; - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; daemon_proc.proc_kill()?; @@ -498,7 +500,7 @@ fn default_keybinding_detach() -> anyhow::Result<()> { a1.run_cmd("export MYVAR=someval")?; a1.run_cmd("echo $MYVAR")?; - lm1.match_re("someval$")?; + lm1.scan_until_re("someval$")?; a1.run_raw_cmd(vec![0, 17])?; // Ctrl-Space Ctrl-q a1.proc.wait()?; @@ -510,7 +512,7 @@ fn default_keybinding_detach() -> anyhow::Result<()> { let mut lm2 = a2.line_matcher()?; a2.run_cmd("echo $MYVAR")?; - lm2.match_re("someval$")?; + lm2.scan_until_re("someval$")?; Ok(()) }) @@ -532,7 +534,7 @@ fn keybinding_input_shear() -> anyhow::Result<()> { a1.run_cmd("export MYVAR=someval")?; a1.run_cmd("echo $MYVAR")?; - lm1.match_re("someval$")?; + lm1.scan_until_re("someval$")?; a1.run_raw(vec![0])?; // Ctrl-Space thread::sleep(time::Duration::from_millis(100)); @@ -546,7 +548,7 @@ fn keybinding_input_shear() -> anyhow::Result<()> { let mut lm2 = a2.line_matcher()?; a2.run_cmd("echo $MYVAR")?; - lm2.match_re("someval$")?; + lm2.scan_until_re("someval$")?; Ok(()) }) @@ -564,7 +566,7 @@ fn keybinding_strip_keys() -> anyhow::Result<()> { // the keybinding is 5 'a' chars in a row, so they should get stripped out a1.run_cmd("echo baaaaad")?; - lm1.match_re("bd$")?; + lm1.scan_until_re("bd$")?; Ok(()) }) @@ -586,7 +588,7 @@ fn keybinding_strip_keys_split() -> anyhow::Result<()> { a1.run_raw("aa".bytes().collect())?; thread::sleep(time::Duration::from_millis(50)); a1.run_raw("aad\n".bytes().collect())?; - lm1.match_re("bd$")?; + lm1.scan_until_re("bd$")?; Ok(()) }) @@ -604,7 +606,7 @@ fn keybinding_partial_match_nostrip() -> anyhow::Result<()> { // the keybinding is 5 'a' chars in a row, this has only 3 a1.run_cmd("echo baaad")?; - lm1.match_re("baaad$")?; + lm1.scan_until_re("baaad$")?; Ok(()) }) @@ -626,7 +628,7 @@ fn keybinding_partial_match_nostrip_split() -> anyhow::Result<()> { a1.run_raw("a".bytes().collect())?; thread::sleep(time::Duration::from_millis(50)); a1.run_raw("ad\n".bytes().collect())?; - lm1.match_re("baaad$")?; + lm1.scan_until_re("baaad$")?; Ok(()) }) @@ -646,7 +648,7 @@ fn custom_keybinding_detach() -> anyhow::Result<()> { a1.run_cmd("export MYVAR=someval")?; a1.run_cmd("echo $MYVAR")?; - lm1.match_re("someval$")?; + lm1.scan_until_re("someval$")?; a1.run_raw_cmd(vec![22, 23, 7])?; // Ctrl-v Ctrl-w Ctrl-g a1.proc.wait()?; @@ -686,7 +688,7 @@ fn injects_term_even_with_env_config() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt redraw attach_proc.run_cmd("echo $SOME_CUSTOM_ENV_VAR")?; - line_matcher.match_re("customvalue$")?; + line_matcher.scan_until_re("customvalue$")?; attach_proc.run_cmd("echo $TERM")?; line_matcher.match_re("dumb$")?; @@ -716,7 +718,7 @@ fn injects_local_env_vars() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo $DISPLAY")?; - line_matcher.match_re(":0$")?; + line_matcher.scan_until_re(":0$")?; attach_proc.run_cmd("echo $LANG")?; line_matcher.match_re("fakelang$")?; @@ -737,7 +739,7 @@ fn has_right_default_path() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo $PATH")?; - line_matcher.match_re("/usr/bin:/bin:/usr/sbin:/sbin$")?; + line_matcher.scan_until_re("/usr/bin:/bin:/usr/sbin:/sbin$")?; Ok(()) }) @@ -757,7 +759,7 @@ fn screen_restore() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo foo")?; - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; } // wait until the daemon has noticed that the connection @@ -771,7 +773,7 @@ fn screen_restore() -> anyhow::Result<()> { // the re-attach should redraw the screen for us, so we should // get a line with "foo" as part of the re-drawn screen. - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; } Ok(()) @@ -792,7 +794,7 @@ fn screen_wide_restore() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy")?; - line_matcher.match_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; + line_matcher.scan_until_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; } // wait until the daemon has noticed that the connection @@ -806,7 +808,7 @@ fn screen_wide_restore() -> anyhow::Result<()> { // the re-attach should redraw the screen for us, so we should // get a line with the full echo result as part of the re-drawn screen. - line_matcher.match_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; + line_matcher.scan_until_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; } Ok(()) @@ -827,7 +829,8 @@ fn lines_restore() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo foo")?; - line_matcher.match_re("foo$")?; + attach_proc.run_cmd("echo")?; + line_matcher.scan_until_re("foo$")?; } // wait until the daemon has noticed that the connection @@ -841,7 +844,7 @@ fn lines_restore() -> anyhow::Result<()> { // the re-attach should redraw the last 2 lines for us, so we should // get a line with "foo" as part of the re-drawn screen. - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; } Ok(()) @@ -871,7 +874,7 @@ fn lines_big_chunk_restore() -> anyhow::Result<()> { // for a single chunk let blob = format!("echo {}", (0..max_chunk_size).map(|_| "x").collect::()); attach_proc.run_cmd(blob.as_str())?; - line_matcher.match_re("xx$")?; + line_matcher.scan_until_re("xx$")?; attach_proc.run_cmd("echo food")?; line_matcher.match_re("food$")?; @@ -939,7 +942,7 @@ fn ttl_hangup() -> anyhow::Result<()> { // ensure the shell is up and running let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo hi")?; - line_matcher.match_re("hi$")?; + line_matcher.scan_until_re("hi$")?; // sleep long enough for the reaper to clobber the thread thread::sleep(time::Duration::from_millis(1200)); @@ -967,7 +970,7 @@ fn ttl_no_hangup_yet() -> anyhow::Result<()> { // ensure the shell is up and running let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo hi")?; - line_matcher.match_re("hi$")?; + line_matcher.scan_until_re("hi$")?; let listout = daemon_proc.list()?; assert!(String::from_utf8_lossy(listout.stdout.as_slice()).contains("sh1")); @@ -997,7 +1000,7 @@ fn prompt_prefix_bash() -> anyhow::Result<()> { .arg("attach") .arg("sh1") .spawn() - .context("spawning daemon process")?; + .context("spawning attach process")?; // The attach shell should be spawned and have read the // initial prompt after half a second. @@ -1037,7 +1040,7 @@ fn prompt_prefix_zsh() -> anyhow::Result<()> { .arg("attach") .arg("sh1") .spawn() - .context("spawning daemon process")?; + .context("spawning attach process")?; // The attach shell should be spawned and have read the // initial prompt after half a second. @@ -1077,7 +1080,7 @@ fn prompt_prefix_fish() -> anyhow::Result<()> { .arg("attach") .arg("sh1") .spawn() - .context("spawning daemon process")?; + .context("spawning attach process")?; // The attach shell should be spawned and have read the // initial prompt after half a second. @@ -1099,6 +1102,73 @@ fn prompt_prefix_fish() -> anyhow::Result<()> { }) } +#[test] +#[timeout(30000)] +fn motd_dump() -> anyhow::Result<()> { + support::dump_err(|| { + // set up the config + let tmp_dir = tempfile::TempDir::with_prefix("shpool-test-config")?; + let tmp_dir_path = if env::var("SHPOOL_LEAVE_TEST_LOGS").is_ok() { + // leave the tmp files around for later inspection if we have been asked + // to leave the logs in place. + tmp_dir.into_path() + } else { + PathBuf::from(tmp_dir.path()) + }; + eprintln!("building config in {:?}", tmp_dir_path); + let motd_file = tmp_dir_path.join("motd.txt"); + { + let mut f = fs::File::create(&motd_file)?; + f.write_all("MOTD_MSG\n".as_bytes())?; + } + let config_tmpl = fs::read_to_string(support::testdata_file("motd_dump.toml.tmpl"))?; + let config_contents = config_tmpl.replace("TMP_MOTD_MSG_FILE", motd_file.to_str().unwrap()); + let config_file = tmp_dir_path.join("motd_dump.toml"); + { + let mut f = fs::File::create(&config_file)?; + f.write_all(config_contents.as_bytes())?; + } + + // spawn a daemon based on our custom config + let daemon_proc = + support::daemon::Proc::new(&config_file, true).context("starting daemon proc")?; + + // We need to manually spawn our attach proc because + // the motd gets printed immediately, so we can't always + // attach a line matcher in time. + let mut child = Command::new(support::shpool_bin()?) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg("--socket") + .arg(&daemon_proc.socket_path) + .arg("--config-file") + .arg(config_file) + .arg("attach") + .arg("sh1") + .spawn() + .context("spawning attach process")?; + + // The attach shell should be spawned and have read the + // initial prompt after half a second. + std::thread::sleep(time::Duration::from_millis(500)); + child.kill().context("killing child")?; + + let mut stderr = child.stderr.take().context("missing stderr")?; + let mut stderr_str = String::from(""); + stderr.read_to_string(&mut stderr_str).context("slurping stderr")?; + assert!(stderr_str.is_empty()); + + let mut stdout = child.stdout.take().context("missing stdout")?; + let mut stdout_str = String::from(""); + stdout.read_to_string(&mut stdout_str).context("slurping stdout")?; + let stdout_re = Regex::new(".*MOTD_MSG.*")?; + // eprintln!("stdout_str='{}'", stdout_str); + assert!(stdout_re.is_match(&stdout_str)); + + Ok(()) + }) +} + #[ignore] // TODO: re-enable, this test if flaky #[test] fn up_arrow_no_crash() -> anyhow::Result<()> { diff --git a/shpool/tests/daemon.rs b/shpool/tests/daemon.rs index e651c009..30d79d37 100644 --- a/shpool/tests/daemon.rs +++ b/shpool/tests/daemon.rs @@ -215,7 +215,7 @@ fn hooks() -> anyhow::Result<()> { // sequencing let mut sh1_matcher = sh1_proc.line_matcher()?; sh1_proc.run_cmd("echo hi")?; - sh1_matcher.match_re("hi$")?; + sh1_matcher.scan_until_re("hi$")?; // 1 busy let mut busy_proc = diff --git a/shpool/tests/data/motd_dump.toml.tmpl b/shpool/tests/data/motd_dump.toml.tmpl new file mode 100644 index 00000000..8b4feae5 --- /dev/null +++ b/shpool/tests/data/motd_dump.toml.tmpl @@ -0,0 +1,11 @@ +norc = true +noecho = true +shell = "/bin/bash" +session_restore_mode = "simple" + +motd = "dump" +motd_args = ["motd=TMP_MOTD_MSG_FILE"] + +[env] +PS1 = "prompt> " +TERM = "xterm" diff --git a/shpool/tests/detach.rs b/shpool/tests/detach.rs index abe046bc..dd139b4a 100644 --- a/shpool/tests/detach.rs +++ b/shpool/tests/detach.rs @@ -126,7 +126,7 @@ fn reattach() -> anyhow::Result<()> { let mut lm1 = sess1.line_matcher()?; sess1.run_cmd("export MYVAR=first ; echo hi")?; - lm1.match_re("hi$")?; + lm1.scan_until_re("hi$")?; let out = daemon_proc.detach(vec![String::from("sh1")])?; assert!(out.status.success(), "not successful"); diff --git a/shpool/tests/kill.rs b/shpool/tests/kill.rs index 10bb8f48..5d9176dd 100644 --- a/shpool/tests/kill.rs +++ b/shpool/tests/kill.rs @@ -120,7 +120,7 @@ fn reattach_after_kill() -> anyhow::Result<()> { let mut lm1 = sess1.line_matcher()?; sess1.run_cmd("export MYVAR=first")?; sess1.run_cmd("echo $MYVAR")?; - lm1.match_re("first$")?; + lm1.scan_until_re("first$")?; let out = daemon_proc.kill(vec![String::from("sh1")])?; assert!(out.status.success()); @@ -138,7 +138,7 @@ fn reattach_after_kill() -> anyhow::Result<()> { daemon_proc.attach("sh1", Default::default()).context("starting attach proc")?; let mut lm2 = sess2.line_matcher()?; sess2.run_cmd("echo ${MYVAR:-second}")?; - lm2.match_re("second$")?; + lm2.scan_until_re("second$")?; Ok(()) }) diff --git a/shpool/tests/support/daemon.rs b/shpool/tests/support/daemon.rs index 7ae97e27..5f3c3830 100644 --- a/shpool/tests/support/daemon.rs +++ b/shpool/tests/support/daemon.rs @@ -102,6 +102,12 @@ impl Proc { let log_file = tmp_dir.join("daemon.log"); eprintln!("spawning daemon proc with log {:?}", &log_file); + let resolved_config = if config.as_ref().exists() { + PathBuf::from(config.as_ref()) + } else { + testdata_file(config) + }; + let mut cmd = Command::new(shpool_bin()?); cmd.stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -111,7 +117,7 @@ impl Proc { .arg("--socket") .arg(&socket_path) .arg("--config-file") - .arg(testdata_file(config)) + .arg(resolved_config) .arg("daemon"); if listen_events { cmd.env("SHPOOL_TEST_HOOK_SOCKET_PATH", &test_hook_socket_path); diff --git a/shpool/tests/support/line_matcher.rs b/shpool/tests/support/line_matcher.rs index 8e8654d6..4e2a31a7 100644 --- a/shpool/tests/support/line_matcher.rs +++ b/shpool/tests/support/line_matcher.rs @@ -14,6 +14,51 @@ impl LineMatcher where R: std::io::Read, { + /// Scan lines until one matches the given regex + pub fn scan_until_re(&mut self, re: &str) -> anyhow::Result<()> { + let compiled_re = Regex::new(re)?; + let start = time::Instant::now(); + loop { + let mut line = String::new(); + match self.out.read_line(&mut line) { + Ok(0) => { + return Err(anyhow!("LineMatcher: EOF")); + } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + if start.elapsed() > CMD_READ_TIMEOUT { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "timed out reading line", + ))?; + } + + std::thread::sleep(CMD_READ_SLEEP_DUR); + continue; + } + + return Err(e).context("reading line from shell output")?; + } + Ok(_) => { + if line.ends_with('\n') { + line.pop(); + if line.ends_with('\r') { + line.pop(); + } + } + } + } + + eprint!("scanning for /{}/... ", re); + if compiled_re.is_match(&line) { + eprintln!(" match"); + return Ok(()); + } else { + eprintln!(" no match"); + } + } + } + pub fn match_re(&mut self, re: &str) -> anyhow::Result<()> { match self.capture_re(re) { Ok(_) => Ok(()),