< Summary

Information
Class: NexusLabs.Needlr.AgentFramework.Workflows.Budget.ContextWindowGuardMiddleware
Assembly: NexusLabs.Needlr.AgentFramework.Workflows
File(s): /home/runner/work/needlr/needlr/src/NexusLabs.Needlr.AgentFramework.Workflows/Budget/ContextWindowGuardMiddleware.cs
Line coverage
89%
Covered lines: 83
Uncovered lines: 10
Coverable lines: 93
Total lines: 222
Line coverage: 89.2%
Branch coverage
68%
Covered branches: 41
Total branches: 60
Branch coverage: 68.3%
Method coverage

Feature is only available for sponsors

Upgrade to PRO version

Metrics

MethodBranch coverage Crap Score Cyclomatic complexity Line coverage
get_CharsPerToken()100%11100%
.ctor(...)100%11100%
GetResponseAsync()87.5%88100%
GetStreamingResponseAsync()75%8895%
EstimateTokenCount(...)80%212084.61%
PruneMessages(...)50%582461.11%

File(s)

/home/runner/work/needlr/needlr/src/NexusLabs.Needlr.AgentFramework.Workflows/Budget/ContextWindowGuardMiddleware.cs

#LineLine coverage
 1using Microsoft.Extensions.AI;
 2
 3using NexusLabs.Needlr.AgentFramework;
 4
 5using NexusLabs.Needlr.AgentFramework.Progress;
 6
 7namespace NexusLabs.Needlr.AgentFramework.Workflows.Budget;
 8
 9/// <summary>
 10/// <see cref="DelegatingChatClient"/> safety net that estimates cumulative context
 11/// size across LLM calls and emits a warning when approaching a configurable limit.
 12/// Optionally prunes oldest non-system messages to keep context under the limit.
 13/// </summary>
 14/// <remarks>
 15/// <para>
 16/// This middleware is a safety net for <c>FunctionInvokingChatClient</c> (FIC) usage
 17/// where conversation history accumulates. It does NOT replace the iterative loop
 18/// pattern — prefer <see cref="NexusLabs.Needlr.AgentFramework.Iterative.IIterativeAgentLoop"/>
 19/// for tool-heavy stages. Use this middleware on stages that remain FIC-based as a
 20/// guard against context window overflow.
 21/// </para>
 22/// <para>
 23/// Token estimation is approximate: each message's text content length is divided by
 24/// <see cref="CharsPerToken"/> (default 4) since exact tokenization requires a
 25/// model-specific tokenizer. This is conservative — it may trigger warnings earlier
 26/// than necessary, but never later.
 27/// </para>
 28/// </remarks>
 29public sealed class ContextWindowGuardMiddleware : DelegatingChatClient
 30{
 31    private readonly int _maxContextTokens;
 32    private readonly double _warningThreshold;
 33    private readonly bool _pruneOnOverflow;
 34    private readonly IProgressReporterAccessor _progressAccessor;
 35
 36    /// <summary>
 37    /// Approximate characters per token for estimation. Defaults to 4.
 38    /// </summary>
 1439    public int CharsPerToken { get; set; } = 4;
 40
 41    /// <param name="innerClient">The inner chat client to delegate to.</param>
 42    /// <param name="maxContextTokens">
 43    /// Estimated maximum context window size in tokens. When the message list
 44    /// exceeds this, a warning is emitted and optionally oldest messages are pruned.
 45    /// </param>
 46    /// <param name="progressAccessor">Progress reporter for emitting warning events.</param>
 47    /// <param name="warningThreshold">
 48    /// Fraction of <paramref name="maxContextTokens"/> at which to emit a warning.
 49    /// Defaults to <c>0.8</c> (80%).
 50    /// </param>
 51    /// <param name="pruneOnOverflow">
 52    /// When <see langword="true"/>, automatically removes oldest non-system messages
 53    /// to keep estimated context under <paramref name="maxContextTokens"/>.
 54    /// Defaults to <see langword="false"/> (warn only).
 55    /// </param>
 56    public ContextWindowGuardMiddleware(
 57        IChatClient innerClient,
 58        int maxContextTokens,
 59        IProgressReporterAccessor progressAccessor,
 60        double warningThreshold = 0.8,
 61        bool pruneOnOverflow = false)
 662        : base(innerClient)
 63    {
 664        ArgumentNullException.ThrowIfNull(progressAccessor);
 665        _maxContextTokens = maxContextTokens;
 666        _warningThreshold = warningThreshold;
 667        _pruneOnOverflow = pruneOnOverflow;
 668        _progressAccessor = progressAccessor;
 669    }
 70
 71    /// <inheritdoc />
 72    public override async Task<ChatResponse> GetResponseAsync(
 73        IEnumerable<ChatMessage> messages,
 74        ChatOptions? options = null,
 75        CancellationToken cancellationToken = default)
 76    {
 577        var messageList = messages as List<ChatMessage> ?? [.. messages];
 578        var estimatedTokens = EstimateTokenCount(messageList);
 579        var reporter = _progressAccessor.Current;
 80
 581        var warningLimit = (long)(_maxContextTokens * _warningThreshold);
 82
 583        if (estimatedTokens > _maxContextTokens)
 84        {
 385            reporter.Report(new BudgetExceededEvent(
 386                Timestamp: DateTimeOffset.UtcNow,
 387                WorkflowId: reporter.WorkflowId,
 388                AgentId: reporter.AgentId,
 389                ParentAgentId: null,
 390                Depth: reporter.Depth,
 391                SequenceNumber: reporter.NextSequence(),
 392                LimitType: "context_window",
 393                CurrentValue: estimatedTokens,
 394                MaxValue: _maxContextTokens));
 95
 396            if (_pruneOnOverflow)
 97            {
 198                PruneMessages(messageList, estimatedTokens);
 99            }
 100        }
 2101        else if (estimatedTokens >= warningLimit)
 102        {
 1103            reporter.Report(new BudgetUpdatedEvent(
 1104                Timestamp: DateTimeOffset.UtcNow,
 1105                WorkflowId: reporter.WorkflowId,
 1106                AgentId: reporter.AgentId,
 1107                ParentAgentId: null,
 1108                Depth: reporter.Depth,
 1109                SequenceNumber: reporter.NextSequence(),
 1110                CurrentInputTokens: estimatedTokens,
 1111                CurrentOutputTokens: 0,
 1112                CurrentTotalTokens: estimatedTokens,
 1113                MaxInputTokens: _maxContextTokens,
 1114                MaxOutputTokens: null,
 1115                MaxTotalTokens: _maxContextTokens));
 116        }
 117
 5118        return await base.GetResponseAsync(messageList, options, cancellationToken)
 5119            .ConfigureAwait(false);
 5120    }
 121
 122    /// <inheritdoc />
 123    public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
 124        IEnumerable<ChatMessage> messages,
 125        ChatOptions? options = null,
 126        [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
 127    {
 1128        var messageList = messages as List<ChatMessage> ?? [.. messages];
 1129        var estimatedTokens = EstimateTokenCount(messageList);
 1130        var reporter = _progressAccessor.Current;
 131
 1132        if (estimatedTokens > _maxContextTokens)
 133        {
 1134            reporter.Report(new BudgetExceededEvent(
 1135                Timestamp: DateTimeOffset.UtcNow,
 1136                WorkflowId: reporter.WorkflowId,
 1137                AgentId: reporter.AgentId,
 1138                ParentAgentId: null,
 1139                Depth: reporter.Depth,
 1140                SequenceNumber: reporter.NextSequence(),
 1141                LimitType: "context_window",
 1142                CurrentValue: estimatedTokens,
 1143                MaxValue: _maxContextTokens));
 144
 1145            if (_pruneOnOverflow)
 146            {
 0147                PruneMessages(messageList, estimatedTokens);
 148            }
 149        }
 150
 4151        await foreach (var update in base.GetStreamingResponseAsync(messageList, options, cancellationToken)
 1152            .ConfigureAwait(false))
 153        {
 1154            yield return update;
 155        }
 1156    }
 157
 158    private long EstimateTokenCount(IEnumerable<ChatMessage> messages)
 159    {
 6160        long totalChars = 0;
 32161        foreach (var msg in messages)
 162        {
 40163            foreach (var content in msg.Contents)
 164            {
 10165                if (content is TextContent tc && tc.Text is { } text)
 166                {
 9167                    totalChars += text.Length;
 168                }
 1169                else if (content is FunctionCallContent fc)
 170                {
 1171                    totalChars += fc.Name.Length + 50;
 1172                    if (fc.Arguments is { } args)
 173                    {
 4174                        foreach (var (_, value) in args)
 175                        {
 1176                            totalChars += value?.ToString()?.Length ?? 0;
 177                        }
 178                    }
 179                }
 0180                else if (content is FunctionResultContent fr)
 181                {
 0182                    totalChars += ToolResultSerializer.Serialize(fr.Result).Length;
 183                }
 184            }
 185        }
 186
 6187        return totalChars / CharsPerToken;
 188    }
 189
 190    private void PruneMessages(List<ChatMessage> messages, long currentTokens)
 191    {
 192        // Remove oldest non-system messages until under the limit.
 193        // Never remove the system message (index 0) or the last user message.
 3194        while (currentTokens > _maxContextTokens && messages.Count > 2)
 195        {
 2196            var idx = messages[0].Role == ChatRole.System ? 1 : 0;
 2197            if (idx >= messages.Count - 1) break;
 198
 2199            var removed = messages[idx];
 2200            long removedTokens = 0;
 8201            foreach (var content in removed.Contents)
 202            {
 2203                if (content is TextContent tc && tc.Text is { } text)
 2204                    removedTokens += text.Length / CharsPerToken;
 0205                else if (content is FunctionCallContent fc)
 206                {
 0207                    removedTokens += (fc.Name.Length + 50) / CharsPerToken;
 0208                    if (fc.Arguments is { } args)
 209                    {
 0210                        foreach (var (_, value) in args)
 0211                            removedTokens += (value?.ToString()?.Length ?? 0) / CharsPerToken;
 212                    }
 213                }
 0214                else if (content is FunctionResultContent fr)
 0215                    removedTokens += ToolResultSerializer.Serialize(fr.Result).Length / CharsPerToken;
 216            }
 217
 2218            messages.RemoveAt(idx);
 2219            currentTokens -= removedTokens;
 220        }
 1221    }
 222}