| | | 1 | | using Microsoft.Extensions.AI; |
| | | 2 | | |
| | | 3 | | using NexusLabs.Needlr.AgentFramework; |
| | | 4 | | |
| | | 5 | | using NexusLabs.Needlr.AgentFramework.Progress; |
| | | 6 | | |
| | | 7 | | namespace NexusLabs.Needlr.AgentFramework.Workflows.Budget; |
| | | 8 | | |
| | | 9 | | /// <summary> |
| | | 10 | | /// <see cref="DelegatingChatClient"/> safety net that estimates cumulative context |
| | | 11 | | /// size across LLM calls and emits a warning when approaching a configurable limit. |
| | | 12 | | /// Optionally prunes oldest non-system messages to keep context under the limit. |
| | | 13 | | /// </summary> |
| | | 14 | | /// <remarks> |
| | | 15 | | /// <para> |
| | | 16 | | /// This middleware is a safety net for <c>FunctionInvokingChatClient</c> (FIC) usage |
| | | 17 | | /// where conversation history accumulates. It does NOT replace the iterative loop |
| | | 18 | | /// pattern — prefer <see cref="NexusLabs.Needlr.AgentFramework.Iterative.IIterativeAgentLoop"/> |
| | | 19 | | /// for tool-heavy stages. Use this middleware on stages that remain FIC-based as a |
| | | 20 | | /// guard against context window overflow. |
| | | 21 | | /// </para> |
| | | 22 | | /// <para> |
| | | 23 | | /// Token estimation is approximate: each message's text content length is divided by |
| | | 24 | | /// <see cref="CharsPerToken"/> (default 4) since exact tokenization requires a |
| | | 25 | | /// model-specific tokenizer. This is conservative — it may trigger warnings earlier |
| | | 26 | | /// than necessary, but never later. |
| | | 27 | | /// </para> |
| | | 28 | | /// </remarks> |
| | | 29 | | public sealed class ContextWindowGuardMiddleware : DelegatingChatClient |
| | | 30 | | { |
| | | 31 | | private readonly int _maxContextTokens; |
| | | 32 | | private readonly double _warningThreshold; |
| | | 33 | | private readonly bool _pruneOnOverflow; |
| | | 34 | | private readonly IProgressReporterAccessor _progressAccessor; |
| | | 35 | | |
| | | 36 | | /// <summary> |
| | | 37 | | /// Approximate characters per token for estimation. Defaults to 4. |
| | | 38 | | /// </summary> |
| | 14 | 39 | | public int CharsPerToken { get; set; } = 4; |
| | | 40 | | |
| | | 41 | | /// <param name="innerClient">The inner chat client to delegate to.</param> |
| | | 42 | | /// <param name="maxContextTokens"> |
| | | 43 | | /// Estimated maximum context window size in tokens. When the message list |
| | | 44 | | /// exceeds this, a warning is emitted and optionally oldest messages are pruned. |
| | | 45 | | /// </param> |
| | | 46 | | /// <param name="progressAccessor">Progress reporter for emitting warning events.</param> |
| | | 47 | | /// <param name="warningThreshold"> |
| | | 48 | | /// Fraction of <paramref name="maxContextTokens"/> at which to emit a warning. |
| | | 49 | | /// Defaults to <c>0.8</c> (80%). |
| | | 50 | | /// </param> |
| | | 51 | | /// <param name="pruneOnOverflow"> |
| | | 52 | | /// When <see langword="true"/>, automatically removes oldest non-system messages |
| | | 53 | | /// to keep estimated context under <paramref name="maxContextTokens"/>. |
| | | 54 | | /// Defaults to <see langword="false"/> (warn only). |
| | | 55 | | /// </param> |
| | | 56 | | public ContextWindowGuardMiddleware( |
| | | 57 | | IChatClient innerClient, |
| | | 58 | | int maxContextTokens, |
| | | 59 | | IProgressReporterAccessor progressAccessor, |
| | | 60 | | double warningThreshold = 0.8, |
| | | 61 | | bool pruneOnOverflow = false) |
| | 6 | 62 | | : base(innerClient) |
| | | 63 | | { |
| | 6 | 64 | | ArgumentNullException.ThrowIfNull(progressAccessor); |
| | 6 | 65 | | _maxContextTokens = maxContextTokens; |
| | 6 | 66 | | _warningThreshold = warningThreshold; |
| | 6 | 67 | | _pruneOnOverflow = pruneOnOverflow; |
| | 6 | 68 | | _progressAccessor = progressAccessor; |
| | 6 | 69 | | } |
| | | 70 | | |
| | | 71 | | /// <inheritdoc /> |
| | | 72 | | public override async Task<ChatResponse> GetResponseAsync( |
| | | 73 | | IEnumerable<ChatMessage> messages, |
| | | 74 | | ChatOptions? options = null, |
| | | 75 | | CancellationToken cancellationToken = default) |
| | | 76 | | { |
| | 5 | 77 | | var messageList = messages as List<ChatMessage> ?? [.. messages]; |
| | 5 | 78 | | var estimatedTokens = EstimateTokenCount(messageList); |
| | 5 | 79 | | var reporter = _progressAccessor.Current; |
| | | 80 | | |
| | 5 | 81 | | var warningLimit = (long)(_maxContextTokens * _warningThreshold); |
| | | 82 | | |
| | 5 | 83 | | if (estimatedTokens > _maxContextTokens) |
| | | 84 | | { |
| | 3 | 85 | | reporter.Report(new BudgetExceededEvent( |
| | 3 | 86 | | Timestamp: DateTimeOffset.UtcNow, |
| | 3 | 87 | | WorkflowId: reporter.WorkflowId, |
| | 3 | 88 | | AgentId: reporter.AgentId, |
| | 3 | 89 | | ParentAgentId: null, |
| | 3 | 90 | | Depth: reporter.Depth, |
| | 3 | 91 | | SequenceNumber: reporter.NextSequence(), |
| | 3 | 92 | | LimitType: "context_window", |
| | 3 | 93 | | CurrentValue: estimatedTokens, |
| | 3 | 94 | | MaxValue: _maxContextTokens)); |
| | | 95 | | |
| | 3 | 96 | | if (_pruneOnOverflow) |
| | | 97 | | { |
| | 1 | 98 | | PruneMessages(messageList, estimatedTokens); |
| | | 99 | | } |
| | | 100 | | } |
| | 2 | 101 | | else if (estimatedTokens >= warningLimit) |
| | | 102 | | { |
| | 1 | 103 | | reporter.Report(new BudgetUpdatedEvent( |
| | 1 | 104 | | Timestamp: DateTimeOffset.UtcNow, |
| | 1 | 105 | | WorkflowId: reporter.WorkflowId, |
| | 1 | 106 | | AgentId: reporter.AgentId, |
| | 1 | 107 | | ParentAgentId: null, |
| | 1 | 108 | | Depth: reporter.Depth, |
| | 1 | 109 | | SequenceNumber: reporter.NextSequence(), |
| | 1 | 110 | | CurrentInputTokens: estimatedTokens, |
| | 1 | 111 | | CurrentOutputTokens: 0, |
| | 1 | 112 | | CurrentTotalTokens: estimatedTokens, |
| | 1 | 113 | | MaxInputTokens: _maxContextTokens, |
| | 1 | 114 | | MaxOutputTokens: null, |
| | 1 | 115 | | MaxTotalTokens: _maxContextTokens)); |
| | | 116 | | } |
| | | 117 | | |
| | 5 | 118 | | return await base.GetResponseAsync(messageList, options, cancellationToken) |
| | 5 | 119 | | .ConfigureAwait(false); |
| | 5 | 120 | | } |
| | | 121 | | |
| | | 122 | | /// <inheritdoc /> |
| | | 123 | | public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync( |
| | | 124 | | IEnumerable<ChatMessage> messages, |
| | | 125 | | ChatOptions? options = null, |
| | | 126 | | [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) |
| | | 127 | | { |
| | 1 | 128 | | var messageList = messages as List<ChatMessage> ?? [.. messages]; |
| | 1 | 129 | | var estimatedTokens = EstimateTokenCount(messageList); |
| | 1 | 130 | | var reporter = _progressAccessor.Current; |
| | | 131 | | |
| | 1 | 132 | | if (estimatedTokens > _maxContextTokens) |
| | | 133 | | { |
| | 1 | 134 | | reporter.Report(new BudgetExceededEvent( |
| | 1 | 135 | | Timestamp: DateTimeOffset.UtcNow, |
| | 1 | 136 | | WorkflowId: reporter.WorkflowId, |
| | 1 | 137 | | AgentId: reporter.AgentId, |
| | 1 | 138 | | ParentAgentId: null, |
| | 1 | 139 | | Depth: reporter.Depth, |
| | 1 | 140 | | SequenceNumber: reporter.NextSequence(), |
| | 1 | 141 | | LimitType: "context_window", |
| | 1 | 142 | | CurrentValue: estimatedTokens, |
| | 1 | 143 | | MaxValue: _maxContextTokens)); |
| | | 144 | | |
| | 1 | 145 | | if (_pruneOnOverflow) |
| | | 146 | | { |
| | 0 | 147 | | PruneMessages(messageList, estimatedTokens); |
| | | 148 | | } |
| | | 149 | | } |
| | | 150 | | |
| | 4 | 151 | | await foreach (var update in base.GetStreamingResponseAsync(messageList, options, cancellationToken) |
| | 1 | 152 | | .ConfigureAwait(false)) |
| | | 153 | | { |
| | 1 | 154 | | yield return update; |
| | | 155 | | } |
| | 1 | 156 | | } |
| | | 157 | | |
| | | 158 | | private long EstimateTokenCount(IEnumerable<ChatMessage> messages) |
| | | 159 | | { |
| | 6 | 160 | | long totalChars = 0; |
| | 32 | 161 | | foreach (var msg in messages) |
| | | 162 | | { |
| | 40 | 163 | | foreach (var content in msg.Contents) |
| | | 164 | | { |
| | 10 | 165 | | if (content is TextContent tc && tc.Text is { } text) |
| | | 166 | | { |
| | 9 | 167 | | totalChars += text.Length; |
| | | 168 | | } |
| | 1 | 169 | | else if (content is FunctionCallContent fc) |
| | | 170 | | { |
| | 1 | 171 | | totalChars += fc.Name.Length + 50; |
| | 1 | 172 | | if (fc.Arguments is { } args) |
| | | 173 | | { |
| | 4 | 174 | | foreach (var (_, value) in args) |
| | | 175 | | { |
| | 1 | 176 | | totalChars += value?.ToString()?.Length ?? 0; |
| | | 177 | | } |
| | | 178 | | } |
| | | 179 | | } |
| | 0 | 180 | | else if (content is FunctionResultContent fr) |
| | | 181 | | { |
| | 0 | 182 | | totalChars += ToolResultSerializer.Serialize(fr.Result).Length; |
| | | 183 | | } |
| | | 184 | | } |
| | | 185 | | } |
| | | 186 | | |
| | 6 | 187 | | return totalChars / CharsPerToken; |
| | | 188 | | } |
| | | 189 | | |
| | | 190 | | private void PruneMessages(List<ChatMessage> messages, long currentTokens) |
| | | 191 | | { |
| | | 192 | | // Remove oldest non-system messages until under the limit. |
| | | 193 | | // Never remove the system message (index 0) or the last user message. |
| | 3 | 194 | | while (currentTokens > _maxContextTokens && messages.Count > 2) |
| | | 195 | | { |
| | 2 | 196 | | var idx = messages[0].Role == ChatRole.System ? 1 : 0; |
| | 2 | 197 | | if (idx >= messages.Count - 1) break; |
| | | 198 | | |
| | 2 | 199 | | var removed = messages[idx]; |
| | 2 | 200 | | long removedTokens = 0; |
| | 8 | 201 | | foreach (var content in removed.Contents) |
| | | 202 | | { |
| | 2 | 203 | | if (content is TextContent tc && tc.Text is { } text) |
| | 2 | 204 | | removedTokens += text.Length / CharsPerToken; |
| | 0 | 205 | | else if (content is FunctionCallContent fc) |
| | | 206 | | { |
| | 0 | 207 | | removedTokens += (fc.Name.Length + 50) / CharsPerToken; |
| | 0 | 208 | | if (fc.Arguments is { } args) |
| | | 209 | | { |
| | 0 | 210 | | foreach (var (_, value) in args) |
| | 0 | 211 | | removedTokens += (value?.ToString()?.Length ?? 0) / CharsPerToken; |
| | | 212 | | } |
| | | 213 | | } |
| | 0 | 214 | | else if (content is FunctionResultContent fr) |
| | 0 | 215 | | removedTokens += ToolResultSerializer.Serialize(fr.Result).Length / CharsPerToken; |
| | | 216 | | } |
| | | 217 | | |
| | 2 | 218 | | messages.RemoveAt(idx); |
| | 2 | 219 | | currentTokens -= removedTokens; |
| | | 220 | | } |
| | 1 | 221 | | } |
| | | 222 | | } |