diff --git a/.cargo/config.toml b/.cargo/config.toml index 3bab4294e..0777ed2ed 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,6 +3,8 @@ target = "x86_64-unknown-none" rustflags = [ "-C", "force-frame-pointers", "-C", "linker-flavor=ld", + "--cfg", "aes_force_soft", + "--cfg", "polyval_force_soft", ] [target.x86_64-unknown-linux-gnu] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 760af1f95..33e45fd4b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,3 +48,6 @@ before allowing them to be committed. It can be installed by running ``` from the projects root directory. + +For detailed instructions on documentation guidelines please have a look at +[DOC-GUIDELINES.md](Documentation/DOC-GUIDELINES.md). diff --git a/Cargo.lock b/Cargo.lock index 1ce56a400..b946fbc64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,41 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -16,15 +51,15 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cfg-if" @@ -32,6 +67,44 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + +[[package]] +name = "cpufeatures" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "gdbstub" version = "0.6.6" @@ -56,6 +129,35 @@ dependencies = [ "num-traits", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "ghash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d930750de5717d2dd0b8c0d42c076c0e884c81a73e6cab859bbd2339c71e3e40" +dependencies = [ + "opaque-debug", + "polyval", +] + +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "intrusive-collections" version = "0.9.6" @@ -65,11 +167,17 @@ dependencies = [ "memoffset", ] +[[package]] +name = "libc" +version = "0.2.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" + [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "managed" @@ -88,50 +196,75 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", ] +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "packit" version = "0.1.0" -source = "git+https://github.com/coconut-svsm/packit#fffebdc18a3f559f0a01425b17cf41b1c249fbe0" +source = "git+https://github.com/coconut-svsm/packit#540b471ee8da1d28fee8d9490888c84a48da04a8" dependencies = [ "zerocopy", ] [[package]] name = "paste" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4b27ab7be369122c218afc2079489cdcb4b517c0a3fc386ff11e1fedfcc2b35" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "polyval" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52cff9d1d4dee5fe6d03729099f4a310a41179e0a10dbf542039873f2e826fb" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] [[package]] name = "proc-macro2" -version = "1.0.63" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.29" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "svsm" version = "0.1.0" dependencies = [ - "bitflags 2.4.0", + "aes-gcm", + "bitflags 2.4.1", "gdbstub", "gdbstub_arch", "intrusive-collections", @@ -142,9 +275,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.109" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", @@ -155,17 +288,39 @@ dependencies = [ name = "test" version = "0.1.0" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-ident" -version = "1.0.10" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22049a19f4a68748a168c0fc439f9516686aa045927ff767eca0a85101fb6e73" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "zerocopy" -version = "0.6.1" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332f188cc1bcf1fe1064b8c58d150f497e697f49774aa846f2dc949d9a25f236" +checksum = "96f8f25c15a0edc9b07eb66e7e6e97d124c0505435c382fde1ab7ceb188aa956" dependencies = [ "byteorder", "zerocopy-derive", @@ -173,9 +328,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.3.2" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6505e6815af7de1746a08f69c69606bb45695a17149517680f3b2149713b19a3" +checksum = "855e0f6af9cd72b87d8a6c586f3cb583f5cdcc62c2c80869d8cd7e96fdf7ee20" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 568ceed01..6e65257b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ gdbstub_arch = { version = "0.2.4", optional = true } intrusive-collections = "0.9.6" log = { version = "0.4.17", features = ["max_level_info", "release_max_level_info"] } packit = { git = "https://github.com/coconut-svsm/packit", version = "0.1.0" } +aes-gcm = { version = "0.10.3", default-features = false, features = ["aes", "alloc"] } [target."x86_64-unknown-none".dev-dependencies] test = { version = "0.1.0", path = "test" } diff --git a/Documentation/DOC-GUIDELINES.md b/Documentation/DOC-GUIDELINES.md new file mode 100644 index 000000000..dbba2bc78 --- /dev/null +++ b/Documentation/DOC-GUIDELINES.md @@ -0,0 +1,126 @@ +Documentation Style +=================== + +In this project, code documentation is generated using Rustdoc, which +automatically generates interactive web documentation. Here are some +guidelines for documenting code effectively: + +- Follow [Rust's official indications.](https://doc.rust-lang.org/rustdoc/how-to-write-documentation.html) + +- Follow standard Markdown format, e.g. variables between backticks: + +- When adding doc comments to your code, use triple slashes (`///`) + to document items; if you also want to document modules or crates, use + `//!` and `#[doc = ""]` for documenting fields or expressions. + +```rust +/// This function does A, takes parameter of type [`M`]. +/// It returns [`B`], keep in mind C +fn main(a: M) -> B { + // Some code here +} +``` + +- Documenting trait implementations is optional since the generated + Rust core library already documents them. The exception would be if your + implementation does something counterintuitive to the trait's general + definition. + +- When mentioning a type (e.g. \`RWLock\`, \`WriteLockGuard\`) it's good to + add a link to the type with square brackets (e.g. [\`RWLock\`], + [\`WriteLockGuard\`]). + +- When documenting a function, examples of usage relying on code blocks + can help understand how to use your code. However, keep in mind that + said code will be built and ran during tests, so it also needs to be + maintained -- keep it simple. Here is an example of function + documentation with Arguments, Returns and Examples: + +```rust + +/// Compares two [`Elf64AddrRange`] instances for partial ordering. It returns +/// [`Some`] if there is a partial order, and [`None`] if there is no +/// order (i.e., if the ranges overlap without being equal). +/// +/// # Arguments +/// +/// * `other` - The other [`Elf64AddrRange`] to compare to. +/// +/// # Returns +/// +/// - [`Some`] if [`self`] is less than `other`. +/// - [`Some`] if [`self`] is greater than `other`. +/// - [`Some`] if [`self`] is equal to `other`. +/// - [`None`] if there is no partial order (i.e., ranges overlap but are not equal). +/// +/// # Examples +/// +/// ```rust +/// use svsm::elf::Elf64AddrRange; +/// use core::cmp::Ordering; +/// +/// let range1 = Elf64AddrRange { vaddr_begin: 0x1000, vaddr_end: 0x1100 }; +/// let range2 = Elf64AddrRange { vaddr_begin: 0x1100, vaddr_end: 0x1200 }; +/// +/// assert_eq!(range1.partial_cmp(&range2), Some(Ordering::Less)); +/// ``` +impl cmp::PartialOrd for Elf64AddrRange { + fn partial_cmp(&self, other: &Elf64AddrRange) -> Option { + //(...) +``` + +- Add section "Safety" if necessary to clarify what is unsafe, specially in + public (`pub`) interfaces, when using `unsafe` blocks or in cases where + undefined behavior may arise. For example: + +```rust +/// # Safety +/// +/// This function is marked as `unsafe` because it uses unsafe assembly. +/// It is the responsibility of the caller to ensure the following: +/// +/// 1. `src` and `dst` must point to valid memory. +/// 2. The length `len` must accurately represent the number of bytes in +/// `data`. +/// 3. `src` must be correctly initialized. +/// +pub unsafe fn example_memcpy(dest: *mut T, src: *const T, len: usize) { + // Ensure the pointers are not null + assert!(!dest.is_null() && !src.is_null()); + let mut rcx: usize; + + unsafe { + asm!( + "rep movsb" + : "={rcx}"(rcx) + : "0"(len), "D"(dest), "S"(src) + : "memory" + ); + } +} +``` +- We can't have a section "Panic" for every place the SVSM may panic, but + they should be included if your code checks assertions or uses the + `unwrap()` method. For instance: + +```rust +/// # Panics +/// +/// The function will panic if the provided length exceeds the buffer's capacity. +/// +pub fn my_function(buffer: &mut Vec, len: usize) { + if len > buffer.capacity() { + panic!("Length exceeds allocated capacity!"); + } +``` + +- Remember that if you update code, you also have to update its related + documentation to ensure maintainability. + +- Be aware that your documentation comments have the potential to break the + documentation generation process (cargo doc), which can delay the merging + of your changes. Your new documentation should be warning-free. + +In general, even imperfect documentation is better than none at all. +Prioritize documenting functions that are publicly exported, especially +API calls, over internal helper functions. diff --git a/Makefile b/Makefile index 990dd3d38..780abfbae 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,12 @@ else TARGET_PATH=debug endif +ifeq ($(V), 1) +CARGO_ARGS += -v +else ifeq ($(V), 2) +CARGO_ARGS += -vv +endif + STAGE2_ELF = "target/x86_64-unknown-none/${TARGET_PATH}/stage2" KERNEL_ELF = "target/x86_64-unknown-none/${TARGET_PATH}/svsm" TEST_KERNEL_ELF = target/x86_64-unknown-none/${TARGET_PATH}/svsm-test diff --git a/README.md b/README.md index 1de316c26..79a516758 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ Contributing Contributing to the project is as easy as sending a pull-request via GitHub. For detailed instructions on patch formatting and contribution guidelines please have a look at [CONTRIBUTING.md](CONTRIBUTING.md). +For documentation guidelines consult [DOC-GUIDELINES.md](Documentation/DOC-GUIDELINES.md). TODO List --------- diff --git a/src/address.rs b/src/address.rs index b1b5cb78c..fbdd3418b 100644 --- a/src/address.rs +++ b/src/address.rs @@ -55,7 +55,7 @@ pub trait Address: self.is_aligned(PAGE_SIZE) } - fn checked_offset(&self, off: InnerAddr) -> Option { + fn checked_add(&self, off: InnerAddr) -> Option { self.bits().checked_add(off).map(|addr| addr.into()) } @@ -63,6 +63,10 @@ pub trait Address: self.bits().checked_sub(off).map(|addr| addr.into()) } + fn saturating_add(&self, off: InnerAddr) -> Self { + Self::from(self.bits().saturating_add(off)) + } + fn page_offset(&self) -> usize { self.bits() & (PAGE_SIZE - 1) } @@ -258,7 +262,7 @@ impl ops::Add for VirtAddr { } impl Address for VirtAddr { - fn checked_offset(&self, off: InnerAddr) -> Option { + fn checked_add(&self, off: InnerAddr) -> Option { self.bits() .checked_add(off) .map(|addr| sign_extend(addr).into()) diff --git a/src/cpu/idt.rs b/src/cpu/idt.rs index 8d8d8d217..218e744be 100644 --- a/src/cpu/idt.rs +++ b/src/cpu/idt.rs @@ -11,7 +11,7 @@ use super::vc::handle_vc_exception; use super::{X86GeneralRegs, X86InterruptFrame}; use crate::address::{Address, VirtAddr}; use crate::cpu::extable::handle_exception_table; -use crate::debug::gdbstub::svsm_gdbstub::handle_bp_exception; +use crate::debug::gdbstub::svsm_gdbstub::handle_debug_exception; use crate::types::SVSM_CS; use core::arch::{asm, global_asm}; use core::mem; @@ -206,6 +206,7 @@ fn generic_idt_handler(ctx: &mut X86ExceptionContext) { .is_err() && !handle_exception_table(ctx) { + handle_debug_exception(ctx, ctx.vector); panic!( "Unhandled Page-Fault at RIP {:#018x} CR2: {:#018x} error code: {:#018x}", rip, cr2, err @@ -214,7 +215,7 @@ fn generic_idt_handler(ctx: &mut X86ExceptionContext) { } else if ctx.vector == VC_VECTOR { handle_vc_exception(ctx); } else if ctx.vector == BP_VECTOR { - handle_bp_exception(ctx); + handle_debug_exception(ctx, ctx.vector); } else { let err = ctx.error_code; let vec = ctx.vector; diff --git a/src/cpu/percpu.rs b/src/cpu/percpu.rs index 9d67458f0..c597d53ce 100644 --- a/src/cpu/percpu.rs +++ b/src/cpu/percpu.rs @@ -25,6 +25,7 @@ use crate::mm::{ use crate::sev::ghcb::GHCB; use crate::sev::utils::RMPFlags; use crate::sev::vmsa::{allocate_new_vmsa, VMSASegment, VMSA}; +use crate::task::RunQueue; use crate::types::{PAGE_SHIFT, PAGE_SHIFT_2M, PAGE_SIZE, PAGE_SIZE_2M, SVSM_TR_FLAGS, SVSM_TSS}; use alloc::sync::Arc; use alloc::vec::Vec; @@ -191,19 +192,16 @@ pub struct PerCpu { pub vrange_4k: VirtualRange, /// Address allocator for per-cpu 2m temporary mappings pub vrange_2m: VirtualRange, -} -impl Default for PerCpu { - fn default() -> Self { - Self::new() - } + /// Task list that has been assigned for scheduling on this CPU + runqueue: RWLock, } impl PerCpu { - pub fn new() -> Self { + fn new(apic_id: u32) -> Self { PerCpu { online: AtomicBool::new(false), - apic_id: 0, + apic_id, pgtbl: SpinLock::::new(PageTableRef::unset()), ghcb: ptr::null_mut(), init_stack: None, @@ -215,6 +213,7 @@ impl PerCpu { vm_range: VMR::new(SVSM_PERCPU_BASE, SVSM_PERCPU_END, PTEntryFlags::GLOBAL), vrange_4k: VirtualRange::new(), vrange_2m: VirtualRange::new(), + runqueue: RWLock::new(RunQueue::new(apic_id)), } } @@ -222,8 +221,7 @@ impl PerCpu { let vaddr = allocate_zeroed_page()?; unsafe { let percpu = vaddr.as_mut_ptr::(); - (*percpu) = PerCpu::new(); - (*percpu).apic_id = apic_id; + (*percpu) = PerCpu::new(apic_id); PERCPU_AREAS.push(PerCpuInfo::new(apic_id, vaddr)); Ok(percpu) } @@ -566,6 +564,17 @@ impl PerCpu { pub fn handle_pf(&self, vaddr: VirtAddr, write: bool) -> Result<(), SvsmError> { self.vm_range.handle_page_fault(vaddr, write) } + + /// Allocate any candidate unallocated tasks from the global task list to our + /// CPU runqueue. + pub fn allocate_tasks(&mut self) { + self.runqueue.lock_write().allocate(); + } + + /// Access the PerCpu runqueue protected with a lock + pub fn runqueue(&self) -> &RWLock { + &self.runqueue + } } unsafe impl Sync for PerCpu {} diff --git a/src/cpu/smp.rs b/src/cpu/smp.rs index e660d39aa..4c4efd577 100644 --- a/src/cpu/smp.rs +++ b/src/cpu/smp.rs @@ -7,9 +7,10 @@ extern crate alloc; use crate::acpi::tables::ACPICPUInfo; -use crate::cpu::percpu::{this_cpu_mut, PerCpu}; +use crate::cpu::percpu::{this_cpu, this_cpu_mut, PerCpu}; use crate::cpu::vmsa::init_svsm_vmsa; use crate::requests::request_loop; +use crate::task::{create_task, TASK_FLAG_SHARE_PT}; fn start_cpu(apic_id: u32) { unsafe { @@ -51,7 +52,7 @@ pub fn start_secondary_cpus(cpus: &[ACPICPUInfo]) { start_cpu(c.apic_id); count += 1; } - log::info!("Brough {} AP(s) online", count); + log::info!("Brought {} AP(s) online", count); } #[no_mangle] @@ -66,8 +67,17 @@ fn start_ap() { // Set CPU online so that BSP can proceed this_cpu_mut().set_online(); - // Loop for now - request_loop(); + // Create the task making sure the task only runs on this new AP + create_task( + ap_request_loop, + TASK_FLAG_SHARE_PT, + Some(this_cpu().get_apic_id()), + ) + .expect("Failed to create AP initial task"); +} +#[no_mangle] +pub extern "C" fn ap_request_loop() { + request_loop(); panic!("Returned from request_loop!"); } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs new file mode 100644 index 000000000..b6efb0595 --- /dev/null +++ b/src/crypto/mod.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! SVSM kernel crypto API + +pub mod aead { + //! API for authentication encryption with associated data + + use crate::{protocols::errors::SvsmReqError, sev::secrets_page::VMPCK_SIZE}; + + // Message Header Format (AMD SEV-SNP spec. table 98) + + /// Authenticated tag size (128 bits) + pub const AUTHTAG_SIZE: usize = 16; + /// Initialization vector size (96 bits) + pub const IV_SIZE: usize = 12; + /// Key size + pub const KEY_SIZE: usize = VMPCK_SIZE; + + /// AES-256 GCM + pub trait Aes256GcmTrait { + /// Encrypt the provided buffer using AES-256 GCM + /// + /// # Arguments + /// + /// * `iv`: Initialization vector + /// * `key`: 256-bit key + /// * `aad`: Additional authenticated data + /// * `inbuf`: Cleartext buffer to be encrypted + /// * `outbuf`: Buffer to store the encrypted data, it must be large enough to also + /// hold the authenticated tag. + /// + /// # Returns + /// + /// * Success + /// * `usize`: Number of bytes written to `outbuf` + /// * Error + /// * [SvsmReqError] + fn encrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result; + + /// Decrypt the provided buffer using AES-256 GCM + /// + /// # Returns + /// + /// * `iv`: Initialization vector + /// * `key`: 256-bit key + /// * `aad`: Additional authenticated data + /// * `inbuf`: Cleartext buffer to be decrypted, followed by the authenticated tag + /// * `outbuf`: Buffer to store the decrypted data + /// + /// # Returns + /// + /// * Success + /// * `usize`: Number of bytes written to `outbuf` + /// * Error + /// * [SvsmReqError] + fn decrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result; + } + + /// Aes256Gcm type + pub struct Aes256Gcm; +} + +// Crypto implementations supported. Only one of them must be compiled-in. + +pub mod rustcrypto; diff --git a/src/crypto/rustcrypto.rs b/src/crypto/rustcrypto.rs new file mode 100644 index 000000000..559181432 --- /dev/null +++ b/src/crypto/rustcrypto.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! RustCrypto implementation + +use aes_gcm::{ + aead::{Aead, Payload}, + Aes256Gcm, Key, KeyInit, Nonce, +}; + +use crate::{ + crypto::aead::{ + Aes256Gcm as CryptoAes256Gcm, Aes256GcmTrait as CryptoAes256GcmTrait, IV_SIZE, KEY_SIZE, + }, + protocols::errors::SvsmReqError, +}; + +#[repr(u64)] +#[derive(Clone, Copy, Debug, PartialEq)] +enum AesGcmOperation { + Encrypt = 0, + Decrypt = 1, +} + +fn aes_gcm_do( + operation: AesGcmOperation, + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], +) -> Result { + let payload = Payload { msg: inbuf, aad }; + + let aes_key = Key::::from_slice(key); + let gcm = Aes256Gcm::new(aes_key); + let nonce = Nonce::from_slice(iv); + + let result = if operation == AesGcmOperation::Encrypt { + gcm.encrypt(nonce, payload) + } else { + gcm.decrypt(nonce, payload) + }; + let buffer = result.map_err(|_| SvsmReqError::invalid_format())?; + + let outbuf = outbuf + .get_mut(..buffer.len()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + outbuf.copy_from_slice(&buffer); + + Ok(buffer.len()) +} + +impl CryptoAes256GcmTrait for CryptoAes256Gcm { + fn encrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result { + aes_gcm_do(AesGcmOperation::Encrypt, iv, key, aad, inbuf, outbuf) + } + + fn decrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result { + aes_gcm_do(AesGcmOperation::Decrypt, iv, key, aad, inbuf, outbuf) + } +} diff --git a/src/debug/gdbstub.rs b/src/debug/gdbstub.rs index b5ad77024..c1c7a606b 100644 --- a/src/debug/gdbstub.rs +++ b/src/debug/gdbstub.rs @@ -11,24 +11,34 @@ // #[cfg(feature = "enable-gdb")] pub mod svsm_gdbstub { + extern crate alloc; + use crate::address::{Address, VirtAddr}; - use crate::cpu::percpu::this_cpu; - use crate::cpu::X86ExceptionContext; + use crate::cpu::control_regs::read_cr3; + use crate::cpu::idt::{X86ExceptionContext, BP_VECTOR}; + use crate::cpu::percpu::{this_cpu, this_cpu_mut}; + use crate::cpu::X86GeneralRegs; use crate::error::SvsmError; + use crate::locking::{LockGuard, SpinLock}; use crate::mm::guestmem::{read_u8, write_u8}; use crate::mm::PerCPUPageMappingGuard; use crate::serial::{SerialPort, Terminal}; use crate::svsm_console::SVSMIOPort; + use crate::task::{is_current_task, TaskContext, TaskState, INITIAL_TASK_ID, TASKLIST}; use core::arch::asm; + use core::fmt; + use core::sync::atomic::{AtomicBool, Ordering}; + use gdbstub::common::{Signal, Tid}; use gdbstub::conn::Connection; use gdbstub::stub::state_machine::GdbStubStateMachine; - use gdbstub::stub::{GdbStubBuilder, SingleThreadStopReason}; - use gdbstub::target::ext::base::singlethread::{ - SingleThreadBase, SingleThreadResume, SingleThreadResumeOps, SingleThreadSingleStep, - SingleThreadSingleStepOps, + use gdbstub::stub::{GdbStubBuilder, MultiThreadStopReason}; + use gdbstub::target::ext::base::multithread::{ + MultiThreadBase, MultiThreadResume, MultiThreadResumeOps, MultiThreadSingleStep, + MultiThreadSingleStepOps, }; use gdbstub::target::ext::base::BaseOps; use gdbstub::target::ext::breakpoints::{Breakpoints, SwBreakpoint}; + use gdbstub::target::ext::thread_extra_info::ThreadExtraInfo; use gdbstub::target::{Target, TargetError}; use gdbstub_arch::x86::reg::X86_64CoreRegs; use gdbstub_arch::x86::X86_64_SSE; @@ -45,21 +55,117 @@ pub mod svsm_gdbstub { .expect("Failed to initialise GDB stub") .run_state_machine(&mut target) .expect("Failed to start GDB state machine"); - GDB_STATE = Some(SvsmGdbStub { gdb, target }); + *GDB_STATE.lock() = Some(SvsmGdbStub { gdb, target }); + GDB_STACK_TOP = GDB_STACK.as_mut_ptr().offset(GDB_STACK.len() as isize - 1) as u64; } + GDB_INITIALISED.store(true, Ordering::Relaxed); Ok(()) } - pub fn handle_bp_exception(ctx: &mut X86ExceptionContext) { - handle_stop(ctx, true); + #[derive(PartialEq, Eq)] + enum ExceptionType { + Debug, + SwBreakpoint, + PageFault, + } + + pub fn handle_debug_exception(ctx: &mut X86ExceptionContext, exception: usize) { + let tp = match exception { + BP_VECTOR => ExceptionType::SwBreakpoint, + _ => ExceptionType::PageFault, + }; + handle_exception(ctx, tp); } pub fn handle_db_exception(ctx: &mut X86ExceptionContext) { - handle_stop(ctx, false); + handle_exception(ctx, ExceptionType::Debug); + } + + fn handle_exception(ctx: &mut X86ExceptionContext, exception_type: ExceptionType) { + let id = this_cpu().runqueue().lock_read().current_task_id(); + let mut task_ctx = TaskContext { + regs: X86GeneralRegs { + r15: ctx.regs.r15, + r14: ctx.regs.r14, + r13: ctx.regs.r13, + r12: ctx.regs.r12, + r11: ctx.regs.r11, + r10: ctx.regs.r10, + r9: ctx.regs.r9, + r8: ctx.regs.r8, + rbp: ctx.regs.rbp, + rdi: ctx.regs.rdi, + rsi: ctx.regs.rsi, + rdx: ctx.regs.rdx, + rcx: ctx.regs.rcx, + rbx: ctx.regs.rbx, + rax: ctx.regs.rax, + }, + rsp: ctx.frame.rsp as u64, + flags: ctx.frame.flags as u64, + ret_addr: ctx.frame.rip as u64, + }; + + if let Some(task_node) = this_cpu_mut().runqueue().lock_read().get_task(id) { + task_node.task.lock_write().rsp = &task_ctx as *const TaskContext as u64; + } + + // Locking the GDB state for the duration of the stop will cause any other + // APs that hit a breakpoint to busy-wait until the current CPU releases + // the GDB state. They will then resume and report the stop state + // to GDB. + // One thing to watch out for - if a breakpoint is inadvertently placed in + // the GDB handling code itself then this will cause a re-entrant state + // within the same CPU causing a deadlock. + loop { + let mut gdb_state = GDB_STATE.lock(); + if let Some(stub) = gdb_state.as_ref() { + if stub.target.is_single_step != 0 && stub.target.is_single_step != id { + continue; + } + } + + unsafe { + asm!( + r#" + movq %rsp, (%rax) + movq %rax, %rsp + call handle_stop + popq %rax + movq %rax, %rsp + "#, + in("rsi") exception_type as u64, + in("rdi") &mut task_ctx, + in("rdx") &mut gdb_state, + in("rax") GDB_STACK_TOP, + options(att_syntax)); + } + + ctx.frame.rip = task_ctx.ret_addr as usize; + ctx.frame.flags = task_ctx.flags as usize; + ctx.frame.rsp = task_ctx.rsp as usize; + ctx.regs.rax = task_ctx.regs.rax; + ctx.regs.rbx = task_ctx.regs.rbx; + ctx.regs.rcx = task_ctx.regs.rcx; + ctx.regs.rdx = task_ctx.regs.rdx; + ctx.regs.rsi = task_ctx.regs.rsi; + ctx.regs.rdi = task_ctx.regs.rdi; + ctx.regs.rbp = task_ctx.regs.rbp; + ctx.regs.r8 = task_ctx.regs.r8; + ctx.regs.r9 = task_ctx.regs.r9; + ctx.regs.r10 = task_ctx.regs.r10; + ctx.regs.r11 = task_ctx.regs.r11; + ctx.regs.r12 = task_ctx.regs.r12; + ctx.regs.r13 = task_ctx.regs.r13; + ctx.regs.r14 = task_ctx.regs.r14; + ctx.regs.r15 = task_ctx.regs.r15; + + break; + } } pub fn debug_break() { - if unsafe { GDB_STATE.is_some() } { + if GDB_INITIALISED.load(Ordering::Acquire) { log::info!("***********************************"); log::info!("* Waiting for connection from GDB *"); log::info!("***********************************"); @@ -69,37 +175,105 @@ pub mod svsm_gdbstub { } } - static mut GDB_STATE: Option = None; + static GDB_INITIALISED: AtomicBool = AtomicBool::new(false); + static GDB_STATE: SpinLock> = SpinLock::new(None); static GDB_IO: SVSMIOPort = SVSMIOPort::new(); static mut GDB_SERIAL: SerialPort = SerialPort { driver: &GDB_IO, port: 0x2f8, }; static mut PACKET_BUFFER: [u8; 4096] = [0; 4096]; + // Allocate the GDB stack as an array of u64's to ensure 8 byte alignment of the stack. + static mut GDB_STACK: [u64; 8192] = [0; 8192]; + static mut GDB_STACK_TOP: u64 = 0; + + struct GdbTaskContext { + cr3: usize, + } + + impl GdbTaskContext { + fn switch_to_task(id: u32) -> Self { + let cr3 = if is_current_task(id) { + 0 + } else { + let tl = TASKLIST.lock(); + let cr3 = read_cr3(); + let task_node = tl.get_task(id); + if let Some(task_node) = task_node { + task_node.task.lock_write().page_table.lock().load(); + cr3.bits() + } else { + 0 + } + }; + Self { cr3 } + } + } + + impl Drop for GdbTaskContext { + fn drop(&mut self) { + if self.cr3 != 0 { + unsafe { + asm!("mov %rax, %cr3", + in("rax") self.cr3, + options(att_syntax)); + } + } + } + } struct SvsmGdbStub<'a> { gdb: GdbStubStateMachine<'a, GdbStubTarget, GdbStubConnection>, target: GdbStubTarget, } - fn handle_stop(ctx: &mut X86ExceptionContext, bp_exception: bool) { - let SvsmGdbStub { gdb, mut target } = unsafe { - GDB_STATE.take().unwrap_or_else(|| { - panic!("GDB stub not initialised!"); - }) - }; + impl<'a> fmt::Debug for SvsmGdbStub<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SvsmGdbStub") + } + } + + #[no_mangle] + fn handle_stop( + ctx: &mut TaskContext, + exception_type: ExceptionType, + gdb_state: &mut LockGuard<'_, Option>>, + ) { + let SvsmGdbStub { gdb, mut target } = gdb_state.take().unwrap_or_else(|| { + panic!("Invalid GDB state"); + }); target.set_regs(ctx); + let hardcoded_bp = (exception_type == ExceptionType::SwBreakpoint) + && !target.is_breakpoint(ctx.ret_addr as usize - 1); + // If the current address is on a breakpoint then we need to // move the IP back by one byte - if bp_exception && target.is_breakpoint(ctx.frame.rip - 1) { - ctx.frame.rip -= 1; + if (exception_type == ExceptionType::SwBreakpoint) + && target.is_breakpoint(ctx.ret_addr as usize - 1) + { + ctx.ret_addr -= 1; } + let tid = Tid::new(this_cpu().runqueue().lock_read().current_task_id() as usize) + .expect("Current task has invalid ID"); let mut new_gdb = match gdb { GdbStubStateMachine::Running(gdb_inner) => { - match gdb_inner.report_stop(&mut target, SingleThreadStopReason::SwBreak(())) { + let reason = if hardcoded_bp { + MultiThreadStopReason::SignalWithThread { + tid, + signal: Signal::SIGINT, + } + } else if exception_type == ExceptionType::PageFault { + MultiThreadStopReason::SignalWithThread { + tid, + signal: Signal::SIGSEGV, + } + } else { + MultiThreadStopReason::SwBreak(tid) + }; + match gdb_inner.report_stop(&mut target, reason) { Ok(gdb) => gdb, Err(_) => panic!("Failed to handle software breakpoint"), } @@ -128,22 +302,19 @@ pub mod svsm_gdbstub { break; } _ => { - log::info!("Invalid GDB state when handling breakpoint interrupt"); - return; + panic!("Invalid GDB state when handling breakpoint interrupt"); } }; } - if target.is_single_step { - ctx.frame.flags |= 0x100; + if target.is_single_step == tid.get() as u32 { + ctx.flags |= 0x100; } else { - ctx.frame.flags &= !0x100; + ctx.flags &= !0x100; } - unsafe { - GDB_STATE = Some(SvsmGdbStub { - gdb: new_gdb, - target, - }) - }; + **gdb_state = Some(SvsmGdbStub { + gdb: new_gdb, + target, + }); } struct GdbStubConnection; @@ -182,7 +353,7 @@ pub mod svsm_gdbstub { struct GdbStubTarget { ctx: usize, breakpoints: [GdbStubBreakpoint; MAX_BREAKPOINTS], - is_single_step: bool, + is_single_step: u32, } impl GdbStubTarget { @@ -193,11 +364,11 @@ pub mod svsm_gdbstub { addr: VirtAddr::null(), inst: 0, }; MAX_BREAKPOINTS], - is_single_step: false, + is_single_step: 0, } } - pub fn set_regs(&mut self, ctx: &X86ExceptionContext) { + pub fn set_regs(&mut self, ctx: &TaskContext) { self.ctx = (ctx as *const _) as usize; } @@ -227,7 +398,7 @@ pub mod svsm_gdbstub { type Error = usize; fn base_ops(&mut self) -> gdbstub::target::ext::base::BaseOps<'_, Self::Arch, Self::Error> { - BaseOps::SingleThread(self) + BaseOps::MultiThread(self) } #[inline(always)] @@ -238,10 +409,10 @@ pub mod svsm_gdbstub { } } - impl From<&X86ExceptionContext> for X86_64CoreRegs { - fn from(value: &X86ExceptionContext) -> Self { + impl From<&TaskContext> for X86_64CoreRegs { + fn from(value: &TaskContext) -> Self { let mut regs = X86_64CoreRegs::default(); - regs.rip = value.frame.rip as u64; + regs.rip = value.ret_addr; regs.regs = [ value.regs.rax as u64, value.regs.rbx as u64, @@ -250,7 +421,7 @@ pub mod svsm_gdbstub { value.regs.rsi as u64, value.regs.rdi as u64, value.regs.rbp as u64, - value.frame.rsp as u64, + value.rsp, value.regs.r8 as u64, value.regs.r9 as u64, value.regs.r10 as u64, @@ -260,34 +431,49 @@ pub mod svsm_gdbstub { value.regs.r14 as u64, value.regs.r15 as u64, ]; - regs.eflags = value.frame.flags as u32; - regs.segments.cs = value.frame.cs as u32; - regs.segments.ss = value.frame.ss as u32; + regs.eflags = value.flags as u32; regs } } - impl SingleThreadBase for GdbStubTarget { + impl MultiThreadBase for GdbStubTarget { fn read_registers( &mut self, regs: &mut ::Registers, + tid: Tid, ) -> gdbstub::target::TargetResult<(), Self> { - unsafe { - let context = (self.ctx as *mut X86ExceptionContext).as_ref().unwrap(); - *regs = X86_64CoreRegs::from(context); + if is_current_task(tid.get() as u32) { + unsafe { + let context = (self.ctx as *const TaskContext).as_ref().unwrap(); + *regs = X86_64CoreRegs::from(context); + } + } else { + let task = TASKLIST.lock().get_task(tid.get() as u32); + if let Some(task_node) = task { + // The registers are stored in the top of the task stack as part of the + // saved context. We need to switch to the task pagetable to access them. + let _task_context = GdbTaskContext::switch_to_task(tid.get() as u32); + let task = task_node.task.lock_read(); + unsafe { + *regs = X86_64CoreRegs::from(&*(task.rsp as *const TaskContext)); + }; + regs.regs[7] = task.rsp; + } else { + *regs = ::Registers::default(); + } } - Ok(()) } fn write_registers( &mut self, regs: &::Registers, + _tid: Tid, ) -> gdbstub::target::TargetResult<(), Self> { unsafe { - let context = (self.ctx as *mut X86ExceptionContext).as_mut().unwrap(); + let context = (self.ctx as *mut TaskContext).as_mut().unwrap(); - context.frame.rip = regs.rip as usize; + context.ret_addr = regs.rip; context.regs.rax = regs.regs[0] as usize; context.regs.rbx = regs.regs[1] as usize; context.regs.rcx = regs.regs[2] as usize; @@ -295,7 +481,7 @@ pub mod svsm_gdbstub { context.regs.rsi = regs.regs[4] as usize; context.regs.rdi = regs.regs[5] as usize; context.regs.rbp = regs.regs[6] as usize; - context.frame.rsp = regs.regs[7] as usize; + context.rsp = regs.regs[7]; context.regs.r8 = regs.regs[8] as usize; context.regs.r9 = regs.regs[9] as usize; context.regs.r10 = regs.regs[10] as usize; @@ -304,9 +490,7 @@ pub mod svsm_gdbstub { context.regs.r13 = regs.regs[13] as usize; context.regs.r14 = regs.regs[14] as usize; context.regs.r15 = regs.regs[15] as usize; - context.frame.flags = regs.eflags as usize; - context.frame.cs = regs.segments.cs as usize; - context.frame.ss = regs.segments.ss as usize; + context.flags = regs.eflags as u64; } Ok(()) } @@ -315,7 +499,11 @@ pub mod svsm_gdbstub { &mut self, start_addr: ::Usize, data: &mut [u8], + tid: Tid, ) -> gdbstub::target::TargetResult<(), Self> { + // Switch to the task pagetable if necessary. The switch back will + // happen automatically when the variable falls out of scope + let _task_context = GdbTaskContext::switch_to_task(tid.get() as u32); let start_addr = VirtAddr::from(start_addr); for (off, dst) in data.iter_mut().enumerate() { let Ok(val) = read_u8(start_addr + off) else { @@ -330,6 +518,7 @@ pub mod svsm_gdbstub { &mut self, start_addr: ::Usize, data: &[u8], + _tid: Tid, ) -> gdbstub::target::TargetResult<(), Self> { let start_addr = VirtAddr::from(start_addr); for (off, src) in data.iter().enumerate() { @@ -341,26 +530,117 @@ pub mod svsm_gdbstub { } #[inline(always)] - fn support_resume(&mut self) -> Option> { + fn support_resume(&mut self) -> Option> { + Some(self) + } + + fn list_active_threads( + &mut self, + thread_is_active: &mut dyn FnMut(gdbstub::common::Tid), + ) -> Result<(), Self::Error> { + let mut tl = TASKLIST.lock(); + + let mut any_scheduled = false; + + if tl.list().is_empty() { + // Task list has not been initialised yet. Report a single thread + // for the current CPU + thread_is_active(Tid::new(INITIAL_TASK_ID as usize).unwrap()); + } else { + let mut cursor = tl.list().front_mut(); + while cursor.get().is_some() { + if cursor.get().unwrap().task.lock_read().allocation.is_some() { + any_scheduled = true; + break; + } + cursor.move_next(); + } + if any_scheduled { + let mut cursor = tl.list().front_mut(); + while cursor.get().is_some() { + thread_is_active( + Tid::new(cursor.get().unwrap().task.lock_read().id as usize).unwrap(), + ); + cursor.move_next(); + } + } else { + thread_is_active(Tid::new(INITIAL_TASK_ID as usize).unwrap()); + } + } + Ok(()) + } + + fn support_thread_extra_info( + &mut self, + ) -> Option> { Some(self) } } - impl SingleThreadResume for GdbStubTarget { - fn resume(&mut self, _signal: Option) -> Result<(), Self::Error> { - self.is_single_step = false; + impl ThreadExtraInfo for GdbStubTarget { + fn thread_extra_info(&self, tid: Tid, buf: &mut [u8]) -> Result { + // Get the current task from the stopped CPU so we can mark it as stopped + let tl = TASKLIST.lock(); + let str = match tl.get_task(tid.get() as u32) { + Some(t) => { + let t = t.task.lock_read(); + match t.state { + TaskState::RUNNING => { + if let Some(allocation) = t.allocation { + if this_cpu().get_apic_id() == allocation { + "Stopped".as_bytes() + } else { + "Running".as_bytes() + } + } else { + "Stopped".as_bytes() + } + } + TaskState::TERMINATED => "Terminated".as_bytes(), + } + } + None => "Stopped".as_bytes(), + }; + let mut count = 0; + for (dst, src) in buf.iter_mut().zip(str) { + *dst = *src; + count += 1; + } + Ok(count) + } + } + + impl MultiThreadResume for GdbStubTarget { + fn resume(&mut self) -> Result<(), Self::Error> { Ok(()) } #[inline(always)] - fn support_single_step(&mut self) -> Option> { + fn support_single_step(&mut self) -> Option> { Some(self) } + + fn clear_resume_actions(&mut self) -> Result<(), Self::Error> { + self.is_single_step = 0; + Ok(()) + } + + fn set_resume_action_continue( + &mut self, + _tid: Tid, + _signal: Option, + ) -> Result<(), Self::Error> { + Ok(()) + } } - impl SingleThreadSingleStep for GdbStubTarget { - fn step(&mut self, _signal: Option) -> Result<(), Self::Error> { - self.is_single_step = true; + impl MultiThreadSingleStep for GdbStubTarget { + fn set_resume_action_step( + &mut self, + tid: Tid, + _signal: Option, + ) -> Result<(), Self::Error> { + self.is_single_step = tid.get() as u32; Ok(()) } } @@ -437,9 +717,9 @@ pub mod svsm_gdbstub { Ok(()) } - pub fn handle_bp_exception(_regs: &mut X86ExceptionContext) {} + pub fn handle_debug_exception(_ctx: &mut X86ExceptionContext, _exception: usize) {} - pub fn handle_db_exception(_regs: &mut X86ExceptionContext) {} + pub fn handle_db_exception(_ctx: &mut X86ExceptionContext) {} pub fn debug_break() {} } diff --git a/src/debug/stacktrace.rs b/src/debug/stacktrace.rs index fe22e8848..a0e94e06b 100644 --- a/src/debug/stacktrace.rs +++ b/src/debug/stacktrace.rs @@ -26,7 +26,7 @@ struct StackBounds { #[cfg(feature = "enable-stacktrace")] impl StackBounds { fn range_is_on_stack(&self, begin: VirtAddr, len: usize) -> bool { - match begin.checked_offset(len) { + match begin.checked_add(len) { Some(end) => begin >= self.bottom && end <= self.top, None => false, } diff --git a/src/elf/mod.rs b/src/elf/mod.rs index 46d567299..d7cf7ae79 100644 --- a/src/elf/mod.rs +++ b/src/elf/mod.rs @@ -35,7 +35,8 @@ use core::mem; /// /// assert_eq!(error_message, "invalid ELF address range"); /// ``` -#[derive(Debug, Clone, Copy)] + +#[derive(Debug, Clone, Copy, PartialEq)] pub enum ElfError { FileTooShort, @@ -330,7 +331,8 @@ impl cmp::PartialOrd for Elf64AddrRange { /// This struct represents a parsed 64-bit ELF file. It contains information /// about the ELF file's header, load segments, dynamic section, and more. -#[derive(Default, Debug, Clone, Copy)] + +#[derive(Default, Debug, Clone, Copy, PartialEq)] pub struct Elf64FileRange { pub offset_begin: usize, pub offset_end: usize, @@ -362,7 +364,7 @@ impl convert::TryFrom<(Elf64Off, Elf64Xword)> for Elf64FileRange { /// This struct represents a parsed 64-bit ELF file. It contains information /// about the ELF file's header, load segments, dynamic section, and more. -#[derive(Default, Debug)] +#[derive(Default, Debug, PartialEq)] pub struct Elf64File<'a> { /// Buffer containing the ELF file data elf_file_buf: &'a [u8], @@ -1005,7 +1007,8 @@ impl<'a> Elf64File<'a> { /// Header of the ELF64 file, including fields describing properties such /// as type, machine architecture, entry point, etc. -#[derive(Debug, Default, Clone, Copy)] + +#[derive(Debug, Default, Clone, Copy, PartialEq)] pub struct Elf64Hdr { #[allow(unused)] /// An array of 16 bytes representing the ELF identification, including the ELF magic number @@ -1442,7 +1445,7 @@ impl Elf64Shdr { /// Represents a collection of ELF64 load segments, each associated with an /// address range and a program header index. -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq)] struct Elf64LoadSegments { segments: Vec<(Elf64AddrRange, Elf64Half)>, } @@ -1551,7 +1554,7 @@ impl Elf64LoadSegments { } /// Represents an ELF64 dynamic relocation table -#[derive(Debug)] +#[derive(Debug, PartialEq)] struct Elf64DynamicRelocTable { /// Virtual address of the relocation table (DT_RELA / DR_REL) base_vaddr: Elf64Addr, @@ -1583,7 +1586,7 @@ impl Elf64DynamicRelocTable { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] struct Elf64DynamicSymtab { /// Base virtual address of the symbol table (DT_SYMTAB) base_vaddr: Elf64Addr, @@ -1609,7 +1612,7 @@ impl Elf64DynamicSymtab { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] struct Elf64Dynamic { // No DT_REL representation: "The AMD64 ABI architectures uses only // Elf64_Rela relocation entries [...]". @@ -1856,7 +1859,7 @@ impl<'a> Iterator for Elf64ImageLoadSegmentIterator<'a> { /// Represents an ELF64 string table ([`Elf64Strtab`]) containing strings /// used within the ELF file -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq)] struct Elf64Strtab<'a> { strtab_buf: &'a [u8], } @@ -2410,3 +2413,197 @@ impl<'a, RP: Elf64RelocProcessor> Iterator for Elf64AppliedRelaIterator<'a, RP> Some(Ok(Some(reloc_op))) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_elf64_shdr_verify_methods() { + // Create a valid Elf64Shdr instance for testing. + let valid_shdr = Elf64Shdr { + sh_name: 1, + sh_type: 2, + sh_flags: Elf64ShdrFlags::WRITE | Elf64ShdrFlags::ALLOC, + sh_addr: 0x1000, + sh_offset: 0x2000, + sh_size: 0x3000, + sh_link: 3, + sh_info: 4, + sh_addralign: 8, + sh_entsize: 0, + }; + + // Verify that the valid Elf64Shdr instance passes verification. + assert!(valid_shdr.verify().is_ok()); + + // Create an invalid Elf64Shdr instance for testing. + let invalid_shdr = Elf64Shdr { + sh_name: 0, + sh_type: 2, + sh_flags: Elf64ShdrFlags::from_bits(0).unwrap(), + sh_addr: 0x1000, + sh_offset: 0x2000, + sh_size: 0x3000, + sh_link: 3, + sh_info: 4, + sh_addralign: 7, // Invalid alignment + sh_entsize: 0, + }; + + // Verify that the invalid Elf64Shdr instance fails verification. + assert!(invalid_shdr.verify().is_err()); + } + + #[test] + fn test_elf64_dynamic_reloc_table_verify_valid() { + // Create a valid Elf64DynamicRelocTable instance for testing. + let reloc_table = Elf64DynamicRelocTable { + base_vaddr: 0x1000, + size: 0x2000, + entsize: 0x30, + }; + + // Verify that the valid Elf64DynamicRelocTable instance passes verification. + assert!(reloc_table.verify().is_ok()); + } + + #[test] + fn test_elf64_addr_range_methods() { + // Test Elf64AddrRange::len() and Elf64AddrRange::is_empty(). + + // Create an Elf64AddrRange instance for testing. + let addr_range = Elf64AddrRange { + vaddr_begin: 0x1000, + vaddr_end: 0x2000, + }; + + // Check that the length calculation is correct. + assert_eq!(addr_range.len(), 0x1000); + + // Check if the address range is empty. + assert!(!addr_range.is_empty()); + + // Test Elf64AddrRange::try_from(). + + // Create a valid input tuple for try_from. + let valid_input: (Elf64Addr, Elf64Xword) = (0x1000, 0x2000); + + // Attempt to create an Elf64AddrRange from the valid input. + let result = Elf64AddrRange::try_from(valid_input); + + // Verify that the result is Ok and contains the expected Elf64AddrRange. + assert!(result.is_ok()); + let valid_addr_range = result.unwrap(); + assert_eq!(valid_addr_range.vaddr_begin, 0x1000); + assert_eq!(valid_addr_range.vaddr_end, 0x3000); + } + + #[test] + fn test_elf64_file_range_try_from() { + // Valid range + let valid_range: (Elf64Off, Elf64Xword) = (0, 100); + let result: Result = valid_range.try_into(); + assert!(result.is_ok()); + let file_range = result.unwrap(); + assert_eq!(file_range.offset_begin, 0); + assert_eq!(file_range.offset_end, 100); + + // Invalid range (overflow) + let invalid_range: (Elf64Off, Elf64Xword) = (usize::MAX as Elf64Off, 100); + let result: Result = invalid_range.try_into(); + assert!(result.is_err()); + } + + #[test] + fn test_elf64_file_read() { + // In the future, you can play around with this skeleton ELF + // file to test other cases + let byte_data: [u8; 184] = [ + // ELF Header + 0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x02, 0x00, 0x3E, 0x00, 0x01, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, // Program Header (with PT_LOAD) + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, // Section Header (simplified) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, // Raw Machine Code Instructions + 0xf3, 0x0f, 0x1e, 0xfa, 0x31, 0xed, 0x49, 0x89, 0xd1, 0x5e, 0x48, 0x89, 0xe2, 0x48, + 0x83, 0xe4, 0xf0, 0x50, 0x54, 0x45, 0x31, 0xc0, 0x31, 0xc9, 0x48, 0x8d, 0x3d, 0xca, + 0x00, 0x00, 0x00, 0xff, 0x15, 0x53, 0x2f, 0x00, 0x00, 0xf4, 0x66, 0x2e, 0x0f, 0x1f, + 0x84, 0x00, 0x00, 0x00, 0x00, 0x48, 0x8d, 0x3d, 0x79, 0x2f, 0x00, 0x00, 0x48, 0x8d, + 0x05, 0x72, 0x2f, 0x00, 0x00, 0x48, 0x39, 0xf8, 0x74, 0x15, 0x48, 0x8b, 0x05, 0x36, + 0x2f, 0x00, 0x00, 0x48, 0x85, 0xc0, 0x74, 0x09, 0xff, 0xe0, 0x0f, 0x1f, 0x80, 0x00, + 0x00, 0x00, 0x00, 0xc3, + ]; + + // Use the Elf64File::read method to create an Elf64File instance + let res = Elf64File::read(&byte_data); + assert_eq!(res, Err(crate::elf::ElfError::InvalidPhdrSize)); + + // Construct an Elf64Hdr instance from the byte data + let elf_hdr = Elf64Hdr::read(&byte_data); + + // Did we fail to read the ELF header? + assert!(elf_hdr.is_ok()); + let elf_hdr = elf_hdr.unwrap(); + + let expected_type = 2; + let expected_machine = 0x3E; + let expected_version = 1; + + // Assert that the fields of the header match the expected values + assert_eq!(elf_hdr.e_type, expected_type); + assert_eq!(elf_hdr.e_machine, expected_machine); + assert_eq!(elf_hdr.e_version, expected_version); + } + + #[test] + fn test_elf64_load_segments() { + let mut load_segments = Elf64LoadSegments::new(); + let vaddr_range1 = Elf64AddrRange { + vaddr_begin: 0x1000, + vaddr_end: 0x2000, + }; + let vaddr_range2 = Elf64AddrRange { + vaddr_begin: 0x3000, + vaddr_end: 0x4000, + }; + let segment_index1 = 0; + let segment_index2 = 1; + + // Insert load segments + assert!(load_segments + .try_insert(vaddr_range1, segment_index1) + .is_ok()); + assert!(load_segments + .try_insert(vaddr_range2, segment_index2) + .is_ok()); + + // Lookup load segments by virtual address + let (index1, offset1) = load_segments + .lookup_vaddr_range(&Elf64AddrRange { + vaddr_begin: 0x1500, + vaddr_end: 0x1700, + }) + .unwrap(); + let (index2, offset2) = load_segments + .lookup_vaddr_range(&Elf64AddrRange { + vaddr_begin: 0x3500, + vaddr_end: 0x3700, + }) + .unwrap(); + + assert_eq!(index1, segment_index1); + assert_eq!(offset1, 0x500); // Offset within the segment + assert_eq!(index2, segment_index2); + assert_eq!(offset2, 0x500); // Offset within the segment + + // Total virtual address range + let total_range = load_segments.total_vaddr_range(); + assert_eq!(total_range.vaddr_begin, 0x1000); + assert_eq!(total_range.vaddr_end, 0x4000); + } +} diff --git a/src/error.rs b/src/error.rs index a86729851..51b7d2723 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,7 @@ use crate::fw_cfg::FwCfgError; use crate::sev::ghcb::GhcbError; use crate::sev::msr_protocol::GhcbMsrError; use crate::sev::SevSnpError; +use crate::task::TaskError; // As a general rule, functions private to a given module may use the // leaf error types. Public functions should return an SvsmError @@ -33,4 +34,6 @@ pub enum SvsmError { Acpi, // Errors from file systems FileSystem(FsError), + // Task management errors, + Task(TaskError), } diff --git a/src/fw_cfg.rs b/src/fw_cfg.rs index bc4e41a78..271644d4c 100644 --- a/src/fw_cfg.rs +++ b/src/fw_cfg.rs @@ -6,8 +6,10 @@ extern crate alloc; +use crate::address::{Address, PhysAddr}; use crate::error::SvsmError; use crate::mm::pagetable::max_phys_addr; +use crate::utils::MemoryRegion; use super::io::IOPort; use super::string::FixedString; @@ -67,20 +69,6 @@ impl FwCfgFile { } } -#[derive(Clone, Copy, Debug)] -pub struct MemoryRegion { - pub start: u64, - pub end: u64, -} - -impl MemoryRegion { - /// Returns `true` if the region overlaps with another region with given - /// start and end. - pub fn overlaps(&self, start: u64, end: u64) -> bool { - self.start < end && start < self.end - } -} - impl<'a> FwCfg<'a> { pub fn new(driver: &'a dyn IOPort) -> Self { FwCfg { driver } @@ -150,7 +138,7 @@ impl<'a> FwCfg<'a> { Err(SvsmError::FwCfg(FwCfgError::FileNotFound)) } - fn find_svsm_region(&self) -> Result { + fn find_svsm_region(&self) -> Result, SvsmError> { let file = self.file_selector("etc/sev/svsm")?; if file.size != 16 { @@ -161,19 +149,19 @@ impl<'a> FwCfg<'a> { Ok(self.read_memory_region()) } - fn read_memory_region(&self) -> MemoryRegion { - let start: u64 = self.read_le(); - let size: u64 = self.read_le(); - let end = start.saturating_add(size); + fn read_memory_region(&self) -> MemoryRegion { + let start = PhysAddr::from(self.read_le::()); + let size = self.read_le::(); + let end = start.saturating_add(size as usize); assert!(start <= max_phys_addr(), "{start:#018x} is out of range"); assert!(end <= max_phys_addr(), "{end:#018x} is out of range"); - MemoryRegion { start, end } + MemoryRegion::from_addresses(start, end) } - pub fn get_memory_regions(&self) -> Result, SvsmError> { - let mut regions: Vec = Vec::new(); + pub fn get_memory_regions(&self) -> Result>, SvsmError> { + let mut regions = Vec::new(); let file = self.file_selector("etc/e820")?; let entries = file.size / 20; @@ -191,33 +179,35 @@ impl<'a> FwCfg<'a> { Ok(regions) } - fn find_kernel_region_e820(&self) -> Result { + fn find_kernel_region_e820(&self) -> Result, SvsmError> { let regions = self.get_memory_regions()?; - let mut kernel_region = regions + let kernel_region = regions .iter() - .max_by_key(|region| region.start) - .copied() + .max_by_key(|region| region.start()) .ok_or(SvsmError::FwCfg(FwCfgError::KernelRegion))?; - let start = - (kernel_region.end.saturating_sub(KERNEL_REGION_SIZE)) & KERNEL_REGION_SIZE_MASK; + let start = PhysAddr::from( + kernel_region + .end() + .bits() + .saturating_sub(KERNEL_REGION_SIZE as usize) + & KERNEL_REGION_SIZE_MASK as usize, + ); - if start < kernel_region.start { + if start < kernel_region.start() { return Err(SvsmError::FwCfg(FwCfgError::KernelRegion)); } - kernel_region.start = start; - - Ok(kernel_region) + Ok(MemoryRegion::new(start, kernel_region.len())) } - pub fn find_kernel_region(&self) -> Result { + pub fn find_kernel_region(&self) -> Result, SvsmError> { let kernel_region = self .find_svsm_region() .or_else(|_| self.find_kernel_region_e820())?; // Make sure that the kernel region doesn't overlap with the loader. - if kernel_region.start < 640 * 1024 { + if kernel_region.start() < PhysAddr::from(640 * 1024u64) { return Err(SvsmError::FwCfg(FwCfgError::KernelRegion)); } @@ -227,7 +217,7 @@ impl<'a> FwCfg<'a> { // This needs to be &mut self to prevent iterator invalidation, where the caller // could do fw_cfg.select() while iterating. Having a mutable reference prevents // other references. - pub fn iter_flash_regions(&mut self) -> impl Iterator + '_ { + pub fn iter_flash_regions(&mut self) -> impl Iterator> + '_ { let num = match self.file_selector("etc/flash") { Ok(file) => { self.select(file.selector); diff --git a/src/fw_meta.rs b/src/fw_meta.rs index 04c933db4..e1f5f9c69 100644 --- a/src/fw_meta.rs +++ b/src/fw_meta.rs @@ -6,59 +6,28 @@ extern crate alloc; -use crate::address::{Address, PhysAddr}; +use crate::address::PhysAddr; use crate::cpu::percpu::this_cpu_mut; use crate::error::SvsmError; +use crate::kernel_launch::KernelLaunchInfo; use crate::mm::PerCPUPageMappingGuard; use crate::sev::ghcb::PageStateChangeOp; use crate::sev::{pvalidate, rmp_adjust, PvalidateOp, RMPFlags}; use crate::types::{PageSize, PAGE_SIZE}; -use crate::utils::{overlap, zero_mem_region}; +use crate::utils::{zero_mem_region, MemoryRegion}; use alloc::vec::Vec; -use core::cmp; use core::fmt; use core::mem::{align_of, size_of, size_of_val}; use core::str::FromStr; -#[derive(Copy, Clone, Debug)] -pub struct SevPreValidMem { - base: PhysAddr, - length: usize, -} - -impl SevPreValidMem { - fn new(base: PhysAddr, length: usize) -> Self { - Self { base, length } - } - - fn new_4k(base: PhysAddr) -> Self { - Self::new(base, PAGE_SIZE) - } - - #[inline] - fn end(&self) -> PhysAddr { - self.base + self.length - } - - fn overlap(&self, other: &Self) -> bool { - overlap(self.base, self.end(), other.base, other.end()) - } - - fn merge(self, other: Self) -> Self { - let base = cmp::min(self.base, other.base); - let length = cmp::max(self.end(), other.end()) - base; - Self::new(base, length) - } -} - #[derive(Clone, Debug)] pub struct SevFWMetaData { pub reset_ip: Option, pub cpuid_page: Option, pub secrets_page: Option, pub caa_page: Option, - pub valid_mem: Vec, + pub valid_mem: Vec>, } impl SevFWMetaData { @@ -73,7 +42,7 @@ impl SevFWMetaData { } pub fn add_valid_mem(&mut self, base: PhysAddr, len: usize) { - self.valid_mem.push(SevPreValidMem::new(base, len)); + self.valid_mem.push(MemoryRegion::new(base, len)); } } @@ -392,8 +361,8 @@ fn parse_sev_meta( Ok(()) } -fn validate_fw_mem_region(region: SevPreValidMem) -> Result<(), SvsmError> { - let pstart = region.base; +fn validate_fw_mem_region(region: MemoryRegion) -> Result<(), SvsmError> { + let pstart = region.start(); let pend = region.end(); log::info!("Validating {:#018x}-{:#018x}", pstart, pend); @@ -408,10 +377,7 @@ fn validate_fw_mem_region(region: SevPreValidMem) -> Result<(), SvsmError> { ) .expect("GHCB PSC call failed to validate firmware memory"); - for paddr in (pstart.bits()..pend.bits()) - .step_by(PAGE_SIZE) - .map(PhysAddr::from) - { + for paddr in region.iter_pages(PageSize::Regular) { let guard = PerCPUPageMappingGuard::create_4k(paddr)?; let vaddr = guard.virt_addr(); @@ -430,17 +396,17 @@ fn validate_fw_mem_region(region: SevPreValidMem) -> Result<(), SvsmError> { Ok(()) } -fn validate_fw_memory_vec(regions: Vec) -> Result<(), SvsmError> { +fn validate_fw_memory_vec(regions: Vec>) -> Result<(), SvsmError> { if regions.is_empty() { return Ok(()); } - let mut next_vec: Vec = Vec::new(); + let mut next_vec = Vec::new(); let mut region = regions[0]; for next in regions.into_iter().skip(1) { - if region.overlap(&next) { - region = region.merge(next); + if region.contiguous(&next) { + region = region.merge(&next); } else { next_vec.push(next); } @@ -450,27 +416,37 @@ fn validate_fw_memory_vec(regions: Vec) -> Result<(), SvsmError> validate_fw_memory_vec(next_vec) } -pub fn validate_fw_memory(fw_meta: &SevFWMetaData) -> Result<(), SvsmError> { +pub fn validate_fw_memory( + fw_meta: &SevFWMetaData, + launch_info: &KernelLaunchInfo, +) -> Result<(), SvsmError> { // Initalize vector with regions from the FW let mut regions = fw_meta.valid_mem.clone(); // Add region for CPUID page if present if let Some(cpuid_paddr) = fw_meta.cpuid_page { - regions.push(SevPreValidMem::new_4k(cpuid_paddr)); + regions.push(MemoryRegion::new(cpuid_paddr, PAGE_SIZE)); } // Add region for Secrets page if present if let Some(secrets_paddr) = fw_meta.secrets_page { - regions.push(SevPreValidMem::new_4k(secrets_paddr)); + regions.push(MemoryRegion::new(secrets_paddr, PAGE_SIZE)); } // Add region for CAA page if present if let Some(caa_paddr) = fw_meta.caa_page { - regions.push(SevPreValidMem::new_4k(caa_paddr)); + regions.push(MemoryRegion::new(caa_paddr, PAGE_SIZE)); } // Sort regions by base address - regions.sort_unstable_by(|a, b| a.base.cmp(&b.base)); + regions.sort_unstable_by_key(|a| a.start()); + + let kernel_region = launch_info.kernel_region(); + for region in regions.iter() { + if region.overlap(&kernel_region) { + panic!("FwMeta region ovelaps with kernel"); + } + } validate_fw_memory_vec(regions) } @@ -501,7 +477,7 @@ pub fn print_fw_meta(fw_meta: &SevFWMetaData) { for region in &fw_meta.valid_mem { log::info!( " Pre-Validated Region {:#018x}-{:#018x}", - region.base, + region.start(), region.end() ); } diff --git a/src/greq/driver.rs b/src/greq/driver.rs new file mode 100644 index 000000000..fad3edfce --- /dev/null +++ b/src/greq/driver.rs @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! Driver to send `SNP_GUEST_REQUEST` commands to the PSP. It can be any of the +//! request or response command types defined in the SEV-SNP spec, regardless if it's +//! a regular or an extended command. + +extern crate alloc; + +use alloc::boxed::Box; +use core::{cell::OnceCell, mem::size_of}; + +use crate::{ + address::VirtAddr, + cpu::percpu::this_cpu_mut, + error::SvsmError, + greq::msg::{SnpGuestRequestExtData, SnpGuestRequestMsg, SnpGuestRequestMsgType}, + locking::SpinLock, + protocols::errors::{SvsmReqError, SvsmResultCode}, + sev::{ + ghcb::GhcbError, + secrets_page::{disable_vmpck0, get_vmpck0, is_vmpck0_clear, VMPCK_SIZE}, + }, + types::PAGE_SHIFT, + BIT, +}; + +/// Global `SNP_GUEST_REQUEST` driver instance +static GREQ_DRIVER: SpinLock> = SpinLock::new(OnceCell::new()); + +// Hypervisor error codes + +/// Buffer provided is too small +const SNP_GUEST_REQ_INVALID_LEN: u64 = BIT!(32); +/// Hypervisor busy, try again +const SNP_GUEST_REQ_ERR_BUSY: u64 = BIT!(33); + +/// Class of the `SNP_GUEST_REQUEST` command: Regular or Extended +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +enum SnpGuestRequestClass { + Regular = 0, + Extended = 1, +} + +/// `SNP_GUEST_REQUEST` driver +#[derive(Debug)] +struct SnpGuestRequestDriver { + /// Shared page used for the `SNP_GUEST_REQUEST` request + request: Box, + /// Shared page used for the `SNP_GUEST_REQUEST` response + response: Box, + /// Encrypted page where we perform crypto operations + staging: Box, + /// Extended data buffer that will be provided to the hypervisor + /// to store the SEV-SNP certificates + ext_data: Box, + /// Extended data size (`certs` size) provided by the user in [`get_extended_report()`]. + /// It will be provided to the hypervisor. + user_extdata_size: usize, + /// Each `SNP_GUEST_REQUEST` message contains a sequence number per VMPCK. + /// The sequence number is incremented with each message sent. Messages + /// sent by the guest to the PSP and by the PSP to the guest must be + /// delivered in order. If not, the PSP will reject subsequent messages + /// by the guest when it detects that the sequence numbers are out of sync. + /// + /// NOTE: If the vmpl field of a `SNP_GUEST_REQUEST` message is set to VMPL0, + /// then it must contain the VMPL0 sequence number and be protected (encrypted) + /// with the VMPCK0 key; additionally, if this message fails, the VMPCK0 key + /// must be disabled. The same idea applies to the other VMPL levels. + /// + /// The SVSM needs to support only VMPL0 `SNP_GUEST_REQUEST` commands because + /// other layers in the software stack (e.g. OVMF and guest kernel) can send + /// non-VMPL0 commands directly to PSP. Therefore, the SVSM needs to maintain + /// the sequence number and the VMPCK only for VMPL0. + vmpck0_seqno: u64, +} + +impl Drop for SnpGuestRequestDriver { + fn drop(&mut self) { + if self.request.set_encrypted().is_err() { + let new_req = + SnpGuestRequestMsg::boxed_new().expect("GREQ: failed to allocate request"); + let old_req = core::mem::replace(&mut self.request, new_req); + log::error!("GREQ: request: failed to set page to encrypted. Memory leak!"); + Box::leak(old_req); + } + if self.response.set_encrypted().is_err() { + let new_resp = + SnpGuestRequestMsg::boxed_new().expect("GREQ: failed to allocate response"); + let old_resp = core::mem::replace(&mut self.response, new_resp); + log::error!("GREQ: response: failed to set page to encrypted. Memory leak!"); + Box::leak(old_resp); + } + if self.ext_data.set_encrypted().is_err() { + let new_data = + SnpGuestRequestExtData::boxed_new().expect("GREQ: failed to allocate ext_data"); + let old_data = core::mem::replace(&mut self.ext_data, new_data); + log::error!("GREQ: ext_data: failed to set pages to encrypted. Memory leak!"); + Box::leak(old_data); + } + } +} + +impl SnpGuestRequestDriver { + /// Create a new [`SnpGuestRequestDriver`] + pub fn new() -> Result { + let request = SnpGuestRequestMsg::boxed_new()?; + let response = SnpGuestRequestMsg::boxed_new()?; + let staging = SnpGuestRequestMsg::boxed_new()?; + let ext_data = SnpGuestRequestExtData::boxed_new()?; + + let mut driver = Self { + request, + response, + staging, + ext_data, + user_extdata_size: size_of::(), + vmpck0_seqno: 0, + }; + + driver.request.set_shared()?; + driver.response.set_shared()?; + driver.ext_data.set_shared()?; + + Ok(driver) + } + + /// Get the last VMPCK0 sequence number accounted + fn seqno_last_used(&self) -> u64 { + self.vmpck0_seqno + } + + /// Increase the VMPCK0 sequence number by two. In order to keep the + /// sequence number in-sync with the PSP, this is called only when the + /// `SNP_GUEST_REQUEST` response is received. + fn seqno_add_two(&mut self) { + self.vmpck0_seqno += 2; + } + + /// Set the user_extdata_size to `n` and clear the first `n` bytes from `ext_data` + pub fn set_user_extdata_size(&mut self, n: usize) -> Result<(), SvsmReqError> { + // At least one page + if (n >> PAGE_SHIFT) == 0 { + return Err(SvsmReqError::invalid_parameter()); + } + self.ext_data.nclear(n)?; + self.user_extdata_size = n; + + Ok(()) + } + + /// Call the GHCB layer to send the encrypted SNP_GUEST_REQUEST message + /// to the PSP. + fn send(&mut self, req_class: SnpGuestRequestClass) -> Result<(), SvsmReqError> { + self.response.clear(); + + let req_page = VirtAddr::from(&mut *self.request as *mut SnpGuestRequestMsg); + let resp_page = VirtAddr::from(&mut *self.response as *mut SnpGuestRequestMsg); + let data_pages = VirtAddr::from(&mut *self.ext_data as *mut SnpGuestRequestExtData); + + if req_class == SnpGuestRequestClass::Extended { + let num_user_pages = (self.user_extdata_size >> PAGE_SHIFT) as u64; + this_cpu_mut().ghcb().guest_ext_request( + req_page, + resp_page, + data_pages, + num_user_pages, + )?; + } else { + this_cpu_mut().ghcb().guest_request(req_page, resp_page)?; + } + + self.seqno_add_two(); + + Ok(()) + } + + // Encrypt the request message from encrypted memory + fn encrypt_request( + &mut self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + buffer: &mut [u8], + command_len: usize, + ) -> Result<(), SvsmReqError> { + // VMPL0 `SNP_GUEST_REQUEST` commands are encrypted with the VMPCK0 key + let vmpck0: [u8; VMPCK_SIZE] = get_vmpck0(); + + let inbuf = buffer + .get(..command_len) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + // For security reasons, encrypt the message in protected memory (staging) + // and then copy the result to shared memory (request) + self.staging + .encrypt_set(msg_type, msg_seqno, &vmpck0, inbuf)?; + *self.request = *self.staging; + Ok(()) + } + + // Decrypt the response message from encrypted memory + fn decrypt_response( + &mut self, + msg_seqno: u64, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + ) -> Result { + let vmpck0: [u8; VMPCK_SIZE] = get_vmpck0(); + + // For security reasons, decrypt the message in protected memory (staging) + *self.staging = *self.response; + let result = self + .staging + .decrypt_get(msg_type, msg_seqno, &vmpck0, buffer); + + if let Err(e) = result { + match e { + // The buffer provided is too small to store the unwrapped response. + // There is no need to clear the VMPCK0, just report it as invalid parameter. + SvsmReqError::RequestError(SvsmResultCode::INVALID_PARAMETER) => (), + _ => disable_vmpck0(), + } + } + + result + } + + /// Send the provided VMPL0 `SNP_GUEST_REQUEST` command to the PSP. + /// + /// The command will be encrypted using AES-256 GCM. + /// + /// # Arguments + /// + /// * `req_class`: whether this is a regular or extended `SNP_GUEST_REQUEST` command + /// * `msg_type`: type of the command stored in `buffer`, e.g. [`SNP_MSG_REPORT_REQ`] + /// * `buffer`: buffer with the `SNP_GUEST_REQUEST` command to be sent. + /// The same buffer will also be used to store the response. + /// * `command_len`: Size (in bytes) of the command stored in `buffer` + /// + /// # Returns + /// + /// * Success: + /// * `usize`: Size (in bytes) of the response stored in `buffer` + /// * Error: + /// * [`SvsmReqError`] + fn send_request( + &mut self, + req_class: SnpGuestRequestClass, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + command_len: usize, + ) -> Result { + if is_vmpck0_clear() { + return Err(SvsmReqError::invalid_request()); + } + + // Message sequence number overflow, the driver will not able + // to send subsequent `SNP_GUEST_REQUEST` messages to the PSP. + // The sequence number is restored only when the guest is rebooted. + let Some(msg_seqno) = self.seqno_last_used().checked_add(1) else { + log::error!("SNP_GUEST_REQUEST: sequence number overflow"); + disable_vmpck0(); + return Err(SvsmReqError::invalid_request()); + }; + + self.encrypt_request(msg_type, msg_seqno, buffer, command_len)?; + + if let Err(e) = self.send(req_class) { + if let SvsmReqError::FatalError(SvsmError::Ghcb(GhcbError::VmgexitError(_rbx, info2))) = + e + { + // For some reason the hypervisor did not forward the request to the PSP. + // + // Because the message sequence number is used as part of the AES-GCM IV, it is important that the + // guest retry the request before allowing another request to be performed so that the IV cannot be + // reused on a new message payload. + match info2 & 0xffff_ffff_0000_0000u64 { + // The certificate buffer provided is too small. + SNP_GUEST_REQ_INVALID_LEN => { + if req_class == SnpGuestRequestClass::Extended { + if let Err(e1) = self.send(SnpGuestRequestClass::Regular) { + log::error!( + "SNP_GUEST_REQ_INVALID_LEN. Aborting, request resend failed" + ); + disable_vmpck0(); + return Err(e1); + } + return Err(e); + } else { + // We sent a regular SNP_GUEST_REQUEST, but the hypervisor returned + // an error code that is exclusive for extended SNP_GUEST_REQUEST + disable_vmpck0(); + return Err(SvsmReqError::invalid_request()); + } + } + // The hypervisor is busy. + SNP_GUEST_REQ_ERR_BUSY => { + if let Err(e2) = self.send(req_class) { + log::error!("SNP_GUEST_REQ_ERR_BUSY. Aborting, request resend failed"); + disable_vmpck0(); + return Err(e2); + } + // ... request resend worked, continue normally. + } + // Failed for unknown reason. Status codes can be found in + // the AMD SEV-SNP spec or in the linux kernel include/uapi/linux/psp-sev.h + _ => { + log::error!("SNP_GUEST_REQUEST failed, unknown error code={}\n", info2); + disable_vmpck0(); + return Err(e); + } + } + } + } + + let msg_seqno = self.seqno_last_used(); + let resp_msg_type = SnpGuestRequestMsgType::try_from(msg_type as u8 + 1)?; + + self.decrypt_response(msg_seqno, resp_msg_type, buffer) + } + + /// Send the provided regular `SNP_GUEST_REQUEST` command to the PSP + pub fn send_regular_guest_request( + &mut self, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + command_len: usize, + ) -> Result { + self.send_request(SnpGuestRequestClass::Regular, msg_type, buffer, command_len) + } + + /// Send the provided extended `SNP_GUEST_REQUEST` command to the PSP + pub fn send_extended_guest_request( + &mut self, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + command_len: usize, + certs: &mut [u8], + ) -> Result { + self.set_user_extdata_size(certs.len())?; + + let outbuf_len: usize = self.send_request( + SnpGuestRequestClass::Extended, + msg_type, + buffer, + command_len, + )?; + + // The SEV-SNP certificates can be used to verify the attestation report. At this point, a zeroed + // ext_data buffer indicates that the certificates were not imported. + // The VM owner can import them from the host using the virtee/snphost project + if self.ext_data.is_nclear(certs.len())? { + log::warn!("SEV-SNP certificates not found. Make sure they were loaded from the host."); + } else { + self.ext_data.copy_to_slice(certs)?; + } + + Ok(outbuf_len) + } +} + +/// Initialize the global `SnpGuestRequestDriver` +/// +/// # Panics +/// +/// This function panics if we fail to initialize any of the `SnpGuestRequestDriver` fields. +pub fn guest_request_driver_init() { + let cell = GREQ_DRIVER.lock(); + let _ = cell.get_or_init(|| { + SnpGuestRequestDriver::new().expect("SnpGuestRequestDriver failed to initialize") + }); +} + +/// Send the provided regular `SNP_GUEST_REQUEST` command to the PSP. +/// Further details can be found in the `SnpGuestRequestDriver.send_request()` documentation. +pub fn send_regular_guest_request( + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + request_len: usize, +) -> Result { + let mut cell = GREQ_DRIVER.lock(); + let driver: &mut SnpGuestRequestDriver = + cell.get_mut().ok_or_else(SvsmReqError::invalid_request)?; + driver.send_regular_guest_request(msg_type, buffer, request_len) +} + +/// Send the provided extended `SNP_GUEST_REQUEST` command to the PSP +/// Further details can be found in the `SnpGuestRequestDriver.send_request()` documentation. +pub fn send_extended_guest_request( + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + request_len: usize, + certs: &mut [u8], +) -> Result { + let mut cell = GREQ_DRIVER.lock(); + let driver: &mut SnpGuestRequestDriver = + cell.get_mut().ok_or_else(SvsmReqError::invalid_request)?; + driver.send_extended_guest_request(msg_type, buffer, request_len, certs) +} diff --git a/src/greq/mod.rs b/src/greq/mod.rs new file mode 100644 index 000000000..56aca7547 --- /dev/null +++ b/src/greq/mod.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! `SNP_GUEST_REQUEST` mechanism to communicate with the PSP + +pub mod driver; +pub mod msg; +pub mod pld_report; +pub mod services; diff --git a/src/greq/msg.rs b/src/greq/msg.rs new file mode 100644 index 000000000..0dccebde5 --- /dev/null +++ b/src/greq/msg.rs @@ -0,0 +1,609 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! Message that carries an encrypted `SNP_GUEST_REQUEST` command in the payload + +extern crate alloc; + +use alloc::{ + alloc::{alloc_zeroed, Layout}, + boxed::Box, +}; +use core::{ + mem::size_of, + slice::{from_raw_parts, from_raw_parts_mut}, +}; + +use crate::{ + address::{Address, VirtAddr}, + cpu::percpu::this_cpu_mut, + crypto::aead::{Aes256Gcm, Aes256GcmTrait, AUTHTAG_SIZE, IV_SIZE}, + mm::virt_to_phys, + protocols::errors::SvsmReqError, + sev::{ghcb::PageStateChangeOp, secrets_page::VMPCK_SIZE}, + types::{PageSize, PAGE_SIZE}, +}; + +// Message Header Format (AMD SEV-SNP spec. table 98) + +/// Version of the message header +const HDR_VERSION: u8 = 1; +/// Version of the message payload +const MSG_VERSION: u8 = 1; + +/// AEAD Algorithm Encodings (AMD SEV-SNP spec. table 99) +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum SnpGuestRequestAead { + Invalid = 0, + Aes256Gcm = 1, +} + +/// Message Type Encodings (AMD SEV-SNP spec. table 100) +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum SnpGuestRequestMsgType { + Invalid = 0, + ReportRequest = 5, + ReportResponse = 6, +} + +impl TryFrom for SnpGuestRequestMsgType { + type Error = SvsmReqError; + + fn try_from(v: u8) -> Result { + match v { + x if x == SnpGuestRequestMsgType::Invalid as u8 => Ok(SnpGuestRequestMsgType::Invalid), + x if x == SnpGuestRequestMsgType::ReportRequest as u8 => { + Ok(SnpGuestRequestMsgType::ReportRequest) + } + x if x == SnpGuestRequestMsgType::ReportResponse as u8 => { + Ok(SnpGuestRequestMsgType::ReportResponse) + } + _ => Err(SvsmReqError::invalid_parameter()), + } + } +} + +/// Message header size +const MSG_HDR_SIZE: usize = size_of::(); +/// Message payload size +const MSG_PAYLOAD_SIZE: usize = PAGE_SIZE - MSG_HDR_SIZE; + +/// Maximum buffer size that the hypervisor takes to store the +/// SEV-SNP certificates +pub const SNP_GUEST_REQ_MAX_DATA_SIZE: usize = 4 * PAGE_SIZE; + +/// `SNP_GUEST_REQUEST` message format +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SnpGuestRequestMsg { + hdr: SnpGuestRequestMsgHdr, + pld: [u8; MSG_PAYLOAD_SIZE], +} + +/// `SNP_GUEST_REQUEST` message header format +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SnpGuestRequestMsgHdr { + /// Message authentication tag + authtag: [u8; 32], + /// The sequence number for this message + msg_seqno: u64, + /// Reserve. Must be zero. + rsvd1: [u8; 8], + /// The AEAD used to encrypt this message + algo: u8, + /// The version of the message header + hdr_version: u8, + /// The size of the message header in bytes + hdr_sz: u16, + /// The type of the payload + msg_type: u8, + /// The version of the payload + msg_version: u8, + /// The size of the payload in bytes + msg_sz: u16, + /// Reserved. Must be zero. + rsvd2: u32, + /// The ID of the VMPCK used to protect this message + msg_vmpck: u8, + /// Reserved. Must be zero. + rsvd3: [u8; 35], +} + +impl SnpGuestRequestMsgHdr { + /// Allocate a new [`SnpGuestRequestMsgHdr`] and initialize it + /// + /// # Panic + /// + /// * [`SnpGuestRequestMsgHdr`] size does not fit in a u16. + pub fn new(msg_sz: u16, msg_type: SnpGuestRequestMsgType, msg_seqno: u64) -> Self { + assert!(u16::try_from(MSG_HDR_SIZE).is_ok()); + + Self { + msg_seqno, + algo: SnpGuestRequestAead::Aes256Gcm as u8, + hdr_version: HDR_VERSION, + hdr_sz: MSG_HDR_SIZE as u16, + msg_type: msg_type as u8, + msg_version: MSG_VERSION, + msg_sz, + msg_vmpck: 0, + ..Default::default() + } + } + + /// Set the authenticated tag + fn set_authtag(&mut self, new_tag: &[u8]) -> Result<(), SvsmReqError> { + self.authtag + .get_mut(..new_tag.len()) + .ok_or_else(SvsmReqError::invalid_parameter)? + .copy_from_slice(new_tag); + Ok(()) + } + + /// Validate the [`SnpGuestRequestMsgHdr`] fields + fn validate( + &self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + ) -> Result<(), SvsmReqError> { + let header_size = + u16::try_from(MSG_HDR_SIZE).map_err(|_| SvsmReqError::invalid_format())?; + if self.hdr_version != HDR_VERSION + || self.hdr_sz != header_size + || self.algo != SnpGuestRequestAead::Aes256Gcm as u8 + || self.msg_type != msg_type as u8 + || self.msg_vmpck != 0 + || self.msg_seqno != msg_seqno + { + return Err(SvsmReqError::invalid_format()); + } + Ok(()) + } + + /// Get a slice of the header fields used as additional authenticated data (AAD) + fn get_aad_slice(&self) -> &[u8] { + let self_gva = self as *const _ as *const u8; + let algo_gva = &self.algo as *const u8; + + let algo_offset = unsafe { algo_gva.offset_from(self_gva) } as usize; + + unsafe { from_raw_parts(algo_gva, MSG_HDR_SIZE - algo_offset) } + } + + /// Get [`SnpGuestRequestMsgHdr`] as a mutable slice reference + fn as_slice_mut(&mut self) -> &mut [u8] { + unsafe { from_raw_parts_mut(self as *mut _ as *mut u8, MSG_HDR_SIZE) } + } +} + +impl Default for SnpGuestRequestMsgHdr { + /// default() method implementation. We can't derive Default because + /// the field "rsvd3: [u8; 35]" conflicts with the Default trait, which + /// supports up to [T; 32]. + fn default() -> Self { + Self { + authtag: [0; 32], + msg_seqno: 0, + rsvd1: [0; 8], + algo: 0, + hdr_version: 0, + hdr_sz: 0, + msg_type: 0, + msg_version: 0, + msg_sz: 0, + rsvd2: 0, + msg_vmpck: 0, + rsvd3: [0; 35], + } + } +} + +impl SnpGuestRequestMsg { + /// Allocate the object in the heap without going through stack as + /// this is a large object + /// + /// # Panic + /// + /// * Memory allocated is not page aligned or Self does not + /// fit into a page + pub fn boxed_new() -> Result, SvsmReqError> { + let layout = Layout::new::(); + + // The GHCB spec says it has to fit in one page and be page aligned + assert!(layout.size() <= PAGE_SIZE); + + unsafe { + let addr = alloc_zeroed(layout); + if addr.is_null() { + return Err(SvsmReqError::invalid_request()); + } + + assert!(VirtAddr::from(addr).is_page_aligned()); + + let ptr = addr.cast::(); + Ok(Box::from_raw(ptr)) + } + } + + /// Clear the C-bit (memory encryption bit) for the Self page + /// + /// # Safety + /// + /// * The caller is responsible for setting the page back to encrypted + /// before the object is dropped. Shared pages should not be freed + /// (returned to the allocator) + pub fn set_shared(&mut self) -> Result<(), SvsmReqError> { + let vaddr = VirtAddr::from(self as *mut Self); + this_cpu_mut() + .get_pgtable() + .set_shared_4k(vaddr) + .map_err(|_| SvsmReqError::invalid_request())?; + + let paddr = virt_to_phys(vaddr); + this_cpu_mut() + .ghcb() + .page_state_change( + paddr, + paddr + PAGE_SIZE, + PageSize::Regular, + PageStateChangeOp::PscShared, + ) + .map_err(|_| SvsmReqError::invalid_request()) + } + + /// Set the C-bit (memory encryption bit) for the Self page + pub fn set_encrypted(&mut self) -> Result<(), SvsmReqError> { + let vaddr = VirtAddr::from(self as *mut Self); + this_cpu_mut() + .get_pgtable() + .set_encrypted_4k(vaddr) + .map_err(|_| SvsmReqError::invalid_request())?; + + let paddr = virt_to_phys(vaddr); + this_cpu_mut() + .ghcb() + .page_state_change( + paddr, + paddr + PAGE_SIZE, + PageSize::Regular, + PageStateChangeOp::PscPrivate, + ) + .map_err(|_| SvsmReqError::invalid_request()) + } + + /// Fill the [`SnpGuestRequestMsg`] fields with zeros + pub fn clear(&mut self) { + self.hdr.as_slice_mut().fill(0); + self.pld.fill(0); + } + + /// Encrypt the provided `SNP_GUEST_REQUEST` command and store the result in the actual message payload + /// + /// The command will be encrypted using AES-256 GCM and part of the message header will be + /// used as additional authenticated data (AAD). + /// + /// # Arguments + /// + /// * `msg_type`: Type of the command stored in the `command` buffer. + /// * `msg_seqno`: VMPL0 sequence number to be used in the message. The PSP will reject + /// subsequent messages when it detects that the sequence numbers are + /// out of sync. The sequence number is also used as initialization + /// vector (IV) in encryption. + /// * `vmpck0`: VMPCK0 key that will be used to encrypt the command. + /// * `command`: command slice to be encrypted. + /// + /// # Returns + /// + /// () on success and [`SvsmReqError`] on error. + /// + /// # Panic + /// + /// * The command length does not fit in a u16 + /// * The encrypted and the original command don't have the same size + pub fn encrypt_set( + &mut self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + vmpck0: &[u8; VMPCK_SIZE], + command: &[u8], + ) -> Result<(), SvsmReqError> { + let payload_size_u16 = + u16::try_from(command.len()).map_err(|_| SvsmReqError::invalid_parameter())?; + + let mut msg_hdr = SnpGuestRequestMsgHdr::new(payload_size_u16, msg_type, msg_seqno); + let aad: &[u8] = msg_hdr.get_aad_slice(); + let iv: [u8; IV_SIZE] = build_iv(msg_seqno); + + self.pld.fill(0); + + // Encrypt the provided command and store the result in the message payload + let authtag_end: usize = Aes256Gcm::encrypt(&iv, vmpck0, aad, command, &mut self.pld)?; + + // In the Aes256Gcm encrypt API, the authtag is postfixed (comes after the encrypted payload) + let ciphertext_end: usize = authtag_end - AUTHTAG_SIZE; + let authtag = self + .pld + .get_mut(ciphertext_end..authtag_end) + .ok_or_else(SvsmReqError::invalid_request)?; + + // The command should have the same size when encrypted and decrypted + assert_eq!(command.len(), ciphertext_end); + + // Move the authtag to the message header + msg_hdr.set_authtag(authtag)?; + authtag.fill(0); + + self.hdr = msg_hdr; + + Ok(()) + } + + /// Decrypt the `SNP_GUEST_REQUEST` command stored in the message and store the decrypted command in + /// the provided `outbuf`. + /// + /// The command stored in the message payload is usually a response command received from the PSP. + /// It will be decrypted using AES-256 GCM and part of the message header will be used as + /// additional authenticated data (AAD). + /// + /// # Arguments + /// + /// * `msg_type`: Type of the command stored in the message payload + /// * `msg_seqno`: VMPL0 sequence number that was used in the message. + /// * `vmpck0`: VMPCK0 key, it will be used to decrypt the message + /// * `outbuf`: buffer that will be used to store the decrypted message payload + /// + /// # Returns + /// + /// * Success + /// * usize: Number of bytes written to `outbuf` + /// * Error + /// * [`SvsmReqError`] + pub fn decrypt_get( + &mut self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + vmpck0: &[u8; VMPCK_SIZE], + outbuf: &mut [u8], + ) -> Result { + self.hdr.validate(msg_type, msg_seqno)?; + + let iv: [u8; IV_SIZE] = build_iv(msg_seqno); + let aad: &[u8] = self.hdr.get_aad_slice(); + + // In the Aes256Gcm decrypt API, the authtag must be provided postfix in the inbuf + let ciphertext_end = usize::from(self.hdr.msg_sz); + let tag_end: usize = ciphertext_end + AUTHTAG_SIZE; + + // The message payload must be large enough to hold the ciphertext and + // the authentication tag. + let hdr_tag = self + .hdr + .authtag + .get(..AUTHTAG_SIZE) + .ok_or_else(SvsmReqError::invalid_request)?; + let pld_tag = self + .pld + .get_mut(ciphertext_end..tag_end) + .ok_or_else(SvsmReqError::invalid_request)?; + pld_tag.copy_from_slice(hdr_tag); + + // Payload with postfixed authtag + let inbuf = self + .pld + .get(..tag_end) + .ok_or_else(SvsmReqError::invalid_request)?; + + let outbuf_len: usize = Aes256Gcm::decrypt(&iv, vmpck0, aad, inbuf, outbuf)?; + + Ok(outbuf_len) + } +} + +/// Build the initialization vector for AES-256 GCM +fn build_iv(msg_seqno: u64) -> [u8; IV_SIZE] { + const U64_SIZE: usize = size_of::(); + let mut iv = [0u8; IV_SIZE]; + + iv[..U64_SIZE].copy_from_slice(&msg_seqno.to_ne_bytes()); + iv +} + +/// Set to encrypted all the 4k pages of a memory range +fn set_encrypted_region_4k(start: VirtAddr, end: VirtAddr) -> Result<(), SvsmReqError> { + for addr in (start.bits()..end.bits()) + .step_by(PAGE_SIZE) + .map(VirtAddr::from) + { + this_cpu_mut() + .get_pgtable() + .set_encrypted_4k(addr) + .map_err(|_| SvsmReqError::invalid_request())?; + + let paddr = virt_to_phys(addr); + this_cpu_mut() + .ghcb() + .page_state_change( + paddr, + paddr + PAGE_SIZE, + PageSize::Regular, + PageStateChangeOp::PscPrivate, + ) + .map_err(|_| SvsmReqError::invalid_request())?; + } + Ok(()) +} + +/// Set to shared all the 4k pages of a memory range +fn set_shared_region_4k(start: VirtAddr, end: VirtAddr) -> Result<(), SvsmReqError> { + for addr in (start.bits()..end.bits()) + .step_by(PAGE_SIZE) + .map(VirtAddr::from) + { + this_cpu_mut() + .get_pgtable() + .set_shared_4k(addr) + .map_err(|_| SvsmReqError::invalid_request())?; + + let paddr = virt_to_phys(addr); + this_cpu_mut() + .ghcb() + .page_state_change( + paddr, + paddr + PAGE_SIZE, + PageSize::Regular, + PageStateChangeOp::PscShared, + ) + .map_err(|_| SvsmReqError::invalid_request())?; + } + Ok(()) +} + +/// Data page(s) the hypervisor will use to store certificate data in +/// an extended `SNP_GUEST_REQUEST` +#[repr(C, packed)] +#[derive(Debug)] +pub struct SnpGuestRequestExtData { + /// According to the GHCB spec, the data page(s) must be contiguous pages if + /// supplying more than one page and all certificate pages must be + /// assigned to the hypervisor (shared). + data: [u8; SNP_GUEST_REQ_MAX_DATA_SIZE], +} + +impl SnpGuestRequestExtData { + /// Allocate the object in the heap without going through stack as + /// this is a large object + pub fn boxed_new() -> Result, SvsmReqError> { + let layout = Layout::new::(); + unsafe { + let addr = alloc_zeroed(layout); + if addr.is_null() { + return Err(SvsmReqError::invalid_request()); + } + assert!(VirtAddr::from(addr).is_page_aligned()); + + let ptr = addr.cast::(); + Ok(Box::from_raw(ptr)) + } + } + + /// Clear the C-bit (memory encryption bit) for the Self pages + /// + /// # Safety + /// + /// * The caller is responsible for setting the page back to encrypted + /// before the object is dropped. Shared pages should not be freed + /// (returned to the allocator) + pub fn set_shared(&mut self) -> Result<(), SvsmReqError> { + const EXT_DATA_SIZE: usize = size_of::(); + + let start = VirtAddr::from(self as *mut Self); + let end = VirtAddr::from(start.bits() + EXT_DATA_SIZE); + set_shared_region_4k(start, end) + } + + /// Set the C-bit (memory encryption bit) for the Self pages + pub fn set_encrypted(&mut self) -> Result<(), SvsmReqError> { + const EXT_DATA_SIZE: usize = size_of::(); + + let start = VirtAddr::from(self as *mut Self); + let end = VirtAddr::from(start.bits() + EXT_DATA_SIZE); + set_encrypted_region_4k(start, end) + } + + /// Clear the first `n` bytes from data + pub fn nclear(&mut self, n: usize) -> Result<(), SvsmReqError> { + self.data + .get_mut(..n) + .ok_or_else(SvsmReqError::invalid_parameter)? + .fill(0); + Ok(()) + } + + /// Fill up the `outbuf` slice provided with bytes from data + pub fn copy_to_slice(&self, outbuf: &mut [u8]) -> Result<(), SvsmReqError> { + let data = self + .data + .get(..outbuf.len()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + outbuf.copy_from_slice(data); + Ok(()) + } + + /// Check if the first `n` bytes from data are zeroed + pub fn is_nclear(&self, n: usize) -> Result { + let data = self + .data + .get(..n) + .ok_or_else(SvsmReqError::invalid_parameter)?; + Ok(data.iter().all(|e| *e == 0)) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + greq::msg::{ + SnpGuestRequestMsg, SnpGuestRequestMsgHdr, SnpGuestRequestMsgType, MSG_HDR_SIZE, + MSG_PAYLOAD_SIZE, + }, + sev::secrets_page::VMPCK_SIZE, + }; + + #[test] + fn u16_from_guest_msg_hdr_size() { + assert!(u16::try_from(MSG_HDR_SIZE).is_ok()); + } + + #[test] + fn aad_size() { + let hdr = SnpGuestRequestMsgHdr::default(); + let aad = hdr.get_aad_slice(); + + const HDR_ALGO_OFFSET: usize = 48; + + assert_eq!(aad.len(), MSG_HDR_SIZE - HDR_ALGO_OFFSET); + } + + #[test] + fn encrypt_decrypt_payload() { + let mut msg = SnpGuestRequestMsg { + hdr: SnpGuestRequestMsgHdr::default(), + pld: [0; MSG_PAYLOAD_SIZE], + }; + + const PLAINTEXT: &[u8] = b"request-to-be-encrypted"; + let vmpck0 = [5u8; VMPCK_SIZE]; + let vmpck0_seqno: u64 = 1; + + let result = msg.encrypt_set( + SnpGuestRequestMsgType::ReportRequest, + vmpck0_seqno, + &vmpck0, + PLAINTEXT, + ); + + assert!(result.is_ok()); + + let mut outbuf = [0u8; PLAINTEXT.len()]; + + let result = msg.decrypt_get( + SnpGuestRequestMsgType::ReportRequest, + vmpck0_seqno, + &vmpck0, + &mut outbuf, + ); + + assert!(result.is_ok()); + + let outbuf_len = result.unwrap(); + assert_eq!(outbuf_len, PLAINTEXT.len()); + + assert_eq!(outbuf, PLAINTEXT); + } +} diff --git a/src/greq/pld_report.rs b/src/greq/pld_report.rs new file mode 100644 index 000000000..fafa2ce09 --- /dev/null +++ b/src/greq/pld_report.rs @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! `SNP_GUEST_REQUEST` command to request an attestation report. + +extern crate alloc; + +use core::mem::size_of; + +use crate::protocols::errors::SvsmReqError; + +/// Size of the `SnpReportRequest.user_data` +pub const USER_DATA_SIZE: usize = 64; + +/// MSG_REPORT_REQ payload format (AMD SEV-SNP spec. table 20) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SnpReportRequest { + /// Guest-provided data to be included in the attestation report + /// REPORT_DATA (512 bits) + user_data: [u8; USER_DATA_SIZE], + /// The VMPL to put in the attestation report + vmpl: u32, + /// 31:2 - Reserved + /// 1:0 - KEY_SEL. Selects which key to use for derivation + /// 0: If VLEK is installed, sign with VLEK. Otherwise, sign with VCEK + /// 1: Sign with VCEK + /// 2: Sign with VLEK + /// 3: Reserved + flags: u32, + /// Reserved, must be zero + rsvd: [u8; 24], +} + +impl SnpReportRequest { + /// Take a slice and return a reference for Self + pub fn try_from_as_ref(buffer: &[u8]) -> Result<&Self, SvsmReqError> { + let buffer = buffer + .get(..size_of::()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + let request = unsafe { &*buffer.as_ptr().cast::() }; + + if !request.is_reserved_clear() { + return Err(SvsmReqError::invalid_parameter()); + } + Ok(request) + } + + pub fn is_vmpl0(&self) -> bool { + self.vmpl == 0 + } + + /// Check if the reserved field is clear + fn is_reserved_clear(&self) -> bool { + self.rsvd.into_iter().all(|e| e == 0) + } +} + +/// MSG_REPORT_RSP payload format (AMD SEV-SNP spec. table 23) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SnpReportResponse { + /// The status of the key derivation operation, see [SnpReportResponseStatus] + status: u32, + /// Size in bytes of the report + report_size: u32, + /// Reserved + _reserved: [u8; 24], + /// The attestation report generated by firmware + report: AttestationReport, +} + +/// Supported values for SnpReportResponse.status +#[repr(u32)] +#[derive(Clone, Copy, Debug)] +pub enum SnpReportResponseStatus { + Success = 0, + InvalidParameters = 0x16, + InvalidKeySelection = 0x27, +} + +impl SnpReportResponse { + pub fn try_from_as_ref(buffer: &[u8]) -> Result<&Self, SvsmReqError> { + let buffer = buffer + .get(..size_of::()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + let response = unsafe { &*buffer.as_ptr().cast::() }; + Ok(response) + } + + /// Validate the [SnpReportResponse] fields + /// + /// # Panic + /// + /// * The size of the struct [`AttestationReport`] must fit in a u32 + pub fn validate(&self) -> Result<(), SvsmReqError> { + if self.status != SnpReportResponseStatus::Success as u32 { + return Err(SvsmReqError::invalid_request()); + } + + const REPORT_SIZE: usize = size_of::(); + assert!(u32::try_from(REPORT_SIZE).is_ok()); + + if self.report_size != REPORT_SIZE as u32 { + return Err(SvsmReqError::invalid_format()); + } + + Ok(()) + } +} + +/// The `TCB_VERSION` contains the security version numbers of each +/// component in the trusted computing base (TCB) of the SNP firmware. +/// (AMD SEV-SNP spec. table 3) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +struct TcbVersion { + /// Version of the Microcode, SNP firmware, PSP and boot loader + raw: u64, +} + +/// Format for an ECDSA P-384 with SHA-384 signature (AMD SEV-SNP spec. table 115) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +struct Signature { + /// R component of this signature + r: [u8; 72], + /// S component of this signature + s: [u8; 72], + /// Reserved + reserved: [u8; 368], +} + +/// ATTESTATION_REPORT format (AMD SEV-SNP spec. table 21) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct AttestationReport { + /// Version number of this attestation report + version: u32, + /// The guest SVN + guest_svn: u32, + /// The guest policy + policy: u64, + /// The family ID provided at launch + family_id: [u8; 16], + /// The image ID provided at launch + image_id: [u8; 16], + /// The request VMPL for the attestation report + vmpl: u32, + /// The signature algorithm used to sign this report + signature_algo: u32, + /// CurrentTcb + platform_version: TcbVersion, + /// Information about the platform + platform_info: u64, + /// Flags + flags: u32, + /// Reserved, must be zero + reserved0: u32, + /// Guest-provided data + report_data: [u8; 64], + /// The measurement calculated at launch + measurement: [u8; 48], + /// Data provided by the hypervisor at launch + host_data: [u8; 32], + /// SHA-384 digest of the ID public key that signed the ID block + /// provided in `SNP_LAUNCH_FINISH` + id_key_digest: [u8; 48], + /// SHA-384 digest of the Author public key that certified the ID key, + /// if provided in `SNP_LAUNCH_FINISH`. Zeroes if `AUTHOR_KEY_EN` is 1 + author_key_digest: [u8; 48], + /// Report ID of this guest + report_id: [u8; 32], + /// Report ID of this guest's migration agent + report_id_ma: [u8; 32], + /// Report TCB version used to derive the VCEK that signed this report + reported_tcb: TcbVersion, + /// Reserved + reserved1: [u8; 24], + /// If `MaskChipId` is set to 0, Identifier unique to the chip as + /// output by `GET_ID`. Otherwise, set to 0h + chip_id: [u8; 64], + /// Reserved and some more flags + reserved2: [u8; 192], + /// Signature of bytes 0h to 29Fh inclusive of this report + signature: Signature, +} diff --git a/src/greq/services.rs b/src/greq/services.rs new file mode 100644 index 000000000..1beebc2d4 --- /dev/null +++ b/src/greq/services.rs @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! API to send `SNP_GUEST_REQUEST` commands to the PSP + +extern crate alloc; + +use crate::{ + greq::{ + driver::{send_extended_guest_request, send_regular_guest_request}, + msg::SnpGuestRequestMsgType, + pld_report::{SnpReportRequest, SnpReportResponse}, + }, + protocols::errors::SvsmReqError, +}; +use core::mem::size_of; + +const REPORT_REQUEST_SIZE: usize = size_of::(); +const REPORT_RESPONSE_SIZE: usize = size_of::(); + +fn get_report(buffer: &mut [u8], certs: Option<&mut [u8]>) -> Result { + let request: &SnpReportRequest = SnpReportRequest::try_from_as_ref(buffer)?; + // Non-VMPL0 attestation reports can be requested by the guest kernel + // directly to the PSP. + if !request.is_vmpl0() { + return Err(SvsmReqError::invalid_parameter()); + } + let response_len = if certs.is_none() { + send_regular_guest_request( + SnpGuestRequestMsgType::ReportRequest, + buffer, + REPORT_REQUEST_SIZE, + )? + } else { + send_extended_guest_request( + SnpGuestRequestMsgType::ReportRequest, + buffer, + REPORT_REQUEST_SIZE, + certs.unwrap(), + )? + }; + if REPORT_RESPONSE_SIZE > response_len { + return Err(SvsmReqError::invalid_request()); + } + let response: &SnpReportResponse = SnpReportResponse::try_from_as_ref(buffer)?; + response.validate()?; + + Ok(response_len) +} + +/// Request a regular VMPL0 attestation report to the PSP. +/// +/// Use the `SNP_GUEST_REQUEST` driver to send the provided `MSG_REPORT_REQ` command to +/// the PSP. The VPML field of the command must be set to zero. +/// +/// The VMPCK0 is disabled for subsequent calls if this function fails in a way that +/// the VM state can be compromised. +/// +/// # Arguments +/// +/// * `buffer`: Buffer with the [`MSG_REPORT_REQ`](SnpReportRequest) command that will be +/// sent to the PSP. It must be large enough to hold the +/// [`MSG_REPORT_RESP`](SnpReportResponse) received from the PSP. +/// +/// # Returns +/// +/// * Success +/// * `usize`: Number of bytes written to `buffer`. It should match the +/// [`MSG_REPORT_RESP`](SnpReportResponse) size. +/// * Error +/// * [`SvsmReqError`] +pub fn get_regular_report(buffer: &mut [u8]) -> Result { + get_report(buffer, None) +} + +/// Request an extended VMPL0 attestation report to the PSP. +/// +/// We say that it is extended because it requests a VMPL0 attestation report +/// to the PSP (as in [`get_regular_report()`]) and also requests to the hypervisor +/// the certificates required to verify the attestation report. +/// +/// The VMPCK0 is disabled for subsequent calls if this function fails in a way that +/// the VM state can be compromised. +/// +/// # Arguments +/// +/// * `buffer`: Buffer with the [`MSG_REPORT_REQ`](SnpReportRequest) command that will be +/// sent to the PSP. It must be large enough to hold the +/// [`MSG_REPORT_RESP`](SnpReportResponse) received from the PSP. +/// * `certs`: Buffer to store the SEV-SNP certificates received from the hypervisor. +/// +/// # Return codes +/// +/// * Success +/// * `usize`: Number of bytes written to `buffer`. It should match +/// the [`MSG_REPORT_RESP`](SnpReportResponse) size. +/// * Error +/// * [`SvsmReqError`] +/// * `SvsmReqError::FatalError(SvsmError::Ghcb(GhcbError::VmgexitError(certs_buffer_size, psp_rc)))`: +/// * `certs` is not large enough to hold the certificates. +/// * `certs_buffer_size`: number of bytes required. +/// * `psp_rc`: PSP return code +pub fn get_extended_report(buffer: &mut [u8], certs: &mut [u8]) -> Result { + get_report(buffer, Some(certs)) +} diff --git a/src/kernel_launch.rs b/src/kernel_launch.rs index 4879be18f..cd5815a7e 100644 --- a/src/kernel_launch.rs +++ b/src/kernel_launch.rs @@ -4,6 +4,9 @@ // // Author: Joerg Roedel +use crate::address::PhysAddr; +use crate::utils::MemoryRegion; + #[derive(Copy, Clone, Debug)] #[repr(C)] pub struct KernelLaunchInfo { @@ -30,4 +33,10 @@ impl KernelLaunchInfo { pub fn heap_area_virt_end(&self) -> u64 { self.heap_area_virt_start + self.heap_area_size() } + + pub fn kernel_region(&self) -> MemoryRegion { + let start = PhysAddr::from(self.kernel_region_phys_start); + let end = PhysAddr::from(self.kernel_region_phys_end); + MemoryRegion::from_addresses(start, end) + } } diff --git a/src/lib.rs b/src/lib.rs index c8625b5cc..7428140ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,12 +16,14 @@ pub mod acpi; pub mod address; pub mod console; pub mod cpu; +pub mod crypto; pub mod debug; pub mod elf; pub mod error; pub mod fs; pub mod fw_cfg; pub mod fw_meta; +pub mod greq; pub mod io; pub mod kernel_launch; pub mod locking; diff --git a/src/locking/rwlock.rs b/src/locking/rwlock.rs index b8dfb3509..386a72648 100644 --- a/src/locking/rwlock.rs +++ b/src/locking/rwlock.rs @@ -9,33 +9,45 @@ use core::fmt::Debug; use core::ops::{Deref, DerefMut}; use core::sync::atomic::{AtomicU64, Ordering}; +/// A guard that provides read access to the data protected by [`RWLock`] #[derive(Debug)] #[must_use = "if unused the RWLock will immediately unlock"] pub struct ReadLockGuard<'a, T: Debug> { + /// Reference to the associated `AtomicU64` in the [`RWLock`] rwlock: &'a AtomicU64, + /// Reference to the protected data data: &'a T, } +/// Implements the behavior of the [`ReadLockGuard`] when it is dropped impl<'a, T: Debug> Drop for ReadLockGuard<'a, T> { + /// Release the read lock fn drop(&mut self) { self.rwlock.fetch_sub(1, Ordering::Release); } } +/// Implements the behavior of dereferencing the [`ReadLockGuard`] to +/// access the protected data. impl<'a, T: Debug> Deref for ReadLockGuard<'a, T> { type Target = T; + /// Allow reading the protected data through deref fn deref(&self) -> &T { self.data } } +/// A guard that provides exclusive write access to the data protected by [`RWLock`] #[derive(Debug)] #[must_use = "if unused the RWLock will immediately unlock"] pub struct WriteLockGuard<'a, T: Debug> { + /// Reference to the associated `AtomicU64` in the [`RWLock`] rwlock: &'a AtomicU64, + /// Reference to the protected data (mutable) data: &'a mut T, } +/// Implements the behavior of the [`WriteLockGuard`] when it is dropped impl<'a, T: Debug> Drop for WriteLockGuard<'a, T> { fn drop(&mut self) { // There are no readers - safe to just set lock to 0 @@ -43,6 +55,8 @@ impl<'a, T: Debug> Drop for WriteLockGuard<'a, T> { } } +/// Implements the behavior of dereferencing the [`WriteLockGuard`] to +/// access the protected data. impl<'a, T: Debug> Deref for WriteLockGuard<'a, T> { type Target = T; fn deref(&self) -> &T { @@ -50,31 +64,93 @@ impl<'a, T: Debug> Deref for WriteLockGuard<'a, T> { } } +/// Implements the behavior of dereferencing the [`WriteLockGuard`] to +/// access the protected data in a mutable way. impl<'a, T: Debug> DerefMut for WriteLockGuard<'a, T> { fn deref_mut(&mut self) -> &mut T { self.data } } +/// A simple Read-Write Lock (RWLock) that allows multiple readers or +/// one exclusive writer. #[derive(Debug)] pub struct RWLock { + /// An atomic 64-bit integer used for synchronization rwlock: AtomicU64, + /// An UnsafeCell for interior mutability data: UnsafeCell, } +/// Implements the trait `Sync` for the [`RWLock`], allowing safe +/// concurrent access across threads. unsafe impl Sync for RWLock {} +/// Splits a 64-bit value into two parts: readers (low 32 bits) and +/// writers (high 32 bits). +/// +/// # Parameters +/// +/// - `val`: A 64-bit unsigned integer value to be split. +/// +/// # Returns +/// +/// A tuple containing two 32-bit unsigned integer values. The first +/// element of the tuple is the lower 32 bits of input value, and the +/// second is the upper 32 bits. +/// #[inline] fn split_val(val: u64) -> (u64, u64) { (val & 0xffff_ffffu64, val >> 32) } +/// Composes a 64-bit value by combining the number of readers (low 32 +/// bits) and writers (high 32 bits). This function is used to create a +/// 64-bit synchronization value that represents the current state of the +/// RWLock, including the count of readers and writers. +/// +/// # Parameters +/// +/// - `readers`: The number of readers (low 32 bits) currently holding read locks. +/// - `writers`: The number of writers (high 32 bits) currently holding write locks. +/// +/// # Returns +/// +/// A 64-bit value representing the combined state of readers and writers in the RWLock. +/// #[inline] fn compose_val(readers: u64, writers: u64) -> u64 { (readers & 0xffff_ffffu64) | (writers << 32) } +/// A reader-writer lock that allows multiple readers or a single writer +/// to access the protected data. [`RWLock`] provides exclusive access for +/// writers and shared access for readers, for efficient synchronization. +/// impl RWLock { + /// Creates a new [`RWLock`] instance with the provided initial data. + /// + /// # Parameters + /// + /// - `data`: The initial data to be protected by the [`RWLock`]. + /// + /// # Returns + /// + /// A new [`RWLock`] instance with the specified initial data. + /// + /// # Example + /// + /// ```rust + /// use svsm::locking::RWLock; + /// + /// #[derive(Debug)] + /// struct MyData { + /// value: i32, + /// } + /// + /// let data = MyData { value: 42 }; + /// let rwlock = RWLock::new(data); + /// ``` pub const fn new(data: T) -> Self { RWLock { rwlock: AtomicU64::new(0), @@ -82,6 +158,14 @@ impl RWLock { } } + /// This function is used to wait until all writers have finished their + /// operations and retrieve the current state of the [`RWLock`]. + /// + /// # Returns + /// + /// A 64-bit value representing the current state of the [`RWLock`], + /// including the count of readers and writers. + /// #[inline] fn wait_for_writers(&self) -> u64 { loop { @@ -95,6 +179,14 @@ impl RWLock { } } + /// This function is used to wait until all readers have finished their + /// operations and retrieve the current state of the [`RWLock`]. + /// + /// # Returns + /// + /// A 64-bit value representing the current state of the [`RWLock`], + /// including the count of readers and writers. + /// #[inline] fn wait_for_readers(&self) -> u64 { loop { @@ -108,6 +200,12 @@ impl RWLock { } } + /// This function allows multiple readers to access the data concurrently. + /// + /// # Returns + /// + /// A [`ReadLockGuard`] that provides read access to the protected data. + /// pub fn lock_read(&self) -> ReadLockGuard { loop { let val = self.wait_for_writers(); @@ -130,6 +228,13 @@ impl RWLock { } } + /// This function ensures exclusive access for a single writer and waits + /// for all readers to finish before granting access to the writer. + /// + /// # Returns + /// + /// A [`WriteLockGuard`] that provides write access to the protected data. + /// pub fn lock_write(&self) -> WriteLockGuard { // Waiting for current writer to finish loop { @@ -156,4 +261,79 @@ impl RWLock { data: unsafe { &mut *self.data.get() }, } } + + /// Waits then locks the RWLock, returning a mutable pointer to the + /// protected item. The lock must be released with a call to + /// [`Self::unlock_write_direct()`] when access to the protected resource is + /// no longer exclusively required. + pub fn lock_write_direct(&self) -> *mut T { + let guard = self.lock_write(); + core::mem::forget(guard); + self.data.get() + } + + /// Unlocks the RWLock, relinquishing access to the raw pointer + /// that was gained by a previous call to [`Self::lock_write_direct()`]. + /// + /// # Safety + /// + /// The caller must ensure that the raw pointer returned by a + /// previous call to [`Self::lock_write_direct()`] is not used after + /// calling this function. Although the pointer may still point + /// to a valid object there is no guarantee of this and use of + /// the pointer is undefined behaviour. + /// + /// In order to gain mutable or immutable access to the object + /// the caller must again restablish the RWLock. + pub unsafe fn unlock_write_direct(&self) { + // There are no readers - safe to just set lock to 0 + self.rwlock.store(0, Ordering::Release); + } +} + +mod tests { + + #[test] + fn test_lock_rw() { + use crate::locking::*; + let rwlock = RWLock::new(42); + + // Acquire a read lock and check the initial value + let read_guard = rwlock.lock_read(); + assert_eq!(*read_guard, 42); + + drop(read_guard); + + let read_guard2 = rwlock.lock_read(); + assert_eq!(*read_guard2, 42); + + // Create another RWLock instance for modification + let rwlock_modify = RWLock::new(0); + + let mut write_guard = rwlock_modify.lock_write(); + *write_guard = 99; + assert_eq!(*write_guard, 99); + + drop(write_guard); + + let read_guard = rwlock.lock_read(); + assert_eq!(*read_guard, 42); + } + + #[test] + fn test_concurrent_readers() { + use crate::locking::*; + // Let's test two concurrent readers on a new RWLock instance + let rwlock_concurrent = RWLock::new(123); + + let read_guard1 = rwlock_concurrent.lock_read(); + let read_guard2 = rwlock_concurrent.lock_read(); + + // Assert that both readers can access the same value (123) + assert_eq!(*read_guard1, 123); + assert_eq!(*read_guard2, 123); + + drop(read_guard1); + drop(read_guard2); + } } diff --git a/src/locking/spinlock.rs b/src/locking/spinlock.rs index aae1ff28b..67c91c44c 100644 --- a/src/locking/spinlock.rs +++ b/src/locking/spinlock.rs @@ -9,6 +9,23 @@ use core::fmt::Debug; use core::ops::{Deref, DerefMut}; use core::sync::atomic::{AtomicU64, Ordering}; +/// A lock guard obtained from a [`SpinLock`]. This lock guard +/// provides exclusive access to the data protected by a [`SpinLock`], +/// ensuring that the lock is released when it goes out of scope. +/// +/// # Examples +/// +/// ``` +/// use svsm::locking::SpinLock; +/// +/// let data = 42; +/// let spin_lock = SpinLock::new(data); +/// +/// { +/// let mut guard = spin_lock.lock(); +/// *guard += 1; // Modify the protected data. +/// }; // Lock is automatically released when `guard` goes out of scope. +/// ``` #[derive(Debug)] #[must_use = "if unused the SpinLock will immediately unlock"] pub struct LockGuard<'a, T: Debug> { @@ -16,29 +33,66 @@ pub struct LockGuard<'a, T: Debug> { data: &'a mut T, } +/// Implements the behavior of the [`LockGuard`] when it is dropped impl<'a, T: Debug> Drop for LockGuard<'a, T> { + /// Automatically releases the lock when the guard is dropped fn drop(&mut self) { self.holder.fetch_add(1, Ordering::Release); } } +/// Implements the behavior of dereferencing the [`LockGuard`] to +/// access the protected data. impl<'a, T: Debug> Deref for LockGuard<'a, T> { type Target = T; + /// Provides read-only access to the protected data fn deref(&self) -> &T { self.data } } +/// Implements the behavior of dereferencing the [`LockGuard`] to +/// access the protected data in a mutable way. impl<'a, T: Debug> DerefMut for LockGuard<'a, T> { + /// Provides mutable access to the protected data fn deref_mut(&mut self) -> &mut T { self.data } } +/// A simple spinlock implementation for protecting concurrent data access. +/// +/// # Examples +/// +/// ``` +/// use svsm::locking::SpinLock; +/// +/// let data = 42; +/// let spin_lock = SpinLock::new(data); +/// +/// // Acquire the lock and modify the protected data. +/// { +/// let mut guard = spin_lock.lock(); +/// *guard += 1; +/// }; // Lock is automatically released when `guard` goes out of scope. +/// +/// // Try to acquire the lock without blocking +/// if let Some(mut guard) = spin_lock.try_lock() { +/// *guard += 2; +/// }; +/// ``` #[derive(Debug)] pub struct SpinLock { + /// This atomic counter is incremented each time a thread attempts to + /// acquire the lock. It helps to determine the order in which threads + /// acquire the lock. current: AtomicU64, + /// This counter represents the thread that currently holds the lock + /// and has access to the protected data. holder: AtomicU64, + /// This `UnsafeCell` is used to provide interior mutability of the + /// protected data. That is, it allows the data to be accessed/modified + /// while enforcing the locking mechanism. data: UnsafeCell, } @@ -46,6 +100,16 @@ unsafe impl Send for SpinLock {} unsafe impl Sync for SpinLock {} impl SpinLock { + /// Creates a new SpinLock instance with the specified initial data. + /// + /// # Examples + /// + /// ``` + /// use svsm::locking::SpinLock; + /// + /// let data = 42; + /// let spin_lock = SpinLock::new(data); + /// ``` pub const fn new(data: T) -> Self { SpinLock { current: AtomicU64::new(0), @@ -54,6 +118,21 @@ impl SpinLock { } } + /// Acquires the lock, providing access to the protected data. + /// + /// # Examples + /// + /// ``` + /// use svsm::locking::SpinLock; + /// + /// let spin_lock = SpinLock::new(42); + /// + /// // Acquire the lock and modify the protected data. + /// { + /// let mut guard = spin_lock.lock(); + /// *guard += 1; + /// }; // Lock is automatically released when `guard` goes out of scope. + /// ``` pub fn lock(&self) -> LockGuard { let ticket = self.current.fetch_add(1, Ordering::Relaxed); loop { @@ -69,6 +148,10 @@ impl SpinLock { } } + /// This method tries to acquire the lock without blocking. If the + /// lock is not available, it returns `None`. If the lock is + /// successfully acquired, it returns a [`LockGuard`] that automatically + /// releases the lock when it goes out of scope. pub fn try_lock(&self) -> Option> { let current = self.current.load(Ordering::Relaxed); let holder = self.holder.load(Ordering::Acquire); @@ -91,3 +174,23 @@ impl SpinLock { None } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_spin_lock() { + let spin_lock = SpinLock::new(0); + + let mut guard = spin_lock.lock(); + *guard += 1; + + // Ensure the locked data is updated. + assert_eq!(*guard, 1); + + // Try to lock again; it should fail and return None. + let try_lock_result = spin_lock.try_lock(); + assert!(try_lock_result.is_none()); + } +} diff --git a/src/mm/address_space.rs b/src/mm/address_space.rs index ab5bcd4bb..3b683c7c3 100644 --- a/src/mm/address_space.rs +++ b/src/mm/address_space.rs @@ -83,13 +83,7 @@ pub const SIZE_LEVEL1: usize = 1usize << ((9 * 1) + 12); pub const SIZE_LEVEL0: usize = 1usize << ((9 * 0) + 12); // Stack definitions -// The GDB stub requires a larger stack. -#[cfg(feature = "enable-gdb")] -pub const STACK_PAGES_GDB: usize = 8; -#[cfg(not(feature = "enable-gdb"))] -pub const STACK_PAGES_GDB: usize = 0; - -pub const STACK_PAGES: usize = 8 + STACK_PAGES_GDB; +pub const STACK_PAGES: usize = 8; pub const STACK_SIZE: usize = PAGE_SIZE * STACK_PAGES; pub const STACK_GUARD_SIZE: usize = STACK_SIZE; pub const STACK_TOTAL_SIZE: usize = STACK_SIZE + STACK_GUARD_SIZE; diff --git a/src/mm/memory.rs b/src/mm/memory.rs index 34c5f9b0b..a168caa39 100644 --- a/src/mm/memory.rs +++ b/src/mm/memory.rs @@ -9,28 +9,27 @@ extern crate alloc; use crate::address::{Address, PhysAddr}; use crate::cpu::percpu::PERCPU_VMSAS; use crate::error::SvsmError; -use crate::fw_cfg::{FwCfg, MemoryRegion}; +use crate::fw_cfg::FwCfg; use crate::kernel_launch::KernelLaunchInfo; use crate::locking::RWLock; +use crate::utils::MemoryRegion; use alloc::vec::Vec; use log; use super::pagetable::LAUNCH_VMSA_ADDR; -static MEMORY_MAP: RWLock> = RWLock::new(Vec::new()); +static MEMORY_MAP: RWLock>> = RWLock::new(Vec::new()); pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result<(), SvsmError> { let mut regions = fwcfg.get_memory_regions()?; + let kernel_region = launch_info.kernel_region(); // Remove SVSM memory from guest memory map let mut i = 0; while i < regions.len() { // Check if the region overlaps with SVSM memory. let region = regions[i]; - if !region.overlaps( - launch_info.kernel_region_phys_start, - launch_info.kernel_region_phys_end, - ) { + if !region.overlap(&kernel_region) { // Check the next region. i += 1; continue; @@ -40,29 +39,23 @@ pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result< regions.remove(i); // 2. Insert a region up until the start of SVSM memory (if non-empty). - let region_before_start = region.start; - let region_before_end = launch_info.kernel_region_phys_start; + let region_before_start = region.start(); + let region_before_end = kernel_region.start(); if region_before_start < region_before_end { regions.insert( i, - MemoryRegion { - start: region_before_start, - end: region_before_end, - }, + MemoryRegion::from_addresses(region_before_start, region_before_end), ); i += 1; } // 3. Insert a region up after the end of SVSM memory (if non-empty). - let region_after_start = launch_info.kernel_region_phys_end; - let region_after_end = region.end; + let region_after_start = kernel_region.end(); + let region_after_end = region.end(); if region_after_start < region_after_end { regions.insert( i, - MemoryRegion { - start: region_after_start, - end: region_after_end, - }, + MemoryRegion::from_addresses(region_after_start, region_after_end), ); i += 1; } @@ -70,7 +63,7 @@ pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result< log::info!("Guest Memory Regions:"); for r in regions.iter() { - log::info!(" {:018x}-{:018x}", r.start, r.end); + log::info!(" {:018x}-{:018x}", r.start(), r.end()); } let mut map = MEMORY_MAP.lock_write(); @@ -81,7 +74,6 @@ pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result< pub fn valid_phys_address(paddr: PhysAddr) -> bool { let page_addr = paddr.page_align(); - let addr = paddr.bits() as u64; if PERCPU_VMSAS.exists(page_addr) { return false; @@ -93,7 +85,7 @@ pub fn valid_phys_address(paddr: PhysAddr) -> bool { MEMORY_MAP .lock_read() .iter() - .any(|region| addr >= region.start && addr < region.end) + .any(|region| region.contains(paddr)) } const ISA_RANGE_START: PhysAddr = PhysAddr::new(0xa0000); diff --git a/src/mm/pagetable.rs b/src/mm/pagetable.rs index 01919d15a..90c03ea9a 100644 --- a/src/mm/pagetable.rs +++ b/src/mm/pagetable.rs @@ -85,8 +85,8 @@ fn encrypt_mask() -> usize { } /// Returns the exclusive end of the physical address space. -pub fn max_phys_addr() -> u64 { - *MAX_PHYS_ADDR +pub fn max_phys_addr() -> PhysAddr { + PhysAddr::from(*MAX_PHYS_ADDR) } fn supported_flags(flags: PTEntryFlags) -> PTEntryFlags { diff --git a/src/mm/ptguards.rs b/src/mm/ptguards.rs index 07dc88ae2..82c92b949 100644 --- a/src/mm/ptguards.rs +++ b/src/mm/ptguards.rs @@ -14,22 +14,12 @@ use crate::mm::virtualrange::{ }; use crate::types::{PAGE_SIZE, PAGE_SIZE_2M}; -#[derive(Debug)] -struct RawPTMappingGuard { - start: VirtAddr, - end: VirtAddr, -} - -impl RawPTMappingGuard { - pub const fn new(start: VirtAddr, end: VirtAddr) -> Self { - RawPTMappingGuard { start, end } - } -} +use crate::utils::MemoryRegion; #[derive(Debug)] #[must_use = "if unused the mapping will immediately be unmapped"] pub struct PerCPUPageMappingGuard { - mapping: Option, + mapping: MemoryRegion, huge: bool, } @@ -72,10 +62,10 @@ impl PerCPUPageMappingGuard { vaddr }; - let raw_mapping = RawPTMappingGuard::new(vaddr, vaddr + size); + let raw_mapping = MemoryRegion::new(vaddr, size); Ok(PerCPUPageMappingGuard { - mapping: Some(raw_mapping), + mapping: raw_mapping, huge, }) } @@ -85,22 +75,23 @@ impl PerCPUPageMappingGuard { } pub fn virt_addr(&self) -> VirtAddr { - self.mapping.as_ref().unwrap().start + self.mapping.start() } } impl Drop for PerCPUPageMappingGuard { fn drop(&mut self) { - if let Some(m) = &self.mapping { - let size = m.end - m.start; - if self.huge { - this_cpu_mut().get_pgtable().unmap_region_2m(m.start, m.end); - virt_free_range_2m(m.start, size); - } else { - this_cpu_mut().get_pgtable().unmap_region_4k(m.start, m.end); - virt_free_range_4k(m.start, size); - } - flush_address_sync(m.start); + let start = self.mapping.start(); + let end = self.mapping.end(); + let size = self.mapping.len(); + + if self.huge { + this_cpu_mut().get_pgtable().unmap_region_2m(start, end); + virt_free_range_2m(start, size); + } else { + this_cpu_mut().get_pgtable().unmap_region_4k(start, end); + virt_free_range_4k(start, size); } + flush_address_sync(start); } } diff --git a/src/mm/stack.rs b/src/mm/stack.rs index b02cf0696..c7ec87049 100644 --- a/src/mm/stack.rs +++ b/src/mm/stack.rs @@ -15,6 +15,7 @@ use crate::mm::{ STACK_PAGES, STACK_SIZE, STACK_TOTAL_SIZE, SVSM_SHARED_STACK_BASE, SVSM_SHARED_STACK_END, }; use crate::types::PAGE_SIZE; +use crate::utils::MemoryRegion; // Limit maximum number of stacks for now, address range support 2**16 8k stacks const MAX_STACKS: usize = 1024; @@ -22,16 +23,15 @@ const BMP_QWORDS: usize = MAX_STACKS / 64; #[derive(Debug)] struct StackRange { - start: VirtAddr, - end: VirtAddr, + region: MemoryRegion, alloc_bitmap: [u64; BMP_QWORDS], } impl StackRange { pub const fn new(start: VirtAddr, end: VirtAddr) -> Self { + let region = MemoryRegion::from_addresses(start, end); StackRange { - start, - end, + region, alloc_bitmap: [0; BMP_QWORDS], } } @@ -49,16 +49,16 @@ impl StackRange { self.alloc_bitmap[i] |= mask; - return Ok(self.start + ((i * 64 + idx) * STACK_TOTAL_SIZE)); + return Ok(self.region.start() + ((i * 64 + idx) * STACK_TOTAL_SIZE)); } Err(SvsmError::Mem) } pub fn dealloc(&mut self, stack: VirtAddr) { - assert!(stack >= self.start && stack < self.end); + assert!(self.region.contains(stack)); - let offset = stack - self.start; + let offset = stack - self.region.start(); let idx = offset / (STACK_TOTAL_SIZE); assert!((offset % (STACK_TOTAL_SIZE)) <= STACK_SIZE); diff --git a/src/mm/validate.rs b/src/mm/validate.rs index a27287bfe..cdb8d5746 100644 --- a/src/mm/validate.rs +++ b/src/mm/validate.rs @@ -10,28 +10,29 @@ use crate::locking::SpinLock; use crate::mm::alloc::{allocate_pages, get_order}; use crate::mm::virt_to_phys; use crate::types::{PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::MemoryRegion; use core::ptr; static VALID_BITMAP: SpinLock = SpinLock::new(ValidBitmap::new()); #[inline(always)] -fn bitmap_alloc_order(pbase: PhysAddr, pend: PhysAddr) -> usize { - let mem_size = (pend - pbase) / (PAGE_SIZE * 8); +fn bitmap_alloc_order(region: MemoryRegion) -> usize { + let mem_size = region.len() / (PAGE_SIZE * 8); get_order(mem_size) } -pub fn init_valid_bitmap_ptr(pbase: PhysAddr, pend: PhysAddr, bitmap: *mut u64) { +pub fn init_valid_bitmap_ptr(region: MemoryRegion, bitmap: *mut u64) { let mut vb_ref = VALID_BITMAP.lock(); - vb_ref.set_region(pbase, pend); + vb_ref.set_region(region); vb_ref.set_bitmap(bitmap); } -pub fn init_valid_bitmap_alloc(pbase: PhysAddr, pend: PhysAddr) -> Result<(), SvsmError> { - let order: usize = bitmap_alloc_order(pbase, pend); +pub fn init_valid_bitmap_alloc(region: MemoryRegion) -> Result<(), SvsmError> { + let order: usize = bitmap_alloc_order(region); let bitmap_addr = allocate_pages(order)?; let mut vb_ref = VALID_BITMAP.lock(); - vb_ref.set_region(pbase, pend); + vb_ref.set_region(region); vb_ref.set_bitmap(bitmap_addr.as_mut_ptr::()); vb_ref.clear_all(); @@ -95,23 +96,20 @@ pub fn valid_bitmap_valid_addr(paddr: PhysAddr) -> bool { #[derive(Debug)] struct ValidBitmap { - pbase: PhysAddr, - pend: PhysAddr, + region: MemoryRegion, bitmap: *mut u64, } impl ValidBitmap { pub const fn new() -> Self { ValidBitmap { - pbase: PhysAddr::null(), - pend: PhysAddr::null(), + region: MemoryRegion::from_addresses(PhysAddr::null(), PhysAddr::null()), bitmap: ptr::null_mut(), } } - pub fn set_region(&mut self, pbase: PhysAddr, pend: PhysAddr) { - self.pbase = pbase; - self.pend = pend; + pub fn set_region(&mut self, region: MemoryRegion) { + self.region = region; } pub fn set_bitmap(&mut self, bitmap: *mut u64) { @@ -119,7 +117,7 @@ impl ValidBitmap { } pub fn check_addr(&self, paddr: PhysAddr) -> bool { - paddr >= self.pbase && paddr < self.pend + self.region.contains(paddr) } pub fn bitmap_addr(&self) -> PhysAddr { @@ -129,7 +127,7 @@ impl ValidBitmap { #[inline(always)] fn index(&self, paddr: PhysAddr) -> (isize, usize) { - let page_offset = (paddr - self.pbase) / PAGE_SIZE; + let page_offset = (paddr - self.region.start()) / PAGE_SIZE; let index: isize = (page_offset / 64).try_into().unwrap(); let bit: usize = page_offset % 64; @@ -137,7 +135,7 @@ impl ValidBitmap { } pub fn clear_all(&mut self) { - let (mut i, bit) = self.index(self.pend); + let (mut i, bit) = self.index(self.region.end()); if bit != 0 { i += 1; } @@ -149,11 +147,11 @@ impl ValidBitmap { } pub fn alloc_order(&self) -> usize { - bitmap_alloc_order(self.pbase, self.pend) + bitmap_alloc_order(self.region) } pub fn migrate(&mut self, new_bitmap: *mut u64) { - let (count, _) = self.index(self.pend); + let (count, _) = self.index(self.region.end()); unsafe { ptr::copy_nonoverlapping(self.bitmap, new_bitmap, count as usize); diff --git a/src/sev/ghcb.rs b/src/sev/ghcb.rs index bed8b0725..3d0ac6067 100644 --- a/src/sev/ghcb.rs +++ b/src/sev/ghcb.rs @@ -123,6 +123,8 @@ enum GHCBExitCode {} impl GHCBExitCode { pub const IOIO: u64 = 0x7b; pub const SNP_PSC: u64 = 0x8000_0010; + pub const GUEST_REQUEST: u64 = 0x8000_0011; + pub const GUEST_EXT_REQUEST: u64 = 0x8000_0012; pub const AP_CREATE: u64 = 0x80000013; pub const RUN_VMPL: u64 = 0x80000018; } @@ -478,6 +480,61 @@ impl GHCB { Ok(()) } + pub fn guest_request( + &mut self, + req_page: VirtAddr, + resp_page: VirtAddr, + ) -> Result<(), SvsmError> { + self.clear(); + + let info1: u64 = u64::from(virt_to_phys(req_page)); + let info2: u64 = u64::from(virt_to_phys(resp_page)); + + self.vmgexit(GHCBExitCode::GUEST_REQUEST, info1, info2)?; + + if !self.is_valid(OFF_SW_EXIT_INFO_2) { + return Err(GhcbError::VmgexitInvalid.into()); + } + + if self.sw_exit_info_2 != 0 { + return Err(GhcbError::VmgexitError(self.sw_exit_info_1, self.sw_exit_info_2).into()); + } + + Ok(()) + } + + pub fn guest_ext_request( + &mut self, + req_page: VirtAddr, + resp_page: VirtAddr, + data_pages: VirtAddr, + data_size: u64, + ) -> Result<(), SvsmError> { + self.clear(); + + let info1: u64 = u64::from(virt_to_phys(req_page)); + let info2: u64 = u64::from(virt_to_phys(resp_page)); + let rax: u64 = u64::from(virt_to_phys(data_pages)); + + self.set_rax(rax); + self.set_rbx(data_size); + + self.vmgexit(GHCBExitCode::GUEST_EXT_REQUEST, info1, info2)?; + + if !self.is_valid(OFF_SW_EXIT_INFO_2) { + return Err(GhcbError::VmgexitInvalid.into()); + } + + // On error, RBX and exit_info_2 are returned for proper error handling. + // For an extended request, if the buffer provided is too small, the hypervisor + // will return in RBX the number of contiguous pages required + if self.sw_exit_info_2 != 0 { + return Err(GhcbError::VmgexitError(self.rbx, self.sw_exit_info_2).into()); + } + + Ok(()) + } + pub fn run_vmpl(&mut self, vmpl: u64) -> Result<(), SvsmError> { self.clear(); self.vmgexit(GHCBExitCode::RUN_VMPL, vmpl, 0)?; diff --git a/src/sev/secrets_page.rs b/src/sev/secrets_page.rs index a09ab7f02..9e8198b5b 100644 --- a/src/sev/secrets_page.rs +++ b/src/sev/secrets_page.rs @@ -7,6 +7,12 @@ use crate::address::VirtAddr; use crate::sev::vmsa::VMPL_MAX; +pub const VMPCK_SIZE: usize = 32; + +extern "C" { + pub static mut SECRETS_PAGE: SecretsPage; +} + #[derive(Copy, Clone, Debug)] #[repr(C, packed)] pub struct SecretsPage { @@ -15,7 +21,7 @@ pub struct SecretsPage { pub fms: u32, reserved_00c: u32, pub gosvw: [u8; 16], - pub vmpck: [[u8; 32]; VMPL_MAX], + pub vmpck: [[u8; VMPCK_SIZE]; VMPL_MAX], reserved_0a0: [u8; 96], pub vmsa_tweak_bmp: [u64; 8], pub svsm_base: u64, @@ -35,3 +41,16 @@ pub fn copy_secrets_page(target: &mut SecretsPage, source: VirtAddr) { *target = *table; } } + +pub fn is_vmpck0_clear() -> bool { + unsafe { SECRETS_PAGE.vmpck[0].iter().all(|e| *e == 0) } +} + +pub fn disable_vmpck0() { + unsafe { SECRETS_PAGE.vmpck[0].iter_mut().for_each(|e| *e = 0) }; + log::warn!("VMPCK0 disabled!"); +} + +pub fn get_vmpck0() -> [u8; VMPCK_SIZE] { + unsafe { SECRETS_PAGE.vmpck[0] } +} diff --git a/src/stage2.rs b/src/stage2.rs index 14f262cfd..12fd2a52e 100644 --- a/src/stage2.rs +++ b/src/stage2.rs @@ -157,10 +157,9 @@ pub extern "C" fn stage2_main(launch_info: &Stage1LaunchInfo) { log::info!("COCONUT Secure Virtual Machine Service Module (SVSM) Stage 2 Loader"); - let kernel_region_phys_start = PhysAddr::from(r.start); - let kernel_region_phys_end = PhysAddr::from(r.end); - init_valid_bitmap_alloc(kernel_region_phys_start, kernel_region_phys_end) - .expect("Failed to allocate valid-bitmap"); + let kernel_region_phys_start = r.start(); + let kernel_region_phys_end = r.end(); + init_valid_bitmap_alloc(r).expect("Failed to allocate valid-bitmap"); // Read the SVSM kernel's ELF file metadata. let kernel_elf_len = kernel_elf_end - kernel_elf_start; diff --git a/src/svsm.rs b/src/svsm.rs index 4af0db55a..e75ac4648 100644 --- a/src/svsm.rs +++ b/src/svsm.rs @@ -12,11 +12,11 @@ extern crate alloc; use alloc::vec::Vec; use svsm::fw_meta::{parse_fw_meta_data, print_fw_meta, validate_fw_memory, SevFWMetaData}; -use core::arch::{asm, global_asm}; +use core::arch::global_asm; use core::panic::PanicInfo; use core::slice; use svsm::acpi::tables::load_acpi_cpu_info; -use svsm::address::{Address, PhysAddr, VirtAddr}; +use svsm::address::{PhysAddr, VirtAddr}; use svsm::console::{init_console, install_console_logger, WRITER}; use svsm::cpu::control_regs::{cr0_init, cr4_init}; use svsm::cpu::cpuid::{dump_cpuid_table, register_cpuid_table, SnpCpuidTable}; @@ -32,6 +32,7 @@ use svsm::elf; use svsm::error::SvsmError; use svsm::fs::{initialize_fs, populate_ram_fs}; use svsm::fw_cfg::FwCfg; +use svsm::greq::driver::guest_request_driver_init; use svsm::kernel_launch::KernelLaunchInfo; use svsm::mm::alloc::{memory_info, print_memory_info, root_mem_init}; use svsm::mm::memory::init_memory_map; @@ -41,13 +42,14 @@ use svsm::mm::{init_kernel_mapping_info, PerCPUPageMappingGuard, SIZE_1G}; use svsm::requests::{request_loop, update_mappings}; use svsm::serial::SerialPort; use svsm::serial::SERIAL_PORT; -use svsm::sev::secrets_page::{copy_secrets_page, SecretsPage}; +use svsm::sev::secrets_page::{copy_secrets_page, disable_vmpck0, SecretsPage}; use svsm::sev::sev_status_init; use svsm::sev::utils::{rmp_adjust, RMPFlags}; use svsm::svsm_console::SVSMIOPort; use svsm::svsm_paging::{init_page_table, invalidate_stage2}; +use svsm::task::{create_task, TASK_FLAG_SHARE_PT}; use svsm::types::{PageSize, GUEST_VMPL, PAGE_SIZE}; -use svsm::utils::{halt, immut_after_init::ImmutAfterInitCell, zero_mem_region}; +use svsm::utils::{halt, immut_after_init::ImmutAfterInitCell, zero_mem_region, MemoryRegion}; use svsm::mm::validate::{init_valid_bitmap_ptr, migrate_valid_bitmap}; @@ -221,26 +223,29 @@ fn validate_flash() -> Result<(), SvsmError> { let mut fw_cfg = FwCfg::new(&CONSOLE_IO); let flash_regions = fw_cfg.iter_flash_regions().collect::>(); + let kernel_region = LAUNCH_INFO.kernel_region(); + let flash_range = { + let one_gib = 1024 * 1024 * 1024usize; + let start = PhysAddr::from(3 * one_gib); + MemoryRegion::new(start, one_gib) + }; // Sanity-check flash regions. for region in flash_regions.iter() { // Make sure that the regions are between 3GiB and 4GiB. - if !region.overlaps(3 * 1024 * 1024 * 1024, 4 * 1024 * 1024 * 1024) { + if !region.overlap(&flash_range) { panic!("flash region in unexpected region"); } // Make sure that no regions overlap with the kernel. - if region.overlaps( - LAUNCH_INFO.kernel_region_phys_start, - LAUNCH_INFO.kernel_region_phys_end, - ) { + if region.overlap(&kernel_region) { panic!("flash region overlaps with kernel"); } } // Make sure that regions don't overlap. for (i, outer) in flash_regions.iter().enumerate() { for inner in flash_regions[..i].iter() { - if outer.overlaps(inner.start, inner.end) { + if outer.overlap(inner) { panic!("flash regions overlap"); } } @@ -248,23 +253,18 @@ fn validate_flash() -> Result<(), SvsmError> { // Make sure that one regions ends at 4GiB. let one_region_ends_at_4gib = flash_regions .iter() - .any(|region| region.end == 4 * 1024 * 1024 * 1024); + .any(|region| region.end() == flash_range.end()); assert!(one_region_ends_at_4gib); for (i, region) in flash_regions.into_iter().enumerate() { - let pstart = PhysAddr::from(region.start); - let pend = PhysAddr::from(region.end); log::info!( "Flash region {} at {:#018x} size {:018x}", i, - pstart, - pend - pstart + region.start(), + region.len(), ); - for paddr in (pstart.bits()..pend.bits()) - .step_by(PAGE_SIZE) - .map(PhysAddr::from) - { + for paddr in region.iter_pages(PageSize::Regular) { let guard = PerCPUPageMappingGuard::create_4k(paddr)?; let vaddr = guard.virt_addr(); if let Err(e) = rmp_adjust( @@ -328,11 +328,7 @@ pub extern "C" fn svsm_start(li: &KernelLaunchInfo, vb_addr: usize) { mapping_info_init(&launch_info); - init_valid_bitmap_ptr( - launch_info.kernel_region_phys_start.into(), - launch_info.kernel_region_phys_end.into(), - vb_ptr, - ); + init_valid_bitmap_ptr(launch_info.kernel_region(), vb_ptr); load_gdt(); early_idt_init(); @@ -407,24 +403,21 @@ pub extern "C" fn svsm_start(li: &KernelLaunchInfo, vb_addr: usize) { log::info!("BSP Runtime stack starts @ {:#018x}", bp); - // Enable runtime stack and jump to main function - unsafe { - asm!("movq %rax, %rsp - jmp svsm_main", - in("rax") bp.bits(), - options(att_syntax)); - } + // Create the root task that runs the entry point then handles the request loop + create_task( + svsm_main, + TASK_FLAG_SHARE_PT, + Some(this_cpu().get_apic_id()), + ) + .expect("Failed to create initial task"); + + panic!("SVSM entry point terminated unexpectedly"); } #[no_mangle] pub extern "C" fn svsm_main() { - // The GDB stub can be started earlier, just after the console is initialised - // in svsm_start() above. It uses a lot of stack though so if you want to move - // it earlier then you need to set bsp_stack to 64K in the inline assembler - // above: - // - // bsp_stack: - // .fill 65536, 1, 0 + // If required, the GDB stub can be started earlier, just after the console + // is initialised in svsm_start() above. gdbstub_start().expect("Could not start GDB stub"); // Uncomment the line below if you want to wait for // a remote GDB connection @@ -459,7 +452,7 @@ pub extern "C" fn svsm_main() { print_fw_meta(&fw_meta); - if let Err(e) = validate_fw_memory(&fw_meta) { + if let Err(e) = validate_fw_memory(&fw_meta, &LAUNCH_INFO) { panic!("Failed to validate firmware memory: {:#?}", e); } @@ -471,6 +464,8 @@ pub extern "C" fn svsm_main() { panic!("Failed to validate flash memory: {:#?}", e); } + guest_request_driver_init(); + prepare_fw_launch(&fw_meta).expect("Failed to setup guest VMSA"); virt_log_usage(); @@ -489,6 +484,8 @@ pub extern "C" fn svsm_main() { #[panic_handler] fn panic(info: &PanicInfo) -> ! { + disable_vmpck0(); + log::error!("Panic: CPU[{}] {}", this_cpu().get_apic_id(), info); print_stack(3); diff --git a/src/task/mod.rs b/src/task/mod.rs index 0a06371ab..bc6232b56 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -4,6 +4,10 @@ // // Author: Roy Hopkins +mod schedule; mod tasks; -pub use tasks::{Task, TaskContext, TaskState, INITIAL_TASK_ID, TASK_FLAG_SHARE_PT}; +pub use schedule::{ + create_task, is_current_task, schedule, RunQueue, TaskNode, TaskPointer, TASKLIST, +}; +pub use tasks::{Task, TaskContext, TaskError, TaskState, INITIAL_TASK_ID, TASK_FLAG_SHARE_PT}; diff --git a/src/task/schedule.rs b/src/task/schedule.rs new file mode 100644 index 000000000..4135f76cd --- /dev/null +++ b/src/task/schedule.rs @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +extern crate alloc; + +use core::ptr::null_mut; + +use super::Task; +use super::{tasks::TaskRuntime, TaskState, INITIAL_TASK_ID}; +use crate::cpu::percpu::{this_cpu, this_cpu_mut}; +use crate::error::SvsmError; +use crate::locking::{RWLock, SpinLock}; +use alloc::boxed::Box; +use alloc::sync::Arc; +use intrusive_collections::{ + intrusive_adapter, Bound, KeyAdapter, LinkedList, LinkedListAtomicLink, RBTree, + RBTreeAtomicLink, +}; + +pub type TaskPointer = Arc; + +#[derive(Debug)] +pub struct TaskNode { + tree_link: RBTreeAtomicLink, + list_link: LinkedListAtomicLink, + pub task: RWLock>, +} + +// SAFETY: Send + Sync is required for Arc to implement Send. The `task` +// member is Send + Sync but the intrusive_collection links are only Send. The only +// access to these is via the intrusive_adapter! generated code which does not use +// them concurrently across threads. +unsafe impl Sync for TaskNode {} + +intrusive_adapter!(pub TaskTreeAdapter = TaskPointer: TaskNode { tree_link: RBTreeAtomicLink }); +intrusive_adapter!(pub TaskListAdapter = TaskPointer: TaskNode { list_link: LinkedListAtomicLink }); + +impl<'a> KeyAdapter<'a> for TaskTreeAdapter { + type Key = u64; + fn get_key(&self, node: &'a TaskNode) -> u64 { + node.task.lock_read().runtime.value() + } +} + +#[derive(Debug)] +struct TaskSwitch { + previous_task: Option, + next_task: Option, +} + +/// A RunQueue implementation that uses an RBTree to efficiently sort the priority +/// of tasks within the queue. +#[derive(Debug)] +pub struct RunQueue { + tree: Option>, + current_task: Option, + terminated_task: Option, + id: u32, + task_switch: TaskSwitch, +} + +impl RunQueue { + /// Create a new runqueue for an id. The id would normally be set + /// to the APIC ID of the CPU that owns the runqueue and is used to + /// determine the affinity of tasks. + pub const fn new(id: u32) -> Self { + Self { + tree: None, + current_task: None, + terminated_task: None, + id, + task_switch: TaskSwitch { + previous_task: None, + next_task: None, + }, + } + } + + fn tree(&mut self) -> &mut RBTree { + self.tree + .get_or_insert_with(|| RBTree::new(TaskTreeAdapter::new())) + } + + pub fn get_task(&self, id: u32) -> Option { + if let Some(task_tree) = &self.tree { + let mut cursor = task_tree.front(); + while let Some(task_node) = cursor.get() { + if task_node.task.lock_read().id == id { + return cursor.clone_pointer(); + } + cursor.move_next(); + } + } + None + } + + pub fn current_task_id(&self) -> u32 { + self.current_task + .as_ref() + .map_or(INITIAL_TASK_ID, |t| t.task.lock_read().id) + } + + /// Determine the next task to run on the vCPU that owns this instance. + /// Populates self.task_switchwith the next task and the previous task. If both + /// are None then the existing task remains in scope. + /// + /// Note that this function does not actually perform the task switch. This is + /// because it holds a mutable reference to self that must be released before + /// the task switch occurs. Call this function from a global function that releases + /// the reference before performing the task switch. + /// + /// # Returns + /// + /// Pointers to the next task and the previous task. + /// + /// If the next task pointer is null_mut() then no task switch is required and the + /// caller must release the runqueue lock. + /// + /// If the next task pointer is not null_mut() then the caller must call + /// next_task->set_current(prev_task) with the runqueue lock still held. + fn schedule(&mut self) -> (*mut Task, *mut Task) { + self.task_switch.previous_task = None; + self.task_switch.next_task = None; + + // Update the state of the current task. This will change the runtime value which + // is used as a key in the RB tree therefore we need to remove and reinsert the + // task. + let prev_task_node = self.update_current_task(); + + // Find the task with the lowest runtime. The tree only contains running tasks that + // are to be scheduled on this vCPU. + let cursor = self.tree().lower_bound(Bound::Unbounded); + + // The cursor will now be on the next task to schedule. There should always be + // a candidate task unless the current cpu task terminated. For now, don't support + // termination of the initial thread which means there will always be a task to schedule + let next_task_node = cursor.clone_pointer().expect("No task to schedule on CPU"); + self.current_task = Some(next_task_node.clone()); + + // Lock the current and next tasks and keep track of the lock state by adding references + // into the structure itself. This allows us to retain the lock over the context switch + // and unlock the tasks before returning to the new context. + let prev_task_ptr = if let Some(prev_task_node) = prev_task_node { + // If the next task is the same as the current one then we have nothing to do. + if prev_task_node.task.lock_read().id == next_task_node.task.lock_read().id { + return (null_mut(), null_mut()); + } + self.task_switch.previous_task = Some(prev_task_node.clone()); + unsafe { (*prev_task_node.task.lock_write_direct()).as_mut() } + } else { + null_mut() + }; + self.task_switch.next_task = Some(next_task_node.clone()); + let next_task_ptr = unsafe { (*next_task_node.task.lock_write_direct()).as_mut() }; + + (next_task_ptr, prev_task_ptr) + } + + fn update_current_task(&mut self) -> Option { + let task_node = self.current_task.take()?; + let task_state = { + let mut task = task_node.task.lock_write(); + task.runtime.schedule_out(); + task.state + }; + + if task_state == TaskState::TERMINATED { + // The current task has terminated. Make sure it doesn't get added back + // into the runtime tree, but also we need to make sure we keep a + // reference to the task because the current stack is owned by it. + // Put it in a holding location which will be cleared by the next + // active task. + unsafe { + self.deallocate(task_node.clone()); + } + self.terminated_task = Some(task_node); + None + } else { + // Reinsert the node into the tree so the position is updated with the new runtime + let mut task_cursor = unsafe { self.tree().cursor_mut_from_ptr(task_node.as_ref()) }; + task_cursor.remove(); + self.tree().insert(task_node.clone()); + Some(task_node) + } + } + + /// Helper function that determines if a task is a candidate for allocating + /// to a CPU + fn is_cpu_candidate(&self, t: &Task) -> bool { + (t.state == TaskState::RUNNING) + && t.allocation.is_none() + && t.affinity.map_or(true, |a| a == self.id) + } + + /// Iterate through all unallocated tasks and find a suitable candidates + /// for allocating to this queue + pub fn allocate(&mut self) { + let mut tl = TASKLIST.lock(); + let lowest_runtime = if let Some(t) = self.tree().lower_bound(Bound::Unbounded).get() { + t.task.lock_read().runtime.value() + } else { + 0 + }; + let mut cursor = tl.list().cursor_mut(); + while !cursor.peek_next().is_null() { + cursor.move_next(); + // Filter on running, unallocated tasks that either have no affinity + // or have an affinity for this CPU ID + if let Some(task_node) = cursor + .get() + .filter(|task_node| self.is_cpu_candidate(task_node.task.lock_read().as_ref())) + { + { + let mut t = task_node.task.lock_write(); + // Now we have the lock, check again that the task has not been allocated + // to another runqueue between the filter above and us taking the lock. + if t.allocation.is_some() { + continue; + } + t.allocation = Some(self.id); + t.runtime.set(lowest_runtime); + } + self.tree() + .insert(cursor.as_cursor().clone_pointer().unwrap()); + } + } + } + + /// Release the spinlock on the previous and next tasks following a task switch. + /// + /// # Safety + /// + /// The caller must ensure that any access to the previous or next tasks via + /// the pointers returned by [`Self::schedule()`] are no longer used after calling this + /// function. The RWLocks protecting the pointers are released by this function + /// meaning that further access to the pointers will cause undefined behaviour. + unsafe fn unlock_tasks(&mut self) { + if let Some(previous_task) = self.task_switch.previous_task.as_ref() { + unsafe { + previous_task.task.unlock_write_direct(); + } + self.task_switch.previous_task = None; + } + if let Some(next_task) = self.task_switch.next_task.as_ref() { + unsafe { + next_task.task.unlock_write_direct(); + } + self.task_switch.next_task = None; + } + } + + /// Deallocate a task from a per CPU runqueue but leave it in the global task list + /// where it can be reallocated if still in the RUNNING state. + /// + /// # Safety + /// + /// The caller must ensure that the function is passed a valid task_node as + /// this function dereferences the pointer contained within the task_node. A + /// [`TaskPointer`] uses an [`Arc`] to manage the lifetime of the contained pointer + /// making it difficult to pass an invalid pointer to this function. + unsafe fn deallocate(&mut self, task_node: TaskPointer) { + let mut cursor = self.tree().cursor_mut_from_ptr(task_node.as_ref()); + cursor.remove(); + task_node.task.lock_write().allocation = None; + } +} + +/// Global task list +/// This contains every task regardless of affinity or run state. +#[derive(Debug)] +pub struct TaskList { + list: Option>, +} + +impl TaskList { + pub const fn new() -> Self { + Self { list: None } + } + + pub fn list(&mut self) -> &mut LinkedList { + self.list + .get_or_insert_with(|| LinkedList::new(TaskListAdapter::new())) + } + + pub fn get_task(&self, id: u32) -> Option { + let task_list = &self.list.as_ref()?; + let mut cursor = task_list.front(); + while let Some(task_node) = cursor.get() { + if task_node.task.lock_read().id == id { + return cursor.clone_pointer(); + } + cursor.move_next(); + } + None + } + + fn terminate(&mut self, task_node: TaskPointer) { + // Set the task state as terminated. If the task being terminated is the + // current task then the task context will still need to be in scope until + // the next schedule() has completed. Schedule will keep a reference to this + // task until some time after the context switch. + task_node.task.lock_write().state = TaskState::TERMINATED; + let mut cursor = unsafe { self.list().cursor_mut_from_ptr(task_node.as_ref()) }; + cursor.remove(); + } +} + +pub static TASKLIST: SpinLock = SpinLock::new(TaskList::new()); + +fn task_switch_hook(_: &Task) { + // Then unlock the spinlocks that protect the previous and new tasks. + + // SAFETY: Unlocking the tasks is a safe operation at this point because + // we do not use the task pointers beyond the task switch itself which + // is complete at the time of this hook. + unsafe { + this_cpu_mut().runqueue().lock_write().unlock_tasks(); + } +} + +pub fn create_task( + entry: extern "C" fn(), + flags: u16, + affinity: Option, +) -> Result { + let mut task = Task::create(entry, flags)?; + task.set_affinity(affinity); + task.set_on_switch_hook(Some(task_switch_hook)); + let node = Arc::new(TaskNode { + tree_link: RBTreeAtomicLink::default(), + list_link: LinkedListAtomicLink::default(), + task: RWLock::new(task), + }); + { + // Ensure the tasklist lock is released before schedule() is called + // otherwise the lock will be held when switching to a new context + let mut tl = TASKLIST.lock(); + tl.list().push_front(node.clone()); + } + schedule(); + + Ok(node) +} + +/// Check to see if the task scheduled on the current processor has the given id +pub fn is_current_task(id: u32) -> bool { + match &this_cpu().runqueue().lock_read().current_task { + Some(current_task) => current_task.task.lock_read().id == id, + None => id == INITIAL_TASK_ID, + } +} + +pub unsafe fn current_task_terminated() { + let mut rq = this_cpu().runqueue().lock_write(); + let task_node = rq + .current_task + .as_mut() + .expect("Task termination handler called when there is no current task"); + TASKLIST.lock().terminate(task_node.clone()); +} + +pub fn schedule() { + this_cpu_mut().allocate_tasks(); + + let (next_task, prev_task) = this_cpu().runqueue().lock_write().schedule(); + if !next_task.is_null() { + unsafe { + (*next_task).set_current(prev_task); + } + } + + // We're now in the context of the new task. If the previous task had terminated + // then we can release it's reference here. + let _ = this_cpu_mut() + .runqueue() + .lock_write() + .terminated_task + .take(); +} diff --git a/src/task/tasks.rs b/src/task/tasks.rs index 99da39975..dc7516e9e 100644 --- a/src/task/tasks.rs +++ b/src/task/tasks.rs @@ -23,16 +23,31 @@ use crate::mm::pagetable::{get_init_pgtable_locked, PTEntryFlags, PageTableRef}; use crate::mm::vm::{Mapping, VMKernelStack, VMR}; use crate::mm::{SVSM_PERTASK_BASE, SVSM_PERTASK_END, SVSM_PERTASK_STACK_BASE}; +use super::schedule::{current_task_terminated, schedule}; + pub const INITIAL_TASK_ID: u32 = 1; #[derive(PartialEq, Debug, Copy, Clone, Default)] pub enum TaskState { RUNNING, - SCHEDULED, #[default] TERMINATED, } +#[derive(Clone, Copy, Debug)] +pub enum TaskError { + // Attempt to close a non-terminated task + NotTerminated, + // A closed task could not be removed from the task list + CloseFailed, +} + +impl From for SvsmError { + fn from(e: TaskError) -> Self { + Self::Task(e) + } +} + pub const TASK_FLAG_SHARE_PT: u16 = 0x01; #[derive(Debug, Default)] @@ -73,18 +88,10 @@ pub trait TaskRuntime { /// update the runtime calculation at this point. fn schedule_out(&mut self); - /// Returns whether this is the first time a task has been - /// considered for scheduling. - fn first(&self) -> bool; - /// Overrides the calculated runtime value with the given value. /// This can be used to set or adjust the runtime of a task. fn set(&mut self, runtime: u64); - /// Flag the runtime as terminated so the scheduler does not - /// find terminated tasks before running tasks. - fn terminated(&mut self); - /// Returns a value that represents the amount of CPU the task /// has been allocated fn value(&self) -> u64; @@ -106,18 +113,10 @@ impl TaskRuntime for TscRuntime { self.runtime += rdtsc() - self.runtime; } - fn first(&self) -> bool { - self.runtime == 0 - } - fn set(&mut self, runtime: u64) { self.runtime = runtime; } - fn terminated(&mut self) { - self.runtime = u64::MAX; - } - fn value(&self) -> u64 { self.runtime } @@ -138,18 +137,10 @@ impl TaskRuntime for CountRuntime { fn schedule_out(&mut self) {} - fn first(&self) -> bool { - self.count == 0 - } - fn set(&mut self, runtime: u64) { self.count = runtime; } - fn terminated(&mut self) { - self.count = u64::MAX; - } - fn value(&self) -> u64 { self.count } @@ -161,6 +152,7 @@ type TaskRuntimeImpl = CountRuntime; #[repr(C)] #[derive(Default, Debug, Clone, Copy)] pub struct TaskContext { + pub rsp: u64, pub regs: X86GeneralRegs, pub flags: u64, pub ret_addr: u64, @@ -184,11 +176,19 @@ pub struct Task { /// u32: The APIC ID of the CPU that the task must run on pub affinity: Option, + // APIC ID of the CPU that task has been assigned to. If 'None' then + // the task is not currently assigned to a CPU + pub allocation: Option, + /// ID of the task pub id: u32, /// Amount of CPU resource the task has consumed pub runtime: TaskRuntimeImpl, + + /// Optional hook that is called immediately after switching to this task + /// before the context is restored + pub on_switch_hook: Option, } impl fmt::Debug for Task { @@ -225,8 +225,10 @@ impl Task { vm_kernel_range, state: TaskState::RUNNING, affinity: None, + allocation: None, id: TASK_ID_ALLOCATOR.next_id(), runtime: TaskRuntimeImpl::default(), + on_switch_hook: None, }); Ok(task) } @@ -258,6 +260,10 @@ impl Task { self.vm_kernel_range.handle_page_fault(vaddr, write) } + pub fn set_on_switch_hook(&mut self, hook: Option) { + self.on_switch_hook = hook; + } + fn allocate_stack(entry: extern "C" fn()) -> Result<(Arc, VirtAddr), SvsmError> { let stack = VMKernelStack::new()?; let offset = stack.top_of_stack(VirtAddr::from(0u64)); @@ -294,7 +300,10 @@ impl Task { } extern "C" fn task_exit() { - panic!("Current task has exited"); + unsafe { + current_task_terminated(); + } + schedule(); } #[allow(unused)] @@ -307,6 +316,14 @@ extern "C" fn apply_new_context(new_task: *mut Task) -> u64 { } } +#[allow(unused)] +#[no_mangle] +extern "C" fn on_switch(new_task: &mut Task) { + if let Some(hook) = new_task.on_switch_hook { + hook(new_task); + } +} + global_asm!( r#" .text @@ -329,6 +346,7 @@ global_asm!( pushq %r13 pushq %r14 pushq %r15 + pushq %rsp // Save the current stack pointer testq %rsi, %rsi @@ -344,6 +362,13 @@ global_asm!( // Switch to the new task stack movq (%rbx), %rsp + // We've already restored rsp + addq $8, %rsp + + mov %rbx, %rdi + call on_switch + + // Restore the task context popq %r15 popq %r14 popq %r13 diff --git a/src/utils/memory_region.rs b/src/utils/memory_region.rs new file mode 100644 index 000000000..6ce56451a --- /dev/null +++ b/src/utils/memory_region.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Carlos López + +use crate::address::Address; +use crate::types::PageSize; + +/// An abstraction over a memory region, expressed in terms of physical +/// ([`PhysAddr`](crate::address::PhysAddr)) or virtual +/// ([`VirtAddr`](crate::address::VirtAddr)) addresses. +#[derive(Clone, Copy, Debug)] +pub struct MemoryRegion { + start: A, + end: A, +} + +impl MemoryRegion +where + A: Address, +{ + /// Create a new memory region starting at address `start`, spanning `len` + /// bytes. + pub fn new(start: A, len: usize) -> Self { + let end = A::from(start.bits() + len); + Self { start, end } + } + + /// Create a new memory region with overflow checks. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let start = VirtAddr::from(u64::MAX); + /// let region = MemoryRegion::checked_new(start, PAGE_SIZE); + /// assert!(region.is_none()); + /// ``` + pub fn checked_new(start: A, len: usize) -> Option { + let end = start.checked_add(len)?; + Some(Self { start, end }) + } + + /// Create a memory region from two raw addresses. + pub const fn from_addresses(start: A, end: A) -> Self { + Self { start, end } + } + + /// The base address of the memory region, originally set in + /// [`MemoryRegion::new()`]. + #[inline] + pub const fn start(&self) -> A { + self.start + } + + /// The length of the memory region in bytes, originally set in + /// [`MemoryRegion::new()`]. + #[inline] + pub fn len(&self) -> usize { + self.end.bits().saturating_sub(self.start.bits()) + } + + /// Returns whether the region spans any actual memory. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::utils::MemoryRegion; + /// let r = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), 0); + /// assert!(r.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The end address of the memory region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let base = VirtAddr::from(0xffffff0000u64); + /// let region = MemoryRegion::new(base, PAGE_SIZE); + /// assert_eq!(region.end(), VirtAddr::from(0xffffff1000u64)); + /// ``` + #[inline] + pub const fn end(&self) -> A { + self.end + } + + /// Checks whether two regions overlap. This does *not* include contiguous + /// regions, use [`MemoryRegion::contiguous()`] for that purpose. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff2000u64), PAGE_SIZE); + /// assert!(!r1.overlap(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(r1.overlap(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// // Contiguous regions do not overlap + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(!r1.overlap(&r2)); + /// ``` + pub fn overlap(&self, other: &Self) -> bool { + self.start() < other.end() && self.end() > other.start() + } + + /// Checks whether two regions are contiguous or overlapping. This is a + /// less strict check than [`MemoryRegion::overlap()`]. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(r1.contiguous(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff2000u64), PAGE_SIZE); + /// assert!(!r1.contiguous(&r2)); + /// ``` + pub fn contiguous(&self, other: &Self) -> bool { + self.start() <= other.end() && self.end() >= other.start() + } + + /// Merge two regions. It does not check whether the two regions are + /// contiguous in the first place, so the resulting region will cover + /// any non-overlapping memory between both. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// let r3 = r1.merge(&r2); + /// assert_eq!(r3.start(), r1.start()); + /// assert_eq!(r3.len(), r1.len() + r2.len()); + /// assert_eq!(r3.end(), r2.end()); + /// ``` + pub fn merge(&self, other: &Self) -> Self { + let start = self.start.min(other.start); + let end = self.end().max(other.end()); + Self { start, end } + } + + /// Iterate over the addresses covering the memory region in jumps of the + /// specified page size. Note that if the base address of the region is not + /// page aligned, returned addresses will not be aligned either. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// let mut iter = region.iter_pages(PageSize::Regular); + /// assert_eq!(iter.next(), Some(VirtAddr::from(0xffffff0000u64))); + /// assert_eq!(iter.next(), Some(VirtAddr::from(0xffffff1000u64))); + /// assert_eq!(iter.next(), None); + /// ``` + pub fn iter_pages(&self, size: PageSize) -> impl Iterator { + let size = usize::from(size); + (self.start().bits()..self.end().bits()) + .step_by(size) + .map(A::from) + } + + /// Check whether an address is within this region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// assert!(region.contains(VirtAddr::from(0xffffff0000u64))); + /// assert!(region.contains(VirtAddr::from(0xffffff0fffu64))); + /// assert!(!region.contains(VirtAddr::from(0xffffff1000u64))); + /// ``` + pub fn contains(&self, addr: A) -> bool { + self.start() <= addr && addr < self.end() + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1f90a64d0..866b5a1eb 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -6,6 +6,8 @@ pub mod bitmap_allocator; pub mod immut_after_init; +pub mod memory_region; pub mod util; +pub use memory_region::MemoryRegion; pub use util::{align_down, align_up, halt, overlap, page_align_up, page_offset, zero_mem_region}; diff --git a/src/utils/util.rs b/src/utils/util.rs index 98783577c..3173ca595 100644 --- a/src/utils/util.rs +++ b/src/utils/util.rs @@ -47,6 +47,14 @@ pub fn zero_mem_region(start: VirtAddr, end: VirtAddr) { unsafe { start.as_mut_ptr::().write_bytes(0, size) } } +/// Obtain bit for a given position +#[macro_export] +macro_rules! BIT { + ($x: expr) => { + (1 << ($x)) + }; +} + #[cfg(test)] mod tests {