diff --git a/src/address.rs b/src/address.rs index b1b5cb78c..fbdd3418b 100644 --- a/src/address.rs +++ b/src/address.rs @@ -55,7 +55,7 @@ pub trait Address: self.is_aligned(PAGE_SIZE) } - fn checked_offset(&self, off: InnerAddr) -> Option { + fn checked_add(&self, off: InnerAddr) -> Option { self.bits().checked_add(off).map(|addr| addr.into()) } @@ -63,6 +63,10 @@ pub trait Address: self.bits().checked_sub(off).map(|addr| addr.into()) } + fn saturating_add(&self, off: InnerAddr) -> Self { + Self::from(self.bits().saturating_add(off)) + } + fn page_offset(&self) -> usize { self.bits() & (PAGE_SIZE - 1) } @@ -258,7 +262,7 @@ impl ops::Add for VirtAddr { } impl Address for VirtAddr { - fn checked_offset(&self, off: InnerAddr) -> Option { + fn checked_add(&self, off: InnerAddr) -> Option { self.bits() .checked_add(off) .map(|addr| sign_extend(addr).into()) diff --git a/src/debug/stacktrace.rs b/src/debug/stacktrace.rs index fe22e8848..a0e94e06b 100644 --- a/src/debug/stacktrace.rs +++ b/src/debug/stacktrace.rs @@ -26,7 +26,7 @@ struct StackBounds { #[cfg(feature = "enable-stacktrace")] impl StackBounds { fn range_is_on_stack(&self, begin: VirtAddr, len: usize) -> bool { - match begin.checked_offset(len) { + match begin.checked_add(len) { Some(end) => begin >= self.bottom && end <= self.top, None => false, } diff --git a/src/fw_cfg.rs b/src/fw_cfg.rs index bc4e41a78..271644d4c 100644 --- a/src/fw_cfg.rs +++ b/src/fw_cfg.rs @@ -6,8 +6,10 @@ extern crate alloc; +use crate::address::{Address, PhysAddr}; use crate::error::SvsmError; use crate::mm::pagetable::max_phys_addr; +use crate::utils::MemoryRegion; use super::io::IOPort; use super::string::FixedString; @@ -67,20 +69,6 @@ impl FwCfgFile { } } -#[derive(Clone, Copy, Debug)] -pub struct MemoryRegion { - pub start: u64, - pub end: u64, -} - -impl MemoryRegion { - /// Returns `true` if the region overlaps with another region with given - /// start and end. - pub fn overlaps(&self, start: u64, end: u64) -> bool { - self.start < end && start < self.end - } -} - impl<'a> FwCfg<'a> { pub fn new(driver: &'a dyn IOPort) -> Self { FwCfg { driver } @@ -150,7 +138,7 @@ impl<'a> FwCfg<'a> { Err(SvsmError::FwCfg(FwCfgError::FileNotFound)) } - fn find_svsm_region(&self) -> Result { + fn find_svsm_region(&self) -> Result, SvsmError> { let file = self.file_selector("etc/sev/svsm")?; if file.size != 16 { @@ -161,19 +149,19 @@ impl<'a> FwCfg<'a> { Ok(self.read_memory_region()) } - fn read_memory_region(&self) -> MemoryRegion { - let start: u64 = self.read_le(); - let size: u64 = self.read_le(); - let end = start.saturating_add(size); + fn read_memory_region(&self) -> MemoryRegion { + let start = PhysAddr::from(self.read_le::()); + let size = self.read_le::(); + let end = start.saturating_add(size as usize); assert!(start <= max_phys_addr(), "{start:#018x} is out of range"); assert!(end <= max_phys_addr(), "{end:#018x} is out of range"); - MemoryRegion { start, end } + MemoryRegion::from_addresses(start, end) } - pub fn get_memory_regions(&self) -> Result, SvsmError> { - let mut regions: Vec = Vec::new(); + pub fn get_memory_regions(&self) -> Result>, SvsmError> { + let mut regions = Vec::new(); let file = self.file_selector("etc/e820")?; let entries = file.size / 20; @@ -191,33 +179,35 @@ impl<'a> FwCfg<'a> { Ok(regions) } - fn find_kernel_region_e820(&self) -> Result { + fn find_kernel_region_e820(&self) -> Result, SvsmError> { let regions = self.get_memory_regions()?; - let mut kernel_region = regions + let kernel_region = regions .iter() - .max_by_key(|region| region.start) - .copied() + .max_by_key(|region| region.start()) .ok_or(SvsmError::FwCfg(FwCfgError::KernelRegion))?; - let start = - (kernel_region.end.saturating_sub(KERNEL_REGION_SIZE)) & KERNEL_REGION_SIZE_MASK; + let start = PhysAddr::from( + kernel_region + .end() + .bits() + .saturating_sub(KERNEL_REGION_SIZE as usize) + & KERNEL_REGION_SIZE_MASK as usize, + ); - if start < kernel_region.start { + if start < kernel_region.start() { return Err(SvsmError::FwCfg(FwCfgError::KernelRegion)); } - kernel_region.start = start; - - Ok(kernel_region) + Ok(MemoryRegion::new(start, kernel_region.len())) } - pub fn find_kernel_region(&self) -> Result { + pub fn find_kernel_region(&self) -> Result, SvsmError> { let kernel_region = self .find_svsm_region() .or_else(|_| self.find_kernel_region_e820())?; // Make sure that the kernel region doesn't overlap with the loader. - if kernel_region.start < 640 * 1024 { + if kernel_region.start() < PhysAddr::from(640 * 1024u64) { return Err(SvsmError::FwCfg(FwCfgError::KernelRegion)); } @@ -227,7 +217,7 @@ impl<'a> FwCfg<'a> { // This needs to be &mut self to prevent iterator invalidation, where the caller // could do fw_cfg.select() while iterating. Having a mutable reference prevents // other references. - pub fn iter_flash_regions(&mut self) -> impl Iterator + '_ { + pub fn iter_flash_regions(&mut self) -> impl Iterator> + '_ { let num = match self.file_selector("etc/flash") { Ok(file) => { self.select(file.selector); diff --git a/src/fw_meta.rs b/src/fw_meta.rs index 04c933db4..e1f5f9c69 100644 --- a/src/fw_meta.rs +++ b/src/fw_meta.rs @@ -6,59 +6,28 @@ extern crate alloc; -use crate::address::{Address, PhysAddr}; +use crate::address::PhysAddr; use crate::cpu::percpu::this_cpu_mut; use crate::error::SvsmError; +use crate::kernel_launch::KernelLaunchInfo; use crate::mm::PerCPUPageMappingGuard; use crate::sev::ghcb::PageStateChangeOp; use crate::sev::{pvalidate, rmp_adjust, PvalidateOp, RMPFlags}; use crate::types::{PageSize, PAGE_SIZE}; -use crate::utils::{overlap, zero_mem_region}; +use crate::utils::{zero_mem_region, MemoryRegion}; use alloc::vec::Vec; -use core::cmp; use core::fmt; use core::mem::{align_of, size_of, size_of_val}; use core::str::FromStr; -#[derive(Copy, Clone, Debug)] -pub struct SevPreValidMem { - base: PhysAddr, - length: usize, -} - -impl SevPreValidMem { - fn new(base: PhysAddr, length: usize) -> Self { - Self { base, length } - } - - fn new_4k(base: PhysAddr) -> Self { - Self::new(base, PAGE_SIZE) - } - - #[inline] - fn end(&self) -> PhysAddr { - self.base + self.length - } - - fn overlap(&self, other: &Self) -> bool { - overlap(self.base, self.end(), other.base, other.end()) - } - - fn merge(self, other: Self) -> Self { - let base = cmp::min(self.base, other.base); - let length = cmp::max(self.end(), other.end()) - base; - Self::new(base, length) - } -} - #[derive(Clone, Debug)] pub struct SevFWMetaData { pub reset_ip: Option, pub cpuid_page: Option, pub secrets_page: Option, pub caa_page: Option, - pub valid_mem: Vec, + pub valid_mem: Vec>, } impl SevFWMetaData { @@ -73,7 +42,7 @@ impl SevFWMetaData { } pub fn add_valid_mem(&mut self, base: PhysAddr, len: usize) { - self.valid_mem.push(SevPreValidMem::new(base, len)); + self.valid_mem.push(MemoryRegion::new(base, len)); } } @@ -392,8 +361,8 @@ fn parse_sev_meta( Ok(()) } -fn validate_fw_mem_region(region: SevPreValidMem) -> Result<(), SvsmError> { - let pstart = region.base; +fn validate_fw_mem_region(region: MemoryRegion) -> Result<(), SvsmError> { + let pstart = region.start(); let pend = region.end(); log::info!("Validating {:#018x}-{:#018x}", pstart, pend); @@ -408,10 +377,7 @@ fn validate_fw_mem_region(region: SevPreValidMem) -> Result<(), SvsmError> { ) .expect("GHCB PSC call failed to validate firmware memory"); - for paddr in (pstart.bits()..pend.bits()) - .step_by(PAGE_SIZE) - .map(PhysAddr::from) - { + for paddr in region.iter_pages(PageSize::Regular) { let guard = PerCPUPageMappingGuard::create_4k(paddr)?; let vaddr = guard.virt_addr(); @@ -430,17 +396,17 @@ fn validate_fw_mem_region(region: SevPreValidMem) -> Result<(), SvsmError> { Ok(()) } -fn validate_fw_memory_vec(regions: Vec) -> Result<(), SvsmError> { +fn validate_fw_memory_vec(regions: Vec>) -> Result<(), SvsmError> { if regions.is_empty() { return Ok(()); } - let mut next_vec: Vec = Vec::new(); + let mut next_vec = Vec::new(); let mut region = regions[0]; for next in regions.into_iter().skip(1) { - if region.overlap(&next) { - region = region.merge(next); + if region.contiguous(&next) { + region = region.merge(&next); } else { next_vec.push(next); } @@ -450,27 +416,37 @@ fn validate_fw_memory_vec(regions: Vec) -> Result<(), SvsmError> validate_fw_memory_vec(next_vec) } -pub fn validate_fw_memory(fw_meta: &SevFWMetaData) -> Result<(), SvsmError> { +pub fn validate_fw_memory( + fw_meta: &SevFWMetaData, + launch_info: &KernelLaunchInfo, +) -> Result<(), SvsmError> { // Initalize vector with regions from the FW let mut regions = fw_meta.valid_mem.clone(); // Add region for CPUID page if present if let Some(cpuid_paddr) = fw_meta.cpuid_page { - regions.push(SevPreValidMem::new_4k(cpuid_paddr)); + regions.push(MemoryRegion::new(cpuid_paddr, PAGE_SIZE)); } // Add region for Secrets page if present if let Some(secrets_paddr) = fw_meta.secrets_page { - regions.push(SevPreValidMem::new_4k(secrets_paddr)); + regions.push(MemoryRegion::new(secrets_paddr, PAGE_SIZE)); } // Add region for CAA page if present if let Some(caa_paddr) = fw_meta.caa_page { - regions.push(SevPreValidMem::new_4k(caa_paddr)); + regions.push(MemoryRegion::new(caa_paddr, PAGE_SIZE)); } // Sort regions by base address - regions.sort_unstable_by(|a, b| a.base.cmp(&b.base)); + regions.sort_unstable_by_key(|a| a.start()); + + let kernel_region = launch_info.kernel_region(); + for region in regions.iter() { + if region.overlap(&kernel_region) { + panic!("FwMeta region ovelaps with kernel"); + } + } validate_fw_memory_vec(regions) } @@ -501,7 +477,7 @@ pub fn print_fw_meta(fw_meta: &SevFWMetaData) { for region in &fw_meta.valid_mem { log::info!( " Pre-Validated Region {:#018x}-{:#018x}", - region.base, + region.start(), region.end() ); } diff --git a/src/kernel_launch.rs b/src/kernel_launch.rs index 4879be18f..cd5815a7e 100644 --- a/src/kernel_launch.rs +++ b/src/kernel_launch.rs @@ -4,6 +4,9 @@ // // Author: Joerg Roedel +use crate::address::PhysAddr; +use crate::utils::MemoryRegion; + #[derive(Copy, Clone, Debug)] #[repr(C)] pub struct KernelLaunchInfo { @@ -30,4 +33,10 @@ impl KernelLaunchInfo { pub fn heap_area_virt_end(&self) -> u64 { self.heap_area_virt_start + self.heap_area_size() } + + pub fn kernel_region(&self) -> MemoryRegion { + let start = PhysAddr::from(self.kernel_region_phys_start); + let end = PhysAddr::from(self.kernel_region_phys_end); + MemoryRegion::from_addresses(start, end) + } } diff --git a/src/mm/memory.rs b/src/mm/memory.rs index 34c5f9b0b..a168caa39 100644 --- a/src/mm/memory.rs +++ b/src/mm/memory.rs @@ -9,28 +9,27 @@ extern crate alloc; use crate::address::{Address, PhysAddr}; use crate::cpu::percpu::PERCPU_VMSAS; use crate::error::SvsmError; -use crate::fw_cfg::{FwCfg, MemoryRegion}; +use crate::fw_cfg::FwCfg; use crate::kernel_launch::KernelLaunchInfo; use crate::locking::RWLock; +use crate::utils::MemoryRegion; use alloc::vec::Vec; use log; use super::pagetable::LAUNCH_VMSA_ADDR; -static MEMORY_MAP: RWLock> = RWLock::new(Vec::new()); +static MEMORY_MAP: RWLock>> = RWLock::new(Vec::new()); pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result<(), SvsmError> { let mut regions = fwcfg.get_memory_regions()?; + let kernel_region = launch_info.kernel_region(); // Remove SVSM memory from guest memory map let mut i = 0; while i < regions.len() { // Check if the region overlaps with SVSM memory. let region = regions[i]; - if !region.overlaps( - launch_info.kernel_region_phys_start, - launch_info.kernel_region_phys_end, - ) { + if !region.overlap(&kernel_region) { // Check the next region. i += 1; continue; @@ -40,29 +39,23 @@ pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result< regions.remove(i); // 2. Insert a region up until the start of SVSM memory (if non-empty). - let region_before_start = region.start; - let region_before_end = launch_info.kernel_region_phys_start; + let region_before_start = region.start(); + let region_before_end = kernel_region.start(); if region_before_start < region_before_end { regions.insert( i, - MemoryRegion { - start: region_before_start, - end: region_before_end, - }, + MemoryRegion::from_addresses(region_before_start, region_before_end), ); i += 1; } // 3. Insert a region up after the end of SVSM memory (if non-empty). - let region_after_start = launch_info.kernel_region_phys_end; - let region_after_end = region.end; + let region_after_start = kernel_region.end(); + let region_after_end = region.end(); if region_after_start < region_after_end { regions.insert( i, - MemoryRegion { - start: region_after_start, - end: region_after_end, - }, + MemoryRegion::from_addresses(region_after_start, region_after_end), ); i += 1; } @@ -70,7 +63,7 @@ pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result< log::info!("Guest Memory Regions:"); for r in regions.iter() { - log::info!(" {:018x}-{:018x}", r.start, r.end); + log::info!(" {:018x}-{:018x}", r.start(), r.end()); } let mut map = MEMORY_MAP.lock_write(); @@ -81,7 +74,6 @@ pub fn init_memory_map(fwcfg: &FwCfg, launch_info: &KernelLaunchInfo) -> Result< pub fn valid_phys_address(paddr: PhysAddr) -> bool { let page_addr = paddr.page_align(); - let addr = paddr.bits() as u64; if PERCPU_VMSAS.exists(page_addr) { return false; @@ -93,7 +85,7 @@ pub fn valid_phys_address(paddr: PhysAddr) -> bool { MEMORY_MAP .lock_read() .iter() - .any(|region| addr >= region.start && addr < region.end) + .any(|region| region.contains(paddr)) } const ISA_RANGE_START: PhysAddr = PhysAddr::new(0xa0000); diff --git a/src/mm/pagetable.rs b/src/mm/pagetable.rs index 01919d15a..90c03ea9a 100644 --- a/src/mm/pagetable.rs +++ b/src/mm/pagetable.rs @@ -85,8 +85,8 @@ fn encrypt_mask() -> usize { } /// Returns the exclusive end of the physical address space. -pub fn max_phys_addr() -> u64 { - *MAX_PHYS_ADDR +pub fn max_phys_addr() -> PhysAddr { + PhysAddr::from(*MAX_PHYS_ADDR) } fn supported_flags(flags: PTEntryFlags) -> PTEntryFlags { diff --git a/src/mm/ptguards.rs b/src/mm/ptguards.rs index 07dc88ae2..82c92b949 100644 --- a/src/mm/ptguards.rs +++ b/src/mm/ptguards.rs @@ -14,22 +14,12 @@ use crate::mm::virtualrange::{ }; use crate::types::{PAGE_SIZE, PAGE_SIZE_2M}; -#[derive(Debug)] -struct RawPTMappingGuard { - start: VirtAddr, - end: VirtAddr, -} - -impl RawPTMappingGuard { - pub const fn new(start: VirtAddr, end: VirtAddr) -> Self { - RawPTMappingGuard { start, end } - } -} +use crate::utils::MemoryRegion; #[derive(Debug)] #[must_use = "if unused the mapping will immediately be unmapped"] pub struct PerCPUPageMappingGuard { - mapping: Option, + mapping: MemoryRegion, huge: bool, } @@ -72,10 +62,10 @@ impl PerCPUPageMappingGuard { vaddr }; - let raw_mapping = RawPTMappingGuard::new(vaddr, vaddr + size); + let raw_mapping = MemoryRegion::new(vaddr, size); Ok(PerCPUPageMappingGuard { - mapping: Some(raw_mapping), + mapping: raw_mapping, huge, }) } @@ -85,22 +75,23 @@ impl PerCPUPageMappingGuard { } pub fn virt_addr(&self) -> VirtAddr { - self.mapping.as_ref().unwrap().start + self.mapping.start() } } impl Drop for PerCPUPageMappingGuard { fn drop(&mut self) { - if let Some(m) = &self.mapping { - let size = m.end - m.start; - if self.huge { - this_cpu_mut().get_pgtable().unmap_region_2m(m.start, m.end); - virt_free_range_2m(m.start, size); - } else { - this_cpu_mut().get_pgtable().unmap_region_4k(m.start, m.end); - virt_free_range_4k(m.start, size); - } - flush_address_sync(m.start); + let start = self.mapping.start(); + let end = self.mapping.end(); + let size = self.mapping.len(); + + if self.huge { + this_cpu_mut().get_pgtable().unmap_region_2m(start, end); + virt_free_range_2m(start, size); + } else { + this_cpu_mut().get_pgtable().unmap_region_4k(start, end); + virt_free_range_4k(start, size); } + flush_address_sync(start); } } diff --git a/src/mm/stack.rs b/src/mm/stack.rs index b02cf0696..c7ec87049 100644 --- a/src/mm/stack.rs +++ b/src/mm/stack.rs @@ -15,6 +15,7 @@ use crate::mm::{ STACK_PAGES, STACK_SIZE, STACK_TOTAL_SIZE, SVSM_SHARED_STACK_BASE, SVSM_SHARED_STACK_END, }; use crate::types::PAGE_SIZE; +use crate::utils::MemoryRegion; // Limit maximum number of stacks for now, address range support 2**16 8k stacks const MAX_STACKS: usize = 1024; @@ -22,16 +23,15 @@ const BMP_QWORDS: usize = MAX_STACKS / 64; #[derive(Debug)] struct StackRange { - start: VirtAddr, - end: VirtAddr, + region: MemoryRegion, alloc_bitmap: [u64; BMP_QWORDS], } impl StackRange { pub const fn new(start: VirtAddr, end: VirtAddr) -> Self { + let region = MemoryRegion::from_addresses(start, end); StackRange { - start, - end, + region, alloc_bitmap: [0; BMP_QWORDS], } } @@ -49,16 +49,16 @@ impl StackRange { self.alloc_bitmap[i] |= mask; - return Ok(self.start + ((i * 64 + idx) * STACK_TOTAL_SIZE)); + return Ok(self.region.start() + ((i * 64 + idx) * STACK_TOTAL_SIZE)); } Err(SvsmError::Mem) } pub fn dealloc(&mut self, stack: VirtAddr) { - assert!(stack >= self.start && stack < self.end); + assert!(self.region.contains(stack)); - let offset = stack - self.start; + let offset = stack - self.region.start(); let idx = offset / (STACK_TOTAL_SIZE); assert!((offset % (STACK_TOTAL_SIZE)) <= STACK_SIZE); diff --git a/src/mm/validate.rs b/src/mm/validate.rs index a27287bfe..cdb8d5746 100644 --- a/src/mm/validate.rs +++ b/src/mm/validate.rs @@ -10,28 +10,29 @@ use crate::locking::SpinLock; use crate::mm::alloc::{allocate_pages, get_order}; use crate::mm::virt_to_phys; use crate::types::{PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::MemoryRegion; use core::ptr; static VALID_BITMAP: SpinLock = SpinLock::new(ValidBitmap::new()); #[inline(always)] -fn bitmap_alloc_order(pbase: PhysAddr, pend: PhysAddr) -> usize { - let mem_size = (pend - pbase) / (PAGE_SIZE * 8); +fn bitmap_alloc_order(region: MemoryRegion) -> usize { + let mem_size = region.len() / (PAGE_SIZE * 8); get_order(mem_size) } -pub fn init_valid_bitmap_ptr(pbase: PhysAddr, pend: PhysAddr, bitmap: *mut u64) { +pub fn init_valid_bitmap_ptr(region: MemoryRegion, bitmap: *mut u64) { let mut vb_ref = VALID_BITMAP.lock(); - vb_ref.set_region(pbase, pend); + vb_ref.set_region(region); vb_ref.set_bitmap(bitmap); } -pub fn init_valid_bitmap_alloc(pbase: PhysAddr, pend: PhysAddr) -> Result<(), SvsmError> { - let order: usize = bitmap_alloc_order(pbase, pend); +pub fn init_valid_bitmap_alloc(region: MemoryRegion) -> Result<(), SvsmError> { + let order: usize = bitmap_alloc_order(region); let bitmap_addr = allocate_pages(order)?; let mut vb_ref = VALID_BITMAP.lock(); - vb_ref.set_region(pbase, pend); + vb_ref.set_region(region); vb_ref.set_bitmap(bitmap_addr.as_mut_ptr::()); vb_ref.clear_all(); @@ -95,23 +96,20 @@ pub fn valid_bitmap_valid_addr(paddr: PhysAddr) -> bool { #[derive(Debug)] struct ValidBitmap { - pbase: PhysAddr, - pend: PhysAddr, + region: MemoryRegion, bitmap: *mut u64, } impl ValidBitmap { pub const fn new() -> Self { ValidBitmap { - pbase: PhysAddr::null(), - pend: PhysAddr::null(), + region: MemoryRegion::from_addresses(PhysAddr::null(), PhysAddr::null()), bitmap: ptr::null_mut(), } } - pub fn set_region(&mut self, pbase: PhysAddr, pend: PhysAddr) { - self.pbase = pbase; - self.pend = pend; + pub fn set_region(&mut self, region: MemoryRegion) { + self.region = region; } pub fn set_bitmap(&mut self, bitmap: *mut u64) { @@ -119,7 +117,7 @@ impl ValidBitmap { } pub fn check_addr(&self, paddr: PhysAddr) -> bool { - paddr >= self.pbase && paddr < self.pend + self.region.contains(paddr) } pub fn bitmap_addr(&self) -> PhysAddr { @@ -129,7 +127,7 @@ impl ValidBitmap { #[inline(always)] fn index(&self, paddr: PhysAddr) -> (isize, usize) { - let page_offset = (paddr - self.pbase) / PAGE_SIZE; + let page_offset = (paddr - self.region.start()) / PAGE_SIZE; let index: isize = (page_offset / 64).try_into().unwrap(); let bit: usize = page_offset % 64; @@ -137,7 +135,7 @@ impl ValidBitmap { } pub fn clear_all(&mut self) { - let (mut i, bit) = self.index(self.pend); + let (mut i, bit) = self.index(self.region.end()); if bit != 0 { i += 1; } @@ -149,11 +147,11 @@ impl ValidBitmap { } pub fn alloc_order(&self) -> usize { - bitmap_alloc_order(self.pbase, self.pend) + bitmap_alloc_order(self.region) } pub fn migrate(&mut self, new_bitmap: *mut u64) { - let (count, _) = self.index(self.pend); + let (count, _) = self.index(self.region.end()); unsafe { ptr::copy_nonoverlapping(self.bitmap, new_bitmap, count as usize); diff --git a/src/stage2.rs b/src/stage2.rs index 14f262cfd..12fd2a52e 100644 --- a/src/stage2.rs +++ b/src/stage2.rs @@ -157,10 +157,9 @@ pub extern "C" fn stage2_main(launch_info: &Stage1LaunchInfo) { log::info!("COCONUT Secure Virtual Machine Service Module (SVSM) Stage 2 Loader"); - let kernel_region_phys_start = PhysAddr::from(r.start); - let kernel_region_phys_end = PhysAddr::from(r.end); - init_valid_bitmap_alloc(kernel_region_phys_start, kernel_region_phys_end) - .expect("Failed to allocate valid-bitmap"); + let kernel_region_phys_start = r.start(); + let kernel_region_phys_end = r.end(); + init_valid_bitmap_alloc(r).expect("Failed to allocate valid-bitmap"); // Read the SVSM kernel's ELF file metadata. let kernel_elf_len = kernel_elf_end - kernel_elf_start; diff --git a/src/svsm.rs b/src/svsm.rs index 4af0db55a..8deb05587 100644 --- a/src/svsm.rs +++ b/src/svsm.rs @@ -47,7 +47,7 @@ use svsm::sev::utils::{rmp_adjust, RMPFlags}; use svsm::svsm_console::SVSMIOPort; use svsm::svsm_paging::{init_page_table, invalidate_stage2}; use svsm::types::{PageSize, GUEST_VMPL, PAGE_SIZE}; -use svsm::utils::{halt, immut_after_init::ImmutAfterInitCell, zero_mem_region}; +use svsm::utils::{halt, immut_after_init::ImmutAfterInitCell, zero_mem_region, MemoryRegion}; use svsm::mm::validate::{init_valid_bitmap_ptr, migrate_valid_bitmap}; @@ -221,26 +221,29 @@ fn validate_flash() -> Result<(), SvsmError> { let mut fw_cfg = FwCfg::new(&CONSOLE_IO); let flash_regions = fw_cfg.iter_flash_regions().collect::>(); + let kernel_region = LAUNCH_INFO.kernel_region(); + let flash_range = { + let one_gib = 1024 * 1024 * 1024usize; + let start = PhysAddr::from(3 * one_gib); + MemoryRegion::new(start, one_gib) + }; // Sanity-check flash regions. for region in flash_regions.iter() { // Make sure that the regions are between 3GiB and 4GiB. - if !region.overlaps(3 * 1024 * 1024 * 1024, 4 * 1024 * 1024 * 1024) { + if !region.overlap(&flash_range) { panic!("flash region in unexpected region"); } // Make sure that no regions overlap with the kernel. - if region.overlaps( - LAUNCH_INFO.kernel_region_phys_start, - LAUNCH_INFO.kernel_region_phys_end, - ) { + if region.overlap(&kernel_region) { panic!("flash region overlaps with kernel"); } } // Make sure that regions don't overlap. for (i, outer) in flash_regions.iter().enumerate() { for inner in flash_regions[..i].iter() { - if outer.overlaps(inner.start, inner.end) { + if outer.overlap(inner) { panic!("flash regions overlap"); } } @@ -248,23 +251,18 @@ fn validate_flash() -> Result<(), SvsmError> { // Make sure that one regions ends at 4GiB. let one_region_ends_at_4gib = flash_regions .iter() - .any(|region| region.end == 4 * 1024 * 1024 * 1024); + .any(|region| region.end() == flash_range.end()); assert!(one_region_ends_at_4gib); for (i, region) in flash_regions.into_iter().enumerate() { - let pstart = PhysAddr::from(region.start); - let pend = PhysAddr::from(region.end); log::info!( "Flash region {} at {:#018x} size {:018x}", i, - pstart, - pend - pstart + region.start(), + region.len(), ); - for paddr in (pstart.bits()..pend.bits()) - .step_by(PAGE_SIZE) - .map(PhysAddr::from) - { + for paddr in region.iter_pages(PageSize::Regular) { let guard = PerCPUPageMappingGuard::create_4k(paddr)?; let vaddr = guard.virt_addr(); if let Err(e) = rmp_adjust( @@ -328,11 +326,7 @@ pub extern "C" fn svsm_start(li: &KernelLaunchInfo, vb_addr: usize) { mapping_info_init(&launch_info); - init_valid_bitmap_ptr( - launch_info.kernel_region_phys_start.into(), - launch_info.kernel_region_phys_end.into(), - vb_ptr, - ); + init_valid_bitmap_ptr(launch_info.kernel_region(), vb_ptr); load_gdt(); early_idt_init(); @@ -459,7 +453,7 @@ pub extern "C" fn svsm_main() { print_fw_meta(&fw_meta); - if let Err(e) = validate_fw_memory(&fw_meta) { + if let Err(e) = validate_fw_memory(&fw_meta, &LAUNCH_INFO) { panic!("Failed to validate firmware memory: {:#?}", e); } diff --git a/src/utils/memory_region.rs b/src/utils/memory_region.rs new file mode 100644 index 000000000..6ce56451a --- /dev/null +++ b/src/utils/memory_region.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Carlos López + +use crate::address::Address; +use crate::types::PageSize; + +/// An abstraction over a memory region, expressed in terms of physical +/// ([`PhysAddr`](crate::address::PhysAddr)) or virtual +/// ([`VirtAddr`](crate::address::VirtAddr)) addresses. +#[derive(Clone, Copy, Debug)] +pub struct MemoryRegion { + start: A, + end: A, +} + +impl MemoryRegion +where + A: Address, +{ + /// Create a new memory region starting at address `start`, spanning `len` + /// bytes. + pub fn new(start: A, len: usize) -> Self { + let end = A::from(start.bits() + len); + Self { start, end } + } + + /// Create a new memory region with overflow checks. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let start = VirtAddr::from(u64::MAX); + /// let region = MemoryRegion::checked_new(start, PAGE_SIZE); + /// assert!(region.is_none()); + /// ``` + pub fn checked_new(start: A, len: usize) -> Option { + let end = start.checked_add(len)?; + Some(Self { start, end }) + } + + /// Create a memory region from two raw addresses. + pub const fn from_addresses(start: A, end: A) -> Self { + Self { start, end } + } + + /// The base address of the memory region, originally set in + /// [`MemoryRegion::new()`]. + #[inline] + pub const fn start(&self) -> A { + self.start + } + + /// The length of the memory region in bytes, originally set in + /// [`MemoryRegion::new()`]. + #[inline] + pub fn len(&self) -> usize { + self.end.bits().saturating_sub(self.start.bits()) + } + + /// Returns whether the region spans any actual memory. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::utils::MemoryRegion; + /// let r = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), 0); + /// assert!(r.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The end address of the memory region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let base = VirtAddr::from(0xffffff0000u64); + /// let region = MemoryRegion::new(base, PAGE_SIZE); + /// assert_eq!(region.end(), VirtAddr::from(0xffffff1000u64)); + /// ``` + #[inline] + pub const fn end(&self) -> A { + self.end + } + + /// Checks whether two regions overlap. This does *not* include contiguous + /// regions, use [`MemoryRegion::contiguous()`] for that purpose. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff2000u64), PAGE_SIZE); + /// assert!(!r1.overlap(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(r1.overlap(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// // Contiguous regions do not overlap + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(!r1.overlap(&r2)); + /// ``` + pub fn overlap(&self, other: &Self) -> bool { + self.start() < other.end() && self.end() > other.start() + } + + /// Checks whether two regions are contiguous or overlapping. This is a + /// less strict check than [`MemoryRegion::overlap()`]. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(r1.contiguous(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff2000u64), PAGE_SIZE); + /// assert!(!r1.contiguous(&r2)); + /// ``` + pub fn contiguous(&self, other: &Self) -> bool { + self.start() <= other.end() && self.end() >= other.start() + } + + /// Merge two regions. It does not check whether the two regions are + /// contiguous in the first place, so the resulting region will cover + /// any non-overlapping memory between both. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// let r3 = r1.merge(&r2); + /// assert_eq!(r3.start(), r1.start()); + /// assert_eq!(r3.len(), r1.len() + r2.len()); + /// assert_eq!(r3.end(), r2.end()); + /// ``` + pub fn merge(&self, other: &Self) -> Self { + let start = self.start.min(other.start); + let end = self.end().max(other.end()); + Self { start, end } + } + + /// Iterate over the addresses covering the memory region in jumps of the + /// specified page size. Note that if the base address of the region is not + /// page aligned, returned addresses will not be aligned either. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// let mut iter = region.iter_pages(PageSize::Regular); + /// assert_eq!(iter.next(), Some(VirtAddr::from(0xffffff0000u64))); + /// assert_eq!(iter.next(), Some(VirtAddr::from(0xffffff1000u64))); + /// assert_eq!(iter.next(), None); + /// ``` + pub fn iter_pages(&self, size: PageSize) -> impl Iterator { + let size = usize::from(size); + (self.start().bits()..self.end().bits()) + .step_by(size) + .map(A::from) + } + + /// Check whether an address is within this region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// assert!(region.contains(VirtAddr::from(0xffffff0000u64))); + /// assert!(region.contains(VirtAddr::from(0xffffff0fffu64))); + /// assert!(!region.contains(VirtAddr::from(0xffffff1000u64))); + /// ``` + pub fn contains(&self, addr: A) -> bool { + self.start() <= addr && addr < self.end() + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1f90a64d0..866b5a1eb 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -6,6 +6,8 @@ pub mod bitmap_allocator; pub mod immut_after_init; +pub mod memory_region; pub mod util; +pub use memory_region::MemoryRegion; pub use util::{align_down, align_up, halt, overlap, page_align_up, page_offset, zero_mem_region};