Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 0e37312

Browse files
Attended PR comments
1 parent d22abcc commit 0e37312

File tree

4 files changed

+57
-27
lines changed

4 files changed

+57
-27
lines changed

src/codegate/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ class Config:
5757
force_certs: bool = False
5858

5959
max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
60-
persona_threshold = 0.75 # Min value is 0 (max similarity), max value is 2 (orthogonal)
60+
# Min value is 0 (max similarity), max value is 2 (orthogonal)
61+
# The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py
62+
persona_threshold = 0.75
6163

6264
# Provider URLs with defaults
6365
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())

src/codegate/db/connection.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ async def add_persona(self, persona: PersonaEmbedding) -> None:
560560
)
561561

562562
try:
563-
# For Pydantic we conver the numpy array to a string when serializing.
563+
# For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
564564
# We need to convert it back to a numpy array before inserting it into the DB.
565565
persona_dict = persona.model_dump()
566566
persona_dict["description_embedding"] = persona.description_embedding
@@ -615,17 +615,19 @@ async def _exec_select_conditions_to_pydantic(
615615
raise e
616616
return None
617617

618-
async def _exec_vec_db_query(
619-
self, sql_command: str, conditions: dict
620-
) -> Optional[CursorResult]:
618+
async def _exec_vec_db_query_to_pydantic(
619+
self, sql_command: str, conditions: dict, model_type: Type[BaseModel]
620+
) -> List[BaseModel]:
621621
"""
622622
Execute a query on the vector database. This is a separate connection to the SQLite
623623
database that has the vector extension loaded.
624624
"""
625625
conn = self._get_vec_db_connection()
626+
conn.row_factory = sqlite3.Row
626627
cursor = conn.cursor()
627-
cursor.execute(sql_command, conditions)
628-
return cursor
628+
results = [model_type(**row) for row in cursor.execute(sql_command, conditions)]
629+
conn.close()
630+
return results
629631

630632
async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]:
631633
sql = text(
@@ -985,14 +987,10 @@ async def get_distance_to_persona(
985987
WHERE id = :id
986988
"""
987989
conditions = {"id": persona_id, "query_embedding": query_embedding}
988-
persona_distance_cursor = await self._exec_vec_db_query(sql, conditions)
989-
persona_distance_raw = persona_distance_cursor.fetchone()
990-
return PersonaDistance(
991-
id=persona_distance_raw[0],
992-
name=persona_distance_raw[1],
993-
description=persona_distance_raw[2],
994-
distance=persona_distance_raw[3],
990+
persona_distance = await self._exec_vec_db_query_to_pydantic(
991+
sql, conditions, PersonaDistance
995992
)
993+
return persona_distance[0]
996994

997995

998996
def init_db_sync(db_path: Optional[str] = None):

src/codegate/db/models.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,6 @@ class MuxRule(BaseModel):
243243
updated_at: Optional[datetime.datetime] = None
244244

245245

246-
# Pydantic doesn't support numpy arrays out of the box. Defining a custom type
247-
# Reference: https://github.com/pydantic/pydantic/issues/7017
248246
def nd_array_custom_before_validator(x):
249247
# custome before validation logic
250248
return x
@@ -255,6 +253,11 @@ def nd_array_custom_serializer(x):
255253
return str(x)
256254

257255

256+
# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.
257+
# There are 2 things necessary for a Pydantic custom type: Validator and Serializer
258+
# The lines below build our custom type
259+
# Docs: https://docs.pydantic.dev/latest/concepts/types/#adding-validation-and-serialization
260+
# Open Pydantic issue for npy support: https://github.com/pydantic/pydantic/issues/7017
258261
NdArray = Annotated[
259262
np.ndarray,
260263
BeforeValidator(nd_array_custom_before_validator),
@@ -263,17 +266,33 @@ def nd_array_custom_serializer(x):
263266

264267

265268
class Persona(BaseModel):
269+
"""
270+
Represents a persona object.
271+
"""
272+
266273
id: str
267274
name: str
268275
description: str
269276

270277

271278
class PersonaEmbedding(Persona):
272-
description_embedding: NdArray # sqlite-vec will handle numpy arrays directly
279+
"""
280+
Represents a persona object with an embedding.
281+
"""
282+
283+
description_embedding: NdArray
273284

274285
# Part of the workaround to allow numpy arrays in pydantic models
275286
model_config = ConfigDict(arbitrary_types_allowed=True)
276287

277288

278289
class PersonaDistance(Persona):
290+
"""
291+
Result of an SQL query to get the distance between the query and the persona description.
292+
293+
A vector similarity search is performed to get the distance. Distance values ranges [0, 2].
294+
0 means the vectors are identical, 2 means they are orthogonal.
295+
See [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
296+
"""
297+
279298
distance: float

src/codegate/muxing/semantic_router.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
logger = structlog.get_logger("codegate")
1414

1515

16+
REMOVE_URLS = re.compile(r"https?://\S+|www\.\S+")
17+
REMOVE_EMAILS = re.compile(r"\S+@\S+")
18+
REMOVE_CODE_BLOCKS = re.compile(r"```[\s\S]*?```")
19+
REMOVE_INLINE_CODE = re.compile(r"`[^`]*`")
20+
REMOVE_HTML_TAGS = re.compile(r"<[^>]+>")
21+
REMOVE_PUNCTUATION = re.compile(r"[^\w\s\']")
22+
NORMALIZE_WHITESPACE = re.compile(r"\s+")
23+
NORMALIZE_DECIMAL_NUMBERS = re.compile(r"\b\d+\.\d+\b")
24+
NORMALIZE_INTEGER_NUMBERS = re.compile(r"\b\d+\b")
25+
26+
1627
class PersonaDoesNotExistError(Exception):
1728
pass
1829

@@ -54,27 +65,27 @@ def _clean_text_for_embedding(self, text: str) -> str:
5465
text = "".join([c for c in text if not unicodedata.combining(c)])
5566

5667
# Remove URLs
57-
text = re.sub(r"https?://\S+|www\.\S+", " ", text)
68+
text = REMOVE_URLS.sub(" ", text)
5869

5970
# Remove email addresses
60-
text = re.sub(r"\S+@\S+", " ", text)
71+
text = REMOVE_EMAILS.sub(" ", text)
6172

6273
# Remove code block markers and other markdown/code syntax
63-
text = re.sub(r"```[\s\S]*?```", " ", text) # Code blocks
64-
text = re.sub(r"`[^`]*`", " ", text) # Inline code
74+
text = REMOVE_CODE_BLOCKS.sub(" ", text)
75+
text = REMOVE_INLINE_CODE.sub(" ", text)
6576

6677
# Remove HTML/XML tags
67-
text = re.sub(r"<[^>]+>", " ", text)
78+
text = REMOVE_HTML_TAGS.sub(" ", text)
6879

6980
# Normalize numbers (replace with placeholder)
70-
text = re.sub(r"\b\d+\.\d+\b", " NUM ", text) # Decimal numbers
71-
text = re.sub(r"\b\d+\b", " NUM ", text) # Integer numbers
81+
text = NORMALIZE_DECIMAL_NUMBERS.sub(" NUM ", text) # Decimal numbers
82+
text = NORMALIZE_INTEGER_NUMBERS.sub(" NUM ", text) # Integer numbers
7283

7384
# Replace punctuation with spaces (keeping apostrophes for contractions)
74-
text = re.sub(r"[^\w\s\']", " ", text)
85+
text = REMOVE_PUNCTUATION.sub(" ", text)
7586

7687
# Normalize whitespace (replace multiple spaces with a single space)
77-
text = re.sub(r"\s+", " ", text)
88+
text = NORMALIZE_WHITESPACE.sub(" ", text)
7889

7990
# Convert to lowercase and strip
8091
text = text.strip()
@@ -91,7 +102,7 @@ async def _embed_text(self, text: str) -> np.ndarray:
91102
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
92103
)
93104
# Use only the first entry in the list and make sure we have the appropriate type
94-
logger.debug("Text embedded in semantic routing", text=cleaned_text[:100])
105+
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
95106
return np.array(embed_list[0], dtype=np.float32)
96107

97108
async def add_persona(self, persona_name: str, persona_desc: str) -> None:

0 commit comments

Comments
 (0)