summaryrefslogtreecommitdiff
path: root/src/lib.rs
blob: 527174618364f2a7180703ce42e4beb925243f42 (plain)
use std::{borrow::Cow, ops::Range, rc::Rc};

#[derive(Clone)]
enum CharSet {
	StaticFn(fn(Option<char>) -> bool),
	BoxedFn(Rc<dyn Fn(Option<char>) -> bool>),
}

#[derive(Clone)]
enum RegexToken {
	Pipe,
	Label(Rc<str>),
	Jump(Rc<str>),
	CharacterSet(CharSet),
	KleeneStar,
	Tilde,
	LeftParenthesis,
	RightParenthesis,
	LeftBracket,
	RightBracket,
}

#[derive(Clone)]
struct Union {
	left_bracket: RegexToken,
	charsets: Vec<CharSet>,
	right_bracket: Option<RegexToken>,
}

enum CsetType {
	Literal(CharSet),
	Union(Union),
}

struct CharSetNode {
	tilde: Option<RegexToken>,
	cset: CsetType,
	star: Option<RegexToken>,
}

struct Grouping {
	left_parenthesis: RegexToken,
	regex: Regex,
	right_parenthesis: Option<RegexToken>,
}

enum RegexNode {
	Grouping(Grouping),
	CharSetNode(CharSetNode),
}

struct Alternate {
	pipe: Option<RegexToken>,
	nodes: Vec<RegexNode>,
}

pub struct Regex {
	alternates: Vec<Alternate>,
}

pub struct Match<'h> {
	haystack: &'h str,
	start: usize,
	end: usize,
}

impl Match<'_> {
	pub fn start(&self) -> usize {
		self.start
	}

	pub fn end(&self) -> usize {
		self.end
	}

	pub fn is_empty(&self) -> bool {
		self.len() == 0
	}

	pub fn len(&self) -> usize {
		self.end - self.start
	}

	pub fn range(&self) -> Range<usize> {
		self.start..self.end
	}

	pub fn as_str(&self) -> &str {
		&self.haystack[self.range()]
	}
}

