1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
use crate::node::Node;
use crate::priority_queue::{PriorityQueue, State};
use std::collections::HashMap;
use std::fmt;
/// Error types that can occur during A* pathfinding.
pub enum AStarError {
StartNodeBlocked((usize, usize)),
GoalNodeBlocked((usize, usize)),
NodeNotFound((usize, usize)),
PathNotFound((usize, usize)),
}
/// Structure implementing the A* algorithm.
#[derive(Debug)]
pub struct AStar {
nodes: HashMap<(usize, usize), Node>,
open_set: PriorityQueue,
came_from: HashMap<(usize, usize), (usize, usize)>,
g_score: HashMap<(usize, usize), usize>,
f_score: HashMap<(usize, usize), usize>,
}
impl AStar {
/// Creates a new `AStar` instance with the provided nodes.
///
/// # Parameters
/// - `nodes`: A map of nodes where the keys are positions (x, y) and the values are `Node` objects.
///
/// # Returns
/// A new `AStar` instance initialized with the provided nodes.
///
/// # Example
/// ```rust
/// let nodes = HashMap::new();
/// let astar = AStar::new(nodes);
/// ```
pub fn new(nodes: HashMap<(usize, usize), Node>) -> Self {
AStar {
nodes,
open_set: PriorityQueue::new(),
came_from: HashMap::new(),
g_score: HashMap::new(),
f_score: HashMap::new(),
}
}
/// Calculates the Manhattan distance between two points.
///
/// The Manhattan distance is the sum of the absolute differences of their Cartesian coordinates.
///
/// # Parameters
/// - `start`: The starting point as a tuple (x, y).
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// The Manhattan distance between the start and goal points as a `usize`.
fn manhattan_distance(start: (usize, usize), goal: (usize, usize)) -> usize {
let (x1, y1) = start;
let (x2, y2) = goal;
// Compute the absolute differences in x and y coordinates and sum them
(x1 as isize - x2 as isize).abs() as usize + (y1 as isize - y2 as isize).abs() as usize
}
/// Initializes the `g_score` and `f_score` maps for the A* algorithm.
///
/// The `g_score` map stores the cost of the shortest path from the start node to each node.
/// The `f_score` map estimates the total cost of the shortest path through each node.
///
/// # Parameters
/// - `start`: The starting point as a tuple (x, y).
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// A tuple containing the initialized `g_score` and `f_score` maps.
fn initialize_scores(
start: (usize, usize),
goal: (usize, usize),
) -> (
HashMap<(usize, usize), usize>,
HashMap<(usize, usize), usize>,
) {
let mut g_score = HashMap::new();
let mut f_score = HashMap::new();
// Set the starting point's g_score to 0
g_score.insert(start, 0);
// Compute the initial f_score as the Manhattan distance from the start to the goal
f_score.insert(start, Self::manhattan_distance(start, goal));
(g_score, f_score)
}
/// Reconstructs the path from the start node to the goal node using the `came_from` map.
///
/// The path is reconstructed by tracing back from the goal node to the start node.
///
/// # Parameters
/// - `came_from`: A map indicating the parent of each node.
/// - `start`: The starting point as a tuple (x, y).
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// A vector of tuples representing the path from the start to the goal.
fn reconstruct_path(
came_from: HashMap<(usize, usize), (usize, usize)>,
start: (usize, usize),
goal: (usize, usize),
) -> Vec<(usize, usize)> {
let mut path = Vec::new();
let mut current = goal;
// Trace the path from the goal to the start
while current != start {
path.push((current.1, current.0));
current = came_from[¤t];
}
// Add the start point and reverse the path to get it from start to goal
path.push((start.1, start.0));
path.reverse();
path
}
/// Finds the neighbors of the current node that are not blocked.
///
/// # Parameters
/// - `current_node`: The current node from which to find neighbors.
///
/// # Returns
/// A vector of positions representing the neighbors of the current node.
fn find_neighbors(&self, current_node: &Node) -> Vec<(usize, usize)> {
current_node
.neighbors
.values()
.filter_map(|&neighbor_pos| {
neighbor_pos.and_then(|pos| {
// Check if the neighbor is not blocked and exists in the nodes map
if let Some(neighbor_node) = self.nodes.get(&pos) {
if !neighbor_node.is_blocked {
return Some(pos);
}
}
None
})
})
.collect()
}
/// Calculates and updates the `g_score` and `f_score` for a neighbor node.
///
/// The `g_score` represents the cost of the shortest path from the start node to the neighbor node.
/// The `f_score` is the estimated total cost from the start node to the goal node through the neighbor node.
///
/// # Parameters
/// - `current_position`: The position of the current node.
/// - `neighbor_pos`: The position of the neighbor node.
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// The calculated `f_score` for the neighbor node.
fn calculate_scores(
&mut self,
current_position: (usize, usize),
neighbor_pos: (usize, usize),
goal: (usize, usize),
) -> usize {
// Calculate the tentative g_score for the neighbor
let tentative_g_score = self.g_score[¤t_position] + 1;
// Check if this path to the neighbor is better than any previously recorded path
if tentative_g_score < *self.g_score.get(&neighbor_pos).unwrap_or(&usize::MAX) {
// Update the path and scores
self.came_from.insert(neighbor_pos, current_position);
self.g_score.insert(neighbor_pos, tentative_g_score);
let f_score_value = tentative_g_score + Self::manhattan_distance(neighbor_pos, goal);
self.f_score.insert(neighbor_pos, f_score_value);
return f_score_value;
}
usize::MAX
}
/// Checks if the current node has reached the goal.
///
/// # Parameters
/// - `current_position`: The position of the current node.
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// `true` if the current position is the goal, otherwise `false`.
fn is_goal_reached(&self, current_position: (usize, usize), goal: (usize, usize)) -> bool {
current_position == goal
}
/// Processes a neighbor node by calculating its `f_score` and adding it to the open set if necessary.
///
/// # Parameters
/// - `current_position`: The position of the current node.
/// - `neighbor_pos`: The position of the neighbor node.
/// - `goal`: The goal point as a tuple (x, y).
fn process_neighbor(
&mut self,
current_position: (usize, usize),
neighbor_pos: (usize, usize),
goal: (usize, usize),
) {
// Calculate the f_score for the neighbor
let f_score_value = self.calculate_scores(current_position, neighbor_pos, goal);
// If the f_score is valid, add the neighbor to the open set
if f_score_value != usize::MAX {
self.open_set.push(State {
cost: f_score_value,
position: neighbor_pos,
});
}
}
/// Validates that the start and goal nodes exist and are not blocked.
///
/// # Parameters
/// - `start`: The starting point as a tuple (x, y).
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// A result indicating success or an `AStarError` if validation fails.
fn validate_nodes(
&self,
start: (usize, usize),
goal: (usize, usize),
) -> Result<(), AStarError> {
// Check if the start node exists
if !self.nodes.contains_key(&start) {
return Err(AStarError::NodeNotFound(start));
}
// Check if the goal node exists
if !self.nodes.contains_key(&goal) {
return Err(AStarError::NodeNotFound(goal));
}
// Check if the start node is blocked
if self.nodes.get(&start).map_or(true, |node| node.is_blocked) {
return Err(AStarError::StartNodeBlocked(start));
}
// Check if the goal node is blocked
if self.nodes.get(&goal).map_or(true, |node| node.is_blocked) {
return Err(AStarError::GoalNodeBlocked(goal));
}
Ok(())
}
/// Finds the shortest path from start to goal using the A* algorithm.
///
/// # Parameters
/// - `start`: The starting point as a tuple (x, y).
/// - `goal`: The goal point as a tuple (x, y).
///
/// # Returns
/// A `Result<Option<Vec<(usize, usize)>>, AStarError>` containing the path from the start to the goal if found,
/// or an `AStarError` if no path is found or if an error occurs.
///
/// # Example
/// ```rust
/// let mut astar = AStar::new(nodes);
/// let path = astar.find_shortest_path((0, 0), (5, 5));
/// ```
pub fn find_shortest_path(
&mut self,
start: (usize, usize),
goal: (usize, usize),
) -> Result<Option<Vec<(usize, usize)>>, AStarError> {
// Validate nodes
self.validate_nodes(start, goal)?;
// Reset the open set and clear previous scores and path information
self.open_set = PriorityQueue::new(); // Reset the open set
self.came_from.clear(); // Clear the `came_from` map
self.g_score.clear(); // Clear the `g_score` map
self.f_score.clear(); // Clear the `f_score` map
// Initialize scores for the start and goal
let (g_score, f_score) = Self::initialize_scores((start.1, start.0), (goal.1, goal.0));
self.g_score = g_score;
self.f_score = f_score;
// Add the start position to the open set
self.open_set.push(State {
cost: self.f_score[&(start.1, start.0)],
position: (start.1, start.0),
});
let mut current_position = (start.1, start.0);
// Main loop of the A* algorithm
while let Some(current_state) = self.open_set.pop() {
current_position = current_state.position;
// Check if the goal has been reached
if self.is_goal_reached(current_position, (goal.1, goal.0)) {
// Reconstruct and return the path
return Ok(Some(Self::reconstruct_path(
self.came_from.clone(),
(start.1, start.0),
(goal.1, goal.0),
)));
}
// Process each neighbor of the current node
if let Some(current_node) = self.nodes.get(¤t_position) {
for neighbor_pos in self.find_neighbors(current_node) {
self.process_neighbor(current_position, neighbor_pos, (goal.1, goal.0));
}
}
}
Err(AStarError::PathNotFound(current_position))
}
}
impl fmt::Display for AStarError {
/// Formats the error for display in user-facing contexts.
///
/// # Example
/// ```rust
/// let error = AStarError::StartNodeBlocked((0, 0));
/// println!("{}", error);
/// // Output: The start node at position (0, 0) is blocked!
/// ```
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
AStarError::StartNodeBlocked(coord) => {
write!(f, "The start node at position {:?} is blocked!", coord)
}
AStarError::GoalNodeBlocked(coord) => {
write!(f, "The goal node at position {:?} is blocked!", coord)
}
AStarError::NodeNotFound(coord) => {
write!(f, "The node at position {:?} was not found!", coord)
}
AStarError::PathNotFound(coord) => {
write!(f, "Path not found! Last checked position was {:?}.", coord)
}
}
}
}
impl fmt::Debug for AStarError {
/// Formats the error for debugging purposes.
///
/// # Example
/// ```rust
/// let error = AStarError::GoalNodeBlocked((5, 5));
/// println!("{:?}", error);
/// // Output: GoalNodeBlocked((5, 5))
/// ```
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}