summaryrefslogtreecommitdiff
path: root/scripts/src/wasm.rs
blob: 7f8f0b51b0e71317b68cf112345df03562bb7181 (plain)
use std::path::Path;

use thiserror::Error;
use wasmtime::{Engine, Instance, Linker, Module, Mutability, Store, ValType};

// It's a bad idea to have the memory allocator return a number that could also
// indicate an error, so this is set to be at least eight.
const MIN_BUMP_POINTER: u32 = 8;

/// Information about the script which can be used by functions it calls
pub struct WasmScriptState {
	pub bump_pointer: u32,
	pub trusted: bool,
}

impl WasmScriptState {
	pub const fn new(trusted: bool) -> Self {
		Self {
			bump_pointer: MIN_BUMP_POINTER,
			trusted,
		}
	}
}

/// A script, its path in the filesystem, and some metadata
pub struct WasmScript {
	path: Box<Path>,
	module: Module,
	store: Store<WasmScriptState>,
	instance: Instance,
	trusted: bool,
	//state: Value,
}

#[derive(Debug, Error)]
pub enum InvalidWasmScript {
	#[error("There is no exported memory called 'memory', which is required")]
	NoExportedMemory,
	#[error("The exported symbol, 'memory' must be a memory")]
	MemoryIsNotAMemory,
	#[error("The memory must be 32-bit, not 64-bit")]
	MemoryTooLarge,
	#[error("There is no exported global called '__heap_base', which is required")]
	NoHeapBase,
	#[error("The exported symbol, '__heap_base' must be a constant global")]
	HeapBaseIsNotGlobal,
	#[error("The exported global, '__heap_base' must be an i32")]
	HeapBaseMustBeI32,
	#[error("The exported global, '__heap_base' must be a constant")]
	HeapBaseMustBeConstant,
	#[error("{}", .0)]
	CompilerError(#[from] wasmtime::Error),
}

/// Confirms that the module can be used as a script
fn validate_module(module: &Module) -> Result<(), InvalidWasmScript> {
	// verify that memory is exported from this module and is valid
	let Some(export) = module.get_export("memory") else {
		return Err(InvalidWasmScript::NoExportedMemory);
	};
	let Some(memory) = export.memory() else {
		return Err(InvalidWasmScript::MemoryIsNotAMemory);
	};
	if memory.is_64() {
		return Err(InvalidWasmScript::MemoryTooLarge);
	}

	// verify __heap_base global
	let Some(export) = module.get_export("__heap_base") else {
		return Err(InvalidWasmScript::NoHeapBase);
	};
	let Some(heap_base) = export.global() else {
		return Err(InvalidWasmScript::HeapBaseIsNotGlobal);
	};
	if heap_base.content().matches(&ValType::I32) {
		return Err(InvalidWasmScript::HeapBaseMustBeI32);
	}
	if heap_base.mutability() != Mutability::Const {
		return Err(InvalidWasmScript::HeapBaseMustBeConstant);
	}

	Ok(())
}

impl WasmScript {
	pub fn new(
		path: &Path,
		engine: &Engine,
		linker: &Linker<WasmScriptState>,
		trusted: bool,
	) -> Result<Self, InvalidWasmScript> {
		let module = Module::from_file(engine, path)?;
		validate_module(&module)?;
		let mut store = Store::new(engine, WasmScriptState::new(trusted));
		let instance = linker.instantiate(&mut store, &module)?;

		Ok(Self {
			path: path.into(),
			module,
			store,
			instance,
			trusted,
		})
	}

	/// Reload from the filesystem
	pub fn reload(
		&mut self,
		engine: &Engine,
		linker: &Linker<WasmScriptState>,
	) -> Result<(), InvalidWasmScript> {
		let module = Module::from_file(engine, &self.path)?;
		validate_module(&module)?;
		self.store = Store::new(engine, WasmScriptState::new(self.trusted));
		self.instance = linker.instantiate(&mut self.store, &module)?;

		Ok(())
	}

	/// Re-links the module. This doesn't load the module from the filesystem.
	pub fn relink(
		&mut self,
		engine: &Engine,
		linker: &Linker<WasmScriptState>,
	) -> Result<(), InvalidWasmScript> {
		self.store = Store::new(engine, WasmScriptState::new(self.trusted));
		self.instance = linker.instantiate(&mut self.store, &self.module)?;
		Ok(())
	}

	pub fn is_trusted(&self) -> bool {
		self.trusted
	}

	pub fn trust(&mut self) {
		self.trusted = true;
	}

	pub fn untrust(&mut self) {
		self.trusted = false;
	}

	fn run_function_if_exists(&mut self, name: &str) -> wasmtime::Result<()> {
		// set bump pointer to the start of the heap
		let heap_base = self
			.instance
			.get_global(&mut self.store, "__heap_base")
			.unwrap()
			.get(&mut self.store)
			.unwrap_i32();
		let bump_ptr = &mut self.store.data_mut().bump_pointer;
		*bump_ptr = (heap_base as u32).max(MIN_BUMP_POINTER);

		// call the given function
		let func = self.instance.get_func(&mut self.store, name);
		if let Some(func) = func {
			func.call(&mut self.store, &[], &mut [])?;
		}

		Ok(())
	}

	pub fn begin(&mut self) -> wasmtime::Result<()> {
		self.run_function_if_exists("begin")
	}

	pub fn update(&mut self) -> wasmtime::Result<()> {
		self.run_function_if_exists("update")
	}
}