Skip to content

Commit 315b495

Browse files
authored
Refactor for better prompt and tool description organization (NVIDIA#350)
This PR moves the default values of tool descriptions into prompt.py and the default prompts into the config objects, so all "prompt-engineerable" strings are tracked in a centralized file and also by the config for easier prompt engineering and version tracking. Authors: - https://github.com/hsin-c Approvers: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: NVIDIA#350
1 parent 50d631b commit 315b495

11 files changed

+109
-78
lines changed

examples/alert_triage_agent/src/aiq_alert_triage_agent/categorizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
from aiq.data_models.function import FunctionBaseConfig
2929

3030
from . import utils
31-
from .prompts import PipelineNodePrompts
31+
from .prompts import CategorizerPrompts
3232

3333

3434
class CategorizerToolConfig(FunctionBaseConfig, name="categorizer"):
35-
description: str = Field(default="This is a categorization tool used at the end of the pipeline.",
36-
description="Description of the tool.")
35+
description: str = Field(default=CategorizerPrompts.TOOL_DESCRIPTION, description="Description of the tool.")
3736
llm_name: LLMRef
37+
prompt: str = Field(default=CategorizerPrompts.PROMPT, description="Main prompt for the categorization task.")
3838

3939

4040
def _extract_markdown_heading_level(report: str) -> str:
@@ -48,8 +48,7 @@ def _extract_markdown_heading_level(report: str) -> str:
4848
async def categorizer_tool(config: CategorizerToolConfig, builder: Builder):
4949
# Set up LLM and chain
5050
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
51-
prompt_template = ChatPromptTemplate([("system", PipelineNodePrompts.CATEGORIZER_PROMPT),
52-
MessagesPlaceholder("msgs")])
51+
prompt_template = ChatPromptTemplate([("system", config.prompt), MessagesPlaceholder("msgs")])
5352
categorization_chain = prompt_template | llm
5453

5554
async def _arun(report: str) -> str:

examples/alert_triage_agent/src/aiq_alert_triage_agent/hardware_check_tool.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@
2424
from aiq.data_models.function import FunctionBaseConfig
2525

2626
from . import utils
27-
from .prompts import ToolReasoningLayerPrompts
27+
from .prompts import HardwareCheckPrompts
2828

2929

3030
class HardwareCheckToolConfig(FunctionBaseConfig, name="hardware_check"):
31-
description: str = Field(
32-
default=("This tool checks hardware health status using IPMI monitoring to detect power state, "
33-
"hardware degradation, and anomalies that could explain alerts. Args: host_id: str"),
34-
description="Description of the tool for the agent.")
31+
description: str = Field(default=HardwareCheckPrompts.TOOL_DESCRIPTION, description="Description of the tool.")
3532
llm_name: LLMRef
33+
prompt: str = Field(default=HardwareCheckPrompts.PROMPT, description="Main prompt for the hardware check task.")
3634
offline_mode: bool = Field(default=True, description="Whether to run in offline model")
3735

3836

@@ -94,7 +92,7 @@ async def _arun(host_id: str) -> str:
9492
# Additional LLM reasoning layer on playbook output to provide a summary of the results
9593
utils.log_header("LLM Reasoning", dash_length=50)
9694

97-
prompt = ToolReasoningLayerPrompts.HARDWARE_CHECK.format(input_data=monitoring_data)
95+
prompt = config.prompt.format(input_data=monitoring_data)
9896

9997
# Get analysis from LLM
10098
conclusion = await utils.llm_ainvoke(config, builder, prompt)

examples/alert_triage_agent/src/aiq_alert_triage_agent/host_performance_check_tool.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@
2323

2424
from . import utils
2525
from .playbooks import HOST_PERFORMANCE_CHECK_PLAYBOOK
26-
from .prompts import ToolReasoningLayerPrompts
26+
from .prompts import HostPerformanceCheckPrompts
2727

2828

2929
class HostPerformanceCheckToolConfig(FunctionBaseConfig, name="host_performance_check"):
30-
description: str = Field(
31-
default=("This is the Host Performance Check Tool. This tool retrieves CPU usage, memory usage, "
32-
"and hardware I/O usage details for a given host. Args: host_id: str"),
33-
description="Description of the tool for the agent.")
30+
description: str = Field(default=HostPerformanceCheckPrompts.TOOL_DESCRIPTION,
31+
description="Description of the tool.")
3432
llm_name: LLMRef
33+
parsing_prompt: str = Field(default=HostPerformanceCheckPrompts.PARSING_PROMPT,
34+
description="Prompt for parsing the raw host performance data.")
35+
analysis_prompt: str = Field(default=HostPerformanceCheckPrompts.ANALYSIS_PROMPT,
36+
description="Prompt for analyzing the parsed host performance data.")
3537
offline_mode: bool = Field(default=True, description="Whether to run in offline model")
3638

3739

@@ -97,7 +99,7 @@ async def _parse_stdout_lines(config, builder, stdout_lines):
9799
# Join the list of lines into a single text block
98100
input_data = "\n".join(stdout_lines) if stdout_lines else ""
99101

100-
prompt = ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_PARSING.format(input_data=input_data)
102+
prompt = config.parsing_prompt.format(input_data=input_data)
101103

102104
response = await utils.llm_ainvoke(config=config, builder=builder, user_prompt=prompt)
103105
except Exception as e:
@@ -146,7 +148,7 @@ async def _arun(host_id: str) -> str:
146148
# Additional LLM reasoning layer on playbook output to provide a summary of the results
147149
utils.log_header("LLM Reasoning", dash_length=50)
148150

149-
prompt_template = ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_ANALYSIS.format(input_data=output)
151+
prompt_template = config.analysis_prompt.format(input_data=output)
150152

151153
conclusion = await utils.llm_ainvoke(config, builder, user_prompt=prompt_template)
152154

examples/alert_triage_agent/src/aiq_alert_triage_agent/maintenance_check.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,16 @@
3131
from aiq.data_models.function import FunctionBaseConfig
3232

3333
from . import utils
34-
from .prompts import PipelineNodePrompts
34+
from .prompts import MaintenanceCheckPrompts
3535

3636
NO_ONGOING_MAINTENANCE_STR = "No ongoing maintenance found for the host."
3737

3838

3939
class MaintenanceCheckToolConfig(FunctionBaseConfig, name="maintenance_check"):
40-
description: str = Field(
41-
default=("Check if a host is under maintenance during the time of an alert to help determine "
42-
"if the alert can be deprioritized."),
43-
description="Description of the tool for the agent.")
40+
description: str = Field(default=MaintenanceCheckPrompts.TOOL_DESCRIPTION, description="Description of the tool.")
4441
llm_name: LLMRef
42+
prompt: str = Field(default=MaintenanceCheckPrompts.PROMPT,
43+
description="Main prompt for the maintenance check task.")
4544
static_data_path: str | None = Field(
4645
default="examples/alert_triage_agent/data/maintenance_static_dataset.csv",
4746
description=(
@@ -167,12 +166,13 @@ def _get_active_maintenance(df: pd.DataFrame, host_id: str, alert_time: datetime
167166
return start_time_str, end_time_str
168167

169168

170-
def _summarize_alert(llm, alert, maintenance_start_str, maintenance_end_str):
169+
def _summarize_alert(llm, prompt_template, alert, maintenance_start_str, maintenance_end_str):
171170
"""
172171
Generate a summary report for an alert when the affected host is under maintenance.
173172
174173
Args:
175174
llm: The language model to use for generating the summary
175+
prompt_template: The prompt template to use for generating the summary
176176
alert (dict): Dictionary containing the alert details
177177
maintenance_start_str (str): Start time of maintenance window in "YYYY-MM-DD HH:MM:SS" format
178178
maintenance_end_str (str): End time of maintenance window in "YYYY-MM-DD HH:MM:SS" format,
@@ -181,8 +181,8 @@ def _summarize_alert(llm, alert, maintenance_start_str, maintenance_end_str):
181181
Returns:
182182
str: A markdown-formatted report summarizing the alert and maintenance status
183183
"""
184-
sys_prompt = PipelineNodePrompts.MAINTENANCE_CHECK_PROMPT.format(maintenance_start_str=maintenance_start_str,
185-
maintenance_end_str=maintenance_end_str)
184+
sys_prompt = prompt_template.format(maintenance_start_str=maintenance_start_str,
185+
maintenance_end_str=maintenance_end_str)
186186
prompt_template = ChatPromptTemplate([("system", sys_prompt), MessagesPlaceholder("msgs")])
187187
summarization_chain = prompt_template | llm
188188
alert_json_str = json.dumps(alert)
@@ -249,7 +249,11 @@ async def _arun(input_message: str) -> str:
249249
# maintenance info found, summarize alert and return a report (agent execution will be skipped)
250250
utils.logger.info("Host: [%s] is under maintenance according to the maintenance database", host)
251251

252-
report = _summarize_alert(llm, alert, maintenance_start_str, maintenance_end_str)
252+
report = _summarize_alert(llm=llm,
253+
prompt_template=config.prompt,
254+
alert=alert,
255+
maintenance_start_str=maintenance_start_str,
256+
maintenance_end_str=maintenance_end_str)
253257

254258
utils.log_footer()
255259
return report

examples/alert_triage_agent/src/aiq_alert_triage_agent/monitoring_process_check_tool.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323

2424
from . import utils
2525
from .playbooks import MONITOR_PROCESS_CHECK_PLAYBOOK
26-
from .prompts import ToolReasoningLayerPrompts
26+
from .prompts import MonitoringProcessCheckPrompts
2727

2828

2929
class MonitoringProcessCheckToolConfig(FunctionBaseConfig, name="monitoring_process_check"):
30-
description: str = Field(default=("This tool checks the status of critical monitoring processes and services "
31-
"on a target host by executing system commands. Args: host_id: str"),
32-
description="Description of the tool for the agent.")
30+
description: str = Field(default=MonitoringProcessCheckPrompts.TOOL_DESCRIPTION,
31+
description="Description of the tool.")
3332
llm_name: LLMRef
33+
prompt: str = Field(default=MonitoringProcessCheckPrompts.PROMPT,
34+
description="Main prompt for the monitoring process check task.")
3435
offline_mode: bool = Field(default=True, description="Whether to run in offline model")
3536

3637

@@ -104,7 +105,7 @@ async def _arun(host_id: str) -> str:
104105
# Additional LLM reasoning layer on playbook output to provide a summary of the results
105106
utils.log_header("LLM Reasoning", dash_length=50)
106107

107-
prompt = ToolReasoningLayerPrompts.MONITORING_PROCESS_CHECK.format(input_data=output_for_prompt)
108+
prompt = config.prompt.format(input_data=output_for_prompt)
108109

109110
conclusion = await utils.llm_ainvoke(config, builder, prompt)
110111

examples/alert_triage_agent/src/aiq_alert_triage_agent/network_connectivity_check_tool.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@
2525
from aiq.data_models.function import FunctionBaseConfig
2626

2727
from . import utils
28-
from .prompts import ToolReasoningLayerPrompts
28+
from .prompts import NetworkConnectivityCheckPrompts
2929

3030

3131
class NetworkConnectivityCheckToolConfig(FunctionBaseConfig, name="network_connectivity_check"):
32-
description: str = Field(
33-
default=("This tool checks network connectivity of a host by running ping and socket connection tests. "
34-
"Args: host_id: str"),
35-
description="Description of the tool for the agent.")
32+
description: str = Field(default=NetworkConnectivityCheckPrompts.TOOL_DESCRIPTION,
33+
description="Description of the tool.")
3634
llm_name: LLMRef
35+
prompt: str = Field(default=NetworkConnectivityCheckPrompts.PROMPT,
36+
description="Main prompt for the network connectivity check task.")
3737
offline_mode: bool = Field(default=True, description="Whether to run in offline model")
3838

3939

@@ -106,8 +106,7 @@ async def _arun(host_id: str) -> str:
106106
# Additional LLM reasoning layer on playbook output to provide a summary of the results
107107
utils.log_header("LLM Reasoning", dash_length=50)
108108

109-
prompt = ToolReasoningLayerPrompts.NETWORK_CONNECTIVITY_CHECK.format(ping_data=ping_data,
110-
telnet_data=telnet_data)
109+
prompt = config.prompt.format(ping_data=ping_data, telnet_data=telnet_data)
111110
conclusion = await utils.llm_ainvoke(config, builder, prompt)
112111

113112
utils.logger.debug(conclusion)

0 commit comments

Comments
 (0)