Source code generated string enums with exhaustion support

4/6/2026
11 minute read

Some time ago (over 4.5 years ago!) I wrote a blog post titled: "A better enumeration - Type safe from start to end". While cool and so - it had some issues. Let's tackle them!

Enumeration

The basic Enumeration class - modernized - looks like this:

public abstract record Enumeration<TEnumeration>(string Key)
    where TEnumeration : Enumeration<TEnumeration>
{
    public static FrozenSet<TEnumeration> All { get; } = GetEnumerations();

    public static bool operator ==(Enumeration<TEnumeration>? a, string? b)
        => a is not null && b is not null && a.Key.Equals(b, StringComparison.Ordinal);

    public static bool operator !=(Enumeration<TEnumeration>? a, string? b) => !(a == b);

    public static TEnumeration Create(string key)
        => All.SingleOrDefault(p => p.Key == key)
           ?? throw new InvalidOperationException($"{key} is not a valid value for {typeof(TEnumeration).Name}");

    public sealed override string ToString() => Key;

    private static FrozenSet<TEnumeration> GetEnumerations()
    {
        var enumerationType = typeof(TEnumeration);

        return enumerationType
            .GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
            .Where(info => info.FieldType == typeof(TEnumeration))
            .Select(info => (TEnumeration)info.GetValue(null)!)
            .ToFrozenSet();
    }
}

Which let's you do fun things like:

public record Color(string Key) : Enumeration<Color>(Key)
{
    public static readonly Color Red = new("Red");
    public static readonly Color Green = new("Green");
    public static readonly Color Blue = new("Blue");
}

var color = Color.Create("Red");
var isGreen = color == "Green"; // false
var isRed = color == "Red"; // true

Issues

We have two major issues:

  1. We are using reflection to fill the FrozenSet<TEnumeration> All property. This is not only slow but also can cause issues in some environments (like AOT compilation).
  2. We have to define each static readonly field even though noone is forcing us to do! If you define a new Color without defining the static fields - you won't get any errors until you try to use the Create method with a key that doesn't exist in the All set.

With point 2 there is another design flaw: I can not make exhaustion checks easily! So let's solve that:

## Source Code Generation

Let's start first with the generated code will look like and how to use it:

[Enumeration("Red", "Green", "Blue")]
public sealed partial record Color;

Yes this is absolutey all. The EnumerationAttribute is generated by the generator and the generated class looks like this:

// <auto-generated/>
#nullable enable

using System;
using System.Collections.Frozen;
using System.Linq;

namespace LinkDotNet.Blog.Infrastructure.Persistence;

public sealed partial record Color
{
    public string Key { get; init; } = default!;

    private Color(string key)
    {
        ArgumentException.ThrowIfNullOrWhiteSpace(key);
        Key = key;
    }

    public static readonly Color Red = new("Red");
    public static readonly Color Green = new("Green");
    public static readonly Color Blue = new("Blue");

    public static FrozenSet<Color> All { get; } =
        new Color[] { Red, Green, Blue }.ToFrozenSet();

    public static Color Create(string key)
    {
        ArgumentException.ThrowIfNullOrWhiteSpace(key);
        return All.SingleOrDefault(p => p.Key == key)
            ?? throw new InvalidOperationException($"{key} is not a valid value for Color");
    }

    public static bool operator ==(Color? a, string? b)
        => a is not null && b is not null && a.Key.Equals(b, StringComparison.Ordinal);

    public static bool operator !=(Color? a, string? b) => !(a == b);

    public override string ToString() => Key;

    public T Match<T>(Func<T> onRed, Func<T> onGreen, Func<T> onBlue)
    {
        if (Key == Red.Key) return onRed();
        if (Key == Green.Key) return onGreen();
        if (Key == Blue.Key) return onBlue();
        throw new InvalidOperationException($"Unhandled enumeration value: {Key}");
    }

    public void Match(Action onRed, Action onGreen, Action onBlue)
    {
        if (Key == Red.Key) { onRed(); return; }
        if (Key == Green.Key) { onGreen(); return; }
        if (Key == Blue.Key) { onBlue(); return; }
        throw new InvalidOperationException($"Unhandled enumeration value: {Key}");
    }
}

So the generator does all the boilerplate stuff! But it also does something more: Match. That is my way of enforcing exhaustion checks. So imagine you have a method like this:

public void PrintColor(Color color)
{
    if (color == Color.Red)
    {
        Console.WriteLine("Red");
    }
    else if (color == Color.Green)
    {
        Console.WriteLine("Green");
    }
    else if (color == Color.Blue)
    {
        Console.WriteLine("Blue");
    }
}

You can forget to handle a new color if you add it later on (or one that is already present). With the Match method you can do this:

public void PrintColor(Color color)
{
    color.Match(
        onRed: () => Console.WriteLine("Red"),
        onGreen: () => Console.WriteLine("Green"),
        onBlue: () => Console.WriteLine("Blue")
    );
}

The generator itself

The generator itself is pretty straightforward and basically just a) generates the attribute and b) generates the class based on the attribute.

[Generator]
public sealed class EnumerationGenerator : IIncrementalGenerator
{
    private const string FullAttributeName = "LinkDotNet.Blog.Generators.EnumerationAttribute";

    private const string AttributeSource = """
        // <auto-generated/>
        #nullable enable

        using System;

        namespace LinkDotNet.Blog.Generators;

        /// <summary>
        /// Marks a <c>sealed partial record</c> as a source-generated enumeration.
        /// The generator emits: Key property, private constructor, static readonly fields,
        /// All, Create, == / != operators, ToString, Match&lt;T&gt; and Match(Action).
        /// </summary>
        [AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)]
        internal sealed class EnumerationAttribute : Attribute
        {
            public string[] Values { get; }
            public EnumerationAttribute(params string[] values) => Values = values;
        }
        """;

    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        context.RegisterPostInitializationOutput(static ctx =>
            ctx.AddSource("EnumerationAttribute.g.cs", SourceText.From(AttributeSource, Encoding.UTF8)));

