feat: Added breadth-first search method that returns the path from root for every node in the graph

This commit is contained in:
Chris Cameron
2020-10-22 12:16:21 -04:00
parent c349757c00
commit f174c32721
2 changed files with 203 additions and 104 deletions

View File

@@ -250,5 +250,40 @@ namespace ICD.Common.Utils.Tests
Assert.AreEqual(6, paths[62].ToArray()[1]);
Assert.AreEqual(62, paths[62].ToArray()[2]);
}
[Test]
public void BreadthFirstSearchPathsTest()
{
Dictionary<int, IEnumerable<int>> paths =
RecursionUtils.BreadthFirstSearchPaths(1, Graph)
.ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
// Verify we visited all destinations
Assert.AreEqual(4, paths.Count);
Assert.IsTrue(paths.ContainsKey(1));
Assert.IsTrue(paths.ContainsKey(2));
Assert.IsTrue(paths.ContainsKey(3));
Assert.IsTrue(paths.ContainsKey(4));
// Verify path for destination 1
Assert.AreEqual(1, paths[1].ToArray().Length);
Assert.AreEqual(1, paths[1].ToArray()[0]);
// Verify path for destination 2
Assert.AreEqual(2, paths[2].ToArray().Length);
Assert.AreEqual(1, paths[2].ToArray()[0]);
Assert.AreEqual(2, paths[2].ToArray()[1]);
// Verify path for destination 3
Assert.AreEqual(2, paths[3].ToArray().Length);
Assert.AreEqual(1, paths[3].ToArray()[0]);
Assert.AreEqual(3, paths[3].ToArray()[1]);
// Verify path for destination 4
Assert.AreEqual(3, paths[4].ToArray().Length);
Assert.AreEqual(1, paths[4].ToArray()[0]);
Assert.AreEqual(2, paths[4].ToArray()[1]);
Assert.AreEqual(4, paths[4].ToArray()[2]);
}
}
}

View File

