Skip to content

feat: #864 support streaming nested tool events in Agent.as_tool #1057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if TYPE_CHECKING:
from .lifecycle import AgentHooks
from .mcp import MCPServer
from .result import RunResult
from .result import RunResult, RunResultStreaming


@dataclass
Expand Down Expand Up @@ -356,9 +356,11 @@ def as_tool(
self,
tool_name: str | None,
tool_description: str | None,
*,
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
stream_inner_events: bool = False,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand Down Expand Up @@ -387,17 +389,36 @@ def as_tool(
async def run_agent(context: RunContextWrapper, input: str) -> str:
from .run import Runner

output = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
)
output_run: RunResult | RunResultStreaming
if stream_inner_events:
from .stream_events import RunItemStreamEvent

sub_run = Runner.run_streamed(
self,
input=input,
context=context.context,
)
parent_queue = getattr(context, "_event_queue", None)
async for ev in sub_run.stream_events():
if parent_queue is not None and isinstance(ev, RunItemStreamEvent):
if ev.name in ("tool_called", "tool_output"):
parent_queue.put_nowait(ev)
output_run = sub_run
else:
output_run = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
)

if custom_output_extractor:
return await custom_output_extractor(output)
return await custom_output_extractor(cast(Any, output_run))

return ItemHelpers.text_message_outputs(output.new_items)
return ItemHelpers.text_message_outputs(output_run.new_items)

return run_agent
tool = run_agent
tool.stream_inner_events = stream_inner_events # type: ignore[attr-defined]
return tool

async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
if isinstance(self.instructions, str):
Expand Down
17 changes: 14 additions & 3 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from dataclasses import dataclass, field, fields
from typing import Any, Optional
from typing import Any, Optional, Union

from openai.types.responses import ResponseFunctionToolCall

Expand All @@ -21,14 +22,24 @@ class ToolContext(RunContextWrapper[TContext]):
tool_name: str = field(default_factory=_assert_must_pass_tool_name)
"""The name of the tool being invoked."""

tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
tool_call_id: Union[str, int] = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

_event_queue: Optional[asyncio.Queue[Any]] = field(default=None, init=False, repr=False)

@property
def event_queue(self) -> Optional[asyncio.Queue[Any]]:
return self._event_queue

@event_queue.setter
def event_queue(self, queue: Optional[asyncio.Queue[Any]]) -> None:
self._event_queue = queue

@classmethod
def from_agent_context(
cls,
context: RunContextWrapper[TContext],
tool_call_id: str,
tool_call_id: Union[str, int],
tool_call: Optional[ResponseFunctionToolCall] = None,
) -> "ToolContext":
"""
Expand Down
Loading