Skip to content

feat: add AI module for LLM interaction and a heuristic for checking code–docstring consistency #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ dependencies = [
"problog >= 2.2.6,<3.0.0",
"cryptography >=44.0.0,<45.0.0",
"semgrep == 1.113.0",
"pydantic >= 2.11.5,<2.12.0",
"gradio_client == 1.4.3",
]
keywords = []
# https://pypi.org/classifiers/
Expand Down
50 changes: 50 additions & 0 deletions src/macaron/ai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Macaron AI Module

This module provides the foundation for interacting with Large Language Models (LLMs) in a provider-agnostic way. It includes an abstract client definition, provider-specific client implementations, a client factory, and utility functions for processing responses.

## Module Components

- **ai_client.py**
Defines the abstract [`AIClient`](./clients/base.py) class. This class handles the initialization of LLM configuration from the defaults and serves as the base for all specific AI client implementations.

- **openai_client.py**
Implements the [`OpenAiClient`](./clients/openai_client.py) class, a concrete subclass of [`AIClient`](./ai_client.py). This client interacts with OpenAI-like APIs by sending requests using HTTP and processing the responses. It also validates and structures responses using the tools provided.

- **ai_factory.py**
Contains the [`AIClientFactory`](./clients/base.py) class, which is responsible for reading provider configuration from the defaults and creating the correct AI client instance.

- **ai_tools.py**
Offers utility functions such as `structure_response` to assist with parsing and validating the JSON response returned by an LLM. These functions ensure that responses conform to a given Pydantic model for easier downstream processing.

## Usage

1. **Configuration:**
The module reads the LLM configuration from the application defaults (using the `defaults` module). Make sure that the `llm` section in your configuration includes valid settings such as `enabled`, `api_key`, `api_endpoint`, `model`, and `context_window`.

2. **Creating a Client:**
Use the [`AIClientFactory`](./clients/ai_factory.py) to create an AI client instance. The factory checks the configured provider and returns a client (e.g., an instance of [`OpenAiClient`](./clients/openai_client.py)) that can be used to invoke the LLM.

Example:
```py
from macaron.ai.clients.ai_factory import AIClientFactory

factory = AIClientFactory()
client = factory.create_client(system_prompt="You are a helpful assistant.")
response = client.invoke("Hello, how can you assist me?")
print(response)
```

3. **Response Processing:**
When a structured response is required, pass a Pydantic model class to the `invoke` method. The [`ai_tools.py`](./ai_tools.py) module takes care of parsing and validating the response to ensure it meets the expected structure.

## Logging and Error Handling

- The module uses Python's logging framework to report important events, such as token usage and warnings when prompts exceed the allowed context window.
- Configuration errors (e.g., missing API key or endpoint) are handled by raising descriptive exceptions, such as those defined in the [`ConfigurationError`](../errors.py).

## Extensibility

The design of the AI module is provider-agnostic. To add support for additional LLM providers:
- Implement a new client by subclassing [`AIClient`](./clients/base.py).
- Add the new client to the [`PROVIDER_MAPPING`](./clients/ai_factory.py).
- Update the configuration defaults accordingly.
2 changes: 2 additions & 0 deletions src/macaron/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
43 changes: 43 additions & 0 deletions src/macaron/ai/ai_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module provides utility functions for Large Language Model (LLM)."""
import json
import logging
import re
from typing import Any

logger: logging.Logger = logging.getLogger(__name__)


def extract_json(response_text: str) -> Any:
"""
Parse the response from the LLM.

If raw JSON parsing fails, attempts to extract a JSON object from text.

Parameters
----------
response_text: str
The response text from the LLM.

Returns
-------
dict[str, Any] | None
The structured JSON object.
"""
try:
data = json.loads(response_text)
except json.JSONDecodeError:
logger.debug("Full JSON parse failed; trying to extract JSON from text.")
# If the response is not a valid JSON, try to extract a JSON object from the text.
match = re.search(r"\{.*\}", response_text, re.DOTALL)
if not match:
return None
try:
data = json.loads(match.group(0))
except json.JSONDecodeError as e:
logger.debug("Failed to parse extracted JSON: %s", e)
return None

return data
9 changes: 9 additions & 0 deletions src/macaron/ai/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module provides a mapping of AI client providers to their respective client classes."""

from macaron.ai.clients.base import AIClient
from macaron.ai.clients.openai_client import OpenAiClient

PROVIDER_MAPPING: dict[str, type[AIClient]] = {"openai": OpenAiClient}
62 changes: 62 additions & 0 deletions src/macaron/ai/clients/ai_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module defines the AIClientFactory class for creating AI clients based on provider configuration."""

import logging

from macaron.ai.clients import PROVIDER_MAPPING
from macaron.ai.clients.base import AIClient
from macaron.config.defaults import defaults
from macaron.errors import ConfigurationError

logger: logging.Logger = logging.getLogger(__name__)


class AIClientFactory:
"""Factory to create AI clients based on provider configuration."""

def __init__(self) -> None:
"""
Initialize the AI client.

