diff --git a/examples/fib-vm.rs b/examples/fib-vm.rs index bc66cdb..8271ffe 100644 --- a/examples/fib-vm.rs +++ b/examples/fib-vm.rs @@ -52,7 +52,7 @@ fn main() { let mut context = VmContext::new(&vm, &mut guard); context - .declare_function(dbg!(Function::new("fib").when(1, fib))) + .declare_function(Function::new("fib").when(1, fib)) .unwrap(); dbg!(context.execute(&code).unwrap()); } diff --git a/src/compiler.rs b/src/compiler.rs index df7c22f..c9c13b9 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -23,7 +23,8 @@ use syntax::{ use crate::runtime::symbol::Symbol; use crate::vm::bitcode::{ - BinaryKind, BitcodeBlock, BitcodeFunction, FaultKind, Label, Op, OpDestination, ValueOrSource, + Access, BinaryKind, BitcodeBlock, BitcodeFunction, FaultKind, Label, Op, OpDestination, + ValueOrSource, }; use crate::vm::{Code, Register, Stack}; @@ -471,6 +472,10 @@ impl<'a> Scope<'a> { } } + fn is_module_root(&self) -> bool { + self.module && self.depth == 0 + } + fn function_root(compiler: &'a mut Compiler) -> Self { compiler.scopes.push(ScopeInfo { kind: ScopeKind::Function, @@ -770,10 +775,17 @@ impl<'a> Scope<'a> { self.compiler .code .load_module(instance, OpDestination::Stack(stack)); - if module.publish.is_some() { + if self.is_module_root() || module.publish.is_some() { self.ensure_in_module(module.name.1); - self.compiler.code.declare(name.clone(), false, stack, dest); + let access = if module.publish.is_some() { + Access::Public + } else { + Access::Private + }; + self.compiler + .code + .declare(name.clone(), false, access, stack, dest); } else { self.declare_local(name.clone(), false, stack, dest); } @@ -958,22 +970,28 @@ impl<'a> Scope<'a> { } } - match (&decl.name, decl.publish.is_some()) { - (Some(name), true) => { - self.ensure_in_module(name.1); - self.compiler.code.declare(name.0.clone(), false, fun, dest); + match (&decl.name, decl.publish.is_some(), self.is_module_root()) { + (Some(name), true, _) | (Some(name), _, true) => { + let access = if decl.publish.is_some() { + Access::Public + } else { + Access::Private + }; + self.compiler + .code + .declare(name.0.clone(), false, access, fun, dest); } - (Some(name), false) => { + (Some(name), false, _) => { let stack = self.new_temporary(); self.compiler.code.copy(fun, stack); self.declare_local(name.0.clone(), false, stack, dest); } - (None, true) => { + (None, true, _) => { self.compiler .errors .push(Ranged::new(range, Error::PublicFunctionRequiresName)); } - (None, false) => { + (None, false, _) => { self.compiler.code.copy(fun, dest); } } @@ -1034,12 +1052,18 @@ impl<'a> Scope<'a> { self.compiler.code.invoke(matches, Symbol::get_symbol(), 1); self.compiler.code.copy(Register(0), variable); - if bindings.publish { + if self.is_module_root() || bindings.publish { self.ensure_in_module(range); + let access = if bindings.publish { + Access::Public + } else { + Access::Private + }; self.compiler.code.declare( name.clone(), bindings.mutable, + access, variable, (), ); @@ -1090,10 +1114,19 @@ impl<'a> Scope<'a> { if let Some(name) = name { self.check_bound_name(Ranged::new(pattern.range(), name.clone()), bindings); - if bindings.publish { - self.compiler - .code - .declare(name.clone(), bindings.mutable, source, ()); + if self.is_module_root() || bindings.publish { + let access = if bindings.publish { + Access::Public + } else { + Access::Private + }; + self.compiler.code.declare( + name.clone(), + bindings.mutable, + access, + source, + (), + ); } else { let stack = self.new_temporary(); self.compiler.code.copy(source, stack); @@ -2161,6 +2194,7 @@ impl<'a> Scope<'a> { let base = self.compile_source(&base.expression); self.compile_function_args(&call.parameters.enclosed, arity); + self.compiler.code.set_current_source_range(range); self.compiler .code .invoke(base, lookup.name.0.clone(), arity); diff --git a/src/runtime/value.rs b/src/runtime/value.rs index 6cdb469..16d30ec 100644 --- a/src/runtime/value.rs +++ b/src/runtime/value.rs @@ -1684,6 +1684,15 @@ where } } +impl PartialEq for Dynamic +where + T: CustomType + Trace, +{ + fn eq(&self, other: &Self) -> bool { + self.0.as_any() == other.0.as_any() + } +} + impl From> for AnyRef where T: CustomType + Trace, diff --git a/src/tests.rs b/src/tests.rs index d4de13f..1c89132 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -81,7 +81,7 @@ fn module_budgeting() { } println!("Executed in {ops} steps"); assert!(ops > 6); - assert!(ops < MAX_OPS); + assert!(ops <= MAX_OPS); } #[test] diff --git a/src/vm.rs b/src/vm.rs index 3d25078..47f9811 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -26,7 +26,6 @@ use std::{array, task}; use ahash::AHashMap; use crossbeam_utils::sync::{Parker, Unparker}; -use kempt::map::Entry; use kempt::Map; use parking_lot::{Mutex, MutexGuard}; use refuse::{CollectionGuard, ContainsNoRefs, NoMapping, Root, Trace}; @@ -35,7 +34,7 @@ use serde::{Deserialize, Serialize}; #[cfg(not(feature = "dispatched"))] use self::bitcode::trusted_loaded_source_to_value; use self::bitcode::{ - BinaryKind, BitcodeFunction, FaultKind, Label, Op, OpDestination, ValueOrSource, + Access, BinaryKind, BitcodeFunction, FaultKind, Label, Op, OpDestination, ValueOrSource, }; use crate::compiler::syntax::token::RegexLiteral; use crate::compiler::syntax::{BitwiseKind, CompareKind, SourceCode, SourceRange}; @@ -277,7 +276,7 @@ impl Vm { value: Value, guard: &mut CollectionGuard<'_>, ) -> Result, Fault> { - VmContext::new(self, guard).declare_inner(name, value, false) + VmContext::new(self, guard).declare(name, value) } /// Declares an mutable variable with `name` containing `value`. @@ -442,6 +441,16 @@ impl<'context, 'guard> VmContext<'context, 'guard> { &mut self.vm } + /// Returns the access to allow the caller of the current function. + pub fn caller_access_level(&self, module: &Dynamic) -> Access { + let current_module = &self.modules[self.frames[self.current_frame].module]; + if current_module == module { + Access::Private + } else { + Access::Public + } + } + fn budget_and_yield(&mut self) -> Result<(), Fault> { let next_count = self.counter - 1; if next_count > 0 { @@ -589,18 +598,25 @@ impl<'context, 'guard> VmContext<'context, 'guard> { ) -> Result { let arity = params.load(self)?; - let mut module_dynamic = self.modules[0] + let mut module_dynamic = self.modules[self.frames[self.current_frame].module] .as_rooted(self.guard) .expect("module missing"); let mut module_declarations = module_dynamic.declarations(); let function = if let Some(decl) = module_declarations.get(name) { - decl.value + if decl.access == Access::Public { + decl.value + } else { + return Err(ExecutionError::new(Fault::UnknownSymbol, self)); + } } else { let name = name.try_load(self.guard)?; let mut parts = name.split('.').peekable(); while let Some(part) = parts.next() { let part = SymbolRef::from(part); - let Some(decl) = module_declarations.get(&part).map(|decl| decl.value) else { + let Some(decl) = module_declarations + .get(&part) + .and_then(|decl| (decl.access == Access::Public).then_some(decl.value)) + else { break; }; if parts.peek().is_some() { @@ -786,7 +802,7 @@ impl<'context, 'guard> VmContext<'context, 'guard> { name: impl Into, value: Value, ) -> Result, Fault> { - self.declare_inner(name, value, false) + self.declare_inner(name, value, false, Access::Public) } /// Declares an mutable variable with `name` containing `value`. @@ -795,7 +811,7 @@ impl<'context, 'guard> VmContext<'context, 'guard> { name: impl Into, value: Value, ) -> Result, Fault> { - self.declare_inner(name, value, true) + self.declare_inner(name, value, true, Access::Public) } fn declare_inner( @@ -803,22 +819,21 @@ impl<'context, 'guard> VmContext<'context, 'guard> { name: impl Into, value: Value, mutable: bool, + access: Access, ) -> Result, Fault> { - match self.modules[self.frames[self.current_frame].module] + Ok(self.modules[self.frames[self.current_frame].module] .load(self.guard) .ok_or(Fault::ValueFreed)? .declarations() - .entry(name.into()) - { - Entry::Occupied(mut field) if field.mutable => { - Ok(Some(std::mem::replace(&mut field.value, value))) - } - Entry::Occupied(_) => Err(Fault::NotMutable), - Entry::Vacant(entry) => { - entry.insert(ModuleDeclaration { mutable, value }); - Ok(None) - } - } + .insert( + name.into(), + ModuleDeclaration { + mutable, + value, + access, + }, + ) + .map(|d| d.value.value)) } /// Declares a compiled function. @@ -831,7 +846,7 @@ impl<'context, 'guard> VmContext<'context, 'guard> { }; function.module = Some(0); - self.declare_inner(name, Value::dynamic(function, &self), true) + self.declare_inner(name, Value::dynamic(function, &self), true, Access::Public) } /// Resolves the value at `path`. @@ -853,10 +868,15 @@ impl<'context, 'guard> VmContext<'context, 'guard> { let name = Symbol::from(name); if path.peek().is_some() { let declarations = module.declarations(); - let value = &declarations + let decl = &declarations .get(&name.downgrade()) - .ok_or(Fault::UnknownSymbol)? - .value; + .ok_or(Fault::UnknownSymbol)?; + let value = if decl.access >= self.caller_access_level(&module_dynamic) { + decl.value + } else { + return Err(Fault::UnknownSymbol); // TODO accessd error + }; + let Some(inner) = value.as_dynamic::() else { return Err(Fault::NotAModule); }; @@ -919,6 +939,8 @@ impl<'context, 'guard> VmContext<'context, 'guard> { let module = &vm.modules[vm.frames[vm.current_frame].module]; if let Some(decl) = module.try_load(self.guard)?.declarations().get_mut(name) { if decl.mutable { + let name = name.try_load(self.guard()).expect("missing symbol"); + println!("Set {name} from {:?} to {value:?}", decl.value); decl.value = value; Ok(()) } else { @@ -1169,9 +1191,10 @@ impl VmContext<'_, '_> { LoadedOp::Declare { name, mutable, + access, value, dest, - } => self.op_declare(code_index, *name, *mutable, *value, *dest), + } => self.op_declare(code_index, *name, *mutable, *access, *value, *dest), LoadedOp::Call { name, arity } => self.op_call(code_index, *name, *arity), LoadedOp::Invoke { target, @@ -1395,6 +1418,7 @@ impl VmContext<'_, '_> { code_index: usize, name: usize, mutable: bool, + access: Access, value: LoadedSource, dest: OpDestination, ) -> Result<(), Fault> { @@ -1408,7 +1432,7 @@ impl VmContext<'_, '_> { .ok_or(Fault::InvalidOpcode)?; self.op_store(code_index, value, dest)?; - self.declare_inner(name, value, mutable)?; + self.declare_inner(name, value, mutable, access)?; Ok(()) } @@ -1477,6 +1501,7 @@ impl VmContext<'_, '_> { }; let resolved = self.resolve(&name)?; + println!("Resolved {name} to {resolved:?}"); self.op_store(code_index, resolved, dest) } @@ -2127,6 +2152,7 @@ impl CodeData { Op::Declare { name, mutable, + access, value, dest, } => { @@ -2137,6 +2163,7 @@ impl CodeData { LoadedOp::Declare { name, mutable: *mutable, + access: *access, value, dest, }, @@ -2366,6 +2393,7 @@ impl Module { ModuleDeclaration { mutable: false, value: Value::dynamic(core, guard), + access: Access::Public, }, ); @@ -2382,6 +2410,7 @@ impl Module { ModuleDeclaration { mutable: false, value: Value::Dynamic(crate::runtime::map::MAP_TYPE.as_any_dynamic()), + access: Access::Public, }, ); declarations.insert( @@ -2389,6 +2418,7 @@ impl Module { ModuleDeclaration { mutable: false, value: Value::Dynamic(crate::runtime::list::LIST_TYPE.as_any_dynamic()), + access: Access::Public, }, ); declarations.insert( @@ -2396,6 +2426,7 @@ impl Module { ModuleDeclaration { mutable: false, value: Value::Dynamic(crate::runtime::string::STRING_TYPE.as_any_dynamic()), + access: Access::Public, }, ); drop(declarations); @@ -2422,7 +2453,12 @@ impl CustomType for Module { let value = vm[Register(1)].take(); match this.declarations().get_mut(sym) { - Some(decl) if decl.mutable => { + Some(decl) + if decl.mutable + && decl.access + >= vm + .caller_access_level(&this.downgrade()) => + { Ok(std::mem::replace(&mut decl.value, value)) } Some(_) => Err(Fault::NotMutable), @@ -2433,17 +2469,24 @@ impl CustomType for Module { let field = vm[Register(0)].take(); let sym = field.as_symbol_ref().ok_or(Fault::ExpectedSymbol)?; - this.declarations() - .get(sym) - .map(|decl| decl.value) - .ok_or(Fault::UnknownSymbol) + let declarations = this.declarations(); + let decl = declarations.get(sym).ok_or(Fault::UnknownSymbol)?; + if decl.access >= vm.caller_access_level(&this.downgrade()) { + Ok(decl.value) + } else { + Err(Fault::UnknownSymbol) + } }) }); let declarations = this.declarations(); if let Some(decl) = declarations.get(name) { - let possible_invoke = decl.value; - drop(declarations); - possible_invoke.call(vm, arity) + if decl.access >= vm.caller_access_level(&this.downgrade()) { + let possible_invoke = decl.value; + drop(declarations); + possible_invoke.call(vm, arity) + } else { + Err(Fault::UnknownSymbol) + } } else { drop(declarations); FUNCTIONS.invoke(vm, name, arity, &this) @@ -2472,6 +2515,7 @@ impl Trace for Module { #[derive(Debug)] struct ModuleDeclaration { + access: Access, mutable: bool, value: Value, } @@ -2530,6 +2574,7 @@ enum LoadedOp { Declare { name: usize, mutable: bool, + access: Access, value: LoadedSource, dest: OpDestination, }, diff --git a/src/vm/bitcode.rs b/src/vm/bitcode.rs index 2d4db0e..fb57ef4 100644 --- a/src/vm/bitcode.rs +++ b/src/vm/bitcode.rs @@ -118,6 +118,15 @@ impl_from!(OpDestination, Register, Register); impl_from!(OpDestination, Stack, Stack); impl_from!(OpDestination, Label, Label); +/// The level of access of a member. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)] +pub enum Access { + /// The member is accessible by any code. + Private, + /// The member is only accessible to code in the same module. + Public, +} + /// A virtual machine operation. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Op { @@ -138,6 +147,8 @@ pub enum Op { name: Symbol, /// If true, the value will be able to be updated with an assignment. mutable: bool, + /// The access level to allow for this declaration. + access: Access, /// The initial value of the declaration. value: ValueOrSource, /// The destination to store a copy of `value`. @@ -243,11 +254,13 @@ impl BitcodeBlock { LoadedOp::Declare { name, mutable, + access, value, dest, } => Op::Declare { name: code.data.symbols[*name].clone(), mutable: *mutable, + access: *access, value: trusted_loaded_source_to_value(value, &code.data), dest: *dest, }, @@ -386,6 +399,7 @@ impl BitcodeBlock { &mut self, name: Symbol, mutable: bool, + access: Access, value: impl Into, dest: impl Into, ) { @@ -393,6 +407,7 @@ impl BitcodeBlock { self.push(Op::Declare { name, mutable, + access, value, dest: dest.into(), }); diff --git a/src/vm/dispatched.rs b/src/vm/dispatched.rs index 7d90cd4..31b5b5a 100644 --- a/src/vm/dispatched.rs +++ b/src/vm/dispatched.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use refuse::CollectionGuard; use super::bitcode::{ - trusted_loaded_source_to_value, BinaryKind, BitcodeFunction, FaultKind, Label, Op, + trusted_loaded_source_to_value, Access, BinaryKind, BitcodeFunction, FaultKind, Label, Op, OpDestination, ValueOrSource, }; use super::{ @@ -36,6 +36,7 @@ impl CodeData { LoadedOp::Declare { name, mutable, + access, value, dest, } => { @@ -47,6 +48,7 @@ impl CodeData { guard, &name, mutable, + access, ); } LoadedOp::Truthy(loaded) => match_truthy( @@ -1213,7 +1215,7 @@ decode_sd_simple!(match_logical_not, compile_logical_not, LogicalNot); decode_sd_simple!(match_bitwise_not, compile_bitwise_not, BitwiseNot); decode_sd_simple!(match_negate, compile_negate, Negate); -decode_sd!(match_declare_function, compile_declare_function, name: &Symbol, mutable: bool); +decode_sd!(match_declare_function, compile_declare_function, name: &Symbol, mutable: bool, access: Access); fn compile_declare_function( _dest: &OpDestination, @@ -1222,6 +1224,7 @@ fn compile_declare_function( f: Value, name: &Symbol, mutable: bool, + access: Access, dest: Dest, ) where Value: Source, @@ -1230,6 +1233,7 @@ fn compile_declare_function( code.push_dispatched(Declare { name: name.clone(), mutable, + access, declaration: f, dest, }); @@ -1526,6 +1530,7 @@ where vm.frames[vm.current_frame].module = module_index.get(); vm.frames[executing_frame].loading_module = Some(module_index); let _init_result = context.resume_async_inner(context.current_frame)?; + context.vm.frames[executing_frame].loading_module = None; module_index }; @@ -1548,6 +1553,7 @@ where struct Declare { name: Symbol, mutable: bool, + access: Access, declaration: Value, dest: Dest, } @@ -1559,7 +1565,7 @@ where { fn execute(&self, vm: &mut VmContext<'_, '_>) -> Result, Fault> { let value = self.declaration.load(vm)?; - vm.declare_inner(self.name.downgrade(), value, self.mutable)?; + vm.declare_inner(self.name.downgrade(), value, self.mutable, self.access)?; self.dest.store(vm, value)?; @@ -1570,6 +1576,7 @@ where Op::Declare { name: self.name.clone(), mutable: self.mutable, + access: self.access, value: self.declaration.as_source(guard), dest: self.dest.as_dest(), }