Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit f61f357

Browse files
Added a class which performs semantic routing (#1192)
* Added a class which performs semantic routing Related to: #1055 For the current implementation of muxing we only need to match a single Persona at a time. For example: 1. mux1 -> persona Architect -> openai o1 2. mux2 -> catch all -> openai gpt4o In the above case we would only need to know if the request matches the persona `Architect`. It's not needed to match any extra personas even if they exist in DB. This PR introduces what's necessary to do the above without actually wiring in muxing rules. The PR: - Creates the persona table in DB - Adds methods to write and read to the new persona table - Implements a function to check if a query matches to the specified persona To check more about the personas and the queries please check the unit tests * Attended PR comments
1 parent 7e3b19a commit f61f357

File tree

6 files changed

+940
-2
lines changed

6 files changed

+940
-2
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""add persona table
2+
3+
Revision ID: 02b710eda156
4+
Revises: 5e5cd2288147
5+
Create Date: 2025-03-03 10:08:16.206617+00:00
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision: str = "02b710eda156"
15+
down_revision: Union[str, None] = "5e5cd2288147"
16+
branch_labels: Union[str, Sequence[str], None] = None
17+
depends_on: Union[str, Sequence[str], None] = None
18+
19+
20+
def upgrade() -> None:
21+
# Begin transaction
22+
op.execute("BEGIN TRANSACTION;")
23+
24+
op.execute(
25+
"""
26+
CREATE TABLE IF NOT EXISTS personas (
27+
id TEXT PRIMARY KEY, -- UUID stored as TEXT
28+
name TEXT NOT NULL UNIQUE,
29+
description TEXT NOT NULL,
30+
description_embedding BLOB NOT NULL
31+
);
32+
"""
33+
)
34+
35+
# Finish transaction
36+
op.execute("COMMIT;")
37+
38+
39+
def downgrade() -> None:
40+
# Begin transaction
41+
op.execute("BEGIN TRANSACTION;")
42+
43+
op.execute(
44+
"""
45+
DROP TABLE personas;
46+
"""
47+
)
48+
49+
# Finish transaction
50+
op.execute("COMMIT;")

src/codegate/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class Config:
5757
force_certs: bool = False
5858

5959
max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
60+
# Min value is 0 (max similarity), max value is 2 (orthogonal)
61+
# The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py
62+
persona_threshold = 0.75
6063

6164
# Provider URLs with defaults
6265
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())

src/codegate/db/connection.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import asyncio
22
import json
3+
import sqlite3
34
import uuid
45
from pathlib import Path
56
from typing import Dict, List, Optional, Type
67

8+
import numpy as np
9+
import sqlite_vec_sl_tmp
710
import structlog
811
from alembic import command as alembic_command
912
from alembic.config import Config as AlembicConfig
@@ -22,6 +25,9 @@
2225
IntermediatePromptWithOutputUsageAlerts,
2326
MuxRule,
2427
Output,
28+
Persona,
29+
PersonaDistance,
30+
PersonaEmbedding,
2531
Prompt,
2632
ProviderAuthMaterial,
2733
ProviderEndpoint,
@@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs):
6571
# It should only be used for testing
6672
if "_no_singleton" in kwargs and kwargs["_no_singleton"]:
6773
kwargs.pop("_no_singleton")
68-
return super().__new__(cls, *args, **kwargs)
74+
return super().__new__(cls)
6975

7076
if cls._instance is None:
7177
cls._instance = super().__new__(cls)
@@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs):
9298
}
9399
self._async_db_engine = create_async_engine(**engine_dict)
94100

101+
def _get_vec_db_connection(self):
102+
"""
103+
Vector database connection is a separate connection to the SQLite database. aiosqlite
104+
does not support loading extensions, so we need to use the sqlite3 module to load the
105+
vector extension.
106+
"""
107+
try:
108+
conn = sqlite3.connect(self._db_path)
109+
conn.enable_load_extension(True)
110+
sqlite_vec_sl_tmp.load(conn)
111+
conn.enable_load_extension(False)
112+
return conn
113+
except Exception:
114+
logger.exception("Failed to initialize vector database connection")
115+
raise
116+
95117
def does_db_exist(self):
96118
return self._db_path.is_file()
97119

@@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:
523545
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
524546
return added_mux
525547

548+
async def add_persona(self, persona: PersonaEmbedding) -> None:
549+
"""Add a new Persona to the DB.
550+
551+
This handles validation and insertion of a new persona.
552+
553+
It may raise a AlreadyExistsError if the persona already exists.
554+
"""
555+
sql = text(
556+
"""
557+
INSERT INTO personas (id, name, description, description_embedding)
558+
VALUES (:id, :name, :description, :description_embedding)
559+
"""
560+
)
561+
562+
try:
563+
# For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
564+
# We need to convert it back to a numpy array before inserting it into the DB.
565+
persona_dict = persona.model_dump()
566+
persona_dict["description_embedding"] = persona.description_embedding
567+
await self._execute_with_no_return(sql, persona_dict)
568+
except IntegrityError as e:
569+
logger.debug(f"Exception type: {type(e)}")
570+
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")
571+
526572