impl Regex {
	fn tokenize(mut regex: &str) -> Vec<RegexToken> {
		fn advance_if_starts_with(regex: &mut &str, pattern: &str) -> bool {
			if regex.starts_with(pattern) {
				*regex = regex.split_at(pattern.len()).1;
				true
			} else {
				false
			}
		}

		fn upto<'a>(regex: &mut &'a str, pattern: &str) -> Option<&'a str> {
			regex.split_once(pattern).map(|(consumed, rest)| {
				*regex = rest;
				consumed
			})
		}

		fn next_char(regex: &mut &str) -> Option<char> {
			let char = regex.chars().next()?;
			for i in 1..9 {
				if regex.is_char_boundary(i) {
					*regex = &regex[i..];
					return Some(char);
				}
			}

			None
		}

		let mut tokens = Vec::new();
		while !regex.is_empty() {
			if advance_if_starts_with(&mut regex, "|") {
				tokens.push(RegexToken::Pipe);
			} else if advance_if_starts_with(&mut regex, "<") {
				let label = upto(&mut regex, ">").expect("labels should terminate");
				tokens.push(RegexToken::Label(label.into()));
			} else if advance_if_starts_with(&mut regex, "@") {
				let label = upto(&mut regex, ">").expect("jumps should terminate");
				tokens.push(RegexToken::Jump(label.into()));
			} else if advance_if_starts_with(&mut regex, "~") {
				tokens.push(RegexToken::Tilde);
			} else if advance_if_starts_with(&mut regex, "[") {
				tokens.push(RegexToken::LeftBracket);
			} else if advance_if_starts_with(&mut regex, "]") {
				tokens.push(RegexToken::RightBracket);
			} else if advance_if_starts_with(&mut regex, "(") {
				tokens.push(RegexToken::LeftParenthesis);
			} else if advance_if_starts_with(&mut regex, ")") {
				tokens.push(RegexToken::RightParenthesis);
			} else if advance_if_starts_with(&mut regex, "*") {
				tokens.push(RegexToken::KleeneStar);
			} else if advance_if_starts_with(&mut regex, ".") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|_| true)));
			} else if advance_if_starts_with(&mut regex, "$") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| c.is_none())));
			} else if advance_if_starts_with(&mut regex, "\\d") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| {
					c.map(|c| c.is_ascii_digit()).unwrap_or(false)
				})))
			} else if advance_if_starts_with(&mut regex, "\\D") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| {
					c.map(|c| !c.is_ascii_digit()).unwrap_or(false)
				})))
			} else if advance_if_starts_with(&mut regex, "\\w") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| {
					c.map(|c| c.is_ascii_alphanumeric() || c == '_')
						.unwrap_or(false)
				})))
			} else if advance_if_starts_with(&mut regex, "\\W") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| {
					c.map(|c| !(c.is_ascii_alphanumeric() || c == '_'))
						.unwrap_or(false)
				})))
			} else if advance_if_starts_with(&mut regex, "\\s") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| {
					c.map(|c| c.is_ascii_whitespace()).unwrap_or(false)
				})))
			} else if advance_if_starts_with(&mut regex, "\\S") {
				tokens.push(RegexToken::CharacterSet(CharSet::StaticFn(|c| {
					c.map(|c| !c.is_ascii_whitespace()).unwrap_or(false)
				})))
			} else if advance_if_starts_with(&mut regex, "\\") {
				let char = next_char(&mut regex).expect("at least one character");
				tokens.push(RegexToken::CharacterSet(CharSet::BoxedFn(Rc::new(
					move |c| c.map(|c| c == char).unwrap_or(false),
				))));
			} else {
				let char = next_char(&mut regex).expect("at least one character");
				tokens.push(RegexToken::CharacterSet(CharSet::BoxedFn(Rc::new(
					move |c| c.map(|c| c == char).unwrap_or(false),
				))));
			}
		}

		tokens
	}

	fn peek_token(tokens: &[RegexToken]) -> Option<&RegexToken> {
		tokens.first()
	}

	fn next_token<'a>(tokens: &mut &'a [RegexToken]) -> Option<&'a RegexToken> {
		let token = tokens.first()?;
		*tokens = &tokens[1..];
		Some(token)
	}

	fn parse_if<T>(
		tokens: &mut &[RegexToken],
		predicate: impl Fn(&RegexToken) -> bool,
		parser: impl Fn(&mut &[RegexToken]) -> Result<T, Rc<str>>,
	) -> Option<Result<T, Rc<str>>> {
		if !tokens.is_empty() && predicate(&tokens[0]) {
			Some(parser(tokens))
		} else {
			None
		}
	}

	fn parse_while<T>(
		tokens: &mut &[RegexToken],
		predicate: impl Fn(&RegexToken) -> bool,
		parser: impl Fn(&mut &[RegexToken]) -> Result<T, Rc<str>>,
	) -> Result<Vec<T>, Rc<str>> {
		let mut result = Vec::new();
		while !tokens.is_empty() && predicate(&tokens[0]) {
			result.push(parser(tokens)?);
		}
		Ok(result)
	}

	fn parse_union(tokens: &mut &[RegexToken]) -> Result<Union, Rc<str>> {
		let Some(left_bracket) = Self::next_token(tokens).cloned() else {
			return Err("expected a left bracket".into());
		};
		let charsets = Self::parse_while(
			tokens,
			|token| matches!(token, RegexToken::CharacterSet(..)),
			|token| {
				if let Some(RegexToken::CharacterSet(set)) = token.first() {
					Ok(set.clone())
				} else {
					unreachable!("just saw a set")
				}
			},
		)?;
		let right_bracket = Self::next_token(tokens).cloned();

		Ok(Union {
			left_bracket,
			charsets,
			right_bracket,
		})
	}

	fn parse_cset_type(tokens: &mut &[RegexToken]) -> Result<CsetType, Rc<str>> {
		if let Some(cset) = Self::parse_if(
			tokens,
			|token| matches!(token, RegexToken::LeftBracket),
			Self::parse_union,
		) {
			Ok(CsetType::Union(cset?))
		} else if let Some(cset) = Self::parse_if(
			tokens,
			|token| matches!(token, RegexToken::CharacterSet(..)),
			|tokens| {
				if let Some(RegexToken::CharacterSet(cset)) = Self::next_token(tokens) {
					Ok(cset.clone())
				} else {
					unreachable!("just saw a char set")
				}
			},
		) {
			Ok(CsetType::Literal(cset?))
		} else {
			Err("invalid char set".into())
		}
	}

	fn parse_grouping(tokens: &mut &[RegexToken]) -> Result<Grouping, Rc<str>> {
		let left_parenthesis = Self::next_token(tokens)
			.expect("a left parenthesis")
			.clone();
		let regex = Self::parse_regex(tokens)?;
		let right_parenthesis = Self::next_token(tokens).cloned();

		Ok(Grouping {
			left_parenthesis,
			regex,
			right_parenthesis,
		})
	}

	fn parse_cset_node(tokens: &mut &[RegexToken]) -> Result<CharSetNode, Rc<str>> {
		let tilde = Self::parse_if(
			tokens,
			|token| matches!(token, RegexToken::Tilde),
			|tokens| {
				Self::next_token(tokens)
					.ok_or("bug: should have a tilde here".into())
					.cloned()
			},
		)
		.transpose()?;

		let cset = Self::parse_cset_type(tokens)?;

		let star = Self::parse_if(
			tokens,
			|token| matches!(token, RegexToken::KleeneStar),
			|tokens| {
				Ok(Self::next_token(tokens)
					.expect("at least one token")
					.clone())
			},
		)
		.transpose()?;

		Ok(CharSetNode { tilde, cset, star })
	}

	fn parse_node(tokens: &mut &[RegexToken]) -> Result<RegexNode, Rc<str>> {
		let peek = Self::peek_token(tokens).ok_or_else(|| "expected a token".to_string())?;
		if matches!(peek, RegexToken::LeftParenthesis) {
			Ok(RegexNode::Grouping(Self::parse_grouping(tokens)?))
		} else {
			Ok(RegexNode::CharSetNode(Self::parse_cset_node(tokens)?))
		}
	}

	fn parse_alternate(tokens: &mut &[RegexToken]) -> Result<Alternate, Rc<str>> {
		let pipe = Self::parse_if(
			tokens,
			|token| matches!(token, RegexToken::Pipe),
			|tokens| Self::next_token(tokens).cloned().ok_or(Rc::<str>::from("")),
		)
		.transpose()?;
		let nodes = Self::parse_while(
			tokens,
			|token| {
				!matches!(
					token,
					RegexToken::Pipe
						| RegexToken::RightParenthesis
						| RegexToken::KleeneStar
						| RegexToken::RightBracket
				)
			},
			Self::parse_node,
		)?;

		Ok(Alternate { pipe, nodes })
	}

	fn parse_regex(tokens: &mut &[RegexToken]) -> Result<Regex, Rc<str>> {
		let alternates = Self::parse_while(
			tokens,
			|token| {
				!matches!(
					token,
					RegexToken::RightParenthesis
						| RegexToken::KleeneStar
						| RegexToken::RightBracket
				)
			},
			Self::parse_alternate,
		)?;
		Ok(Regex { alternates })
	}

	fn find_cset(cset: &CharSet, input: &str) -> bool {
		let char = input.chars().next();
		match cset {
			CharSet::StaticFn(func) => func(char),
			CharSet::BoxedFn(func) => func(char),
		}
	}

	fn find_union(union: &Union, input: &str) -> bool {
		for cset in &union.charsets {
			if Self::find_cset(cset, input) {
				return true;
			}
		}

		false
	}

	fn find_cset_node(cset: &CharSetNode, input: &str) -> bool {
		let mut is_match = false;
		loop {
			let match_found = match &cset.cset {
				CsetType::Literal(cset) => Self::find_cset(cset, input),
				CsetType::Union(cset) => Self::find_union(cset, input),
			};

			if match_found {
				is_match = true;
			} else {
				break;
			}

			if cset.star.is_none() {
				break;
			}
		}

		if cset.tilde.is_some() {
			is_match = !is_match;
		}

		is_match
	}

	fn find_node(node: &RegexNode, input: &str) -> Option<usize> {
		match node {
			RegexNode::CharSetNode(cset) => Self::find_cset_node(cset, input)
				.then(|| input.chars().next().map(char::len_utf8))
				.flatten(),
			RegexNode::Grouping(grouping) => Self::find(&grouping.regex, input).map(|m| m.len()),
		}
	}

	fn find_alternate<'s>(alternate: &Alternate, mut input: &'s str) -> Option<Match<'s>> {
		let original = input;
		for node in &alternate.nodes {
			if let Some(match_len) = Self::find_node(node, input) {
				input = &input[match_len..];
			} else {
				return None;
			}
		}

		Some(Match {
			haystack: original,
			start: 0,
			end: original.len() - input.len(),
		})
	}

	pub fn new(regex: &str) -> Result<Self, Rc<str>> {
		let tokens = Self::tokenize(regex);
		let mut tokens = tokens.as_slice();
		Self::parse_regex(&mut tokens)
	}

	pub fn find<'s>(&self, input: &'s str) -> Option<Match<'s>> {
		for start in 0..input.len() {
			if let Some(match_str) = self.find_at(input, start) {
				return Some(match_str);
			}
		}

		None
	}

	pub fn find_all<'s>(&self, input: &'s str) -> impl IntoIterator<Item = Match<'s>> {
		let mut matches = Vec::new();
		let mut i = 0;
		while i < input.len() {
			if let Some(match_str) = self.find_at(input, i) {
				i = match_str.end;
				matches.push(match_str);
			} else {
				i += 1;
			}
		}

		matches
	}

	pub fn find_at<'s>(&self, input: &'s str, start: usize) -> Option<Match<'s>> {
		for alternate in &self.alternates {
			if let Some(match_str) = Self::find_alternate(alternate, &input[start..]) {
				return Some(Match {
					haystack: input,
					start,
					end: start + match_str.len(),
				});
			}
		}

		None
	}

	pub fn is_match_at(&self, input: &str, start: usize) -> bool {
		self.find_at(input, start).is_some()
	}

	pub fn is_match(&self, input: &str) -> bool {
		self.find(input).is_some()
	}

	pub fn replace<'s>(&self, input: &'s str, replacement: impl AsRef<str>) -> Cow<'s, str> {
		if let Some(match_str) = self.find(input) {
			input
				.replace(match_str.as_str(), replacement.as_ref())
				.into()
		} else {
			Cow::Borrowed(input)
		}
	}

	pub fn replace_all<'s>(&self, input: &'s str, replacement: impl AsRef<str>) -> Cow<'s, str> {
		let mut output = Cow::Borrowed(input);
		for match_str in self.find_all(input) {
			output = output
				.replace(match_str.as_str(), replacement.as_ref())
				.into();
		}

		output
	}
}