diff --git a/ICD.Common.Utils/RecursionUtils.cs b/ICD.Common.Utils/RecursionUtils.cs index fd491f4..df9919d 100644 --- a/ICD.Common.Utils/RecursionUtils.cs +++ b/ICD.Common.Utils/RecursionUtils.cs @@ -3,127 +3,203 @@ using System.Collections.Generic; using System.Linq; using ICD.Common.Properties; using ICD.Common.Utils.Collections; +using ICD.Common.Utils.Extensions; namespace ICD.Common.Utils { - public static class RecursionUtils - { - /// - /// Returns all of the nodes in the tree via breadth-first search. - /// - /// - /// - /// - /// - public static IEnumerable BreadthFirstSearch(T root, Func> getChildren) - { - if (getChildren == null) - throw new ArgumentNullException("getChildren"); + public static class RecursionUtils + { + /// + /// Returns all of the nodes in the tree via breadth-first search. + /// + /// + /// + /// + /// + public static IEnumerable BreadthFirstSearch(T root, Func> getChildren) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); - Queue process = new Queue(); - process.Enqueue(root); + Queue process = new Queue(); + process.Enqueue(root); - while (process.Count > 0) - { - T current = process.Dequeue(); - yield return current; + while (process.Count > 0) + { + T current = process.Dequeue(); + yield return current; - foreach (T child in getChildren(current)) - process.Enqueue(child); - } - } + foreach (T child in getChildren(current)) + process.Enqueue(child); + } + } - /// - /// Returns the shortest path from root to destination via breadth-first search. - /// - /// - /// - /// - /// - /// - [CanBeNull] - public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren) - { - if (getChildren == null) - throw new ArgumentNullException("getChildren"); + /// + /// Returns the shortest path from root to destination via breadth-first search. + /// + /// + /// + /// + /// + /// + [CanBeNull] + public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); - return BreadthFirstSearchPath(root, destination, getChildren, EqualityComparer.Default); - } + return BreadthFirstSearchPath(root, destination, getChildren, EqualityComparer.Default); + } - /// - /// Returns the shortest path from root to destination via breadth-first search. - /// - /// - /// - /// - /// - /// - /// - [CanBeNull] - public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren, - IEqualityComparer comparer) - { - if (getChildren == null) - throw new ArgumentNullException("getChildren"); + [NotNull] + public static Dictionary> BreadthFirstSearchManyDestinations(T root, Dictionary destinations, Func> getChildren) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); - if (comparer == null) - throw new ArgumentNullException("comparer"); + return BreadthFirstSearchPathManyDestinations(root, destinations, getChildren, EqualityComparer.Default); + } - // Edge case - root and destination are the same - if (comparer.Equals(root, destination)) - return new[] {root}; + /// + /// Returns the shortest path from root to destination via breadth-first search. + /// + /// + /// + /// + /// + /// + /// + [CanBeNull] + public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren, + IEqualityComparer comparer) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); - Queue queue = new Queue(); - queue.Enqueue(root); + if (comparer == null) + throw new ArgumentNullException("comparer"); - Dictionary nodeParents = new Dictionary(); + // Edge case - root and destination are the same + if (comparer.Equals(root, destination)) + return new[] { root }; - while (queue.Count > 0) - { - T current = queue.Dequeue(); + Queue queue = new Queue(); + queue.Enqueue(root); - foreach (T node in getChildren(current)) - { - if (nodeParents.ContainsKey(node)) - continue; + Dictionary nodeParents = new Dictionary(); - queue.Enqueue(node); - nodeParents.Add(node, current); + while (queue.Count > 0) + { + T current = queue.Dequeue(); - // Found a path to the destination - if (comparer.Equals(node, destination)) - return GetPath(destination, nodeParents).Reverse(); - } - } + foreach (T node in getChildren(current)) + { + if (nodeParents.ContainsKey(node)) + continue; - return null; - } + queue.Enqueue(node); + nodeParents.Add(node, current); - /// - /// Walks through a map of nodes from the starting point. - /// - /// - /// - /// - /// - private static IEnumerable GetPath(T start, IDictionary nodeParents) - { - IcdHashSet visited = new IcdHashSet(); + // Found a path to the destination + if (comparer.Equals(node, destination)) + return GetPath(destination, nodeParents).Reverse(); + } + } - while (true) - { - yield return start; - visited.Add(start); + return null; + } - T next; - if (!nodeParents.TryGetValue(start, out next)) - break; + [NotNull] + public static Dictionary> BreadthFirstSearchPathManyDestinations(T root, Dictionary destinations, + Func> getChildren, + IEqualityComparer comparer) + { + if (getChildren == null) + throw new ArgumentNullException("getChildren"); - if (visited.Contains(next)) - break; + if (comparer == null) + throw new ArgumentNullException("comparer"); - start = next; - } - } - } + Dictionary destinationsToBeProcessed = new Dictionary(destinations); + List destinationsProcessed = new List(); + Dictionary> pathsToReturn = new Dictionary>(); + + //Edge case, root is the destination + foreach (var destination in destinationsToBeProcessed.Where(destination => comparer.Equals(root, destination.Value))) + { + destinationsProcessed.Add(destination.Value); + pathsToReturn.Add(destination.Key, new[] { root }); + } + + foreach (var destination in destinationsProcessed) + { + destinationsToBeProcessed.RemoveValue(destination); + } + destinationsProcessed.Clear(); + if (destinationsToBeProcessed.Count == 0) + { + return pathsToReturn; + } + + Queue queue = new Queue(); + queue.Enqueue(root); + + Dictionary nodeParents = new Dictionary(); + + while (queue.Count > 0) + { + T current = queue.Dequeue(); + + foreach (T node in getChildren(current).Where(node => !nodeParents.ContainsKey(node))) + { + queue.Enqueue(node); + nodeParents.Add(node, current); + + foreach (var destination in destinationsToBeProcessed.Where(destination => comparer.Equals(node, destination.Value))) + { + destinationsProcessed.Add(destination.Value); + pathsToReturn.Add(destination.Key, GetPath(destination.Value, nodeParents).Reverse()); + } + foreach (var destination in destinationsProcessed) + { + destinationsToBeProcessed.RemoveValue(destination); + } + destinationsProcessed.Clear(); + if (destinationsToBeProcessed.Count == 0) + { + return pathsToReturn; + } + } + } + + return pathsToReturn; + } + + /// + /// Walks through a map of nodes from the starting point. + /// + /// + /// + /// + /// + private static IEnumerable GetPath(T start, IDictionary nodeParents) + { + IcdHashSet visited = new IcdHashSet(); + + while (true) + { + yield return start; + visited.Add(start); + + T next; + if (!nodeParents.TryGetValue(start, out next)) + break; + + if (visited.Contains(next)) + break; + + start = next; + } + } + } }