diff --git a/ICD.Common.Utils.Tests/RecursionUtilsTest.cs b/ICD.Common.Utils.Tests/RecursionUtilsTest.cs index ff96110..3c6cadb 100644 --- a/ICD.Common.Utils.Tests/RecursionUtilsTest.cs +++ b/ICD.Common.Utils.Tests/RecursionUtilsTest.cs @@ -250,5 +250,40 @@ namespace ICD.Common.Utils.Tests Assert.AreEqual(6, paths[62].ToArray()[1]); Assert.AreEqual(62, paths[62].ToArray()[2]); } + + [Test] + public void BreadthFirstSearchPathsTest() + { + Dictionary> paths = + RecursionUtils.BreadthFirstSearchPaths(1, Graph) + .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + + // Verify we visited all destinations + Assert.AreEqual(4, paths.Count); + Assert.IsTrue(paths.ContainsKey(1)); + Assert.IsTrue(paths.ContainsKey(2)); + Assert.IsTrue(paths.ContainsKey(3)); + Assert.IsTrue(paths.ContainsKey(4)); + + // Verify path for destination 1 + Assert.AreEqual(1, paths[1].ToArray().Length); + Assert.AreEqual(1, paths[1].ToArray()[0]); + + // Verify path for destination 2 + Assert.AreEqual(2, paths[2].ToArray().Length); + Assert.AreEqual(1, paths[2].ToArray()[0]); + Assert.AreEqual(2, paths[2].ToArray()[1]); + + // Verify path for destination 3 + Assert.AreEqual(2, paths[3].ToArray().Length); + Assert.AreEqual(1, paths[3].ToArray()[0]); + Assert.AreEqual(3, paths[3].ToArray()[1]); + + // Verify path for destination 4 + Assert.AreEqual(3, paths[4].ToArray().Length); + Assert.AreEqual(1, paths[4].ToArray()[0]); + Assert.AreEqual(2, paths[4].ToArray()[1]); + Assert.AreEqual(4, paths[4].ToArray()[2]); + } } } diff --git a/ICD.Common.Utils/RecursionUtils.cs b/ICD.Common.Utils/RecursionUtils.cs index ad0e857..1ef3408 100644 --- a/ICD.Common.Utils/RecursionUtils.cs +++ b/ICD.Common.Utils/RecursionUtils.cs @@ -16,7 +16,9 @@ namespace ICD.Common.Utils /// /// /// - public static IEnumerable> GetCliques(IEnumerable nodes, Func> getAdjacent) + [NotNull] + public static IEnumerable> GetCliques([NotNull] IEnumerable nodes, + [NotNull] Func> getAdjacent) { if (nodes == null) throw new ArgumentNullException("nodes"); @@ -39,7 +41,8 @@ namespace ICD.Common.Utils /// /// /// - public static IEnumerable GetClique(T node, Func> getAdjacent) + [NotNull] + public static IEnumerable GetClique([NotNull] T node, [NotNull] Func> getAdjacent) { if (node == null) throw new ArgumentNullException("node"); @@ -57,7 +60,8 @@ namespace ICD.Common.Utils /// /// /// - public static IEnumerable GetClique(IDictionary> map, T node) + [NotNull] + public static IEnumerable GetClique([NotNull] IDictionary> map, [NotNull] T node) { if (map == null) throw new ArgumentNullException("map"); @@ -77,7 +81,8 @@ namespace ICD.Common.Utils /// /// /// - private static IEnumerable GetClique(IDictionary> map, IcdHashSet visited, T node) + private static IEnumerable GetClique([NotNull] IDictionary> map, + [NotNull] IcdHashSet visited, [NotNull] T node) { if (map == null) throw new ArgumentNullException("map"); @@ -89,19 +94,6 @@ namespace ICD.Common.Utils if (node == null) throw new ArgumentNullException("node"); - return GetCliqueIterator(map, visited, node); - } - - /// - /// Gets the clique containing the node. - /// - /// - /// - /// - /// - /// - private static IEnumerable GetCliqueIterator(IDictionary> map, IcdHashSet visited, T node) - { if (!visited.Add(node)) yield break; @@ -123,8 +115,19 @@ namespace ICD.Common.Utils /// /// /// - public static bool BreadthFirstSearch(T root, T destination, Func> getChildren) + public static bool BreadthFirstSearch([NotNull] T root, [NotNull] T destination, + [NotNull] Func> getChildren) { +// ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) +// ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + +// ReSharper disable CompareNonConstrainedGenericWithNull + if (destination == null) +// ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("destination"); + if (getChildren == null) throw new ArgumentNullException("getChildren"); @@ -140,8 +143,20 @@ namespace ICD.Common.Utils /// /// /// - public static bool BreadthFirstSearch(T root, T destination, Func> getChildren, IEqualityComparer comparer) + public static bool BreadthFirstSearch([NotNull] T root, [NotNull] T destination, + [NotNull] Func> getChildren, + [NotNull] IEqualityComparer comparer) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + + // ReSharper disable CompareNonConstrainedGenericWithNull + if (destination == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("destination"); + if (getChildren == null) throw new ArgumentNullException("getChildren"); @@ -158,38 +173,17 @@ namespace ICD.Common.Utils /// /// /// - public static IEnumerable BreadthFirstSearch(T root, Func> getChildren) + public static IEnumerable BreadthFirstSearch([NotNull] T root, [NotNull] Func> getChildren) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + if (getChildren == null) throw new ArgumentNullException("getChildren"); - return BreadthFirstSearchIterator(root, getChildren); - } - - /// - /// Returns all of the nodes in the graph via breadth-first search. - /// - /// - /// - /// - /// - private static IEnumerable BreadthFirstSearchIterator(T root, Func> getChildren) - { - IcdHashSet visited = new IcdHashSet {root}; - Queue process = new Queue(); - process.Enqueue(root); - - T current; - while (process.Dequeue(out current)) - { - yield return current; - - foreach (T child in getChildren(current).Where(c => !visited.Contains(c))) - { - visited.Add(child); - process.Enqueue(child); - } - } + return BreadthFirstSearchPaths(root, getChildren, EqualityComparer.Default).Select(kvp => kvp.Key); } /// @@ -201,8 +195,19 @@ namespace ICD.Common.Utils /// /// [CanBeNull] - public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren) + public static IEnumerable BreadthFirstSearchPath([NotNull] T root, [NotNull] T destination, + [NotNull] Func> getChildren) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + + // ReSharper disable CompareNonConstrainedGenericWithNull + if (destination == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("destination"); + if (getChildren == null) throw new ArgumentNullException("getChildren"); @@ -219,42 +224,29 @@ namespace ICD.Common.Utils /// /// [CanBeNull] - public static IEnumerable BreadthFirstSearchPath(T root, T destination, Func> getChildren, - IEqualityComparer comparer) + public static IEnumerable BreadthFirstSearchPath([NotNull] T root, [NotNull] T destination, + [NotNull] Func> getChildren, + [NotNull] IEqualityComparer comparer) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + + // ReSharper disable CompareNonConstrainedGenericWithNull + if (destination == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("destination"); + if (getChildren == null) throw new ArgumentNullException("getChildren"); if (comparer == null) throw new ArgumentNullException("comparer"); - // Edge case - root and destination are the same - if (comparer.Equals(root, destination)) - return new[] {root}; - - Queue queue = new Queue(); - queue.Enqueue(root); - - Dictionary nodeParents = new Dictionary(comparer); - - T current; - while (queue.Dequeue(out current)) - { - foreach (T node in getChildren(current)) - { - if (nodeParents.ContainsKey(node)) - continue; - - queue.Enqueue(node); - nodeParents.Add(node, current); - - // Found a path to the destination - if (comparer.Equals(node, destination)) - return GetPath(destination, root, nodeParents, comparer).Reverse(); - } - } - - return null; + return BreadthFirstSearchPathManyDestinations(root, destination.Yield(), getChildren, comparer) + .Select(kvp => kvp.Value) + .FirstOrDefault(); } /// @@ -265,9 +257,16 @@ namespace ICD.Common.Utils /// /// /// + [NotNull] public static IEnumerable>> - BreadthFirstSearchPathManyDestinations(T root, IEnumerable destinations, Func> getChildren) + BreadthFirstSearchPathManyDestinations([NotNull] T root, [NotNull] IEnumerable destinations, + [NotNull] Func> getChildren) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + if (destinations == null) throw new ArgumentNullException("destinations"); @@ -286,10 +285,17 @@ namespace ICD.Common.Utils /// /// /// + [NotNull] public static IEnumerable>> - BreadthFirstSearchPathManyDestinations(T root, IEnumerable destinations, Func> getChildren, - IEqualityComparer comparer) + BreadthFirstSearchPathManyDestinations([NotNull] T root, [NotNull] IEnumerable destinations, + [NotNull] Func> getChildren, + [NotNull] IEqualityComparer comparer) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + if (destinations == null) throw new ArgumentNullException("destinations"); @@ -299,23 +305,68 @@ namespace ICD.Common.Utils if (comparer == null) throw new ArgumentNullException("comparer"); - IcdHashSet destinationsToBeProcessed = new IcdHashSet(destinations); + IcdHashSet destinationsSet = destinations as IcdHashSet ?? destinations.ToIcdHashSet(); + return BreadthFirstSearchPaths(root, getChildren, comparer) + .Where(kvp => destinationsSet.Contains(kvp.Key)) + .Take(destinationsSet.Count); + } - // Edge case, root is the destination - foreach (T destination in - destinationsToBeProcessed.Where(destination => comparer.Equals(root, destination)).ToArray()) - { - destinationsToBeProcessed.Remove(destination); - yield return new KeyValuePair>(destination, new[] {root}); - } + /// + /// Returns the shortest path from root to each destination via breadth-first search. + /// + /// + /// + /// + /// + [NotNull] + public static IEnumerable>> + BreadthFirstSearchPaths([NotNull] T root, [NotNull] Func> getChildren) + { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); - if (destinationsToBeProcessed.Count == 0) - yield break; + if (getChildren == null) + throw new ArgumentNullException("getChildren"); + + return BreadthFirstSearchPaths(root, getChildren, EqualityComparer.Default); + } + + /// + /// Returns the shortest path from root to each destination via breadth-first search. + /// + /// + /// + /// + /// + /// + [NotNull] + public static IEnumerable>> + BreadthFirstSearchPaths([NotNull] T root, [NotNull] Func> getChildren, + [NotNull] IEqualityComparer comparer) + { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (root == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("root"); + + if (getChildren == null) + throw new ArgumentNullException("getChildren"); + + if (comparer == null) + throw new ArgumentNullException("comparer"); + + // Edge case - root is the same as destination + yield return new KeyValuePair>(root, root.Yield()); Queue queue = new Queue(); queue.Enqueue(root); - Dictionary nodeParents = new Dictionary(comparer); + Dictionary nodeParents = new Dictionary(comparer) + { + {root, default(T)} + }; while (queue.Count > 0) { @@ -326,18 +377,8 @@ namespace ICD.Common.Utils queue.Enqueue(node); nodeParents.Add(node, current); - T closureNode = node; - foreach (T destination in - destinationsToBeProcessed.Where(destination => comparer.Equals(closureNode, destination)).ToArray()) - { - destinationsToBeProcessed.Remove(destination); - - yield return - new KeyValuePair>(destination, GetPath(destination, root, nodeParents, comparer).Reverse()); - } - - if (destinationsToBeProcessed.Count == 0) - yield break; + yield return + new KeyValuePair>(node, GetPath(node, root, nodeParents, comparer).Reverse()); } } } @@ -351,8 +392,26 @@ namespace ICD.Common.Utils /// /// /// - private static IEnumerable GetPath(T start, T end, IDictionary nodeParents, IEqualityComparer comparer) + [NotNull] + private static IEnumerable GetPath([NotNull] T start, [NotNull] T end, [NotNull] IDictionary nodeParents, + [NotNull] IEqualityComparer comparer) { + // ReSharper disable CompareNonConstrainedGenericWithNull + if (start == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("start"); + + // ReSharper disable CompareNonConstrainedGenericWithNull + if (end == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + throw new ArgumentNullException("end"); + + if (nodeParents == null) + throw new ArgumentNullException("nodeParents"); + + if (comparer == null) + throw new ArgumentNullException("comparer"); + IcdHashSet visited = new IcdHashSet(); while (true) @@ -368,6 +427,11 @@ namespace ICD.Common.Utils if (!nodeParents.TryGetValue(start, out next)) break; + // ReSharper disable CompareNonConstrainedGenericWithNull + if (next == null) + // ReSharper restore CompareNonConstrainedGenericWithNull + break; + if (visited.Contains(next)) break;