diff --git a/ICD.Common.Utils.Tests/Extensions/TypeExtensionsTest.cs b/ICD.Common.Utils.Tests/Extensions/TypeExtensionsTest.cs index cf44533..1f8fb59 100644 --- a/ICD.Common.Utils.Tests/Extensions/TypeExtensionsTest.cs +++ b/ICD.Common.Utils.Tests/Extensions/TypeExtensionsTest.cs @@ -1,4 +1,6 @@ using System; +using System.Collections; +using System.Collections.Generic; using System.Linq; using ICD.Common.Utils.Extensions; using NUnit.Framework; @@ -93,6 +95,20 @@ namespace ICD.Common.Utils.Tests.Extensions Assert.IsTrue(baseTypes.Contains(typeof(object))); } + [Test] + public void GetImmediateInterfacesTest() + { + Type[] interfaces = typeof(ICollection).GetImmediateInterfaces().ToArray(); + + Assert.AreEqual(1, interfaces.Length); + Assert.AreEqual(typeof(IEnumerable), interfaces[0]); + + interfaces = typeof(IEnumerable).GetImmediateInterfaces().ToArray(); + + Assert.AreEqual(1, interfaces.Length); + Assert.AreEqual(typeof(IEnumerable), interfaces[0]); + } + private interface C { } diff --git a/ICD.Common.Utils/Extensions/TypeExtensions.cs b/ICD.Common.Utils/Extensions/TypeExtensions.cs index 08b0571..ad0a70d 100644 --- a/ICD.Common.Utils/Extensions/TypeExtensions.cs +++ b/ICD.Common.Utils/Extensions/TypeExtensions.cs @@ -47,6 +47,7 @@ namespace ICD.Common.Utils.Extensions private static readonly Dictionary s_TypeAllTypes; private static readonly Dictionary s_TypeBaseTypes; + private static readonly Dictionary s_TypeImmediateInterfaces; private static readonly Dictionary s_TypeMinimalInterfaces; /// @@ -56,6 +57,7 @@ namespace ICD.Common.Utils.Extensions { s_TypeAllTypes = new Dictionary(); s_TypeBaseTypes = new Dictionary(); + s_TypeImmediateInterfaces = new Dictionary(); s_TypeMinimalInterfaces = new Dictionary(); } @@ -184,6 +186,35 @@ namespace ICD.Common.Utils.Extensions } while (type != null); } + /// + /// Gets the interfaces that the given type implements that are not implemented further + /// down the inheritance chain. + /// + /// + /// + public static IEnumerable GetImmediateInterfaces(this Type extends) + { + if (extends == null) + throw new ArgumentNullException("extends"); + + if (!s_TypeImmediateInterfaces.ContainsKey(extends)) + { + IEnumerable allInterfaces = extends.GetInterfaces(); + + IEnumerable childInterfaces = + extends.GetAllTypes() + .Except(extends) + .SelectMany(t => t.GetImmediateInterfaces()) + .Distinct(); + + Type[] immediateInterfaces = allInterfaces.Except(childInterfaces).ToArray(); + + s_TypeImmediateInterfaces.Add(extends, immediateInterfaces); + } + + return s_TypeImmediateInterfaces[extends]; + } + /// /// Gets the smallest set of interfaces for the given type that cover all implemented interfaces. ///