diff --git a/ICD.Common.Utils/RecursionUtils.cs b/ICD.Common.Utils/RecursionUtils.cs index bef79d9..b1617ee 100644 --- a/ICD.Common.Utils/RecursionUtils.cs +++ b/ICD.Common.Utils/RecursionUtils.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Linq; +using ICD.Common.Properties; namespace ICD.Common.Utils { @@ -14,6 +16,9 @@ namespace ICD.Common.Utils /// public static IEnumerable BreadthFirstSearch(T root, Func> getChildren) { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); + Queue process = new Queue(); process.Enqueue(root); @@ -26,5 +31,84 @@ namespace ICD.Common.Utils process.Enqueue(child); } } + + /// + /// Returns the shortest path from root to destination via breadth-first search. + /// + /// + /// + /// + /// + /// + [CanBeNull] + public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); + + return BreadthFirstSearchPath(root, destination, getChildren, EqualityComparer.Default); + } + + /// + /// Returns the shortest path from root to destination via breadth-first search. + /// + /// + /// + /// + /// + /// + /// + [CanBeNull] + public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren, + IEqualityComparer comparer) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); + + if (comparer == null) + throw new ArgumentNullException("comparer"); + + Queue queue = new Queue(); + queue.Enqueue(root); + + Dictionary nodeParents = new Dictionary(); + + while (queue.Count > 0) + { + T current = queue.Dequeue(); + + foreach (T node in getChildren(current)) + { + if (nodeParents.ContainsKey(node)) + continue; + + queue.Enqueue(node); + nodeParents.Add(node, current); + + // Found a path to the destination + if (comparer.Equals(node, destination)) + return GetPath(destination, nodeParents).Reverse(); + } + } + + return null; + } + + /// + /// Walks through a map of nodes from the starting point. + /// + /// + /// + /// + /// + private static IEnumerable GetPath(T start, IDictionary nodeParents) + { + while (true) + { + yield return start; + if (!nodeParents.TryGetValue(start, out start)) + break; + } + } } }