| | | 1 | | using System.Collections.Immutable; |
| | | 2 | | using System.Linq; |
| | | 3 | | using System.Text; |
| | | 4 | | |
| | | 5 | | using Microsoft.CodeAnalysis; |
| | | 6 | | using Microsoft.CodeAnalysis.CSharp.Syntax; |
| | | 7 | | using Microsoft.CodeAnalysis.Text; |
| | | 8 | | |
| | | 9 | | using NexusLabs.Needlr.AgentFramework.Generators.CodeGen; |
| | | 10 | | using NexusLabs.Needlr.AgentFramework.Generators.Models; |
| | | 11 | | |
| | | 12 | | namespace NexusLabs.Needlr.AgentFramework.Generators |
| | | 13 | | { |
| | | 14 | | /// <summary> |
| | | 15 | | /// Source generator for [AsyncLocalScoped]-decorated interfaces. |
| | | 16 | | /// Emits an internal sealed class implementing the interface with proper |
| | | 17 | | /// AsyncLocal scoping and dispose semantics. |
| | | 18 | | /// </summary> |
| | | 19 | | [Generator] |
| | | 20 | | public class AsyncLocalScopedGenerator : IIncrementalGenerator |
| | | 21 | | { |
| | | 22 | | private const string AsyncLocalScopedAttributeName = |
| | | 23 | | "NexusLabs.Needlr.AgentFramework.AsyncLocalScopedAttribute"; |
| | | 24 | | |
| | | 25 | | public void Initialize(IncrementalGeneratorInitializationContext context) |
| | | 26 | | { |
| | 11 | 27 | | var interfaces = context.SyntaxProvider |
| | 11 | 28 | | .ForAttributeWithMetadataName( |
| | 11 | 29 | | AsyncLocalScopedAttributeName, |
| | 11 | 30 | | predicate: static (s, _) => s is InterfaceDeclarationSyntax, |
| | 10 | 31 | | transform: static (ctx, ct) => ExtractInfo(ctx)) |
| | 10 | 32 | | .Where(static m => m.HasValue) |
| | 19 | 33 | | .Select(static (m, _) => m!.Value); |
| | | 34 | | |
| | 11 | 35 | | context.RegisterSourceOutput(interfaces, static (spc, info) => |
| | 11 | 36 | | { |
| | 8 | 37 | | var source = AsyncLocalScopedCodeGenerator.Generate(info); |
| | 8 | 38 | | var safeName = info.InterfaceFullName |
| | 8 | 39 | | .Replace("global::", "") |
| | 8 | 40 | | .Replace(".", "_") |
| | 8 | 41 | | .Replace("<", "_") |
| | 8 | 42 | | .Replace(">", "_"); |
| | 11 | 43 | | |
| | 8 | 44 | | spc.AddSource(safeName + ".AsyncLocalScoped.g.cs", |
| | 8 | 45 | | SourceText.From(source, Encoding.UTF8)); |
| | 19 | 46 | | }); |
| | 11 | 47 | | } |
| | | 48 | | |
| | | 49 | | private static AsyncLocalScopedInfo? ExtractInfo(GeneratorAttributeSyntaxContext ctx) |
| | | 50 | | { |
| | 10 | 51 | | var typeSymbol = ctx.TargetSymbol as INamedTypeSymbol; |
| | 10 | 52 | | if (typeSymbol == null || typeSymbol.TypeKind != TypeKind.Interface) |
| | 0 | 53 | | return null; |
| | | 54 | | |
| | | 55 | | // Find the [AsyncLocalScoped] attribute and its Mutable property |
| | 10 | 56 | | var attrData = ctx.Attributes.FirstOrDefault(a => |
| | 20 | 57 | | a.AttributeClass != null && |
| | 20 | 58 | | a.AttributeClass.ToDisplayString() == AsyncLocalScopedAttributeName); |
| | | 59 | | |
| | 10 | 60 | | if (attrData == null) |
| | 0 | 61 | | return null; |
| | | 62 | | |
| | 10 | 63 | | bool isMutable = false; |
| | 30 | 64 | | foreach (var named in attrData.NamedArguments) |
| | | 65 | | { |
| | 5 | 66 | | if (named.Key == "Mutable" && named.Value.Value is bool b) |
| | 5 | 67 | | isMutable = b; |
| | | 68 | | } |
| | | 69 | | |
| | | 70 | | // Find the "Current" property to determine the value type |
| | 10 | 71 | | var currentProp = typeSymbol.GetMembers() |
| | 10 | 72 | | .OfType<IPropertySymbol>() |
| | 19 | 73 | | .FirstOrDefault(p => p.Name == "Current"); |
| | | 74 | | |
| | 10 | 75 | | if (currentProp == null) |
| | 1 | 76 | | return null; |
| | | 77 | | |
| | | 78 | | // Get the non-nullable underlying type |
| | 9 | 79 | | var valueType = currentProp.Type; |
| | 9 | 80 | | if (valueType is INamedTypeSymbol namedType && |
| | 9 | 81 | | namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) |
| | | 82 | | { |
| | 2 | 83 | | valueType = namedType.TypeArguments[0]; |
| | | 84 | | } |
| | 7 | 85 | | else if (valueType.NullableAnnotation == NullableAnnotation.Annotated && |
| | 7 | 86 | | valueType is INamedTypeSymbol annotatedType) |
| | | 87 | | { |
| | 7 | 88 | | valueType = annotatedType; |
| | | 89 | | } |
| | | 90 | | |
| | 9 | 91 | | var valueTypeFullName = valueType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); |
| | | 92 | | |
| | | 93 | | // Find the scope method (returns IDisposable) |
| | 9 | 94 | | var scopeMethod = typeSymbol.GetMembers() |
| | 9 | 95 | | .OfType<IMethodSymbol>() |
| | 9 | 96 | | .FirstOrDefault(m => |
| | 26 | 97 | | m.ReturnType != null && |
| | 26 | 98 | | m.ReturnType.ToDisplayString().EndsWith("IDisposable")); |
| | | 99 | | |
| | 9 | 100 | | if (scopeMethod == null) |
| | 1 | 101 | | return null; |
| | | 102 | | |
| | 8 | 103 | | bool hasScopeParameter = scopeMethod.Parameters.Length > 0; |
| | 8 | 104 | | string scopeParameterTypeFullName = ""; |
| | | 105 | | |
| | 8 | 106 | | if (hasScopeParameter) |
| | | 107 | | { |
| | 6 | 108 | | var paramType = scopeMethod.Parameters[0].Type; |
| | 6 | 109 | | scopeParameterTypeFullName = paramType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); |
| | | 110 | | } |
| | | 111 | | |
| | 8 | 112 | | var interfaceFullName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); |
| | | 113 | | |
| | 8 | 114 | | var namespaceName = typeSymbol.ContainingNamespace?.IsGlobalNamespace == true |
| | 8 | 115 | | ? "" |
| | 8 | 116 | | : typeSymbol.ContainingNamespace?.ToDisplayString() ?? ""; |
| | | 117 | | |
| | | 118 | | // Discover additional properties on the interface (beyond "Current") |
| | | 119 | | // that should be proxied through to Current?.PropertyName. |
| | 8 | 120 | | var proxyProps = ImmutableArray.CreateBuilder<Models.AsyncLocalScopedPropertyInfo>(); |
| | 106 | 121 | | foreach (var member in typeSymbol.GetMembers()) |
| | | 122 | | { |
| | 45 | 123 | | if (member is IPropertySymbol prop && |
| | 45 | 124 | | prop.Name != "Current" && |
| | 45 | 125 | | !prop.IsStatic && |
| | 45 | 126 | | !prop.IsIndexer) |
| | | 127 | | { |
| | 8 | 128 | | var propTypeFullName = prop.Type.ToDisplayString( |
| | 8 | 129 | | SymbolDisplayFormat.FullyQualifiedFormat |
| | 8 | 130 | | .WithMiscellaneousOptions( |
| | 8 | 131 | | SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier)); |
| | | 132 | | |
| | 8 | 133 | | bool hasSetter = prop.SetMethod != null; |
| | | 134 | | |
| | 8 | 135 | | bool isNonNullableValueType = prop.Type.IsValueType && |
| | 8 | 136 | | prop.Type.NullableAnnotation != NullableAnnotation.Annotated && |
| | 8 | 137 | | !(prop.Type is INamedTypeSymbol nt && |
| | 8 | 138 | | nt.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T); |
| | | 139 | | |
| | 8 | 140 | | proxyProps.Add(new Models.AsyncLocalScopedPropertyInfo( |
| | 8 | 141 | | name: prop.Name, |
| | 8 | 142 | | typeFullName: propTypeFullName, |
| | 8 | 143 | | hasSetter: hasSetter, |
| | 8 | 144 | | isNonNullableValueType: isNonNullableValueType)); |
| | | 145 | | } |
| | | 146 | | } |
| | | 147 | | |
| | 8 | 148 | | return new AsyncLocalScopedInfo( |
| | 8 | 149 | | interfaceFullName: interfaceFullName, |
| | 8 | 150 | | interfaceName: typeSymbol.Name, |
| | 8 | 151 | | namespaceName: namespaceName, |
| | 8 | 152 | | valueTypeFullName: valueTypeFullName, |
| | 8 | 153 | | scopeMethodName: scopeMethod.Name, |
| | 8 | 154 | | hasScopeParameter: hasScopeParameter, |
| | 8 | 155 | | scopeParameterTypeFullName: scopeParameterTypeFullName, |
| | 8 | 156 | | isMutable: isMutable, |
| | 8 | 157 | | proxyProperties: proxyProps.ToImmutable()); |
| | | 158 | | } |
| | | 159 | | } |
| | | 160 | | } |