| | | 1 | | using Microsoft.Extensions.AI; |
| | | 2 | | using Microsoft.Extensions.DependencyInjection; |
| | | 3 | | |
| | | 4 | | using NexusLabs.Needlr.AgentFramework; |
| | | 5 | | using NexusLabs.Needlr.AgentFramework.Budget; |
| | | 6 | | using NexusLabs.Needlr.AgentFramework.Progress; |
| | | 7 | | |
| | | 8 | | namespace NexusLabs.Needlr.AgentFramework.Workflows.Budget; |
| | | 9 | | |
| | | 10 | | /// <summary> |
| | | 11 | | /// Extension methods for wiring token-budget enforcement into the agent framework. |
| | | 12 | | /// </summary> |
| | | 13 | | public static class TokenBudgetExtensions |
| | | 14 | | { |
| | | 15 | | /// <summary> |
| | | 16 | | /// Wires <see cref="TokenUsageRecordingMiddleware"/> to record token usage from |
| | | 17 | | /// every LLM call into <see cref="ITokenBudgetTracker"/>. This enables |
| | | 18 | | /// <see cref="ITokenBudgetTracker.CurrentTokens"/> without enforcing any budget. |
| | | 19 | | /// </summary> |
| | | 20 | | /// <remarks> |
| | | 21 | | /// Idempotent — calling this multiple times (or via both <c>UsingTokenBudget()</c> |
| | | 22 | | /// and <c>UsingDiagnostics()</c>) wires the recording middleware exactly once. |
| | | 23 | | /// </remarks> |
| | | 24 | | public static AgentFrameworkSyringe UsingTokenTracking( |
| | | 25 | | this AgentFrameworkSyringe syringe) |
| | | 26 | | { |
| | 37 | 27 | | ArgumentNullException.ThrowIfNull(syringe); |
| | 38 | 28 | | if (syringe.TokenTrackingWired) return syringe; |
| | | 29 | | |
| | 36 | 30 | | var result = syringe.Configure(opts => |
| | 36 | 31 | | { |
| | 35 | 32 | | var tracker = opts.ServiceProvider.GetRequiredService<ITokenBudgetTracker>(); |
| | 35 | 33 | | var existingFactory = opts.ChatClientFactory; |
| | 35 | 34 | | opts.ChatClientFactory = sp => |
| | 35 | 35 | | { |
| | 44 | 36 | | var innerClient = existingFactory?.Invoke(sp) |
| | 44 | 37 | | ?? sp.GetRequiredService<IChatClient>(); |
| | 44 | 38 | | return new TokenUsageRecordingMiddleware(innerClient, tracker); |
| | 35 | 39 | | }; |
| | 71 | 40 | | }); |
| | | 41 | | |
| | 36 | 42 | | return result with { TokenTrackingWired = true }; |
| | | 43 | | } |
| | | 44 | | |
| | | 45 | | /// <summary> |
| | | 46 | | /// Wraps the configured <see cref="IChatClient"/> with <see cref="TokenBudgetChatMiddleware"/>, |
| | | 47 | | /// enabling per-pipeline token budgets via <see cref="ITokenBudgetTracker"/>. |
| | | 48 | | /// Automatically includes <see cref="UsingTokenTracking"/> for token recording. |
| | | 49 | | /// </summary> |
| | | 50 | | public static AgentFrameworkSyringe UsingTokenBudget( |
| | | 51 | | this AgentFrameworkSyringe syringe) |
| | | 52 | | { |
| | 1 | 53 | | ArgumentNullException.ThrowIfNull(syringe); |
| | | 54 | | |
| | 1 | 55 | | syringe = syringe.UsingTokenTracking(); |
| | | 56 | | |
| | 1 | 57 | | return syringe.Configure(opts => |
| | 1 | 58 | | { |
| | 1 | 59 | | var tracker = opts.ServiceProvider.GetRequiredService<ITokenBudgetTracker>(); |
| | 1 | 60 | | var progressAccessor = opts.ServiceProvider.GetRequiredService<IProgressReporterAccessor>(); |
| | 1 | 61 | | |
| | 1 | 62 | | var existingFactory = opts.ChatClientFactory; |
| | 1 | 63 | | opts.ChatClientFactory = sp => |
| | 1 | 64 | | { |
| | 1 | 65 | | var innerClient = existingFactory?.Invoke(sp) |
| | 1 | 66 | | ?? sp.GetRequiredService<IChatClient>(); |
| | 1 | 67 | | |
| | 1 | 68 | | return new TokenBudgetChatMiddleware(innerClient, tracker, progressAccessor); |
| | 1 | 69 | | }; |
| | 2 | 70 | | }); |
| | | 71 | | } |
| | | 72 | | } |