The LLM configuration is read from defaults.
"""
self.params = self._load_defaults()

def _load_defaults(self) -> dict | None:
section_name = "llm"
default_values = {
"enabled": False,
"provider": "",
"api_key": "",
"api_endpoint": "",
"model": "",
}

if defaults.has_section(section_name):
section = defaults[section_name]
default_values["enabled"] = section.getboolean("enabled", default_values["enabled"])
for key, default_value in default_values.items():
if isinstance(default_value, str):
default_values[key] = str(section.get(key, default_value)).strip().lower()

if default_values["enabled"]:
for key, value in default_values.items():
if not value:
raise ConfigurationError(
f"AI client configuration '{key}' is required but not set in the defaults."
)

return default_values

def create_client(self, system_prompt: str) -> AIClient | None:
"""Create an AI client based on the configured provider."""
if not self.params or not self.params["enabled"]:
return None

client_class = PROVIDER_MAPPING.get(self.params["provider"])
if client_class is None:
logger.error("Provider '%s' is not supported.", self.params["provider"])
return None
return client_class(system_prompt, self.params)
45 changes: 45 additions & 0 deletions src/macaron/ai/clients/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module defines the abstract AIClient class for implementing AI clients."""

from abc import ABC, abstractmethod


class AIClient(ABC):
"""This abstract class is used to implement ai clients."""

def __init__(self, system_prompt: str, params: dict) -> None:
"""
Initialize the AI client.

The LLM configuration is read from defaults.
"""
self.system_prompt = system_prompt
self.params = params

@abstractmethod
def invoke(
self,
user_prompt: str,
temperature: float = 0.2,
response_format: dict | None = None,
) -> dict:
"""
Invoke the LLM and optionally validate its response.

Parameters
----------
user_prompt: str
The user prompt to send to the LLM.
temperature: float
The temperature for the LLM response.
response_format: dict | None
The json schema to validate the response against.

Returns
-------
dict
The validated schema if `response_format` is provided,
or the raw string response if not.
"""
89 changes: 89 additions & 0 deletions src/macaron/ai/clients/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module provides a client for interacting with a Large Language Model (LLM) that is Openai like."""

import logging
from typing import Any, TypeVar

from pydantic import BaseModel

from macaron.ai.ai_tools import extract_json
from macaron.ai.clients.base import AIClient
from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError
from macaron.util import send_post_http_raw

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class OpenAiClient(AIClient):
"""A client for interacting with a Large Language Model that is OpenAI API like."""

def invoke(
self,
user_prompt: str,
temperature: float = 0.2,
response_format: dict | None = None,
max_tokens: int = 4000,
timeout: int = 30,
) -> Any:
"""
Invoke the LLM and optionally validate its response.

Parameters
----------
user_prompt: str
The user prompt to send to the LLM.
temperature: float
The temperature for the LLM response.
response_format: dict
The json schema to validate the response against. If provided, the response will be parsed and validated.
max_tokens: int
The maximum number of tokens for the LLM response.
timeout: int
The timeout for the HTTP request in seconds.

Returns
-------
Optional[T | str]
The validated Pydantic model instance if `structured_output` is provided,
or the raw string response if not.

Raises
------
HeuristicAnalyzerValueError
If there is an error in parsing or validating the response.
"""
if not self.params["enabled"]:
raise ConfigurationError("AI client is not enabled. Please check your configuration.")

headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.params['api_key']}"}
payload = {
"model": self.params["model"],
"messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}],
"response_format": response_format,
"temperature": temperature,
"max_tokens": max_tokens,
}

try:
response = send_post_http_raw(
url=self.params["api_endpoint"], json_data=payload, headers=headers, timeout=timeout
)
if not response:
raise HeuristicAnalyzerValueError("No response received from the LLM.")
response_json = response.json()
usage = response_json.get("usage", {})

if usage:
usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items())
logger.info("LLM call token usage: %s", usage_str)

message_content = response_json["choices"][0]["message"]["content"]
return extract_json(message_content)

except Exception as e:
logger.error("Error during LLM invocation: %s", e)
raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e
2 changes: 2 additions & 0 deletions src/macaron/ai/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
2 changes: 2 additions & 0 deletions src/macaron/ai/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
15 changes: 15 additions & 0 deletions src/macaron/config/defaults.ini
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,18 @@ custom_semgrep_rules_path =
# .yaml prefix. Note, this will be ignored if a path to custom semgrep rules is not provided. This list may not contain
# duplicated elements, meaning that ruleset names must be unique.
disabled_custom_rulesets =

[llm]
# The LLM configuration for Macaron.
# If enabled, the LLM will be used to analyze the results and provide insights.
enabled = False
# The provider for the LLM service.
# Supported providers :
# - openai: OpenAI's GPT models.
provider =
# The API key for the LLM service.
api_key =
# The API endpoint for the LLM service.
api_endpoint =
# The model to use for the LLM service.
model =
6 changes: 6 additions & 0 deletions src/macaron/malware_analyzer/pypi_heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class Heuristics(str, Enum):
#: Indicates that the package source code contains suspicious code patterns.
SUSPICIOUS_PATTERNS = "suspicious_patterns"

#: Indicates that the package contains some code that doesn't match the docstrings.
MATCHING_DOCSTRINGS = "matching_docstrings"

#: Indicates that the package description is inconsistent.
INCONSISTENT_DESCRIPTION = "inconsistent_description"


class HeuristicResult(str, Enum):
"""Result type indicating the outcome of a heuristic."""
Expand Down
Loading