refactor: Reworking EnumUtils to rely less on casting back and forth to int, added method for getting enum names

This commit is contained in:
Chris Cameron
2020-04-13 14:49:57 -04:00
parent 5e61098e71
commit 79344a3667
2 changed files with 188 additions and 133 deletions

View File

@@ -5,8 +5,8 @@ using System.Linq;
namespace ICD.Common.Utils.Tests namespace ICD.Common.Utils.Tests
{ {
[TestFixture] [TestFixture]
public sealed class EnumUtilsTest public sealed class EnumUtilsTest
{ {
public enum eTestEnum public enum eTestEnum
{ {
None = 0, None = 0,
@@ -57,17 +57,29 @@ namespace ICD.Common.Utils.Tests
Assert.IsFalse(EnumUtils.IsEnum("")); Assert.IsFalse(EnumUtils.IsEnum(""));
} }
[Test] [Test]
public void IsDefinedTest() public void IsDefinedTest()
{ {
Assert.IsFalse(EnumUtils.IsDefined((eTestEnum)20)); Assert.IsFalse(EnumUtils.IsDefined((eTestEnum)20));
Assert.IsTrue(EnumUtils.IsDefined((eTestEnum)2)); Assert.IsTrue(EnumUtils.IsDefined((eTestEnum)2));
Assert.IsFalse(EnumUtils.IsDefined((eTestFlagsEnum)8)); Assert.IsFalse(EnumUtils.IsDefined((eTestFlagsEnum)8));
Assert.IsTrue(EnumUtils.IsDefined((eTestFlagsEnum)3)); Assert.IsTrue(EnumUtils.IsDefined((eTestFlagsEnum)3));
} }
#region Values #region Values
[Test]
public void GetNamesTest()
{
string[] names = EnumUtils.GetNames<eTestEnum>().ToArray();
Assert.AreEqual(4, names.Length);
Assert.IsTrue(names.Contains("None"));
Assert.IsTrue(names.Contains("A"));
Assert.IsTrue(names.Contains("B"));
Assert.IsTrue(names.Contains("C"));
}
[Test] [Test]
public void GetValuesGenericTest() public void GetValuesGenericTest()
@@ -107,7 +119,9 @@ namespace ICD.Common.Utils.Tests
[Test] [Test]
public void GetFlagsIntersectionGenericTest() public void GetFlagsIntersectionGenericTest()
{ {
Assert.AreEqual(eTestFlagsEnum.B, EnumUtils.GetFlagsIntersection(eTestFlagsEnum.A | eTestFlagsEnum.B, eTestFlagsEnum.B | eTestFlagsEnum.C)); Assert.AreEqual(eTestFlagsEnum.B,
EnumUtils.GetFlagsIntersection(eTestFlagsEnum.A | eTestFlagsEnum.B,
eTestFlagsEnum.B | eTestFlagsEnum.C));
} }
[Test] [Test]
@@ -138,11 +152,11 @@ namespace ICD.Common.Utils.Tests
Assert.IsTrue(aValues.Contains(eTestFlagsEnum.D)); Assert.IsTrue(aValues.Contains(eTestFlagsEnum.D));
} }
[Test] [Test]
public void GetAllFlagCombinationsExceptNoneGenericTest() public void GetAllFlagCombinationsExceptNoneGenericTest()
{ {
eTestFlagsEnum a = eTestFlagsEnum.A | eTestFlagsEnum.B | eTestFlagsEnum.C | eTestFlagsEnum.D; eTestFlagsEnum a = eTestFlagsEnum.A | eTestFlagsEnum.B | eTestFlagsEnum.C | eTestFlagsEnum.D;
eTestFlagsEnum[] aValues = EnumUtils.GetAllFlagCombinationsExceptNone(a).ToArray(); eTestFlagsEnum[] aValues = EnumUtils.GetAllFlagCombinationsExceptNone(a).ToArray();
Assert.AreEqual(15, aValues.Length); Assert.AreEqual(15, aValues.Length);
Assert.IsFalse(aValues.Contains(eTestFlagsEnum.None)); Assert.IsFalse(aValues.Contains(eTestFlagsEnum.None));
@@ -167,7 +181,8 @@ namespace ICD.Common.Utils.Tests
public void GetFlagsAllValueGenericTest() public void GetFlagsAllValueGenericTest()
{ {
eTestFlagsEnum value = EnumUtils.GetFlagsAllValue<eTestFlagsEnum>(); eTestFlagsEnum value = EnumUtils.GetFlagsAllValue<eTestFlagsEnum>();
Assert.AreEqual(eTestFlagsEnum.None | eTestFlagsEnum.A | eTestFlagsEnum.B | eTestFlagsEnum.C | eTestFlagsEnum.D, value); Assert.AreEqual(eTestFlagsEnum.None | eTestFlagsEnum.A | eTestFlagsEnum.B | eTestFlagsEnum.C | eTestFlagsEnum.D,
value);
} }
[Test] [Test]
@@ -181,7 +196,8 @@ namespace ICD.Common.Utils.Tests
public void HasFlagsGenericTest() public void HasFlagsGenericTest()
{ {
Assert.IsTrue(EnumUtils.HasFlags(eTestFlagsEnum.A | eTestFlagsEnum.B, eTestFlagsEnum.A | eTestFlagsEnum.B)); Assert.IsTrue(EnumUtils.HasFlags(eTestFlagsEnum.A | eTestFlagsEnum.B, eTestFlagsEnum.A | eTestFlagsEnum.B));
Assert.IsFalse(EnumUtils.HasFlags(eTestFlagsEnum.A | eTestFlagsEnum.B, eTestFlagsEnum.A | eTestFlagsEnum.C)); Assert.IsFalse(EnumUtils.HasFlags(eTestFlagsEnum.A | eTestFlagsEnum.B,
eTestFlagsEnum.A | eTestFlagsEnum.C));
} }
[Test] [Test]
@@ -200,28 +216,28 @@ namespace ICD.Common.Utils.Tests
Assert.IsTrue(EnumUtils.HasMultipleFlags(eTestFlagsEnum.A | eTestFlagsEnum.B)); Assert.IsTrue(EnumUtils.HasMultipleFlags(eTestFlagsEnum.A | eTestFlagsEnum.B));
} }
[TestCase(false, eTestFlagsEnum.None)] [TestCase(false, eTestFlagsEnum.None)]
[TestCase(true, eTestFlagsEnum.B)] [TestCase(true, eTestFlagsEnum.B)]
[TestCase(true, eTestFlagsEnum.B | eTestFlagsEnum.C)] [TestCase(true, eTestFlagsEnum.B | eTestFlagsEnum.C)]
public void HasAnyFlagsTest(bool expected, eTestFlagsEnum value) public void HasAnyFlagsTest(bool expected, eTestFlagsEnum value)
{ {
Assert.AreEqual(expected, EnumUtils.HasAnyFlags(value)); Assert.AreEqual(expected, EnumUtils.HasAnyFlags(value));
} }
[TestCase(false, eTestFlagsEnum.None, eTestFlagsEnum.None)] [TestCase(false, eTestFlagsEnum.None, eTestFlagsEnum.None)]
[TestCase(false, eTestFlagsEnum.None, eTestFlagsEnum.B)] [TestCase(false, eTestFlagsEnum.None, eTestFlagsEnum.B)]
[TestCase(true, eTestFlagsEnum.B, eTestFlagsEnum.B)] [TestCase(true, eTestFlagsEnum.B, eTestFlagsEnum.B)]
[TestCase(false, eTestFlagsEnum.None | eTestFlagsEnum.A, eTestFlagsEnum.B | eTestFlagsEnum.C)] [TestCase(false, eTestFlagsEnum.None | eTestFlagsEnum.A, eTestFlagsEnum.B | eTestFlagsEnum.C)]
public void HasAnyFlagsValueTest(bool expected, eTestFlagsEnum value, eTestFlagsEnum other) public void HasAnyFlagsValueTest(bool expected, eTestFlagsEnum value, eTestFlagsEnum other)
{ {
Assert.AreEqual(expected, EnumUtils.HasAnyFlags(value, other)); Assert.AreEqual(expected, EnumUtils.HasAnyFlags(value, other));
} }
#endregion #endregion
#region Conversion #region Conversion
[Test] [Test]
public void ParseGenericTest() public void ParseGenericTest()
{ {
Assert.AreEqual(eTestEnum.A, EnumUtils.Parse<eTestEnum>("A", false)); Assert.AreEqual(eTestEnum.A, EnumUtils.Parse<eTestEnum>("A", false));
@@ -243,35 +259,35 @@ namespace ICD.Common.Utils.Tests
Assert.AreEqual(default(eTestEnum), output); Assert.AreEqual(default(eTestEnum), output);
} }
[Test] [Test]
public void ParseStrictGenericTest() public void ParseStrictGenericTest()
{ {
Assert.AreEqual(eTestEnum.A, EnumUtils.ParseStrict<eTestEnum>("1", false)); Assert.AreEqual(eTestEnum.A, EnumUtils.ParseStrict<eTestEnum>("1", false));
Assert.Throws<ArgumentOutOfRangeException>(() => EnumUtils.ParseStrict<eTestEnum>("4", false)); Assert.Throws<ArgumentOutOfRangeException>(() => EnumUtils.ParseStrict<eTestEnum>("4", false));
Assert.AreEqual(eTestFlagsEnum.A | eTestFlagsEnum.B, EnumUtils.ParseStrict<eTestFlagsEnum>("3", false)); Assert.AreEqual(eTestFlagsEnum.A | eTestFlagsEnum.B, EnumUtils.ParseStrict<eTestFlagsEnum>("3", false));
Assert.Throws<ArgumentOutOfRangeException>(() => EnumUtils.ParseStrict<eTestFlagsEnum>("8", false)); Assert.Throws<ArgumentOutOfRangeException>(() => EnumUtils.ParseStrict<eTestFlagsEnum>("8", false));
} }
[Test] [Test]
public void TryParseStrictGenericTest() public void TryParseStrictGenericTest()
{ {
eTestEnum outputA; eTestEnum outputA;
Assert.AreEqual(true, EnumUtils.TryParseStrict("1", false, out outputA)); Assert.AreEqual(true, EnumUtils.TryParseStrict("1", false, out outputA));
Assert.AreEqual(eTestEnum.A, outputA); Assert.AreEqual(eTestEnum.A, outputA);
Assert.AreEqual(false, EnumUtils.TryParseStrict("4", false, out outputA)); Assert.AreEqual(false, EnumUtils.TryParseStrict("4", false, out outputA));
Assert.AreEqual(eTestEnum.None, outputA); Assert.AreEqual(eTestEnum.None, outputA);
eTestFlagsEnum outputB; eTestFlagsEnum outputB;
Assert.AreEqual(true, EnumUtils.TryParseStrict("3", false, out outputB)); Assert.AreEqual(true, EnumUtils.TryParseStrict("3", false, out outputB));
Assert.AreEqual(eTestFlagsEnum.A | eTestFlagsEnum.B, outputB); Assert.AreEqual(eTestFlagsEnum.A | eTestFlagsEnum.B, outputB);
Assert.AreEqual(false, EnumUtils.TryParseStrict("8", false, out outputB)); Assert.AreEqual(false, EnumUtils.TryParseStrict("8", false, out outputB));
Assert.AreEqual(eTestFlagsEnum.None, outputB); Assert.AreEqual(eTestFlagsEnum.None, outputB);
} }
#endregion #endregion
} }

