summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike White <botahamec@outlook.com>2021-09-06 08:02:51 -0400
committerMike White <botahamec@outlook.com>2021-09-06 08:02:51 -0400
commit744050c4b4747ac4645480e2f4a935a027b8350f (patch)
tree07d3bd4a73e08cf00379a2df3f0165f4e856e279
parent3df9c8a9a2f9274e863785fc9a7b196fe20ee87d (diff)
Alpha-beta pruning
-rw-r--r--ai/src/lib.rs13
-rw-r--r--cli/src/eval.rs4
-rw-r--r--cli/src/main.rs13
3 files changed, 23 insertions, 7 deletions
diff --git a/ai/src/lib.rs b/ai/src/lib.rs
index 583c50b..405c88d 100644
--- a/ai/src/lib.rs
+++ b/ai/src/lib.rs
@@ -26,7 +26,7 @@ fn eval_position(board: CheckersBitBoard) -> f32 {
}
}
-pub fn eval(depth: usize, board: CheckersBitBoard) -> f32 {
+pub fn eval(depth: usize, mut alpha: f32, beta: f32, board: CheckersBitBoard) -> f32 {
if depth == 0 {
eval_position(board)
} else {
@@ -35,14 +35,21 @@ pub fn eval(depth: usize, board: CheckersBitBoard) -> f32 {
for current_move in PossibleMoves::moves(board) {
let board = unsafe { current_move.apply_to(board) };
let current_eval = if board.turn() != turn {
- 1.0 - eval(depth - 1, board)
+ 1.0 - eval(depth - 1, 1.0 - beta, 1.0 - alpha, board)
} else {
- eval(depth - 1, board)
+ eval(depth - 1, alpha, beta, board)
};
+ if current_eval >= beta {
+ return beta;
+ }
+
if best_eval < current_eval {
best_eval = current_eval;
}
+ if alpha < best_eval {
+ alpha = best_eval;
+ }
}
best_eval
diff --git a/cli/src/eval.rs b/cli/src/eval.rs
index eaa2d41..d078c5f 100644
--- a/cli/src/eval.rs
+++ b/cli/src/eval.rs
@@ -1,4 +1,4 @@
use ai::CheckersBitBoard;
-pub fn eval() -> f32 {
- ai::eval(12, CheckersBitBoard::starting_position())
+pub fn eval(depth: usize) -> f32 {
+ ai::eval(depth, 0.0, 1.0, CheckersBitBoard::starting_position())
}
diff --git a/cli/src/main.rs b/cli/src/main.rs
index 57991f0..a550092 100644
--- a/cli/src/main.rs
+++ b/cli/src/main.rs
@@ -47,7 +47,16 @@ fn main() {
);
}
- if let Some(_matches) = matches.subcommand_matches("eval") {
- println!("{}", eval::eval());
+ if let Some(matches) = matches.subcommand_matches("eval") {
+ println!(
+ "{}",
+ eval::eval(
+ matches
+ .value_of("depth")
+ .unwrap()
+ .parse::<usize>()
+ .expect("Error: not a valid number")
+ )
+ );
}
}