@@ -16,7 +16,9 @@ namespace ICD.Common.Utils
/// <param name="nodes"></param>
/// <param name="getAdjacent"></param>
/// <returns></returns>
public static IEnumerable<IEnumerable<T>> GetCliques<T>(IEnumerable<T> nodes, Func<T, IEnumerable<T>> getAdjacent)
[NotNull]
public static IEnumerable<IEnumerable<T>> GetCliques<T>([NotNull] IEnumerable<T> nodes,
[NotNull] Func<T, IEnumerable<T>> getAdjacent)
{
if (nodes == null)
throw new ArgumentNullException("nodes");
@@ -39,7 +41,8 @@ namespace ICD.Common.Utils
/// <param name="node"></param>
/// <param name="getAdjacent"></param>
/// <returns></returns>
public static IEnumerable<T> GetClique<T>(T node, Func<T, IEnumerable<T>> getAdjacent)
[NotNull]
public static IEnumerable<T> GetClique<T>([NotNull] T node, [NotNull] Func<T, IEnumerable<T>> getAdjacent)
{
if (node == null)
throw new ArgumentNullException("node");
@@ -57,7 +60,8 @@ namespace ICD.Common.Utils
/// <param name="map"></param>
/// <param name="node"></param>
/// <returns></returns>
public static IEnumerable<T> GetClique<T>(IDictionary<T, IEnumerable<T>> map, T node)
[NotNull]
public static IEnumerable<T> GetClique<T>([NotNull] IDictionary<T, IEnumerable<T>> map, [NotNull] T node)
{
if (map == null)
throw new ArgumentNullException("map");
@@ -77,7 +81,8 @@ namespace ICD.Common.Utils
/// <param name="visited"></param>
/// <param name="node"></param>
/// <returns></returns>
private static IEnumerable<T> GetClique<T>(IDictionary<T, IEnumerable<T>> map, IcdHashSet<T> visited, T node)
private static IEnumerable<T> GetClique<T>([NotNull] IDictionary<T, IEnumerable<T>> map,
[NotNull] IcdHashSet<T> visited, [NotNull] T node)
{
if (map == null)
throw new ArgumentNullException("map");
@@ -89,19 +94,6 @@ namespace ICD.Common.Utils
if (node == null)
throw new ArgumentNullException("node");
return GetCliqueIterator(map, visited, node);
}
/// <summary>
/// Gets the clique containing the node.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="map"></param>
/// <param name="visited"></param>
/// <param name="node"></param>
/// <returns></returns>
private static IEnumerable<T> GetCliqueIterator<T>(IDictionary<T, IEnumerable<T>> map, IcdHashSet<T> visited, T node)
{
if (!visited.Add(node))
yield break;
@@ -123,8 +115,19 @@ namespace ICD.Common.Utils
/// <param name="destination"></param>
/// <param name="getChildren"></param>
/// <returns></returns>
public static bool BreadthFirstSearch<T>(T root, T destination, Func<T, IEnumerable<T>> getChildren)
public static bool BreadthFirstSearch<T>([NotNull] T root, [NotNull] T destination,
[NotNull] Func<T, IEnumerable<T>> getChildren)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
// ReSharper disable CompareNonConstrainedGenericWithNull
if (destination == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("destination");
if (getChildren == null)
throw new ArgumentNullException("getChildren");
@@ -140,8 +143,20 @@ namespace ICD.Common.Utils
/// <param name="getChildren"></param>
/// <param name="comparer"></param>
/// <returns></returns>
public static bool BreadthFirstSearch<T>(T root, T destination, Func<T, IEnumerable<T>> getChildren, IEqualityComparer<T> comparer)
public static bool BreadthFirstSearch<T>([NotNull] T root, [NotNull] T destination,
[NotNull] Func<T, IEnumerable<T>> getChildren,
[NotNull] IEqualityComparer<T> comparer)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
// ReSharper disable CompareNonConstrainedGenericWithNull
if (destination == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("destination");
if (getChildren == null)
throw new ArgumentNullException("getChildren");
@@ -158,38 +173,17 @@ namespace ICD.Common.Utils
/// <param name="root"></param>
/// <param name="getChildren"></param>
/// <returns></returns>
public static IEnumerable<T> BreadthFirstSearch<T>(T root, Func<T, IEnumerable<T>> getChildren)
public static IEnumerable<T> BreadthFirstSearch<T>([NotNull] T root, [NotNull] Func<T, IEnumerable<T>> getChildren)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
if (getChildren == null)
throw new ArgumentNullException("getChildren");
return BreadthFirstSearchIterator(root, getChildren);
}
/// <summary>
/// Returns all of the nodes in the graph via breadth-first search.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="root"></param>
/// <param name="getChildren"></param>
/// <returns></returns>
private static IEnumerable<T> BreadthFirstSearchIterator<T>(T root, Func<T, IEnumerable<T>> getChildren)
{
IcdHashSet<T> visited = new IcdHashSet<T> {root};
Queue<T> process = new Queue<T>();
process.Enqueue(root);
T current;
while (process.Dequeue(out current))
{
yield return current;
foreach (T child in getChildren(current).Where(c => !visited.Contains(c)))
{
visited.Add(child);
process.Enqueue(child);
}
}
return BreadthFirstSearchPaths(root, getChildren, EqualityComparer<T>.Default).Select(kvp => kvp.Key);
}
/// <summary>
@@ -201,8 +195,19 @@ namespace ICD.Common.Utils
/// <param name="getChildren"></param>
/// <returns></returns>
[CanBeNull]
public static IEnumerable<T> BreadthFirstSearchPath<T>(T root, T destination, Func<T, IEnumerable<T>> getChildren)
public static IEnumerable<T> BreadthFirstSearchPath<T>([NotNull] T root, [NotNull] T destination,
[NotNull] Func<T, IEnumerable<T>> getChildren)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
// ReSharper disable CompareNonConstrainedGenericWithNull
if (destination == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("destination");
if (getChildren == null)
throw new ArgumentNullException("getChildren");
@@ -219,42 +224,29 @@ namespace ICD.Common.Utils
/// <param name="comparer"></param>
/// <returns></returns>
[CanBeNull]
public static IEnumerable<T> BreadthFirstSearchPath<T>(T root, T destination, Func<T, IEnumerable<T>> getChildren,
IEqualityComparer<T> comparer)
public static IEnumerable<T> BreadthFirstSearchPath<T>([NotNull] T root, [NotNull] T destination,
[NotNull] Func<T, IEnumerable<T>> getChildren,
[NotNull] IEqualityComparer<T> comparer)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
// ReSharper disable CompareNonConstrainedGenericWithNull
if (destination == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("destination");
if (getChildren == null)
throw new ArgumentNullException("getChildren");
if (comparer == null)
throw new ArgumentNullException("comparer");
// Edge case - root and destination are the same
if (comparer.Equals(root, destination))
return new[] {root};
Queue<T> queue = new Queue<T>();
queue.Enqueue(root);
Dictionary<T, T> nodeParents = new Dictionary<T, T>(comparer);
T current;
while (queue.Dequeue(out current))
{
foreach (T node in getChildren(current))
{
if (nodeParents.ContainsKey(node))
continue;
queue.Enqueue(node);
nodeParents.Add(node, current);
// Found a path to the destination
if (comparer.Equals(node, destination))
return GetPath(destination, root, nodeParents, comparer).Reverse();
}
}
return null;
return BreadthFirstSearchPathManyDestinations(root, destination.Yield(), getChildren, comparer)
.Select(kvp => kvp.Value)
.FirstOrDefault();
}
/// <summary>
@@ -265,9 +257,16 @@ namespace ICD.Common.Utils
/// <param name="destinations"></param>
/// <param name="getChildren"></param>
/// <returns></returns>
[NotNull]
public static IEnumerable<KeyValuePair<T, IEnumerable<T>>>
BreadthFirstSearchPathManyDestinations<T>(T root, IEnumerable<T> destinations, Func<T, IEnumerable<T>> getChildren)
BreadthFirstSearchPathManyDestinations<T>([NotNull] T root, [NotNull] IEnumerable<T> destinations,
[NotNull] Func<T, IEnumerable<T>> getChildren)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
if (destinations == null)
throw new ArgumentNullException("destinations");
@@ -286,10 +285,17 @@ namespace ICD.Common.Utils
/// <param name="getChildren"></param>
/// <param name="comparer"></param>
/// <returns></returns>
[NotNull]
public static IEnumerable<KeyValuePair<T, IEnumerable<T>>>
BreadthFirstSearchPathManyDestinations<T>(T root, IEnumerable<T> destinations, Func<T, IEnumerable<T>> getChildren,
IEqualityComparer<T> comparer)
BreadthFirstSearchPathManyDestinations<T>([NotNull] T root, [NotNull] IEnumerable<T> destinations,
[NotNull] Func<T, IEnumerable<T>> getChildren,
[NotNull] IEqualityComparer<T> comparer)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
if (destinations == null)
throw new ArgumentNullException("destinations");
@@ -299,23 +305,68 @@ namespace ICD.Common.Utils
if (comparer == null)
throw new ArgumentNullException("comparer");
IcdHashSet<T> destinationsToBeProcessed = new IcdHashSet<T>(destinations);
IcdHashSet<T> destinationsSet = destinations as IcdHashSet<T> ?? destinations.ToIcdHashSet();
return BreadthFirstSearchPaths(root, getChildren, comparer)
.Where(kvp => destinationsSet.Contains(kvp.Key))
.Take(destinationsSet.Count);
}
// Edge case, root is the destination
foreach (T destination in
destinationsToBeProcessed.Where(destination => comparer.Equals(root, destination)).ToArray())
{
destinationsToBeProcessed.Remove(destination);
yield return new KeyValuePair<T, IEnumerable<T>>(destination, new[] {root});
}
/// <summary>
/// Returns the shortest path from root to each destination via breadth-first search.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="root"></param>
/// <param name="getChildren"></param>
/// <returns></returns>
[NotNull]
public static IEnumerable<KeyValuePair<T, IEnumerable<T>>>
BreadthFirstSearchPaths<T>([NotNull] T root, [NotNull] Func<T, IEnumerable<T>> getChildren)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
if (destinationsToBeProcessed.Count == 0)
yield break;
if (getChildren == null)
throw new ArgumentNullException("getChildren");
return BreadthFirstSearchPaths(root, getChildren, EqualityComparer<T>.Default);
}
/// <summary>
/// Returns the shortest path from root to each destination via breadth-first search.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="root"></param>
/// <param name="getChildren"></param>
/// <param name="comparer"></param>
/// <returns></returns>
[NotNull]
public static IEnumerable<KeyValuePair<T, IEnumerable<T>>>
BreadthFirstSearchPaths<T>([NotNull] T root, [NotNull] Func<T, IEnumerable<T>> getChildren,
[NotNull] IEqualityComparer<T> comparer)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (root == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("root");
if (getChildren == null)
throw new ArgumentNullException("getChildren");
if (comparer == null)
throw new ArgumentNullException("comparer");
// Edge case - root is the same as destination
yield return new KeyValuePair<T, IEnumerable<T>>(root, root.Yield());
Queue<T> queue = new Queue<T>();
queue.Enqueue(root);
Dictionary<T, T> nodeParents = new Dictionary<T, T>(comparer);
Dictionary<T, T> nodeParents = new Dictionary<T, T>(comparer)
{
{root, default(T)}
};
while (queue.Count > 0)
{
@@ -326,18 +377,8 @@ namespace ICD.Common.Utils
queue.Enqueue(node);
nodeParents.Add(node, current);
T closureNode = node;
foreach (T destination in
destinationsToBeProcessed.Where(destination => comparer.Equals(closureNode, destination)).ToArray())
{
destinationsToBeProcessed.Remove(destination);
yield return
new KeyValuePair<T, IEnumerable<T>>(destination, GetPath(destination, root, nodeParents, comparer).Reverse());
}
if (destinationsToBeProcessed.Count == 0)
yield break;
yield return
new KeyValuePair<T, IEnumerable<T>>(node, GetPath(node, root, nodeParents, comparer).Reverse());
}
}
}
@@ -351,8 +392,26 @@ namespace ICD.Common.Utils
/// <param name="nodeParents"></param>
/// <param name="comparer"></param>
/// <returns></returns>
private static IEnumerable<T> GetPath<T>(T start, T end, IDictionary<T, T> nodeParents, IEqualityComparer<T> comparer)
[NotNull]
private static IEnumerable<T> GetPath<T>([NotNull] T start, [NotNull] T end, [NotNull] IDictionary<T, T> nodeParents,
[NotNull] IEqualityComparer<T> comparer)
{
// ReSharper disable CompareNonConstrainedGenericWithNull
if (start == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("start");
// ReSharper disable CompareNonConstrainedGenericWithNull
if (end == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
throw new ArgumentNullException("end");
if (nodeParents == null)
throw new ArgumentNullException("nodeParents");
if (comparer == null)
throw new ArgumentNullException("comparer");
IcdHashSet<T> visited = new IcdHashSet<T>();
while (true)
@@ -368,6 +427,11 @@ namespace ICD.Common.Utils
if (!nodeParents.TryGetValue(start, out next))
break;
// ReSharper disable CompareNonConstrainedGenericWithNull
if (next == null)
// ReSharper restore CompareNonConstrainedGenericWithNull
break;
if (visited.Contains(next))
break;