Some time ago (over 4.5 years ago!) I wrote a blog post titled: "A better enumeration - Type safe from start to end". While cool and so - it had some issues. Let's tackle them!
Enumeration
The basic Enumeration class - modernized - looks like this:
public abstract record Enumeration<TEnumeration>(string Key)
where TEnumeration : Enumeration<TEnumeration>
{
public static FrozenSet<TEnumeration> All { get; } = GetEnumerations();
public static bool operator ==(Enumeration<TEnumeration>? a, string? b)
=> a is not null && b is not null && a.Key.Equals(b, StringComparison.Ordinal);
public static bool operator !=(Enumeration<TEnumeration>? a, string? b) => !(a == b);
public static TEnumeration Create(string key)
=> All.SingleOrDefault(p => p.Key == key)
?? throw new InvalidOperationException($"{key} is not a valid value for {typeof(TEnumeration).Name}");
public sealed override string ToString() => Key;
private static FrozenSet<TEnumeration> GetEnumerations()
{
var enumerationType = typeof(TEnumeration);
return enumerationType
.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(info => info.FieldType == typeof(TEnumeration))
.Select(info => (TEnumeration)info.GetValue(null)!)
.ToFrozenSet();
}
}
Which let's you do fun things like:
public record Color(string Key) : Enumeration<Color>(Key)
{
public static readonly Color Red = new("Red");
public static readonly Color Green = new("Green");
public static readonly Color Blue = new("Blue");
}
var color = Color.Create("Red");
var isGreen = color == "Green"; // false
var isRed = color == "Red"; // true
Issues
We have two major issues:
- We are using reflection to fill the
FrozenSet<TEnumeration> Allproperty. This is not only slow but also can cause issues in some environments (like AOT compilation). - We have to define each
static readonlyfield even though noone is forcing us to do! If you define a newColorwithout defining the static fields - you won't get any errors until you try to use theCreatemethod with a key that doesn't exist in theAllset.
With point 2 there is another design flaw: I can not make exhaustion checks easily! So let's solve that:
## Source Code Generation
Let's start first with the generated code will look like and how to use it:
[Enumeration("Red", "Green", "Blue")]
public sealed partial record Color;
Yes this is absolutey all. The EnumerationAttribute is generated by the generator and the generated class looks like this:
// <auto-generated/>
#nullable enable
using System;
using System.Collections.Frozen;
using System.Linq;
namespace LinkDotNet.Blog.Infrastructure.Persistence;
public sealed partial record Color
{
public string Key { get; init; } = default!;
private Color(string key)
{
ArgumentException.ThrowIfNullOrWhiteSpace(key);
Key = key;
}
public static readonly Color Red = new("Red");
public static readonly Color Green = new("Green");
public static readonly Color Blue = new("Blue");
public static FrozenSet<Color> All { get; } =
new Color[] { Red, Green, Blue }.ToFrozenSet();
public static Color Create(string key)
{
ArgumentException.ThrowIfNullOrWhiteSpace(key);
return All.SingleOrDefault(p => p.Key == key)
?? throw new InvalidOperationException($"{key} is not a valid value for Color");
}
public static bool operator ==(Color? a, string? b)
=> a is not null && b is not null && a.Key.Equals(b, StringComparison.Ordinal);
public static bool operator !=(Color? a, string? b) => !(a == b);
public override string ToString() => Key;
public T Match<T>(Func<T> onRed, Func<T> onGreen, Func<T> onBlue)
{
if (Key == Red.Key) return onRed();
if (Key == Green.Key) return onGreen();
if (Key == Blue.Key) return onBlue();
throw new InvalidOperationException($"Unhandled enumeration value: {Key}");
}
public void Match(Action onRed, Action onGreen, Action onBlue)
{
if (Key == Red.Key) { onRed(); return; }
if (Key == Green.Key) { onGreen(); return; }
if (Key == Blue.Key) { onBlue(); return; }
throw new InvalidOperationException($"Unhandled enumeration value: {Key}");
}
}
So the generator does all the boilerplate stuff! But it also does something more: Match. That is my way of enforcing exhaustion checks. So imagine you have a method like this:
public void PrintColor(Color color)
{
if (color == Color.Red)
{
Console.WriteLine("Red");
}
else if (color == Color.Green)
{
Console.WriteLine("Green");
}
else if (color == Color.Blue)
{
Console.WriteLine("Blue");
}
}
You can forget to handle a new color if you add it later on (or one that is already present). With the Match method you can do this:
public void PrintColor(Color color)
{
color.Match(
onRed: () => Console.WriteLine("Red"),
onGreen: () => Console.WriteLine("Green"),
onBlue: () => Console.WriteLine("Blue")
);
}
The generator itself
The generator itself is pretty straightforward and basically just a) generates the attribute and b) generates the class based on the attribute.
[Generator]
public sealed class EnumerationGenerator : IIncrementalGenerator
{
private const string FullAttributeName = "LinkDotNet.Blog.Generators.EnumerationAttribute";
private const string AttributeSource = """
// <auto-generated/>
#nullable enable
using System;
namespace LinkDotNet.Blog.Generators;
/// <summary>
/// Marks a <c>sealed partial record</c> as a source-generated enumeration.
/// The generator emits: Key property, private constructor, static readonly fields,
/// All, Create, == / != operators, ToString, Match<T> and Match(Action).
/// </summary>
[AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)]
internal sealed class EnumerationAttribute : Attribute
{
public string[] Values { get; }
public EnumerationAttribute(params string[] values) => Values = values;
}
""";
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
ctx.AddSource("EnumerationAttribute.g.cs", SourceText.From(AttributeSource, Encoding.UTF8)));
var models = context.SyntaxProvider
.ForAttributeWithMetadataName(
FullAttributeName,
predicate: static (node, _) => node is RecordDeclarationSyntax,
transform: static (ctx, _) => GetModel(ctx))
.Where(static m => m is not null);
context.RegisterSourceOutput(models, static (spc, model) => Emit(spc, model!));
}
private static EnumerationModel? GetModel(GeneratorAttributeSyntaxContext ctx)
{
if (ctx.TargetSymbol is not INamedTypeSymbol type)
{
return null;
}
var attr = ctx.Attributes.FirstOrDefault();
if (attr is null || attr.ConstructorArguments.Length == 0)
{
return null;
}
var arg = attr.ConstructorArguments[0];
var rawValues = arg.Kind == TypedConstantKind.Array ? arg.Values : [arg];
var values = rawValues
.Select(static v => v.Value as string)
.Where(static v => v is not null)
.Select(static v => v!)
.ToImmutableArray();
if (values.IsEmpty)
{
return null;
}
var ns = type.ContainingNamespace.IsGlobalNamespace
? null
: type.ContainingNamespace.ToDisplayString();
return new EnumerationModel(ns, type.Name, values);
}
private static void Emit(SourceProductionContext ctx, EnumerationModel model)
{
var sb = new StringBuilder();
var typeName = model.TypeName;
var allFieldNames = string.Join(", ", model.Values);
sb.AppendLine("// <auto-generated/>");
sb.AppendLine("#nullable enable");
sb.AppendLine();
sb.AppendLine("using System;");
sb.AppendLine("using System.Collections.Frozen;");
sb.AppendLine("using System.Linq;");
sb.AppendLine();
if (model.Namespace is not null)
{
sb.AppendLine($"namespace {model.Namespace};");
sb.AppendLine();
}
sb.AppendLine($"public sealed partial record {typeName}");
sb.AppendLine("{");
sb.AppendLine(" public string Key { get; init; } = default!;");
sb.AppendLine();
sb.AppendLine($" private {typeName}(string key)");
sb.AppendLine(" {");
sb.AppendLine(" ArgumentException.ThrowIfNullOrWhiteSpace(key);");
sb.AppendLine(" Key = key;");
sb.AppendLine(" }");
sb.AppendLine();
foreach (var value in model.Values)
{
sb.AppendLine($" public static readonly {typeName} {value} = new(\"{value}\");");
}
sb.AppendLine();
sb.AppendLine($" public static FrozenSet<{typeName}> All {{ get; }} =");
sb.AppendLine($" new {typeName}[] {{ {allFieldNames} }}.ToFrozenSet();");
sb.AppendLine();
sb.AppendLine($" public static {typeName} Create(string key)");
sb.AppendLine(" {");
sb.AppendLine(" ArgumentException.ThrowIfNullOrWhiteSpace(key);");
sb.AppendLine($" return All.SingleOrDefault(p => p.Key == key)");
sb.AppendLine($" ?? throw new InvalidOperationException($\"{{key}} is not a valid value for {typeName}\");");
sb.AppendLine(" }");
sb.AppendLine();
sb.AppendLine($" public static bool operator ==({typeName}? a, string? b)");
sb.AppendLine(" => a is not null && b is not null && a.Key.Equals(b, StringComparison.Ordinal);");
sb.AppendLine();
sb.AppendLine($" public static bool operator !=({typeName}? a, string? b) => !(a == b);");
sb.AppendLine();
sb.AppendLine(" public override string ToString() => Key;");
sb.AppendLine();
var funcParams = string.Join(", ", model.Values.Select(static v => $"Func<T> on{v}"));
sb.AppendLine($" public T Match<T>({funcParams})");
sb.AppendLine(" {");
foreach (var value in model.Values)
{
sb.AppendLine($" if (Key == {value}.Key) return on{value}();");
}
sb.AppendLine(" throw new InvalidOperationException($\"Unhandled enumeration value: {Key}\");");
sb.AppendLine(" }");
sb.AppendLine();
var actionParams = string.Join(", ", model.Values.Select(static v => $"Action on{v}"));
sb.AppendLine($" public void Match({actionParams})");
sb.AppendLine(" {");
foreach (var value in model.Values)
{
sb.AppendLine($" if (Key == {value}.Key) {{ on{value}(); return; }}");
}
sb.AppendLine(" throw new InvalidOperationException($\"Unhandled enumeration value: {Key}\");");
sb.AppendLine(" }");
sb.AppendLine("}");
ctx.AddSource($"{typeName}.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8));
}
}
internal sealed class EnumerationModel
{
public EnumerationModel(string? ns, string typeName, ImmutableArray<string> values)
{
Namespace = ns;
TypeName = typeName;
Values = values;
}
public string? Namespace { get; }
public string TypeName { get; }
public ImmutableArray<string> Values { get; }
}
You are interested and wanna use it?