        var models = context.SyntaxProvider
            .ForAttributeWithMetadataName(
                FullAttributeName,
                predicate: static (node, _) => node is RecordDeclarationSyntax,
                transform: static (ctx, _) => GetModel(ctx))
            .Where(static m => m is not null);

        context.RegisterSourceOutput(models, static (spc, model) => Emit(spc, model!));
    }

    private static EnumerationModel? GetModel(GeneratorAttributeSyntaxContext ctx)
    {
        if (ctx.TargetSymbol is not INamedTypeSymbol type)
        {
            return null;
        }

        var attr = ctx.Attributes.FirstOrDefault();
        if (attr is null || attr.ConstructorArguments.Length == 0)
        {
            return null;
        }

        var arg = attr.ConstructorArguments[0];
        var rawValues = arg.Kind == TypedConstantKind.Array ? arg.Values : [arg];

        var values = rawValues
            .Select(static v => v.Value as string)
            .Where(static v => v is not null)
            .Select(static v => v!)
            .ToImmutableArray();

        if (values.IsEmpty)
        {
            return null;
        }

        var ns = type.ContainingNamespace.IsGlobalNamespace
            ? null
            : type.ContainingNamespace.ToDisplayString();

        return new EnumerationModel(ns, type.Name, values);
    }

    private static void Emit(SourceProductionContext ctx, EnumerationModel model)
    {
        var sb = new StringBuilder();
        var typeName = model.TypeName;
        var allFieldNames = string.Join(", ", model.Values);

        sb.AppendLine("// <auto-generated/>");
        sb.AppendLine("#nullable enable");
        sb.AppendLine();
        sb.AppendLine("using System;");
        sb.AppendLine("using System.Collections.Frozen;");
        sb.AppendLine("using System.Linq;");
        sb.AppendLine();

        if (model.Namespace is not null)
        {
            sb.AppendLine($"namespace {model.Namespace};");
            sb.AppendLine();
        }

        sb.AppendLine($"public sealed partial record {typeName}");
        sb.AppendLine("{");

        sb.AppendLine("    public string Key { get; init; } = default!;");
        sb.AppendLine();

        sb.AppendLine($"    private {typeName}(string key)");
        sb.AppendLine("    {");
        sb.AppendLine("        ArgumentException.ThrowIfNullOrWhiteSpace(key);");
        sb.AppendLine("        Key = key;");
        sb.AppendLine("    }");
        sb.AppendLine();

        foreach (var value in model.Values)
        {
            sb.AppendLine($"    public static readonly {typeName} {value} = new(\"{value}\");");
        }

        sb.AppendLine();

        sb.AppendLine($"    public static FrozenSet<{typeName}> All {{ get; }} =");
        sb.AppendLine($"        new {typeName}[] {{ {allFieldNames} }}.ToFrozenSet();");
        sb.AppendLine();

        sb.AppendLine($"    public static {typeName} Create(string key)");
        sb.AppendLine("    {");
        sb.AppendLine("        ArgumentException.ThrowIfNullOrWhiteSpace(key);");
        sb.AppendLine($"        return All.SingleOrDefault(p => p.Key == key)");
        sb.AppendLine($"            ?? throw new InvalidOperationException($\"{{key}} is not a valid value for {typeName}\");");
        sb.AppendLine("    }");
        sb.AppendLine();

        sb.AppendLine($"    public static bool operator ==({typeName}? a, string? b)");
        sb.AppendLine("        => a is not null && b is not null && a.Key.Equals(b, StringComparison.Ordinal);");
        sb.AppendLine();
        sb.AppendLine($"    public static bool operator !=({typeName}? a, string? b) => !(a == b);");
        sb.AppendLine();

        sb.AppendLine("    public override string ToString() => Key;");
        sb.AppendLine();

        var funcParams = string.Join(", ", model.Values.Select(static v => $"Func<T> on{v}"));
        sb.AppendLine($"    public T Match<T>({funcParams})");
        sb.AppendLine("    {");
        foreach (var value in model.Values)
        {
            sb.AppendLine($"        if (Key == {value}.Key) return on{value}();");
        }

        sb.AppendLine("        throw new InvalidOperationException($\"Unhandled enumeration value: {Key}\");");
        sb.AppendLine("    }");
        sb.AppendLine();

        var actionParams = string.Join(", ", model.Values.Select(static v => $"Action on{v}"));
        sb.AppendLine($"    public void Match({actionParams})");
        sb.AppendLine("    {");
        foreach (var value in model.Values)
        {
            sb.AppendLine($"        if (Key == {value}.Key) {{ on{value}(); return; }}");
        }

        sb.AppendLine("        throw new InvalidOperationException($\"Unhandled enumeration value: {Key}\");");
        sb.AppendLine("    }");

        sb.AppendLine("}");

        ctx.AddSource($"{typeName}.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8));
    }
}

internal sealed class EnumerationModel
{
    public EnumerationModel(string? ns, string typeName, ImmutableArray<string> values)
    {
        Namespace = ns;
        TypeName = typeName;
        Values = values;
    }

    public string? Namespace { get; }
    public string TypeName { get; }
    public ImmutableArray<string> Values { get; }
}

You are interested and wanna use it?

An error has occurred. This application may no longer respond until reloaded.Reload x