527573
class DbReader(DbCodeGate):
528574
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
@@ -569,6 +615,20 @@ async def _exec_select_conditions_to_pydantic(
569615
raise e
570616
return None
571617

618+
async def _exec_vec_db_query_to_pydantic(
619+
self, sql_command: str, conditions: dict, model_type: Type[BaseModel]
620+
) -> List[BaseModel]:
621+
"""
622+
Execute a query on the vector database. This is a separate connection to the SQLite
623+
database that has the vector extension loaded.
624+
"""
625+
conn = self._get_vec_db_connection()
626+
conn.row_factory = sqlite3.Row
627+
cursor = conn.cursor()
628+
results = [model_type(**row) for row in cursor.execute(sql_command, conditions)]
629+
conn.close()
630+
return results
631+
572632
async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]:
573633
sql = text(
574634
"""
@@ -893,6 +953,45 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
893953
)
894954
return muxes
895955

956+
async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
957+
"""
958+
Get a persona by name.
959+
"""
960+
sql = text(
961+
"""
962+
SELECT
963+
id, name, description
964+
FROM personas
965+
WHERE name = :name
966+
"""
967+
)
968+
conditions = {"name": persona_name}
969+
personas = await self._exec_select_conditions_to_pydantic(
970+
Persona, sql, conditions, should_raise=True
971+
)
972+
return personas[0] if personas else None
973+
974+
async def get_distance_to_persona(
975+
self, persona_id: str, query_embedding: np.ndarray
976+
) -> PersonaDistance:
977+
"""
978+
Get the distance between a persona and a query embedding.
979+
"""
980+
sql = """
981+
SELECT
982+
id,
983+
name,
984+
description,
985+
vec_distance_cosine(description_embedding, :query_embedding) as distance
986+
FROM personas
987+
WHERE id = :id
988+
"""
989+
conditions = {"id": persona_id, "query_embedding": query_embedding}
990+
persona_distance = await self._exec_vec_db_query_to_pydantic(
991+
sql, conditions, PersonaDistance
992+
)
993+
return persona_distance[0]
994+
896995

897996
def init_db_sync(db_path: Optional[str] = None):
898997
"""DB will be initialized in the constructor in case it doesn't exist."""

src/codegate/db/models.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from enum import Enum
33
from typing import Annotated, Any, Dict, List, Optional
44

5-
from pydantic import BaseModel, StringConstraints
5+
import numpy as np
6+
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints
67

78

89
class AlertSeverity(str, Enum):
@@ -240,3 +241,58 @@ class MuxRule(BaseModel):
240241
priority: int
241242
created_at: Optional[datetime.datetime] = None
242243
updated_at: Optional[datetime.datetime] = None
244+
245+
246+
def nd_array_custom_before_validator(x):
247+
# custome before validation logic
248+
return x
249+
250+
251+
def nd_array_custom_serializer(x):
252+
# custome serialization logic
253+
return str(x)
254+
255+
256+
# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.
257+
# There are 2 things necessary for a Pydantic custom type: Validator and Serializer
258+
# The lines below build our custom type
259+
# Docs: https://docs.pydantic.dev/latest/concepts/types/#adding-validation-and-serialization
260+
# Open Pydantic issue for npy support: https://github.com/pydantic/pydantic/issues/7017
261+
NdArray = Annotated[
262+
np.ndarray,
263+
BeforeValidator(nd_array_custom_before_validator),
264+
PlainSerializer(nd_array_custom_serializer, return_type=str),
265+
]
266+
267+
268+
class Persona(BaseModel):
269+
"""
270+
Represents a persona object.
271+
"""
272+
273+
id: str
274+
name: str
275+
description: str
276+
277+
278+
class PersonaEmbedding(Persona):
279+
"""
280+
Represents a persona object with an embedding.
281+
"""
282+
283+
description_embedding: NdArray
284+
285+
# Part of the workaround to allow numpy arrays in pydantic models
286+
model_config = ConfigDict(arbitrary_types_allowed=True)
287+
288+
289+
class PersonaDistance(Persona):
290+
"""
291+
Result of an SQL query to get the distance between the query and the persona description.
292+
293+
A vector similarity search is performed to get the distance. Distance values ranges [0, 2].
294+
0 means the vectors are identical, 2 means they are orthogonal.
295+
See [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
296+
"""
297+
298+
distance: float

0 commit comments

Comments
 (0)