| | | 1 | | using System.Collections.Concurrent; |
| | | 2 | | using System.Diagnostics; |
| | | 3 | | using System.Text; |
| | | 4 | | |
| | | 5 | | using Microsoft.Agents.AI; |
| | | 6 | | using Microsoft.Agents.AI.Workflows; |
| | | 7 | | using Microsoft.Extensions.AI; |
| | | 8 | | |
| | | 9 | | using NexusLabs.Needlr.AgentFramework; |
| | | 10 | | using NexusLabs.Needlr.AgentFramework.Diagnostics; |
| | | 11 | | using NexusLabs.Needlr.AgentFramework.Iterative; |
| | | 12 | | |
| | | 13 | | using ProgressEvents = NexusLabs.Needlr.AgentFramework.Progress; |
| | | 14 | | |
| | | 15 | | namespace NexusLabs.Needlr.AgentFramework.Workflows; |
| | | 16 | | |
| | | 17 | | /// <summary> |
| | | 18 | | /// Executes DAG/graph workflows using either MAF's native BSP engine or the |
| | | 19 | | /// Needlr-native executor, depending on declared topology. |
| | | 20 | | /// </summary> |
| | | 21 | | /// <remarks> |
| | | 22 | | /// All dependencies are resolved via DI — no reflection into private fields. |
| | | 23 | | /// </remarks> |
| | | 24 | | internal sealed class GraphWorkflowRunner : IGraphWorkflowRunner |
| | | 25 | | { |
| | | 26 | | private readonly IWorkflowFactory _workflowFactory; |
| | | 27 | | private readonly IAgentFactory _agentFactory; |
| | | 28 | | private readonly IChatClientAccessor _chatClientAccessor; |
| | | 29 | | private readonly IAgentDiagnosticsAccessor? _diagnosticsAccessor; |
| | | 30 | | private readonly GraphTopologyProvider _topologyProvider; |
| | | 31 | | private readonly GraphEdgeRouter _edgeRouter; |
| | | 32 | | |
| | 37 | 33 | | public GraphWorkflowRunner( |
| | 37 | 34 | | IWorkflowFactory workflowFactory, |
| | 37 | 35 | | IAgentFactory agentFactory, |
| | 37 | 36 | | IChatClientAccessor chatClientAccessor, |
| | 37 | 37 | | GraphTopologyProvider topologyProvider, |
| | 37 | 38 | | GraphEdgeRouter edgeRouter, |
| | 37 | 39 | | IAgentDiagnosticsAccessor? diagnosticsAccessor = null) |
| | | 40 | | { |
| | 37 | 41 | | _workflowFactory = workflowFactory; |
| | 37 | 42 | | _agentFactory = agentFactory; |
| | 37 | 43 | | _chatClientAccessor = chatClientAccessor; |
| | 37 | 44 | | _topologyProvider = topologyProvider; |
| | 37 | 45 | | _edgeRouter = edgeRouter; |
| | 37 | 46 | | _diagnosticsAccessor = diagnosticsAccessor; |
| | 37 | 47 | | } |
| | | 48 | | |
| | | 49 | | public async Task<IDagRunResult> RunGraphAsync( |
| | | 50 | | string graphName, |
| | | 51 | | string input, |
| | | 52 | | ProgressEvents.IProgressReporter? progress = null, |
| | | 53 | | CancellationToken cancellationToken = default) |
| | | 54 | | { |
| | 31 | 55 | | ArgumentException.ThrowIfNullOrWhiteSpace(graphName); |
| | 31 | 56 | | ArgumentException.ThrowIfNullOrWhiteSpace(input); |
| | | 57 | | |
| | 31 | 58 | | var topology = _topologyProvider.GetTopology(graphName); |
| | | 59 | | |
| | 31 | 60 | | if (!topology.RequiresNeedlrExecutor) |
| | | 61 | | { |
| | 1 | 62 | | return await RunWaitAllWithDiagnosticsAsync( |
| | 1 | 63 | | topology, graphName, input, progress, cancellationToken); |
| | | 64 | | } |
| | | 65 | | |
| | 30 | 66 | | return await RunWithNeedlrExecutorAsync( |
| | 30 | 67 | | topology, graphName, input, progress, cancellationToken); |
| | 31 | 68 | | } |
| | | 69 | | |
| | | 70 | | private async Task<IDagRunResult> RunWaitAllWithDiagnosticsAsync( |
| | | 71 | | GraphTopology topology, |
| | | 72 | | string graphName, |
| | | 73 | | string input, |
| | | 74 | | ProgressEvents.IProgressReporter? progress, |
| | | 75 | | CancellationToken cancellationToken) |
| | | 76 | | { |
| | 1 | 77 | | var workflow = _workflowFactory.CreateGraphWorkflow(graphName); |
| | 1 | 78 | | var dagStart = Stopwatch.GetTimestamp(); |
| | | 79 | | |
| | 1 | 80 | | var responses = new Dictionary<string, StringBuilder>(); |
| | 1 | 81 | | var invocationTimestamps = new List<(string ExecutorId, DateTimeOffset At)>(); |
| | 1 | 82 | | bool succeeded = true; |
| | 1 | 83 | | string? errorMessage = null; |
| | 1 | 84 | | Exception? caughtException = null; |
| | | 85 | | |
| | 1 | 86 | | var collector = _diagnosticsAccessor?.CompletionCollector; |
| | 1 | 87 | | var toolCollector = _diagnosticsAccessor?.ToolCallCollector; |
| | 1 | 88 | | collector?.DrainCompletions(); |
| | 1 | 89 | | toolCollector?.DrainToolCalls(); |
| | | 90 | | |
| | 1 | 91 | | progress?.Report(new ProgressEvents.WorkflowStartedEvent( |
| | 1 | 92 | | DateTimeOffset.UtcNow, |
| | 1 | 93 | | progress.WorkflowId, |
| | 1 | 94 | | progress.AgentId, |
| | 1 | 95 | | null, |
| | 1 | 96 | | progress.Depth, |
| | 1 | 97 | | progress.NextSequence())); |
| | | 98 | | |
| | | 99 | | try |
| | | 100 | | { |
| | 1 | 101 | | IDisposable? captureScope = _diagnosticsAccessor?.BeginCapture(); |
| | | 102 | | try |
| | | 103 | | { |
| | 1 | 104 | | await using var run = await InProcessExecution.RunStreamingAsync( |
| | 1 | 105 | | workflow, |
| | 1 | 106 | | new ChatMessage(ChatRole.User, input), |
| | 1 | 107 | | cancellationToken: cancellationToken); |
| | | 108 | | |
| | 1 | 109 | | await run.TrySendMessageAsync(new TurnToken(emitEvents: true)); |
| | | 110 | | |
| | 1 | 111 | | await using var budgetReg = cancellationToken.CanBeCanceled |
| | 0 | 112 | | ? cancellationToken.Register(() => _ = run.CancelRunAsync()) |
| | 1 | 113 | | : default(CancellationTokenRegistration?); |
| | | 114 | | |
| | 16 | 115 | | await foreach (var evt in run.WatchStreamAsync(cancellationToken)) |
| | | 116 | | { |
| | 7 | 117 | | if (evt is ExecutorInvokedEvent invoked) |
| | | 118 | | { |
| | 2 | 119 | | var id = invoked.ExecutorId ?? "unknown"; |
| | 2 | 120 | | invocationTimestamps.Add((id, DateTimeOffset.UtcNow)); |
| | | 121 | | |
| | 2 | 122 | | progress?.Report(new ProgressEvents.AgentInvokedEvent( |
| | 2 | 123 | | DateTimeOffset.UtcNow, |
| | 2 | 124 | | progress.WorkflowId, |
| | 2 | 125 | | id, |
| | 2 | 126 | | null, |
| | 2 | 127 | | progress.Depth + 1, |
| | 2 | 128 | | progress.NextSequence(), |
| | 2 | 129 | | AgentName: id, |
| | 2 | 130 | | GraphName: graphName, |
| | 2 | 131 | | NodeId: id)); |
| | 0 | 132 | | continue; |
| | | 133 | | } |
| | | 134 | | |
| | 5 | 135 | | if (evt is ExecutorFailedEvent executorFailed) |
| | | 136 | | { |
| | 1 | 137 | | succeeded = false; |
| | 1 | 138 | | errorMessage = executorFailed.Data?.Message; |
| | 1 | 139 | | var failedId = executorFailed.ExecutorId ?? "unknown"; |
| | | 140 | | |
| | 1 | 141 | | progress?.Report(new ProgressEvents.AgentFailedEvent( |
| | 1 | 142 | | DateTimeOffset.UtcNow, |
| | 1 | 143 | | progress.WorkflowId, |
| | 1 | 144 | | failedId, |
| | 1 | 145 | | null, |
| | 1 | 146 | | progress.Depth + 1, |
| | 1 | 147 | | progress.NextSequence(), |
| | 1 | 148 | | AgentName: failedId, |
| | 1 | 149 | | ErrorMessage: executorFailed.Data?.Message ?? "unknown error")); |
| | 0 | 150 | | continue; |
| | | 151 | | } |
| | | 152 | | |
| | 4 | 153 | | if (evt is WorkflowErrorEvent workflowError) |
| | | 154 | | { |
| | 1 | 155 | | succeeded = false; |
| | 1 | 156 | | errorMessage = workflowError.Exception?.Message; |
| | 1 | 157 | | continue; |
| | | 158 | | } |
| | | 159 | | |
| | 3 | 160 | | if (evt is not AgentResponseUpdateEvent update |
| | 3 | 161 | | || update.ExecutorId is null |
| | 3 | 162 | | || update.Data is null) |
| | | 163 | | { |
| | | 164 | | continue; |
| | | 165 | | } |
| | | 166 | | |
| | 0 | 167 | | var text = update.Data.ToString(); |
| | 0 | 168 | | if (string.IsNullOrEmpty(text)) |
| | | 169 | | { |
| | | 170 | | continue; |
| | | 171 | | } |
| | | 172 | | |
| | 0 | 173 | | if (!responses.TryGetValue(update.ExecutorId, out var sb)) |
| | | 174 | | { |
| | 0 | 175 | | responses[update.ExecutorId] = sb = new StringBuilder(); |
| | | 176 | | } |
| | | 177 | | |
| | 0 | 178 | | sb.Append(text); |
| | | 179 | | |
| | 0 | 180 | | progress?.Report(new ProgressEvents.AgentResponseChunkEvent( |
| | 0 | 181 | | DateTimeOffset.UtcNow, |
| | 0 | 182 | | progress.WorkflowId, |
| | 0 | 183 | | update.ExecutorId, |
| | 0 | 184 | | null, |
| | 0 | 185 | | progress.Depth + 1, |
| | 0 | 186 | | progress.NextSequence(), |
| | 0 | 187 | | AgentName: update.ExecutorId, |
| | 0 | 188 | | Text: text)); |
| | | 189 | | } |
| | 1 | 190 | | } |
| | | 191 | | finally |
| | | 192 | | { |
| | 1 | 193 | | captureScope?.Dispose(); |
| | | 194 | | } |
| | 1 | 195 | | } |
| | 0 | 196 | | catch (Exception ex) |
| | | 197 | | { |
| | 0 | 198 | | succeeded = false; |
| | 0 | 199 | | errorMessage = ex.Message; |
| | 0 | 200 | | caughtException = ex; |
| | 0 | 201 | | } |
| | | 202 | | |
| | 1 | 203 | | cancellationToken.ThrowIfCancellationRequested(); |
| | | 204 | | |
| | 1 | 205 | | var totalDuration = Stopwatch.GetElapsedTime(dagStart); |
| | | 206 | | |
| | 1 | 207 | | var allCompletions = collector?.DrainCompletions() |
| | 0 | 208 | | ?.OrderBy(c => c.StartedAt).ToList() |
| | 1 | 209 | | ?? []; |
| | 1 | 210 | | var allToolCalls = toolCollector?.DrainToolCalls() |
| | 0 | 211 | | ?.OrderBy(t => t.StartedAt).ToList() |
| | 1 | 212 | | ?? []; |
| | | 213 | | |
| | 3 | 214 | | var invokedIds = invocationTimestamps.Select(inv => inv.ExecutorId).ToHashSet(); |
| | 1 | 215 | | var respondedIds = responses.Keys.ToHashSet(); |
| | 1 | 216 | | var agentIds = invokedIds.Union(respondedIds).Distinct().ToList(); |
| | | 217 | | |
| | | 218 | | // Build a mapping from agent IDs (executor IDs from MAF) to their |
| | | 219 | | // corresponding Type for namespace-safe edge lookups. |
| | 1 | 220 | | var agentIdToType = new Dictionary<string, Type>(StringComparer.Ordinal); |
| | 4 | 221 | | foreach (var id in agentIds) |
| | | 222 | | { |
| | 1 | 223 | | var matchedType = topology.AllTypes.FirstOrDefault(t => |
| | 2 | 224 | | id.Equals(t.Name, StringComparison.Ordinal) || |
| | 2 | 225 | | id.StartsWith(t.Name + "_", StringComparison.Ordinal)); |
| | 1 | 226 | | if (matchedType is not null) |
| | | 227 | | { |
| | 1 | 228 | | agentIdToType[id] = matchedType; |
| | | 229 | | } |
| | | 230 | | } |
| | | 231 | | |
| | 1 | 232 | | var completionsByAgent = new Dictionary<string, List<ChatCompletionDiagnostics>>(); |
| | 1 | 233 | | var toolCallsByAgent = new Dictionary<string, List<ToolCallDiagnostics>>(); |
| | 4 | 234 | | foreach (var id in agentIds) |
| | | 235 | | { |
| | 1 | 236 | | completionsByAgent[id] = []; |
| | 1 | 237 | | toolCallsByAgent[id] = []; |
| | | 238 | | } |
| | | 239 | | |
| | 2 | 240 | | foreach (var c in allCompletions) |
| | | 241 | | { |
| | 0 | 242 | | var matched = agentIds.FirstOrDefault(id => |
| | 0 | 243 | | c.AgentName is not null && |
| | 0 | 244 | | (id.Equals(c.AgentName, StringComparison.Ordinal) || |
| | 0 | 245 | | id.StartsWith(c.AgentName + "_", StringComparison.Ordinal))); |
| | 0 | 246 | | if (matched is not null) |
| | | 247 | | { |
| | 0 | 248 | | completionsByAgent[matched].Add(c); |
| | | 249 | | } |
| | | 250 | | } |
| | | 251 | | |
| | 2 | 252 | | foreach (var tc in allToolCalls) |
| | | 253 | | { |
| | 0 | 254 | | var matched = agentIds.FirstOrDefault(id => |
| | 0 | 255 | | tc.AgentName is not null && |
| | 0 | 256 | | (id.Equals(tc.AgentName, StringComparison.Ordinal) || |
| | 0 | 257 | | id.StartsWith(tc.AgentName + "_", StringComparison.Ordinal))); |
| | 0 | 258 | | if (matched is not null) |
| | | 259 | | { |
| | 0 | 260 | | toolCallsByAgent[matched].Add(tc); |
| | | 261 | | } |
| | | 262 | | } |
| | | 263 | | |
| | 1 | 264 | | var nodeResults = new Dictionary<string, IDagNodeResult>(); |
| | 1 | 265 | | var stages = new List<IAgentStageResult>(); |
| | 1 | 266 | | var branchResults = new Dictionary<string, IReadOnlyList<IAgentStageResult>>(); |
| | | 267 | | |
| | 4 | 268 | | foreach (var agentId in agentIds) |
| | | 269 | | { |
| | 1 | 270 | | var responseText = responses.TryGetValue(agentId, out var respSb) |
| | 1 | 271 | | ? respSb.ToString() |
| | 1 | 272 | | : string.Empty; |
| | | 273 | | |
| | 1 | 274 | | ChatResponse? finalResponse = !string.IsNullOrEmpty(responseText) |
| | 1 | 275 | | ? new ChatResponse(new ChatMessage(ChatRole.Assistant, responseText)) |
| | 1 | 276 | | : null; |
| | | 277 | | |
| | 1 | 278 | | var agentCompletions = completionsByAgent.GetValueOrDefault(agentId, []); |
| | 1 | 279 | | var agentToolCalls = toolCallsByAgent.GetValueOrDefault(agentId, []); |
| | | 280 | | |
| | | 281 | | TimeSpan nodeDuration; |
| | | 282 | | DateTimeOffset nodeStartedAt; |
| | 1 | 283 | | if (agentCompletions.Count > 0) |
| | | 284 | | { |
| | 0 | 285 | | nodeStartedAt = agentCompletions[0].StartedAt; |
| | 0 | 286 | | nodeDuration = agentCompletions[^1].CompletedAt - nodeStartedAt; |
| | | 287 | | } |
| | | 288 | | else |
| | | 289 | | { |
| | 1 | 290 | | var invTs = invocationTimestamps |
| | 2 | 291 | | .FirstOrDefault(x => x.ExecutorId == agentId).At; |
| | 1 | 292 | | nodeStartedAt = invTs != default ? invTs : DateTimeOffset.UtcNow; |
| | 1 | 293 | | nodeDuration = totalDuration / Math.Max(agentIds.Count, 1); |
| | | 294 | | } |
| | | 295 | | |
| | 1 | 296 | | var dagStartTime = DateTimeOffset.UtcNow - totalDuration; |
| | 1 | 297 | | var startOffset = nodeStartedAt - dagStartTime; |
| | 1 | 298 | | if (startOffset < TimeSpan.Zero) |
| | | 299 | | { |
| | 0 | 300 | | startOffset = TimeSpan.Zero; |
| | | 301 | | } |
| | | 302 | | |
| | 1 | 303 | | var tokenUsage = new TokenUsage( |
| | 0 | 304 | | InputTokens: agentCompletions.Sum(c => c.Tokens.InputTokens), |
| | 0 | 305 | | OutputTokens: agentCompletions.Sum(c => c.Tokens.OutputTokens), |
| | 0 | 306 | | TotalTokens: agentCompletions.Sum(c => c.Tokens.TotalTokens), |
| | 0 | 307 | | CachedInputTokens: agentCompletions.Sum(c => c.Tokens.CachedInputTokens), |
| | 1 | 308 | | ReasoningTokens: agentCompletions.Sum(c => c.Tokens.ReasoningTokens)); |
| | | 309 | | |
| | 1 | 310 | | IAgentRunDiagnostics diag = new AgentRunDiagnostics( |
| | 1 | 311 | | AgentName: agentId, |
| | 1 | 312 | | TotalDuration: nodeDuration, |
| | 1 | 313 | | AggregateTokenUsage: tokenUsage, |
| | 1 | 314 | | ChatCompletions: agentCompletions, |
| | 1 | 315 | | ToolCalls: agentToolCalls, |
| | 1 | 316 | | TotalInputMessages: 0, |
| | 1 | 317 | | TotalOutputMessages: 0, |
| | 1 | 318 | | InputMessages: [], |
| | 1 | 319 | | OutputResponse: null, |
| | 1 | 320 | | Succeeded: true, |
| | 1 | 321 | | ErrorMessage: null, |
| | 1 | 322 | | StartedAt: nodeStartedAt, |
| | 1 | 323 | | CompletedAt: nodeStartedAt + nodeDuration); |
| | | 324 | | |
| | | 325 | | // Resolve Type-based edges to FullName strings for the public interface. |
| | 1 | 326 | | var resolvedType = agentIdToType.GetValueOrDefault(agentId); |
| | 1 | 327 | | var inEdges = resolvedType is not null |
| | 1 | 328 | | ? topology.InboundEdges.GetValueOrDefault(resolvedType, []) |
| | 0 | 329 | | .Select(t => t.FullName ?? t.Name).ToList() |
| | 1 | 330 | | : (IReadOnlyList<string>)[]; |
| | 1 | 331 | | var outEdges = resolvedType is not null |
| | 1 | 332 | | ? topology.OutboundEdges.GetValueOrDefault(resolvedType, []) |
| | 2 | 333 | | .Select(t => t.FullName ?? t.Name).ToList() |
| | 1 | 334 | | : (IReadOnlyList<string>)[]; |
| | | 335 | | |
| | 1 | 336 | | var nodeResult = new DagNodeResult( |
| | 1 | 337 | | nodeId: agentId, |
| | 1 | 338 | | agentName: agentId, |
| | 1 | 339 | | kind: NodeKind.Agent, |
| | 1 | 340 | | diagnostics: diag, |
| | 1 | 341 | | finalResponse: finalResponse, |
| | 1 | 342 | | inboundEdges: inEdges, |
| | 1 | 343 | | outboundEdges: outEdges, |
| | 1 | 344 | | startOffset: startOffset, |
| | 1 | 345 | | duration: nodeDuration); |
| | 1 | 346 | | nodeResults[agentId] = nodeResult; |
| | | 347 | | |
| | 1 | 348 | | var stageResult = new AgentStageResult(agentId, finalResponse, diag); |
| | 1 | 349 | | stages.Add(stageResult); |
| | | 350 | | |
| | 1 | 351 | | progress?.Report(new ProgressEvents.AgentCompletedEvent( |
| | 1 | 352 | | DateTimeOffset.UtcNow, |
| | 1 | 353 | | progress.WorkflowId, |
| | 1 | 354 | | agentId, |
| | 1 | 355 | | null, |
| | 1 | 356 | | progress.Depth + 1, |
| | 1 | 357 | | progress.NextSequence(), |
| | 1 | 358 | | AgentName: agentId, |
| | 1 | 359 | | Duration: nodeDuration, |
| | 1 | 360 | | TotalTokens: tokenUsage.TotalTokens, |
| | 1 | 361 | | InputTokens: tokenUsage.InputTokens, |
| | 1 | 362 | | OutputTokens: tokenUsage.OutputTokens)); |
| | | 363 | | } |
| | | 364 | | |
| | 1 | 365 | | var branchIndex = 0; |
| | 1 | 366 | | var nodesByInbound = stages |
| | 1 | 367 | | .Where(s => nodeResults.ContainsKey(s.AgentName)) |
| | 1 | 368 | | .GroupBy(s => string.Join(",", nodeResults[s.AgentName].InboundEdges)) |
| | 2 | 369 | | .Where(g => g.Count() > 1); |
| | 2 | 370 | | foreach (var group in nodesByInbound) |
| | | 371 | | { |
| | 0 | 372 | | branchResults[$"branch-{branchIndex++}"] = group.ToList(); |
| | | 373 | | } |
| | | 374 | | |
| | 1 | 375 | | progress?.Report(new ProgressEvents.WorkflowCompletedEvent( |
| | 1 | 376 | | DateTimeOffset.UtcNow, |
| | 1 | 377 | | progress.WorkflowId, |
| | 1 | 378 | | progress.AgentId, |
| | 1 | 379 | | null, |
| | 1 | 380 | | progress.Depth, |
| | 1 | 381 | | progress.NextSequence(), |
| | 1 | 382 | | Succeeded: succeeded, |
| | 1 | 383 | | ErrorMessage: errorMessage, |
| | 1 | 384 | | TotalDuration: totalDuration)); |
| | | 385 | | |
| | 1 | 386 | | return new DagRunResult( |
| | 1 | 387 | | stages: stages, |
| | 1 | 388 | | nodeResults: nodeResults, |
| | 1 | 389 | | branchResults: branchResults, |
| | 1 | 390 | | totalDuration: totalDuration, |
| | 1 | 391 | | succeeded: succeeded, |
| | 1 | 392 | | errorMessage: errorMessage, |
| | 1 | 393 | | exception: caughtException); |
| | 1 | 394 | | } |
| | | 395 | | |
| | | 396 | | private async Task<IDagRunResult> RunWithNeedlrExecutorAsync( |
| | | 397 | | GraphTopology topology, |
| | | 398 | | string graphName, |
| | | 399 | | string input, |
| | | 400 | | ProgressEvents.IProgressReporter? progress, |
| | | 401 | | CancellationToken cancellationToken) |
| | | 402 | | { |
| | 30 | 403 | | if (topology.EntryType is null) |
| | | 404 | | { |
| | 0 | 405 | | throw new InvalidOperationException( |
| | 0 | 406 | | $"Cannot run graph workflow '{graphName}': no entry point found."); |
| | | 407 | | } |
| | | 408 | | |
| | 30 | 409 | | var agents = new Dictionary<Type, AIAgent>(); |
| | 282 | 410 | | foreach (var type in topology.AllTypes) |
| | | 411 | | { |
| | 111 | 412 | | agents[type] = _agentFactory.CreateAgent(type.FullName ?? type.Name); |
| | | 413 | | } |
| | | 414 | | |
| | 30 | 415 | | var completionSources = new Dictionary<Type, TaskCompletionSource<string>>(); |
| | 282 | 416 | | foreach (var type in topology.AllTypes) |
| | | 417 | | { |
| | 111 | 418 | | completionSources[type] = new TaskCompletionSource<string>(); |
| | | 419 | | } |
| | | 420 | | |
| | 30 | 421 | | var skippedNodes = new ConcurrentDictionary<Type, bool>(); |
| | 30 | 422 | | var dagStart = Stopwatch.GetTimestamp(); |
| | 30 | 423 | | var nodeTimings = new ConcurrentDictionary<Type, (TimeSpan StartOffset, TimeSpan Duration)>(); |
| | 30 | 424 | | var nodeDiagnostics = new ConcurrentDictionary<Type, IAgentRunDiagnostics?>(); |
| | 30 | 425 | | var nodeExceptions = new ConcurrentDictionary<Type, Exception>(); |
| | | 426 | | |
| | 30 | 427 | | progress?.Report(new ProgressEvents.WorkflowStartedEvent( |
| | 30 | 428 | | DateTimeOffset.UtcNow, |
| | 30 | 429 | | progress?.WorkflowId ?? string.Empty, |
| | 30 | 430 | | progress?.AgentId, |
| | 30 | 431 | | null, |
| | 30 | 432 | | progress?.Depth ?? 0, |
| | 30 | 433 | | progress?.NextSequence() ?? 0)); |
| | | 434 | | |
| | 30 | 435 | | var routingChatClient = _chatClientAccessor.ChatClient; |
| | 30 | 436 | | var nodeTasks = new List<Task>(); |
| | | 437 | | |
| | | 438 | | // Create a linked CTS for each WaitAny join node so that remaining |
| | | 439 | | // branches can be cancelled once the first valid result arrives. |
| | 30 | 440 | | var waitAnyCtsMap = new ConcurrentDictionary<Type, CancellationTokenSource>(); |
| | 282 | 441 | | foreach (var type in topology.AllTypes) |
| | | 442 | | { |
| | 111 | 443 | | if (topology.JoinModes.GetValueOrDefault(type, GraphJoinMode.WaitAll) == GraphJoinMode.WaitAny) |
| | | 444 | | { |
| | 19 | 445 | | waitAnyCtsMap[type] = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); |
| | | 446 | | } |
| | | 447 | | } |
| | | 448 | | |
| | | 449 | | // Pre-compute the effective cancellation token for each node. |
| | | 450 | | // Nodes that are dependencies of a WaitAny join use the linked token |
| | | 451 | | // so they can be cancelled when the winning branch completes. |
| | 30 | 452 | | var nodeEffectiveTokens = new Dictionary<Type, CancellationToken>(); |
| | 282 | 453 | | foreach (var type in topology.AllTypes) |
| | | 454 | | { |
| | 111 | 455 | | CancellationToken effectiveToken = cancellationToken; |
| | 344 | 456 | | foreach (var (waitAnyType, cts) in waitAnyCtsMap) |
| | | 457 | | { |
| | 78 | 458 | | var waitAnyDeps = topology.IncomingTypes.GetValueOrDefault(waitAnyType, []); |
| | 78 | 459 | | if (waitAnyDeps.Contains(type)) |
| | | 460 | | { |
| | 34 | 461 | | effectiveToken = cts.Token; |
| | 34 | 462 | | break; |
| | | 463 | | } |
| | | 464 | | } |
| | | 465 | | |
| | 111 | 466 | | nodeEffectiveTokens[type] = effectiveToken; |
| | | 467 | | } |
| | | 468 | | |
| | 282 | 469 | | foreach (var type in topology.AllTypes) |
| | | 470 | | { |
| | 111 | 471 | | var nodeType = type; |
| | 111 | 472 | | var deps = topology.IncomingTypes.GetValueOrDefault(nodeType, []); |
| | 111 | 473 | | var joinMode = topology.JoinModes.GetValueOrDefault(nodeType, GraphJoinMode.WaitAll); |
| | | 474 | | |
| | 111 | 475 | | nodeTasks.Add(Task.Run(async () => |
| | 111 | 476 | | { |
| | 111 | 477 | | try |
| | 111 | 478 | | { |
| | 111 | 479 | | string nodeInput; |
| | 107 | 480 | | if (nodeType == topology.EntryType) |
| | 111 | 481 | | { |
| | 29 | 482 | | nodeInput = input; |
| | 111 | 483 | | } |
| | 78 | 484 | | else if (deps.Count == 0) |
| | 111 | 485 | | { |
| | 0 | 486 | | nodeInput = input; |
| | 111 | 487 | | } |
| | 111 | 488 | | else |
| | 111 | 489 | | { |
| | 78 | 490 | | if (joinMode == GraphJoinMode.WaitAny) |
| | 111 | 491 | | { |
| | 18 | 492 | | var taskToDepType = new Dictionary<Task<string>, Type>(); |
| | 100 | 493 | | foreach (var dep in deps) |
| | 111 | 494 | | { |
| | 32 | 495 | | taskToDepType[completionSources[dep].Task] = dep; |
| | 111 | 496 | | } |
| | 111 | 497 | | |
| | 18 | 498 | | var remaining = new HashSet<Task<string>>(taskToDepType.Keys); |
| | 18 | 499 | | nodeInput = input; |
| | 111 | 500 | | |
| | 25 | 501 | | while (remaining.Count > 0) |
| | 111 | 502 | | { |
| | 25 | 503 | | var first = await Task.WhenAny(remaining).WaitAsync(cancellationToken); |
| | 25 | 504 | | remaining.Remove(first); |
| | 111 | 505 | | |
| | 25 | 506 | | var depType = taskToDepType[first]; |
| | 25 | 507 | | var result = await first; |
| | 111 | 508 | | |
| | 23 | 509 | | if (!skippedNodes.ContainsKey(depType) && !string.IsNullOrWhiteSpace(result)) |
| | 111 | 510 | | { |
| | 16 | 511 | | nodeInput = result; |
| | 111 | 512 | | |
| | 111 | 513 | | // Cancel remaining branches for this WaitAny scope. |
| | 16 | 514 | | if (waitAnyCtsMap.TryGetValue(nodeType, out var waitAnyCts)) |
| | 111 | 515 | | { |
| | 16 | 516 | | waitAnyCts.Cancel(); |
| | 111 | 517 | | } |
| | 111 | 518 | | |
| | 16 | 519 | | break; |
| | 111 | 520 | | } |
| | 7 | 521 | | } |
| | 16 | 522 | | } |
| | 111 | 523 | | else |
| | 111 | 524 | | { |
| | 60 | 525 | | var pendingResults = new List<string>(); |
| | 245 | 526 | | foreach (var dep in deps) |
| | 111 | 527 | | { |
| | 65 | 528 | | if (skippedNodes.ContainsKey(dep)) |
| | 111 | 529 | | continue; |
| | 111 | 530 | | |
| | 111 | 531 | | try |
| | 111 | 532 | | { |
| | 65 | 533 | | var depResult = await completionSources[dep].Task.WaitAsync(cancellationToken); |
| | 60 | 534 | | if (!string.IsNullOrEmpty(depResult)) |
| | 60 | 535 | | pendingResults.Add(depResult); |
| | 60 | 536 | | } |
| | 5 | 537 | | catch when (IsOptionalEdge(dep, nodeType, topology)) |
| | 111 | 538 | | { |
| | 111 | 539 | | // Optional upstream failed — treat as degraded. |
| | 0 | 540 | | } |
| | 60 | 541 | | } |
| | 111 | 542 | | |
| | 55 | 543 | | if (pendingResults.Count >= 2 && topology.ReducerFunc is not null) |
| | 111 | 544 | | { |
| | 5 | 545 | | var reducerStart = Stopwatch.GetTimestamp(); |
| | 5 | 546 | | nodeInput = topology.ReducerFunc(pendingResults); |
| | 4 | 547 | | var reducerDuration = Stopwatch.GetElapsedTime(reducerStart); |
| | 111 | 548 | | |
| | 4 | 549 | | progress?.Report(new ProgressEvents.ReducerNodeInvokedEvent( |
| | 4 | 550 | | DateTimeOffset.UtcNow, |
| | 4 | 551 | | progress.WorkflowId, |
| | 4 | 552 | | progress.AgentId, |
| | 4 | 553 | | null, |
| | 4 | 554 | | progress.Depth + 1, |
| | 4 | 555 | | progress.NextSequence(), |
| | 4 | 556 | | NodeId: topology.ReducerType?.FullName ?? topology.ReducerType?.Name ?? "reducer", |
| | 4 | 557 | | GraphName: graphName, |
| | 4 | 558 | | BranchId: null, |
| | 4 | 559 | | InputBranchCount: pendingResults.Count, |
| | 4 | 560 | | Duration: reducerDuration)); |
| | 111 | 561 | | } |
| | 50 | 562 | | else if (pendingResults.Count > 0) |
| | 111 | 563 | | { |
| | 50 | 564 | | nodeInput = string.Join("\n\n---\n\n", pendingResults); |
| | 111 | 565 | | } |
| | 111 | 566 | | else |
| | 111 | 567 | | { |
| | 0 | 568 | | nodeInput = input; |
| | 111 | 569 | | } |
| | 54 | 570 | | } |
| | 111 | 571 | | } |
| | 111 | 572 | | |
| | 99 | 573 | | if (skippedNodes.ContainsKey(nodeType)) |
| | 111 | 574 | | { |
| | 12 | 575 | | return; |
| | 111 | 576 | | } |
| | 111 | 577 | | |
| | 87 | 578 | | var agent = agents[nodeType]; |
| | 87 | 579 | | var agentName = agent.Name ?? nodeType.Name; |
| | 111 | 580 | | |
| | 87 | 581 | | progress?.Report(new ProgressEvents.AgentInvokedEvent( |
| | 87 | 582 | | DateTimeOffset.UtcNow, |
| | 87 | 583 | | progress.WorkflowId, |
| | 87 | 584 | | agentName, |
| | 87 | 585 | | null, |
| | 87 | 586 | | progress.Depth + 1, |
| | 87 | 587 | | progress.NextSequence(), |
| | 87 | 588 | | AgentName: agentName, |
| | 87 | 589 | | GraphName: graphName, |
| | 87 | 590 | | NodeId: nodeType.FullName ?? nodeType.Name)); |
| | 111 | 591 | | |
| | 87 | 592 | | var nodeStart = Stopwatch.GetTimestamp(); |
| | 87 | 593 | | using var diagnosticsBuilder = AgentRunDiagnosticsBuilder.StartNew(agentName); |
| | 87 | 594 | | var nodeToken = nodeEffectiveTokens[nodeType]; |
| | 87 | 595 | | var response = await agent.RunAsync(nodeInput, cancellationToken: nodeToken); |
| | 81 | 596 | | var nodeElapsed = Stopwatch.GetElapsedTime(nodeStart); |
| | 81 | 597 | | var startOffset = Stopwatch.GetElapsedTime(dagStart, nodeStart); |
| | 111 | 598 | | |
| | 81 | 599 | | var diag = diagnosticsBuilder.Build(); |
| | 81 | 600 | | nodeDiagnostics[nodeType] = diag; |
| | 81 | 601 | | nodeTimings[nodeType] = (startOffset, nodeElapsed); |
| | 111 | 602 | | |
| | 81 | 603 | | var text = string.Join("\n", response.Messages |
| | 81 | 604 | | .Where(m => !string.IsNullOrEmpty(m.Text)) |
| | 162 | 605 | | .Select(m => m.Text)); |
| | 111 | 606 | | |
| | 81 | 607 | | if (topology.OutgoingEdgesBySource.TryGetValue(nodeType, out var outEdges) && outEdges.Count > 0) |
| | 111 | 608 | | { |
| | 59 | 609 | | var conditionInput = !string.IsNullOrWhiteSpace(text) ? text : nodeInput; |
| | 59 | 610 | | var resolvedEdges = await _edgeRouter.ResolveOutgoingEdgesAsync( |
| | 59 | 611 | | nodeType, conditionInput, topology, routingChatClient, nodeToken); |
| | 124 | 612 | | var resolvedTargets = resolvedEdges.Select(e => e.Target).ToHashSet(); |
| | 111 | 613 | | |
| | 272 | 614 | | foreach (var edge in outEdges) |
| | 111 | 615 | | { |
| | 80 | 616 | | if (!resolvedTargets.Contains(edge.Target)) |
| | 111 | 617 | | { |
| | 12 | 618 | | skippedNodes[edge.Target] = true; |
| | 12 | 619 | | completionSources[edge.Target].TrySetResult(string.Empty); |
| | 111 | 620 | | } |
| | 111 | 621 | | } |
| | 111 | 622 | | } |
| | 111 | 623 | | |
| | 78 | 624 | | var totalTokens = diag.AggregateTokenUsage.TotalTokens; |
| | 78 | 625 | | progress?.Report(new ProgressEvents.AgentCompletedEvent( |
| | 78 | 626 | | DateTimeOffset.UtcNow, |
| | 78 | 627 | | progress.WorkflowId, |
| | 78 | 628 | | agentName, |
| | 78 | 629 | | null, |
| | 78 | 630 | | progress.Depth + 1, |
| | 78 | 631 | | progress.NextSequence(), |
| | 78 | 632 | | AgentName: agentName, |
| | 78 | 633 | | Duration: nodeElapsed, |
| | 78 | 634 | | TotalTokens: totalTokens, |
| | 78 | 635 | | InputTokens: diag.AggregateTokenUsage.InputTokens, |
| | 78 | 636 | | OutputTokens: diag.AggregateTokenUsage.OutputTokens)); |
| | 111 | 637 | | |
| | 78 | 638 | | completionSources[nodeType].TrySetResult(text); |
| | 78 | 639 | | } |
| | 1 | 640 | | catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) |
| | 111 | 641 | | { |
| | 111 | 642 | | // Graceful cancellation from a WaitAny scope — treat as skip. |
| | 1 | 643 | | skippedNodes[nodeType] = true; |
| | 1 | 644 | | completionSources[nodeType].TrySetResult(string.Empty); |
| | 1 | 645 | | } |
| | 16 | 646 | | catch (Exception ex) |
| | 111 | 647 | | { |
| | 16 | 648 | | var agentName = agents.TryGetValue(nodeType, out var a) |
| | 16 | 649 | | ? a.Name ?? nodeType.Name |
| | 16 | 650 | | : nodeType.Name; |
| | 111 | 651 | | |
| | 16 | 652 | | nodeExceptions[nodeType] = ex; |
| | 111 | 653 | | |
| | 16 | 654 | | progress?.Report(new ProgressEvents.AgentFailedEvent( |
| | 16 | 655 | | DateTimeOffset.UtcNow, |
| | 16 | 656 | | progress.WorkflowId, |
| | 16 | 657 | | agentName, |
| | 16 | 658 | | null, |
| | 16 | 659 | | progress.Depth + 1, |
| | 16 | 660 | | progress.NextSequence(), |
| | 16 | 661 | | AgentName: agentName, |
| | 16 | 662 | | ErrorMessage: ex.Message)); |
| | 111 | 663 | | |
| | 16 | 664 | | if (IsNodeRequiredByAllIncomingEdges(nodeType, topology)) |
| | 111 | 665 | | { |
| | 14 | 666 | | completionSources[nodeType].TrySetException(ex); |
| | 111 | 667 | | } |
| | 111 | 668 | | else |
| | 111 | 669 | | { |
| | 2 | 670 | | completionSources[nodeType].TrySetResult(string.Empty); |
| | 111 | 671 | | } |
| | 16 | 672 | | } |
| | 218 | 673 | | }, cancellationToken)); |
| | | 674 | | } |
| | | 675 | | |
| | 30 | 676 | | Exception? dagException = null; |
| | | 677 | | bool succeeded; |
| | 30 | 678 | | string? errorMessage = null; |
| | | 679 | | |
| | | 680 | | try |
| | | 681 | | { |
| | 30 | 682 | | await Task.WhenAll(nodeTasks).WaitAsync(cancellationToken); |
| | | 683 | | |
| | 29 | 684 | | var requiredFailures = nodeExceptions |
| | 16 | 685 | | .Where(kv => IsNodeRequiredByAllIncomingEdges(kv.Key, topology)) |
| | 29 | 686 | | .ToList(); |
| | | 687 | | |
| | 29 | 688 | | succeeded = requiredFailures.Count == 0; |
| | 29 | 689 | | if (!succeeded) |
| | | 690 | | { |
| | 6 | 691 | | var firstError = requiredFailures.First().Value; |
| | 6 | 692 | | errorMessage = firstError.Message; |
| | 20 | 693 | | dagException = new AggregateException(requiredFailures.Select(kv => kv.Value)); |
| | | 694 | | } |
| | 29 | 695 | | } |
| | 1 | 696 | | catch (Exception ex) |
| | | 697 | | { |
| | 1 | 698 | | succeeded = false; |
| | 1 | 699 | | errorMessage = ex.Message; |
| | 1 | 700 | | dagException = ex; |
| | 1 | 701 | | } |
| | | 702 | | finally |
| | | 703 | | { |
| | 98 | 704 | | foreach (var cts in waitAnyCtsMap.Values) |
| | | 705 | | { |
| | 19 | 706 | | cts.Dispose(); |
| | | 707 | | } |
| | | 708 | | } |
| | | 709 | | |
| | 30 | 710 | | var totalDuration = Stopwatch.GetElapsedTime(dagStart); |
| | | 711 | | |
| | 30 | 712 | | var nodeResultsDict = new Dictionary<string, IDagNodeResult>(); |
| | 30 | 713 | | var stagesList = new List<IAgentStageResult>(); |
| | | 714 | | |
| | 282 | 715 | | foreach (var type in topology.AllTypes) |
| | | 716 | | { |
| | 111 | 717 | | if (skippedNodes.ContainsKey(type)) |
| | | 718 | | continue; |
| | | 719 | | |
| | 98 | 720 | | var agentName = agents.TryGetValue(type, out var ag) |
| | 98 | 721 | | ? ag.Name ?? type.Name |
| | 98 | 722 | | : type.Name; |
| | 98 | 723 | | var (startOffsetVal, duration) = nodeTimings.GetValueOrDefault(type, (TimeSpan.Zero, TimeSpan.Zero)); |
| | 98 | 724 | | var diag = nodeDiagnostics.GetValueOrDefault(type); |
| | | 725 | | |
| | 98 | 726 | | ChatResponse? finalResponse = null; |
| | 98 | 727 | | if (completionSources[type].Task.IsCompletedSuccessfully) |
| | | 728 | | { |
| | 80 | 729 | | var text = completionSources[type].Task.Result; |
| | 80 | 730 | | if (!string.IsNullOrEmpty(text)) |
| | | 731 | | { |
| | 78 | 732 | | finalResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, text)); |
| | | 733 | | } |
| | | 734 | | } |
| | | 735 | | |
| | 98 | 736 | | var nodeResult = new DagNodeResult( |
| | 98 | 737 | | nodeId: type.FullName ?? type.Name, |
| | 98 | 738 | | agentName: agentName, |
| | 98 | 739 | | kind: NodeKind.Agent, |
| | 98 | 740 | | diagnostics: diag, |
| | 98 | 741 | | finalResponse: finalResponse, |
| | 98 | 742 | | inboundEdges: topology.InboundEdges.GetValueOrDefault(type, []) |
| | 88 | 743 | | .Select(t => t.FullName ?? t.Name).ToList(), |
| | 98 | 744 | | outboundEdges: topology.OutboundEdges.GetValueOrDefault(type, []) |
| | 94 | 745 | | .Select(t => t.FullName ?? t.Name).ToList(), |
| | 98 | 746 | | startOffset: startOffsetVal, |
| | 98 | 747 | | duration: duration); |
| | | 748 | | |
| | 98 | 749 | | nodeResultsDict[type.FullName ?? type.Name] = nodeResult; |
| | 98 | 750 | | stagesList.Add(new AgentStageResult(agentName, finalResponse, diag)); |
| | | 751 | | } |
| | | 752 | | |
| | 30 | 753 | | var branchResults = new Dictionary<string, IReadOnlyList<IAgentStageResult>>(); |
| | 30 | 754 | | var branchIndex = 0; |
| | 30 | 755 | | var nodesByInbound = topology.AllTypes |
| | 111 | 756 | | .Where(t => !skippedNodes.ContainsKey(t) && |
| | 111 | 757 | | topology.InboundEdges.ContainsKey(t) && |
| | 111 | 758 | | topology.InboundEdges[t].Count > 0) |
| | 68 | 759 | | .GroupBy(t => string.Join(",", |
| | 68 | 760 | | topology.InboundEdges[t] |
| | 88 | 761 | | .Select(dep => dep.FullName ?? dep.Name) |
| | 156 | 762 | | .OrderBy(n => n))) |
| | 83 | 763 | | .Where(g => g.Count() > 1); |
| | 90 | 764 | | foreach (var group in nodesByInbound) |
| | | 765 | | { |
| | 15 | 766 | | var groupStages = group |
| | 30 | 767 | | .Select(t => stagesList.FirstOrDefault(s => |
| | 105 | 768 | | s.AgentName == (agents.TryGetValue(t, out var a) |
| | 105 | 769 | | ? a.Name ?? t.Name |
| | 105 | 770 | | : t.Name))) |
| | 30 | 771 | | .Where(s => s is not null) |
| | 15 | 772 | | .Cast<IAgentStageResult>() |
| | 15 | 773 | | .ToList(); |
| | 15 | 774 | | if (groupStages.Count > 1) |
| | 15 | 775 | | branchResults[$"branch-{branchIndex++}"] = groupStages; |
| | | 776 | | } |
| | | 777 | | |
| | 30 | 778 | | progress?.Report(new ProgressEvents.WorkflowCompletedEvent( |
| | 30 | 779 | | DateTimeOffset.UtcNow, |
| | 30 | 780 | | progress.WorkflowId, |
| | 30 | 781 | | progress.AgentId, |
| | 30 | 782 | | null, |
| | 30 | 783 | | progress.Depth, |
| | 30 | 784 | | progress.NextSequence(), |
| | 30 | 785 | | Succeeded: succeeded, |
| | 30 | 786 | | ErrorMessage: errorMessage, |
| | 30 | 787 | | TotalDuration: totalDuration)); |
| | | 788 | | |
| | 30 | 789 | | return new DagRunResult( |
| | 30 | 790 | | stages: stagesList, |
| | 30 | 791 | | nodeResults: nodeResultsDict, |
| | 30 | 792 | | branchResults: branchResults, |
| | 30 | 793 | | totalDuration: totalDuration, |
| | 30 | 794 | | succeeded: succeeded, |
| | 30 | 795 | | errorMessage: errorMessage, |
| | 30 | 796 | | exception: dagException); |
| | 30 | 797 | | } |
| | | 798 | | |
| | | 799 | | private static bool IsNodeRequiredByAllIncomingEdges(Type nodeType, GraphTopology topology) |
| | | 800 | | { |
| | 32 | 801 | | var incomingDeps = topology.IncomingTypes.GetValueOrDefault(nodeType, []); |
| | 32 | 802 | | if (incomingDeps.Count == 0) |
| | 6 | 803 | | return true; |
| | | 804 | | |
| | 112 | 805 | | foreach (var dep in incomingDeps) |
| | | 806 | | { |
| | 32 | 807 | | if (topology.EdgeIsRequired.TryGetValue((dep, nodeType), out var isReq) && !isReq) |
| | 4 | 808 | | return false; |
| | | 809 | | } |
| | | 810 | | |
| | 22 | 811 | | return true; |
| | 4 | 812 | | } |
| | | 813 | | |
| | | 814 | | private static bool IsOptionalEdge(Type sourceType, Type targetType, GraphTopology topology) |
| | | 815 | | { |
| | 5 | 816 | | if (topology.EdgeIsRequired.TryGetValue((sourceType, targetType), out var isReq)) |
| | 5 | 817 | | return !isReq; |
| | 0 | 818 | | return false; |
| | | 819 | | } |
| | | 820 | | } |