Skip to content

Commit 9d0577a

Browse files
Support additional_instructions (#302)
React agent supports additional_instructions configuration option. Add test cases for custom system prompt and adding additional_instructions. Closes #301 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/AIQToolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - https://github.com/gfreeman-nvidia Approvers: - David Gardner (https://github.com/dagardner-nv) URL: #302
1 parent c5c1043 commit 9d0577a

File tree

4 files changed

+70
-31
lines changed

4 files changed

+70
-31
lines changed

src/aiq/agent/react_agent/agent.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from langchain_core.messages.base import BaseMessage
2727
from langchain_core.messages.human import HumanMessage
2828
from langchain_core.messages.tool import ToolMessage
29-
from langchain_core.prompts.chat import ChatPromptTemplate
29+
from langchain_core.prompts import ChatPromptTemplate
30+
from langchain_core.prompts import MessagesPlaceholder
3031
from langchain_core.runnables.config import RunnableConfig
3132
from langchain_core.tools import BaseTool
3233
from pydantic import BaseModel
@@ -42,6 +43,9 @@
4243
from aiq.agent.dual_node import DualNodeAgent
4344
from aiq.agent.react_agent.output_parser import ReActOutputParser
4445
from aiq.agent.react_agent.output_parser import ReActOutputParserException
46+
from aiq.agent.react_agent.prompt import SYSTEM_PROMPT
47+
from aiq.agent.react_agent.prompt import USER_PROMPT
48+
from aiq.agent.react_agent.register import ReActAgentWorkflowConfig
4549

4650
logger = logging.getLogger(__name__)
4751

@@ -320,3 +324,32 @@ def validate_system_prompt(system_prompt: str) -> bool:
320324
logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
321325
raise ValueError(error_text)
322326
return True
327+
328+
329+
def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
330+
"""
331+
Create a ReAct Agent prompt from the config.
332+
333+
Args:
334+
config (ReActAgentWorkflowConfig): The config to use for the prompt.
335+
336+
Returns:
337+
ChatPromptTemplate: The ReAct Agent prompt.
338+
"""
339+
# the ReAct Agent prompt can be customized via config option system_prompt and additional_instructions.
340+
341+
if config.system_prompt:
342+
prompt_str = config.system_prompt
343+
else:
344+
prompt_str = SYSTEM_PROMPT
345+
346+
if config.additional_instructions:
347+
prompt_str += f" {config.additional_instructions}"
348+
349+
valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
350+
if not valid_prompt:
351+
logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
352+
raise ValueError("Invalid system_prompt")
353+
prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
354+
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
355+
return prompt

src/aiq/agent/react_agent/prompt.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515

1616
# flake8: noqa
17-
from langchain_core.prompts.chat import ChatPromptTemplate
18-
from langchain_core.prompts.chat import MessagesPlaceholder
1917

2018
SYSTEM_PROMPT = """
2119
Answer the following questions as best you can. You may ask the human to use the following tools:
@@ -37,10 +35,7 @@
3735
Thought: I now know the final answer
3836
Final Answer: the final answer to the original input question
3937
"""
38+
4039
USER_PROMPT = """
4140
Question: {question}
4241
"""
43-
44-
# This is the prompt - (ReAct Agent prompt)
45-
react_agent_prompt = ChatPromptTemplate([("system", SYSTEM_PROMPT), ("user", USER_PROMPT),
46-
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])

src/aiq/agent/react_agent/register.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,29 +63,13 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
6363
async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
6464
from langchain.schema import BaseMessage
6565
from langchain_core.messages import trim_messages
66-
from langchain_core.prompts import ChatPromptTemplate
67-
from langchain_core.prompts import MessagesPlaceholder
6866
from langgraph.graph.graph import CompiledGraph
6967

70-
from aiq.agent.react_agent.prompt import USER_PROMPT
71-
72-
from .agent import ReActAgentGraph
73-
from .agent import ReActGraphState
74-
from .prompt import react_agent_prompt
75-
76-
# the ReAct Agent prompt comes from prompt.py, and can be customized there or via config option system_prompt.
77-
if config.system_prompt:
78-
_prompt_str = config.system_prompt
79-
if config.additional_instructions:
80-
_prompt_str += f" {config.additional_instructions}"
81-
valid_prompt = ReActAgentGraph.validate_system_prompt(config.system_prompt)
82-
if not valid_prompt:
83-
logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
84-
raise ValueError("Invalid system_prompt")
85-
prompt = ChatPromptTemplate([("system", config.system_prompt), ("user", USER_PROMPT),
86-
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
87-
else:
88-
prompt = react_agent_prompt
68+
from aiq.agent.react_agent.agent import ReActAgentGraph
69+
from aiq.agent.react_agent.agent import ReActGraphState
70+
from aiq.agent.react_agent.agent import create_react_agent_prompt
71+
72+
prompt = create_react_agent_prompt(config)
8973

9074
# we can choose an LLM for the ReAct agent in the config file
9175
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)

tests/aiq/agent/test_react.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from aiq.agent.react_agent.agent import TOOL_NOT_FOUND_ERROR_MESSAGE
2626
from aiq.agent.react_agent.agent import ReActAgentGraph
2727
from aiq.agent.react_agent.agent import ReActGraphState
28+
from aiq.agent.react_agent.agent import create_react_agent_prompt
2829
from aiq.agent.react_agent.output_parser import FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE
2930
from aiq.agent.react_agent.output_parser import MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE
3031
from aiq.agent.react_agent.output_parser import MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE
3132
from aiq.agent.react_agent.output_parser import ReActOutputParser
3233
from aiq.agent.react_agent.output_parser import ReActOutputParserException
33-
from aiq.agent.react_agent.prompt import react_agent_prompt
3434
from aiq.agent.react_agent.register import ReActAgentWorkflowConfig
3535

3636

@@ -59,7 +59,7 @@ def mock_config():
5959

6060
def test_react_init(mock_config_react_agent, mock_llm, mock_tool):
6161
tools = [mock_tool('Tool A'), mock_tool('Tool B')]
62-
prompt = react_agent_prompt
62+
prompt = create_react_agent_prompt(mock_config_react_agent)
6363
agent = ReActAgentGraph(llm=mock_llm, prompt=prompt, tools=tools, detailed_logs=mock_config_react_agent.verbose)
6464
assert isinstance(agent, ReActAgentGraph)
6565
assert agent.llm == mock_llm
@@ -72,7 +72,7 @@ def test_react_init(mock_config_react_agent, mock_llm, mock_tool):
7272
@pytest.fixture(name='mock_react_agent', scope="module")
7373
def mock_agent(mock_config_react_agent, mock_llm, mock_tool):
7474
tools = [mock_tool('Tool A'), mock_tool('Tool B')]
75-
prompt = react_agent_prompt
75+
prompt = create_react_agent_prompt(mock_config_react_agent)
7676
agent = ReActAgentGraph(llm=mock_llm, prompt=prompt, tools=tools, detailed_logs=mock_config_react_agent.verbose)
7777
return agent
7878

@@ -412,3 +412,30 @@ async def test_output_parser_missing_action_input(mock_react_output_parser):
412412
await mock_react_output_parser.aparse(mock_input)
413413
assert isinstance(ex.value, ReActOutputParserException)
414414
assert ex.value.observation == MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE
415+
416+
417+
def test_react_additional_instructions(mock_llm, mock_tool):
418+
config_react_agent = ReActAgentWorkflowConfig(tool_names=['test'],
419+
llm_name='test',
420+
verbose=True,
421+
retry_parsing_errors=False,
422+
additional_instructions="Talk like a parrot and repeat the question.")
423+
tools = [mock_tool('Tool A'), mock_tool('Tool B')]
424+
prompt = create_react_agent_prompt(config_react_agent)
425+
agent = ReActAgentGraph(llm=mock_llm, prompt=prompt, tools=tools, detailed_logs=config_react_agent.verbose)
426+
assert isinstance(agent, ReActAgentGraph)
427+
assert "Talk like a parrot" in agent.agent.get_prompts()[0].messages[0].prompt.template
428+
429+
430+
def test_react_custom_system_prompt(mock_llm, mock_tool):
431+
config_react_agent = ReActAgentWorkflowConfig(
432+
tool_names=['test'],
433+
llm_name='test',
434+
verbose=True,
435+
retry_parsing_errors=False,
436+
system_prompt="Refuse to run any of the following tools: {tools}. or ones named: {tool_names}")
437+
tools = [mock_tool('Tool A'), mock_tool('Tool B')]
438+
prompt = create_react_agent_prompt(config_react_agent)
439+
agent = ReActAgentGraph(llm=mock_llm, prompt=prompt, tools=tools, detailed_logs=config_react_agent.verbose)
440+
assert isinstance(agent, ReActAgentGraph)
441+
assert "Refuse" in agent.agent.get_prompts()[0].messages[0].prompt.template

0 commit comments

Comments
 (0)