| | | 1 | | using System.Collections.Generic; |
| | | 2 | | using System.Linq; |
| | | 3 | | using Microsoft.CodeAnalysis; |
| | | 4 | | using NexusLabs.Needlr.Generators.Models; |
| | | 5 | | |
| | | 6 | | namespace NexusLabs.Needlr.Generators; |
| | | 7 | | |
| | | 8 | | /// <summary> |
| | | 9 | | /// Helper for discovering [Provider] attributes from Roslyn symbols. |
| | | 10 | | /// </summary> |
| | | 11 | | internal static class ProviderDiscoveryHelper |
| | | 12 | | { |
| | | 13 | | private const string ProviderAttributeName = "ProviderAttribute"; |
| | | 14 | | private const string ProviderAttributeFullName = "NexusLabs.Needlr.Generators.ProviderAttribute"; |
| | | 15 | | |
| | | 16 | | /// <summary> |
| | | 17 | | /// Checks if a type has the [Provider] attribute. |
| | | 18 | | /// </summary> |
| | | 19 | | public static bool HasProviderAttribute(INamedTypeSymbol typeSymbol) |
| | | 20 | | { |
| | 9106966 | 21 | | foreach (var attribute in typeSymbol.GetAttributes()) |
| | | 22 | | { |
| | 2733045 | 23 | | var attributeClass = attribute.AttributeClass; |
| | 2733045 | 24 | | if (attributeClass == null) |
| | | 25 | | continue; |
| | | 26 | | |
| | 2733045 | 27 | | var name = attributeClass.Name; |
| | 2733045 | 28 | | if (name == ProviderAttributeName) |
| | 24 | 29 | | return true; |
| | | 30 | | |
| | 2733021 | 31 | | var fullName = attributeClass.ToDisplayString(); |
| | 2733021 | 32 | | if (fullName == ProviderAttributeFullName) |
| | 0 | 33 | | return true; |
| | | 34 | | } |
| | | 35 | | |
| | 1820426 | 36 | | return false; |
| | | 37 | | } |
| | | 38 | | |
| | | 39 | | /// <summary> |
| | | 40 | | /// Gets the [Provider] attribute data from a type. |
| | | 41 | | /// </summary> |
| | | 42 | | public static AttributeData? GetProviderAttribute(INamedTypeSymbol typeSymbol) |
| | | 43 | | { |
| | 54 | 44 | | foreach (var attribute in typeSymbol.GetAttributes()) |
| | | 45 | | { |
| | 18 | 46 | | var attributeClass = attribute.AttributeClass; |
| | 18 | 47 | | if (attributeClass == null) |
| | | 48 | | continue; |
| | | 49 | | |
| | 18 | 50 | | var name = attributeClass.Name; |
| | 18 | 51 | | var fullName = attributeClass.ToDisplayString(); |
| | | 52 | | |
| | 18 | 53 | | if (name == ProviderAttributeName || fullName == ProviderAttributeFullName) |
| | 18 | 54 | | return attribute; |
| | | 55 | | } |
| | | 56 | | |
| | 0 | 57 | | return null; |
| | | 58 | | } |
| | | 59 | | |
| | | 60 | | /// <summary> |
| | | 61 | | /// Discovers a provider from a type symbol with [Provider] attribute. |
| | | 62 | | /// </summary> |
| | | 63 | | /// <param name="typeSymbol">The type symbol to discover.</param> |
| | | 64 | | /// <param name="assemblyName">The assembly name containing the type.</param> |
| | | 65 | | /// <param name="generatedNamespace">The namespace where generated types are placed (e.g., "AssemblyName.Generated") |
| | | 66 | | public static DiscoveredProvider? DiscoverProvider(INamedTypeSymbol typeSymbol, string assemblyName, string generate |
| | | 67 | | { |
| | 18 | 68 | | var providerAttr = GetProviderAttribute(typeSymbol); |
| | 18 | 69 | | if (providerAttr == null) |
| | 0 | 70 | | return null; |
| | | 71 | | |
| | 18 | 72 | | var typeName = TypeDiscoveryHelper.GetFullyQualifiedName(typeSymbol); |
| | 18 | 73 | | var isInterface = typeSymbol.TypeKind == TypeKind.Interface; |
| | 18 | 74 | | var isPartial = IsPartialType(typeSymbol); |
| | 18 | 75 | | var sourceFilePath = typeSymbol.Locations.FirstOrDefault()?.SourceTree?.FilePath; |
| | | 76 | | |
| | 18 | 77 | | var properties = new List<ProviderPropertyInfo>(); |
| | | 78 | | |
| | 18 | 79 | | if (isInterface) |
| | | 80 | | { |
| | | 81 | | // Interface mode: Extract properties from interface definition |
| | 12 | 82 | | properties.AddRange(ExtractPropertiesFromInterface(typeSymbol)); |
| | | 83 | | } |
| | | 84 | | else |
| | | 85 | | { |
| | | 86 | | // Class mode: Extract from attribute constructor args and named args |
| | 6 | 87 | | properties.AddRange(ExtractPropertiesFromAttribute(providerAttr, generatedNamespace)); |
| | | 88 | | } |
| | | 89 | | |
| | 18 | 90 | | return new DiscoveredProvider( |
| | 18 | 91 | | typeName, |
| | 18 | 92 | | assemblyName, |
| | 18 | 93 | | isInterface, |
| | 18 | 94 | | isPartial, |
| | 18 | 95 | | properties, |
| | 18 | 96 | | sourceFilePath); |
| | | 97 | | } |
| | | 98 | | |
| | | 99 | | /// <summary> |
| | | 100 | | /// Checks if a type is declared as partial. |
| | | 101 | | /// </summary> |
| | | 102 | | private static bool IsPartialType(INamedTypeSymbol typeSymbol) |
| | | 103 | | { |
| | 66 | 104 | | foreach (var syntaxRef in typeSymbol.DeclaringSyntaxReferences) |
| | | 105 | | { |
| | 18 | 106 | | var syntax = syntaxRef.GetSyntax(); |
| | 18 | 107 | | if (syntax is Microsoft.CodeAnalysis.CSharp.Syntax.TypeDeclarationSyntax typeDecl) |
| | | 108 | | { |
| | 42 | 109 | | if (typeDecl.Modifiers.Any(m => m.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.PartialKeyword))) |
| | 6 | 110 | | return true; |
| | | 111 | | } |
| | | 112 | | } |
| | 12 | 113 | | return false; |
| | | 114 | | } |
| | | 115 | | |
| | | 116 | | /// <summary> |
| | | 117 | | /// Extracts provider properties from an interface's get-only properties. |
| | | 118 | | /// </summary> |
| | | 119 | | private static IEnumerable<ProviderPropertyInfo> ExtractPropertiesFromInterface(INamedTypeSymbol interfaceSymbol) |
| | | 120 | | { |
| | 84 | 121 | | foreach (var member in interfaceSymbol.GetMembers()) |
| | | 122 | | { |
| | 30 | 123 | | if (member is IPropertySymbol property && property.GetMethod != null && property.SetMethod == null) |
| | | 124 | | { |
| | 15 | 125 | | var propertyName = property.Name; |
| | 15 | 126 | | if (property.Type is INamedTypeSymbol namedType) |
| | | 127 | | { |
| | 15 | 128 | | var serviceTypeName = TypeDiscoveryHelper.GetFullyQualifiedName(namedType); |
| | 15 | 129 | | var kind = DeterminePropertyKind(property.Type, property.NullableAnnotation); |
| | | 130 | | |
| | 15 | 131 | | yield return new ProviderPropertyInfo(propertyName, serviceTypeName, kind); |
| | | 132 | | } |
| | | 133 | | } |
| | | 134 | | } |
| | 12 | 135 | | } |
| | | 136 | | |
| | | 137 | | /// <summary> |
| | | 138 | | /// Extracts provider properties from [Provider] attribute arguments. |
| | | 139 | | /// </summary> |
| | | 140 | | /// <param name="attribute">The attribute data to extract from.</param> |
| | | 141 | | /// <param name="generatedNamespace">The namespace where generated types are placed.</param> |
| | | 142 | | private static IEnumerable<ProviderPropertyInfo> ExtractPropertiesFromAttribute(AttributeData attribute, string gene |
| | | 143 | | { |
| | | 144 | | // Process constructor arguments (required services) |
| | 6 | 145 | | if (attribute.ConstructorArguments.Length > 0) |
| | | 146 | | { |
| | 2 | 147 | | var firstArg = attribute.ConstructorArguments[0]; |
| | 2 | 148 | | if (firstArg.Kind == TypedConstantKind.Array) |
| | | 149 | | { |
| | 10 | 150 | | foreach (var typeArg in firstArg.Values) |
| | | 151 | | { |
| | 3 | 152 | | if (typeArg.Value is INamedTypeSymbol typeSymbol) |
| | | 153 | | { |
| | 3 | 154 | | var propertyName = DerivePropertyName(typeSymbol); |
| | 3 | 155 | | var serviceTypeName = TypeDiscoveryHelper.GetFullyQualifiedName(typeSymbol); |
| | 3 | 156 | | yield return new ProviderPropertyInfo(propertyName, serviceTypeName, ProviderPropertyKind.Requir |
| | | 157 | | } |
| | | 158 | | } |
| | | 159 | | } |
| | | 160 | | } |
| | | 161 | | |
| | | 162 | | // Process named arguments |
| | 20 | 163 | | foreach (var namedArg in attribute.NamedArguments) |
| | | 164 | | { |
| | 4 | 165 | | var kind = namedArg.Key switch |
| | 4 | 166 | | { |
| | 0 | 167 | | "Required" => ProviderPropertyKind.Required, |
| | 2 | 168 | | "Optional" => ProviderPropertyKind.Optional, |
| | 1 | 169 | | "Collections" => ProviderPropertyKind.Collection, |
| | 1 | 170 | | "Factories" => ProviderPropertyKind.Factory, |
| | 0 | 171 | | _ => (ProviderPropertyKind?)null |
| | 4 | 172 | | }; |
| | | 173 | | |
| | 4 | 174 | | if (kind.HasValue && namedArg.Value.Kind == TypedConstantKind.Array) |
| | | 175 | | { |
| | 16 | 176 | | foreach (var typeArg in namedArg.Value.Values) |
| | | 177 | | { |
| | 4 | 178 | | if (typeArg.Value is INamedTypeSymbol typeSymbol) |
| | | 179 | | { |
| | 4 | 180 | | var propertyName = DerivePropertyName(typeSymbol, kind.Value); |
| | 4 | 181 | | var serviceTypeName = TypeDiscoveryHelper.GetFullyQualifiedName(typeSymbol); |
| | | 182 | | |
| | | 183 | | // For collections, wrap in IEnumerable<T> |
| | 4 | 184 | | if (kind.Value == ProviderPropertyKind.Collection) |
| | | 185 | | { |
| | 1 | 186 | | serviceTypeName = $"global::System.Collections.Generic.IEnumerable<{serviceTypeName}>"; |
| | | 187 | | } |
| | | 188 | | // For factories, convert to factory interface type |
| | 3 | 189 | | else if (kind.Value == ProviderPropertyKind.Factory) |
| | | 190 | | { |
| | 1 | 191 | | serviceTypeName = DeriveFactoryTypeName(typeSymbol, generatedNamespace); |
| | | 192 | | } |
| | | 193 | | |
| | 4 | 194 | | yield return new ProviderPropertyInfo(propertyName, serviceTypeName, kind.Value); |
| | | 195 | | } |
| | | 196 | | } |
| | | 197 | | } |
| | | 198 | | } |
| | 6 | 199 | | } |
| | | 200 | | |
| | | 201 | | /// <summary> |
| | | 202 | | /// Derives a property name from a type (e.g., IOrderRepository → OrderRepository). |
| | | 203 | | /// For collections, pluralizes the name (e.g., IHandler → Handlers). |
| | | 204 | | /// For factories, appends Factory suffix (e.g., IOrderService → OrderServiceFactory). |
| | | 205 | | /// </summary> |
| | | 206 | | private static string DerivePropertyName(INamedTypeSymbol typeSymbol, ProviderPropertyKind kind = ProviderPropertyKi |
| | | 207 | | { |
| | 7 | 208 | | var name = typeSymbol.Name; |
| | | 209 | | |
| | | 210 | | // Remove leading 'I' from interface names |
| | 7 | 211 | | if (typeSymbol.TypeKind == TypeKind.Interface && name.StartsWith("I") && name.Length > 1 && char.IsUpper(name[1] |
| | | 212 | | { |
| | 7 | 213 | | name = name.Substring(1); |
| | | 214 | | } |
| | | 215 | | |
| | | 216 | | // Pluralize collection property names |
| | 7 | 217 | | if (kind == ProviderPropertyKind.Collection) |
| | | 218 | | { |
| | 1 | 219 | | name = Pluralize(name); |
| | | 220 | | } |
| | | 221 | | // Append Factory suffix for factory properties |
| | 6 | 222 | | else if (kind == ProviderPropertyKind.Factory) |
| | | 223 | | { |
| | 1 | 224 | | name = name + "Factory"; |
| | | 225 | | } |
| | | 226 | | |
| | 7 | 227 | | return name; |
| | | 228 | | } |
| | | 229 | | |
| | | 230 | | /// <summary> |
| | | 231 | | /// Simple pluralization for property names. |
| | | 232 | | /// </summary> |
| | | 233 | | private static string Pluralize(string name) |
| | | 234 | | { |
| | 1 | 235 | | if (string.IsNullOrEmpty(name)) |
| | 0 | 236 | | return name; |
| | | 237 | | |
| | | 238 | | // Basic pluralization rules |
| | 1 | 239 | | if (name.EndsWith("y") && name.Length > 1 && !IsVowel(name[name.Length - 2])) |
| | | 240 | | { |
| | 0 | 241 | | return name.Substring(0, name.Length - 1) + "ies"; |
| | | 242 | | } |
| | 1 | 243 | | if (name.EndsWith("s") || name.EndsWith("x") || name.EndsWith("ch") || name.EndsWith("sh")) |
| | | 244 | | { |
| | 0 | 245 | | return name + "es"; |
| | | 246 | | } |
| | 1 | 247 | | return name + "s"; |
| | | 248 | | } |
| | | 249 | | |
| | 0 | 250 | | private static bool IsVowel(char c) => "aeiouAEIOU".Contains(c); |
| | | 251 | | |
| | | 252 | | /// <summary> |
| | | 253 | | /// Derives a factory interface type name from a service type. |
| | | 254 | | /// E.g., IOrderService → IOrderServiceFactory (in Generated namespace) |
| | | 255 | | /// </summary> |
| | | 256 | | /// <param name="typeSymbol">The type to derive the factory name from.</param> |
| | | 257 | | /// <param name="generatedNamespace">The namespace where generated types are placed.</param> |
| | | 258 | | private static string DeriveFactoryTypeName(INamedTypeSymbol typeSymbol, string generatedNamespace) |
| | | 259 | | { |
| | 1 | 260 | | var name = typeSymbol.Name; |
| | | 261 | | |
| | | 262 | | // Remove leading 'I' from interface names to get base name |
| | 1 | 263 | | if (typeSymbol.TypeKind == TypeKind.Interface && name.StartsWith("I") && name.Length > 1 && char.IsUpper(name[1] |
| | | 264 | | { |
| | 1 | 265 | | name = name.Substring(1); |
| | | 266 | | } |
| | | 267 | | |
| | | 268 | | // Factory interface is I{Name}Factory in the assembly's generated namespace |
| | 1 | 269 | | return $"global::{generatedNamespace}.I{name}Factory"; |
| | | 270 | | } |
| | | 271 | | |
| | | 272 | | /// <summary> |
| | | 273 | | /// Determines the property kind based on the type. |
| | | 274 | | /// </summary> |
| | | 275 | | private static ProviderPropertyKind DeterminePropertyKind(ITypeSymbol type, NullableAnnotation nullableAnnotation) |
| | | 276 | | { |
| | | 277 | | // Check for IEnumerable<T> |
| | 15 | 278 | | if (type is INamedTypeSymbol namedType) |
| | | 279 | | { |
| | 15 | 280 | | var displayName = namedType.OriginalDefinition.ToDisplayString(); |
| | 15 | 281 | | if (displayName.StartsWith("System.Collections.Generic.IEnumerable<") || |
| | 15 | 282 | | displayName.StartsWith("System.Collections.Generic.IReadOnlyCollection<") || |
| | 15 | 283 | | displayName.StartsWith("System.Collections.Generic.IReadOnlyList<")) |
| | | 284 | | { |
| | 2 | 285 | | return ProviderPropertyKind.Collection; |
| | | 286 | | } |
| | | 287 | | } |
| | | 288 | | |
| | | 289 | | // Check for nullable annotation |
| | 13 | 290 | | if (nullableAnnotation == NullableAnnotation.Annotated) |
| | | 291 | | { |
| | 2 | 292 | | return ProviderPropertyKind.Optional; |
| | | 293 | | } |
| | | 294 | | |
| | 11 | 295 | | return ProviderPropertyKind.Required; |
| | | 296 | | } |
| | | 297 | | } |