< Summary

Information
Class: NexusLabs.Needlr.AgentFramework.Generators.GraphData
Assembly: NexusLabs.Needlr.AgentFramework.Generators
File(s): /home/runner/work/needlr/needlr/src/NexusLabs.Needlr.AgentFramework.Generators/AgentFrameworkFunctionRegistryGenerator.cs
Line coverage
100%
Covered lines: 14
Uncovered lines: 0
Coverable lines: 14
Total lines: 351
Line coverage: 100%
Branch coverage
N/A
Covered branches: 0
Total branches: 0
Branch coverage: N/A
Method coverage

Feature is only available for sponsors

Upgrade to PRO version

Metrics

MethodBranch coverage Crap Score Cyclomatic complexity Line coverage
.ctor(...)100%11100%
get_Edges()100%11100%
get_EntryPoints()100%11100%
get_Nodes()100%11100%
get_Reducers()100%11100%

File(s)

/home/runner/work/needlr/needlr/src/NexusLabs.Needlr.AgentFramework.Generators/AgentFrameworkFunctionRegistryGenerator.cs

#LineLine coverage
 1// Copyright (c) NexusLabs. All rights reserved.
 2// Licensed under the MIT License.
 3
 4using System;
 5using System.Collections.Generic;
 6using System.Collections.Immutable;
 7using System.Linq;
 8using System.Text;
 9using System.Threading;
 10
 11using Microsoft.CodeAnalysis;
 12using Microsoft.CodeAnalysis.CSharp.Syntax;
 13using Microsoft.CodeAnalysis.Text;
 14
 15namespace NexusLabs.Needlr.AgentFramework.Generators;
 16
 17/// <summary>
 18/// Source generator for Microsoft Agent Framework functions.
 19/// Discovers classes with [AgentFunction] methods and generates a compile-time type registry.
 20/// Also discovers classes with [AgentFunctionGroup] attributes and generates a group registry.
 21/// Also discovers classes with [NeedlrAiAgent] attributes and generates an agent registry.
 22/// Always emits a [ModuleInitializer] that auto-registers all discovered types with
 23/// AgentFrameworkGeneratedBootstrap on assembly load.
 24/// </summary>
 25[Generator]
 26public class AgentFrameworkFunctionRegistryGenerator : IIncrementalGenerator
 27{
 28    private const string NeedlrAiAgentAttributeName = "NexusLabs.Needlr.AgentFramework.NeedlrAiAgentAttribute";
 29    private const string AgentHandoffsToAttributeName = "NexusLabs.Needlr.AgentFramework.AgentHandoffsToAttribute";
 30    private const string AgentGroupChatMemberAttributeName = "NexusLabs.Needlr.AgentFramework.AgentGroupChatMemberAttrib
 31    private const string AgentSequenceMemberAttributeName = "NexusLabs.Needlr.AgentFramework.AgentSequenceMemberAttribut
 32    private const string WorkflowRunTerminationConditionAttributeName = "NexusLabs.Needlr.AgentFramework.WorkflowRunTerm
 33    private const string ProgressSinksAttributeName = "NexusLabs.Needlr.AgentFramework.ProgressSinksAttribute";
 34    private const string AgentGraphEdgeAttributeName = "NexusLabs.Needlr.AgentFramework.AgentGraphEdgeAttribute";
 35    private const string AgentGraphEntryAttributeName = "NexusLabs.Needlr.AgentFramework.AgentGraphEntryAttribute";
 36    private const string AgentGraphNodeAttributeName = "NexusLabs.Needlr.AgentFramework.AgentGraphNodeAttribute";
 37    private const string AgentGraphReducerAttributeName = "NexusLabs.Needlr.AgentFramework.AgentGraphReducerAttribute";
 38
 39    public void Initialize(IncrementalGeneratorInitializationContext context)
 40    {
 41        // [AgentFunction] method-bearing classes → AgentFrameworkFunctionRegistry
 42        var functionClasses = context.SyntaxProvider
 43            .CreateSyntaxProvider(
 44                predicate: static (s, _) => s is ClassDeclarationSyntax,
 45                transform: static (ctx, ct) => AgentDiscoveryHelper.GetAgentFunctionTypeInfo(ctx, ct))
 46            .Where(static m => m is not null);
 47
 48        // [AgentFunctionGroup] class-level annotations → AgentFrameworkFunctionGroupRegistry
 49        var groupClasses = context.SyntaxProvider
 50            .CreateSyntaxProvider(
 51                predicate: static (s, _) => s is ClassDeclarationSyntax,
 52                transform: static (ctx, ct) => AgentDiscoveryHelper.GetAgentFunctionGroupEntries(ctx, ct))
 53            .Where(static arr => arr.Length > 0);
 54
 55        // [NeedlrAiAgent] declared agent types → AgentRegistry + partial companions
 56        var agentClasses = context.SyntaxProvider
 57            .ForAttributeWithMetadataName(
 58                NeedlrAiAgentAttributeName,
 59                predicate: static (s, _) => s is ClassDeclarationSyntax,
 60                transform: static (ctx, ct) => AgentDiscoveryHelper.GetNeedlrAiAgentTypeInfo(ctx, ct))
 61            .Where(static m => m is not null);
 62
 63        // [AgentHandoffsTo] annotations → handoff topology registry
 64        var handoffEntries = context.SyntaxProvider
 65            .ForAttributeWithMetadataName(
 66                AgentHandoffsToAttributeName,
 67                predicate: static (s, _) => s is ClassDeclarationSyntax,
 68                transform: static (ctx, ct) => AgentDiscoveryHelper.GetHandoffEntries(ctx, ct))
 69            .Where(static arr => arr.Length > 0);
 70
 71        // [AgentGroupChatMember] annotations → group chat registry
 72        var groupChatEntries = context.SyntaxProvider
 73            .ForAttributeWithMetadataName(
 74                AgentGroupChatMemberAttributeName,
 75                predicate: static (s, _) => s is ClassDeclarationSyntax,
 76                transform: static (ctx, ct) => AgentDiscoveryHelper.GetGroupChatEntries(ctx, ct))
 77            .Where(static arr => arr.Length > 0);
 78
 79        // [AgentSequenceMember] annotations → sequential pipeline registry
 80        var sequenceEntries = context.SyntaxProvider
 81            .ForAttributeWithMetadataName(
 82                AgentSequenceMemberAttributeName,
 83                predicate: static (s, _) => s is ClassDeclarationSyntax,
 84                transform: static (ctx, ct) => AgentDiscoveryHelper.GetSequenceEntries(ctx, ct))
 85            .Where(static arr => arr.Length > 0);
 86
 87        // [WorkflowRunTerminationCondition] → termination conditions per agent
 88        var terminationConditionEntries = context.SyntaxProvider
 89            .ForAttributeWithMetadataName(
 90                WorkflowRunTerminationConditionAttributeName,
 91                predicate: static (s, _) => s is ClassDeclarationSyntax,
 92                transform: static (ctx, ct) => AgentDiscoveryHelper.GetTerminationConditionEntries(ctx, ct))
 93            .Where(static arr => arr.Length > 0);
 94
 95        // [ProgressSinks] → per-agent progress sink declarations
 96        var progressSinksEntries = context.SyntaxProvider
 97            .ForAttributeWithMetadataName(
 98                ProgressSinksAttributeName,
 99                predicate: static (s, _) => s is ClassDeclarationSyntax,
 100                transform: static (ctx, ct) => AgentDiscoveryHelper.GetProgressSinksEntries(ctx, ct))
 101            .Where(static arr => arr.Length > 0);
 102
 103        // [AgentGraphEdge] annotations → graph edge topology
 104        var graphEdgeEntries = context.SyntaxProvider
 105            .ForAttributeWithMetadataName(
 106                AgentGraphEdgeAttributeName,
 107                predicate: static (s, _) => s is ClassDeclarationSyntax,
 108                transform: static (ctx, ct) => GraphDiscoveryHelper.GetGraphEdgeEntries(ctx, ct))
 109            .Where(static arr => arr.Length > 0);
 110
 111        // [AgentGraphEntry] annotations → graph entry points
 112        var graphEntryPointEntries = context.SyntaxProvider
 113            .ForAttributeWithMetadataName(
 114                AgentGraphEntryAttributeName,
 115                predicate: static (s, _) => s is ClassDeclarationSyntax,
 116                transform: static (ctx, ct) => GraphDiscoveryHelper.GetGraphEntryPointEntries(ctx, ct))
 117            .Where(static arr => arr.Length > 0);
 118
 119        // [AgentGraphNode] annotations → graph node join modes
 120        var graphNodeEntries = context.SyntaxProvider
 121            .ForAttributeWithMetadataName(
 122                AgentGraphNodeAttributeName,
 123                predicate: static (s, _) => s is ClassDeclarationSyntax,
 124                transform: static (ctx, ct) => GraphDiscoveryHelper.GetGraphNodeEntries(ctx, ct))
 125            .Where(static arr => arr.Length > 0);
 126
 127        // [AgentGraphReducer] annotations → graph reducer metadata
 128        var graphReducerEntries = context.SyntaxProvider
 129            .ForAttributeWithMetadataName(
 130                AgentGraphReducerAttributeName,
 131                predicate: static (s, _) => s is ClassDeclarationSyntax,
 132                transform: static (ctx, ct) => GraphDiscoveryHelper.GetGraphReducerEntries(ctx, ct))
 133            .Where(static arr => arr.Length > 0);
 134
 135        // Unified output: all pipelines combined with compilation metadata and build config.
 136        // Always emits all registries + [ModuleInitializer] bootstrap, even when empty.
 137        var combined = functionClasses.Collect()
 138            .Combine(groupClasses.Collect())
 139            .Combine(agentClasses.Collect())
 140            .Combine(handoffEntries.Collect())
 141            .Combine(groupChatEntries.Collect())
 142            .Combine(sequenceEntries.Collect())
 143            .Combine(terminationConditionEntries.Collect())
 144            .Combine(progressSinksEntries.Collect())
 145            .Combine(graphEdgeEntries.Collect())
 146            .Combine(graphEntryPointEntries.Collect())
 147            .Combine(graphNodeEntries.Collect())
 148            .Combine(graphReducerEntries.Collect())
 149            .Combine(context.CompilationProvider)
 150            .Combine(context.AnalyzerConfigOptionsProvider);
 151
 152        context.RegisterSourceOutput(combined,
 153            static (spc, data) =>
 154            {
 155                var (((((((((((((functionData, groupData), agentData), handoffData), groupChatData), sequenceData), term
 156                ExecuteAll(functionData, groupData, agentData, handoffData, groupChatData, sequenceData, terminationData
 157            });
 158    }
 159
 160    private static void ExecuteAll(
 161        ImmutableArray<AgentFunctionTypeInfo?> functionData,
 162        ImmutableArray<ImmutableArray<AgentFunctionGroupEntry>> groupData,
 163        ImmutableArray<NeedlrAiAgentTypeInfo?> agentData,
 164        ImmutableArray<ImmutableArray<HandoffEntry>> handoffData,
 165        ImmutableArray<ImmutableArray<GroupChatEntry>> groupChatData,
 166        ImmutableArray<ImmutableArray<SequenceEntry>> sequenceData,
 167        ImmutableArray<ImmutableArray<TerminationConditionEntry>> terminationData,
 168        ImmutableArray<ImmutableArray<ProgressSinksEntry>> progressSinksData,
 169        ImmutableArray<ImmutableArray<GraphEdgeEntry>> graphEdgeData,
 170        ImmutableArray<ImmutableArray<GraphEntryPointEntry>> graphEntryData,
 171        ImmutableArray<ImmutableArray<GraphNodeEntry>> graphNodeData,
 172        ImmutableArray<ImmutableArray<GraphReducerEntry>> graphReducerData,
 173        Compilation compilation,
 174        Microsoft.CodeAnalysis.Diagnostics.AnalyzerConfigOptionsProvider configOptions,
 175        SourceProductionContext spc)
 176    {
 177        var assemblyName = compilation.AssemblyName ?? "UnknownAssembly";
 178        var safeAssemblyName = AgentDiscoveryHelper.SanitizeIdentifier(assemblyName);
 179
 180        var validFunctionTypes = functionData
 181            .Where(t => t.HasValue)
 182            .Select(t => t!.Value)
 183            .ToList();
 184
 185        var allGroupEntries = groupData.SelectMany(a => a).ToList();
 186        var groupedByName = allGroupEntries
 187            .GroupBy(e => e.GroupName)
 188            .ToDictionary(g => g.Key, g => g.Select(e => e.TypeName).Distinct().ToList());
 189
 190        var validAgentTypes = agentData
 191            .Where(t => t.HasValue)
 192            .Select(t => t!.Value)
 193            .ToList();
 194
 195        var allHandoffEntries = handoffData.SelectMany(a => a).ToList();
 196        var handoffByInitialAgent = allHandoffEntries
 197            .GroupBy(e => (e.InitialAgentTypeName, e.InitialAgentClassName))
 198            .ToDictionary(
 199                g => g.Key,
 200                g => g.Select(e => (e.TargetAgentTypeName, e.HandoffReason)).ToList());
 201
 202        var allGroupChatEntries = groupChatData.SelectMany(a => a).ToList();
 203        var groupChatByGroupName = allGroupChatEntries
 204            .GroupBy(e => e.GroupName)
 205            .ToDictionary(
 206                g => g.Key,
 207                g => g.OrderBy(e => e.Order)
 208                    .ThenBy(e => e.AgentTypeName, StringComparer.Ordinal)
 209                    .Select(e => e.AgentTypeName)
 210                    .Distinct()
 211                    .ToList());
 212
 213        var allSequenceEntries = sequenceData.SelectMany(a => a).ToList();
 214        var sequenceByPipelineName = allSequenceEntries
 215            .GroupBy(e => e.PipelineName)
 216            .ToDictionary(
 217                g => g.Key,
 218                g => g.OrderBy(e => e.Order).Select(e => e.AgentTypeName).ToList());
 219
 220        var conditionsByAgentTypeName = terminationData
 221            .SelectMany(a => a)
 222            .GroupBy(e => e.AgentTypeName)
 223            .ToDictionary(g => g.Key, g => g.ToList());
 224
 225        var progressSinksByAgent = progressSinksData
 226            .SelectMany(a => a)
 227            .ToDictionary(e => e.AgentClassName, e => e.SinkTypeFQNs);
 228
 229        var allGraphEdges = graphEdgeData.SelectMany(a => a).ToList();
 230        var allGraphEntryPoints = graphEntryData.SelectMany(a => a).ToList();
 231        var allGraphNodes = graphNodeData.SelectMany(a => a).ToList();
 232        var allGraphReducers = graphReducerData.SelectMany(a => a).ToList();
 233
 234        var graphDataByName = BuildGraphDataByName(allGraphEdges, allGraphEntryPoints, allGraphNodes, allGraphReducers);
 235
 236        // Always emit all registries (may be empty) and the bootstrap
 237        spc.AddSource("AgentFrameworkFunctions.g.cs",
 238            SourceText.From(RegistryCodeGenerator.GenerateRegistrySource(validFunctionTypes, safeAssemblyName), Encoding
 239
 240        spc.AddSource("AgentFrameworkFunctionGroups.g.cs",
 241            SourceText.From(RegistryCodeGenerator.GenerateGroupRegistrySource(groupedByName, safeAssemblyName), Encoding
 242
 243        spc.AddSource("AgentRegistry.g.cs",
 244            SourceText.From(RegistryCodeGenerator.GenerateAgentRegistrySource(validAgentTypes, safeAssemblyName), Encodi
 245
 246        spc.AddSource("AgentHandoffTopologyRegistry.g.cs",
 247            SourceText.From(RegistryCodeGenerator.GenerateHandoffTopologyRegistrySource(handoffByInitialAgent, safeAssem
 248
 249        spc.AddSource("AgentGroupChatRegistry.g.cs",
 250            SourceText.From(RegistryCodeGenerator.GenerateGroupChatRegistrySource(groupChatByGroupName, safeAssemblyName
 251
 252        spc.AddSource("AgentSequentialTopologyRegistry.g.cs",
 253            SourceText.From(RegistryCodeGenerator.GenerateSequentialTopologyRegistrySource(sequenceByPipelineName, safeA
 254
 255        spc.AddSource("AgentGraphTopologyRegistry.g.cs",
 256            SourceText.From(RegistryCodeGenerator.GenerateGraphTopologyRegistrySource(graphDataByName, safeAssemblyName)
 257
 258        spc.AddSource("NeedlrAgentFrameworkBootstrap.g.cs",
 259            SourceText.From(BootstrapCodeGenerator.GenerateBootstrapSource(safeAssemblyName), Encoding.UTF8));
 260
 261        spc.AddSource("WorkflowFactoryExtensions.g.cs",
 262            SourceText.From(ExtensionsCodeGenerator.GenerateWorkflowFactoryExtensionsSource(
 263                handoffByInitialAgent, groupChatByGroupName, sequenceByPipelineName,
 264                conditionsByAgentTypeName, graphDataByName, safeAssemblyName), Encoding.UTF8));
 265
 266        spc.AddSource("AgentFactoryExtensions.g.cs",
 267            SourceText.From(ExtensionsCodeGenerator.GenerateAgentFactoryExtensionsSource(validAgentTypes, progressSinksB
 268
 269        spc.AddSource("AgentTopologyConstants.g.cs",
 270            SourceText.From(ExtensionsCodeGenerator.GenerateAgentTopologyConstantsSource(validAgentTypes, allGroupEntrie
 271
 272        spc.AddSource("AgentFrameworkSyringeExtensions.g.cs",
 273            SourceText.From(ExtensionsCodeGenerator.GenerateSyringeExtensionsSource(allGroupEntries, safeAssemblyName), 
 274
 275        if (progressSinksByAgent.Count > 0)
 276        {
 277            spc.AddSource("GeneratedProgressSinkRegistrations.g.cs",
 278                SourceText.From(ExtensionsCodeGenerator.GenerateProgressSinkRegistrationSource(progressSinksByAgent, saf
 279        }
 280
 281        spc.AddSource("GeneratedAIFunctionProvider.g.cs",
 282            SourceText.From(AIFunctionProviderCodeGenerator.GenerateAIFunctionProviderSource(validFunctionTypes, safeAss
 283
 284        configOptions.GlobalOptions.TryGetValue("build_property.NeedlrDiagnostics", out var diagValue);
 285        if (string.Equals(diagValue, "true", StringComparison.OrdinalIgnoreCase))
 286        {
 287            var mermaid = TopologyGraphCodeGenerator.GenerateMermaidDiagram(handoffByInitialAgent, groupChatByGroupName,
 288
 289            spc.AddSource("AgentTopologyGraph.g.cs",
 290                SourceText.From(TopologyGraphCodeGenerator.GenerateTopologyGraphSource(mermaid, safeAssemblyName), Encod
 291        }
 292
 293        // Partial companions for [NeedlrAiAgent] classes declared as partial
 294        foreach (var agentType in validAgentTypes.Where(a => a.IsPartial))
 295        {
 296            var safeTypeName = agentType.TypeName
 297                .Replace("global::", "")
 298                .Replace(".", "_")
 299                .Replace("<", "_")
 300                .Replace(">", "_");
 301
 302            spc.AddSource($"{safeTypeName}.NeedlrAiAgent.g.cs",
 303                SourceText.From(BootstrapCodeGenerator.GeneratePartialCompanionSource(agentType, groupedByName), Encodin
 304        }
 305    }
 306
 307    private static Dictionary<string, GraphData> BuildGraphDataByName(
 308        List<GraphEdgeEntry> allEdges,
 309        List<GraphEntryPointEntry> allEntryPoints,
 310        List<GraphNodeEntry> allNodes,
 311        List<GraphReducerEntry> allReducers)
 312    {
 313        var graphNames = new HashSet<string>(StringComparer.Ordinal);
 314        foreach (var e in allEdges) graphNames.Add(e.GraphName);
 315        foreach (var e in allEntryPoints) graphNames.Add(e.GraphName);
 316        foreach (var e in allNodes) graphNames.Add(e.GraphName);
 317        foreach (var e in allReducers) graphNames.Add(e.GraphName);
 318
 319        var result = new Dictionary<string, GraphData>(StringComparer.Ordinal);
 320        foreach (var name in graphNames)
 321        {
 322            result[name] = new GraphData(
 323                allEdges.Where(e => string.Equals(e.GraphName, name, StringComparison.Ordinal)).ToList(),
 324                allEntryPoints.Where(e => string.Equals(e.GraphName, name, StringComparison.Ordinal)).ToList(),
 325                allNodes.Where(e => string.Equals(e.GraphName, name, StringComparison.Ordinal)).ToList(),
 326                allReducers.Where(e => string.Equals(e.GraphName, name, StringComparison.Ordinal)).ToList());
 327        }
 328
 329        return result;
 330    }
 331}
 332
 333internal sealed class GraphData
 334{
 21335    public GraphData(
 21336        List<GraphEdgeEntry> edges,
 21337        List<GraphEntryPointEntry> entryPoints,
 21338        List<GraphNodeEntry> nodes,
 21339        List<GraphReducerEntry> reducers)
 340    {
 21341        Edges = edges;
 21342        EntryPoints = entryPoints;
 21343        Nodes = nodes;
 21344        Reducers = reducers;
 21345    }
 346
 48347    public List<GraphEdgeEntry> Edges { get; }
 27348    public List<GraphEntryPointEntry> EntryPoints { get; }
 27349    public List<GraphNodeEntry> Nodes { get; }
 27350    public List<GraphReducerEntry> Reducers { get; }
 351}