View File

@@ -13,16 +13,16 @@ namespace ICD.Common.Utils
{ {
public static class EnumUtils public static class EnumUtils
{ {
private static readonly Dictionary<Type, object> s_EnumValuesCache; private static readonly Dictionary<Type, object[]> s_EnumValuesCache;
private static readonly Dictionary<Type, Dictionary<int, object>> s_EnumFlagsCache; private static readonly Dictionary<Type, Dictionary<object, object[]>> s_EnumFlagsCache;
/// <summary> /// <summary>
/// Static constructor. /// Static constructor.
/// </summary> /// </summary>
static EnumUtils() static EnumUtils()
{ {
s_EnumValuesCache = new Dictionary<Type, object>(); s_EnumValuesCache = new Dictionary<Type, object[]>();
s_EnumFlagsCache = new Dictionary<Type, Dictionary<int, object>>(); s_EnumFlagsCache = new Dictionary<Type, Dictionary<object, object[]>>();
} }
#region Validation #region Validation
@@ -58,7 +58,6 @@ namespace ICD.Common.Utils
/// <returns></returns> /// <returns></returns>
public static bool IsEnum<T>(T value) public static bool IsEnum<T>(T value)
{ {
// ReSharper disable once CompareNonConstrainedGenericWithNull
return value != null && IsEnumType(value.GetType()); return value != null && IsEnumType(value.GetType());
} }
@@ -123,6 +122,30 @@ namespace ICD.Common.Utils
#region Values #region Values
/// <summary>
/// Gets the names of the values in the enumeration.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static IEnumerable<string> GetNames<T>()
where T : struct, IConvertible
{
return GetNames(typeof(T));
}
/// <summary>
/// Gets the names of the values in the enumeration.
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
public static IEnumerable<string> GetNames([NotNull] Type type)
{
if (type == null)
throw new ArgumentNullException("type");
return GetValues(type).Select(v => v.ToString());
}
/// <summary> /// <summary>
/// Gets the values from an enumeration. /// Gets the values from an enumeration.
/// </summary> /// </summary>
@@ -131,46 +154,51 @@ namespace ICD.Common.Utils
public static IEnumerable<T> GetValues<T>() public static IEnumerable<T> GetValues<T>()
where T : struct, IConvertible where T : struct, IConvertible
{ {
Type type = typeof(T); return GetValues(typeof(T)).Cast<T>();
// Reflection is slow and this method is called a lot, so we cache the results.
object cache;
if (!s_EnumValuesCache.TryGetValue(type, out cache))
{
cache = GetValuesUncached<T>().ToArray();
s_EnumValuesCache[type] = cache;
}
return cache as T[];
} }
/// <summary> /// <summary>
/// Gets the values from an enumeration without performing any caching. This is slow because of reflection. /// Gets the values from an enumeration.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public static IEnumerable<int> GetValues(Type type) public static IEnumerable<object> GetValues(Type type)
{ {
if (type == null) if (type == null)
throw new ArgumentNullException("type"); throw new ArgumentNullException("type");
return GetValuesUncached(type); // Reflection is slow and this method is called a lot, so we cache the results.
return s_EnumValuesCache.GetOrAddNew(type, () => GetValuesUncached(type).ToArray());
} }
/// <summary> /// <summary>
/// Gets the values from an enumeration without performing any caching. This is slow because of reflection. /// Gets the values from an enumeration except the 0 value.
/// </summary> /// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns> /// <returns></returns>
private static IEnumerable<T> GetValuesUncached<T>() public static IEnumerable<T> GetValuesExceptNone<T>()
where T : struct, IConvertible where T : struct, IConvertible
{ {
return GetValuesUncached(typeof(T)).Select(i => (T)(object)i); return GetValuesExceptNone(typeof(T)).Cast<T>();
}
/// <summary>
/// Gets the values from an enumeration except the 0 value.
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
public static IEnumerable<object> GetValuesExceptNone([NotNull] Type type)
{
if (type == null)
throw new ArgumentNullException("type");
return GetFlagsExceptNone(type);
} }
/// <summary> /// <summary>
/// Gets the values from an enumeration without performing any caching. This is slow because of reflection. /// Gets the values from an enumeration without performing any caching. This is slow because of reflection.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
private static IEnumerable<int> GetValuesUncached(Type type) private static IEnumerable<object> GetValuesUncached(Type type)
{ {
if (type == null) if (type == null)
throw new ArgumentNullException("type"); throw new ArgumentNullException("type");
@@ -182,34 +210,10 @@ namespace ICD.Common.Utils
#if SIMPLSHARP #if SIMPLSHARP
.GetCType() .GetCType()
#else #else
.GetTypeInfo() .GetTypeInfo()
#endif #endif
.GetFields(BindingFlags.Static | BindingFlags.Public) .GetFields(BindingFlags.Static | BindingFlags.Public)
.Select(x => x.GetValue(null)) .Select(x => x.GetValue(null));
.Cast<int>();
}
/// <summary>
/// Gets the values from an enumeration except the 0 value.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static IEnumerable<T> GetValuesExceptNone<T>()
where T : struct, IConvertible
{
return GetFlagsExceptNone<T>();
}
/// <summary>
/// Gets the values from an enumeration except the 0 value without performing any caching. This is slow because of reflection.
/// </summary>
/// <returns></returns>
public static IEnumerable<int> GetValuesExceptNone(Type type)
{
if (type == null)
throw new ArgumentNullException("type");
return GetValues(type).Except(0);
} }
#endregion #endregion
@@ -330,24 +334,21 @@ namespace ICD.Common.Utils
public static IEnumerable<T> GetFlags<T>(T value) public static IEnumerable<T> GetFlags<T>(T value)
where T : struct, IConvertible where T : struct, IConvertible
{ {
Type type = typeof(T); return GetFlags(typeof(T), value).Cast<T>();
int valueInt = (int)(object)value; }
Dictionary<int, object> cache; /// <summary>
if (!s_EnumFlagsCache.TryGetValue(type, out cache)) /// Gets all of the set flags on the given enum.
{ /// </summary>
cache = new Dictionary<int, object>(); /// <param name="type"></param>
s_EnumFlagsCache[type] = cache; /// <param name="value"></param>
} /// <returns></returns>
public static IEnumerable<object> GetFlags(Type type, object value)
object flags; {
if (!cache.TryGetValue(valueInt, out flags)) return s_EnumFlagsCache.GetOrAddNew(type, () => new Dictionary<object, object[]>())
{ .GetOrAddNew(value, () => GetValues(type)
flags = GetValues<T>().Where(e => HasFlag(value, e)).ToArray(); .Where(f => HasFlag(value, f))
cache[valueInt] = flags; .ToArray());
}
return flags as T[];
} }
/// <summary> /// <summary>
@@ -358,8 +359,7 @@ namespace ICD.Common.Utils
public static IEnumerable<T> GetFlagsExceptNone<T>() public static IEnumerable<T> GetFlagsExceptNone<T>()
where T : struct, IConvertible where T : struct, IConvertible
{ {
T allValue = GetFlagsAllValue<T>(); return GetFlagsExceptNone(typeof(T)).Cast<T>();
return GetFlagsExceptNone(allValue);
} }
/// <summary> /// <summary>
@@ -371,7 +371,35 @@ namespace ICD.Common.Utils
public static IEnumerable<T> GetFlagsExceptNone<T>(T value) public static IEnumerable<T> GetFlagsExceptNone<T>(T value)
where T : struct, IConvertible where T : struct, IConvertible
{ {
return GetFlags(value).Except(default(T)); return GetFlagsExceptNone(typeof(T), value).Cast<T>();
}
/// <summary>
/// Gets all of the set flags on the given enum type except 0.
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
public static IEnumerable<object> GetFlagsExceptNone([NotNull] Type type)
{
if (type == null)
throw new ArgumentNullException("type");
object allValue = GetFlagsAllValue(type);
return GetFlagsExceptNone(type, allValue);
}
/// <summary>
/// Gets all of the set flags on the given enum except 0.
/// </summary>
/// <param name="type"></param>
/// <param name="value"></param>
/// <returns></returns>
public static IEnumerable<object> GetFlagsExceptNone([NotNull] Type type, object value)
{
if (type == null)
throw new ArgumentNullException("type");
return GetFlags(type, value).Where(f => (int)f != 0);
} }
/// <summary> /// <summary>
@@ -412,7 +440,7 @@ namespace ICD.Common.Utils
if (type == null) if (type == null)
throw new ArgumentNullException("type"); throw new ArgumentNullException("type");
return GetValuesUncached(type).Aggregate(0, (current, value) => current | value); return GetValuesUncached(type).Aggregate(0, (current, value) => current | (int)value);
} }
/// <summary> /// <summary>
@@ -428,6 +456,17 @@ namespace ICD.Common.Utils
return HasFlags(value, flag); return HasFlags(value, flag);
} }
/// <summary>
/// Returns true if the enum contains the given flag.
/// </summary>
/// <param name="value"></param>
/// <param name="flag"></param>
/// <returns></returns>
public static bool HasFlag(object value, object flag)
{
return HasFlag((int)value, (int)flag);
}
/// <summary> /// <summary>
/// Returns true if the enum contains the given flag. /// Returns true if the enum contains the given flag.
/// </summary> /// </summary>