Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions libs/langgraph/langgraph/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,16 @@ def create_agent(
)

def model_request(state: AgentState) -> AgentState:
request = state.model_request or ModelRequest(
request = state.get("model_request") or ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state.messages,
messages=state["messages"],
tool_choice=None,
)

# prepare messages
print(request.system_prompt)
if request.system_prompt:
messages = [SystemMessage(request.system_prompt)] + request.messages
else:
Expand All @@ -112,7 +111,7 @@ def model_request(state: AgentState) -> AgentState:
parallel_tool_calls=False,
)
output = model_.invoke(messages)
if state.response is not None:
if state.get("response") is not None:
return {"messages": output, "response": None}
else:
return {"messages": output}
Expand All @@ -125,7 +124,7 @@ def model_request(state: AgentState) -> AgentState:
graph.add_node(
f"{m.__class__.__name__}.before_model",
m.before_model,
input_schema=m.State,
input_schema=m.state_schema,
)
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:

Expand All @@ -137,27 +136,27 @@ def modify_model_request_node(state: AgentState) -> dict[str, ModelRequest]:
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state.messages,
messages=state["messages"],
tool_choice=None,
)

return {
"model_request": m.modify_model_request(
state.model_request or default_model_request, state
state.get("model_request") or default_model_request, state
)
}

graph.add_node(
f"{m.__class__.__name__}.modify_model_request",
modify_model_request_node,
input_schema=m.State,
input_schema=m.state_schema,
)

if m.__class__.after_model is not AgentMiddleware.after_model:
graph.add_node(
f"{m.__class__.__name__}.after_model",
m.after_model,
input_schema=m.State,
input_schema=m.state_schema,
)

# add start edge
Expand Down Expand Up @@ -258,9 +257,9 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:

def _make_model_to_tools_edge(first_node: str) -> Callable[[AgentState], str | None]:
def model_to_tools(state: AgentState) -> str | None:
if state.jump_to:
return _resolve_jump(state.jump_to, first_node)
message = state.messages[-1]
if jump_to := state.get("jump_to"):
return _resolve_jump(jump_to, first_node)
message = state["messages"][-1]
if isinstance(message, AIMessage) and message.tool_calls:
return "tools"

Expand All @@ -273,7 +272,7 @@ def _make_tools_to_model_edge(
tool_node: ToolNode, next_node: str
) -> Callable[[AgentState], str | None]:
def tools_to_model(state: AgentState) -> str | None:
ai_message = [m for m in state.messages if isinstance(m, AIMessage)][-1]
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
if all(
tool_node.tools_by_name[c["name"]].return_direct
for c in ai_message.tool_calls
Expand Down Expand Up @@ -302,7 +301,7 @@ def _add_middleware_edge(

def jump_edge(state: AgentState) -> str:
return (
_resolve_jump(state.jump_to, model_destination) or default_destination
_resolve_jump(state.get("jump_to"), model_destination) or default_destination
)

destinations = [default_destination, END, "tools"]
Expand Down
8 changes: 0 additions & 8 deletions libs/langgraph/langgraph/agent/middleware/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ class ToolCallLimitMiddleware(AgentMiddleware):
class State(AgentMiddleware.State):
important: Annotated[dict[str, int], Input, Output] = field(default_factory=dict)

@dataclass
class InputState(AgentMiddleware.State):
important: dict[str, int]

@dataclass
class OutputState(AgentMiddleware.State):
important: dict[str, int]

def __init__(self, tool_limits: dict[str, int]):
self.tool_limits = tool_limits

Expand Down
30 changes: 16 additions & 14 deletions libs/langgraph/langgraph/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, TypeVar, Generic, ClassVar

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AnyMessage
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from typing_extensions import TypedDict
from typing_extensions import TypedDict, Required

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import Messages, add_messages
Expand All @@ -28,27 +28,29 @@ class ModelRequest:
response_format: ResponseFormat | None


@dataclass
class AgentState:
class AgentState(TypedDict, total=False):
# TODO: figure out Required/NotRequired wrapping annotated and still registering reducer properly
messages: Annotated[list[AnyMessage], add_messages]
model_request: Annotated[ModelRequest | None, EphemeralValue] = None
jump_to: Annotated[JumpTo | None, EphemeralValue] = None
response: dict | None = None
model_request: Annotated[ModelRequest | None, EphemeralValue]
jump_to: Annotated[JumpTo | None, EphemeralValue]
response: dict

StateT = TypeVar("StateT", bound=AgentState, default=AgentState, contravariant=True)

class AgentMiddleware:
class State(AgentState):
pass
class AgentMiddleware(Generic[StateT]):

tools: list[BaseTool]
# TODO: I thought this should be a ClassVar[type[StateT]] but inherently class vars can't use type vars
# bc they're instance dependent
state_schema: type[StateT]
tools: list[BaseTool] = []

def before_model(self, state: State) -> AgentUpdate | AgentJump | None:
def before_model(self, state: StateT) -> AgentUpdate | AgentJump | None:
pass

def modify_model_request(self, request: ModelRequest, state: State) -> ModelRequest:
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest:
return request

def after_model(self, state: State) -> AgentUpdate | AgentJump | None:
def after_model(self, state: StateT) -> AgentUpdate | AgentJump | None:
pass


Expand Down
56 changes: 56 additions & 0 deletions test_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass, field
from typing import Annotated, Any, Dict, List, cast

from langchain_core.messages import AIMessage, HumanMessage
from typing_extensions import Annotated
from typing import ClassVar
from operator import add

from langgraph.agent import create_agent
from langgraph.agent.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate, ModelRequest

class AcceptInput:
...

class ExposeOutput:
...

# other ideas
# state_extensions: ClassVar[dict[str, type | type[Annotated]]] = {
# "int1": Annotated[int, add],
# "int2": Annotated[int, add, AcceptInput],
# "int3": Annotated[int, add, ExposeOutput],
# "int4": Annotated[int, add, AcceptInput, ExposeOutput],
# }

class State(AgentState):
int1: Annotated[int, add]
int2: Annotated[int, AcceptInput, add]
int3: Annotated[int, ExposeOutput, add]
int4: Annotated[int, AcceptInput, ExposeOutput, add]

class StateModMidleware(AgentMiddleware[State]):
"""Terminates after a specific tool is called N times."""

state_schema: type[State] = State

def __init__(self):
pass

def before_model(self, state: State) -> AgentUpdate | AgentJump | None:
return {"int1": 1, "int2": 1, "int3": 1, "int4": 1}

def modify_model_request(self, request: ModelRequest, state: State) -> ModelRequest:
return request

agent = create_agent(
model="gpt-4o",
tools=[],
system_prompt="You are a helpful assistant.",
# TODO: figure out invariance here
middleware=[StateModMidleware()],
)
agent = agent.compile()

result = agent.invoke({"messages": [HumanMessage("What is 2+2?")]})
print(result)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will want to see all state updates above

Loading