diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index bba6ab8e..0c756e2e 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -12,7 +12,12 @@ from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, WorkspaceWithModel +from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel +from codegate.muxing.persona import ( + PersonaDoesNotExistError, + PersonaManager, + PersonaSimilarDescriptionError, +) from codegate.providers import crud as provendcrud from codegate.workspaces import crud @@ -21,6 +26,7 @@ v1 = APIRouter() wscrud = crud.WorkspaceCrud() pcrud = provendcrud.ProviderCrud() +persona_manager = PersonaManager() # This is a singleton object dbreader = DbReader() @@ -665,3 +671,89 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage except Exception: logger.exception("Error while getting messages") raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.get("/personas", tags=["Personas"], generate_unique_id_function=uniq_name) +async def list_personas() -> List[Persona]: + """List all personas.""" + try: + personas = await dbreader.get_all_personas() + return personas + except Exception: + logger.exception("Error while getting personas") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.get("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name) +async def get_persona(persona_name: str) -> Persona: + """Get a persona by name.""" + try: + persona = await dbreader.get_persona_by_name(persona_name) + if not persona: + raise HTTPException(status_code=404, detail=f"Persona {persona_name} not found") + return persona + except Exception as e: + if isinstance(e, HTTPException): + raise e + logger.exception(f"Error while getting persona {persona_name}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201) +async def create_persona(request: v1_models.PersonaRequest) -> Persona: + """Create a new persona.""" + try: + await persona_manager.add_persona(request.name, request.description) + persona = await dbreader.get_persona_by_name(request.name) + return persona + except PersonaSimilarDescriptionError: + logger.exception("Error while creating persona") + raise HTTPException(status_code=409, detail="Persona has a similar description to another") + except AlreadyExistsError: + logger.exception("Error while creating persona") + raise HTTPException(status_code=409, detail="Persona already exists") + except Exception: + logger.exception("Error while creating persona") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.put("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name) +async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequest) -> Persona: + """Update an existing persona.""" + try: + await persona_manager.update_persona( + persona_name, request.new_name, request.new_description + ) + persona = await dbreader.get_persona_by_name(request.new_name) + return persona + except PersonaSimilarDescriptionError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=409, detail="Persona has a similar description to another") + except PersonaDoesNotExistError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=404, detail="Persona does not exist") + except AlreadyExistsError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=409, detail="Persona already exists") + except Exception: + logger.exception("Error while updating persona") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.delete( + "/personas/{persona_name}", + tags=["Personas"], + generate_unique_id_function=uniq_name, + status_code=204, +) +async def delete_persona(persona_name: str): + """Delete a persona.""" + try: + await persona_manager.delete_persona(persona_name) + return Response(status_code=204) + except PersonaDoesNotExistError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=404, detail="Persona does not exist") + except Exception: + logger.exception("Error while deleting persona") + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 6cbc2be3..dff26489 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -315,3 +315,21 @@ class ModelByProvider(pydantic.BaseModel): def __str__(self): return f"{self.provider_name} / {self.name}" + + +class PersonaRequest(pydantic.BaseModel): + """ + Model for creating a new Persona. + """ + + name: str + description: str + + +class PersonaUpdateRequest(pydantic.BaseModel): + """ + Model for updating a Persona. + """ + + new_name: str + new_description: str diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 420f27e8..5c514e5c 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -561,15 +561,41 @@ async def add_persona(self, persona: PersonaEmbedding) -> None: ) try: - # For Pydantic we convert the numpy array to string when serializing with .model_dumpy() - # We need to convert it back to a numpy array before inserting it into the DB. - persona_dict = persona.model_dump() - persona_dict["description_embedding"] = persona.description_embedding - await self._execute_with_no_return(sql, persona_dict) + await self._execute_with_no_return(sql, persona.model_dump()) except IntegrityError as e: logger.debug(f"Exception type: {type(e)}") raise AlreadyExistsError(f"Persona '{persona.name}' already exists.") + async def update_persona(self, persona: PersonaEmbedding) -> None: + """ + Update an existing Persona in the DB. + + This handles validation and update of an existing persona. + """ + sql = text( + """ + UPDATE personas + SET name = :name, + description = :description, + description_embedding = :description_embedding + WHERE id = :id + """ + ) + + try: + await self._execute_with_no_return(sql, persona.model_dump()) + except IntegrityError as e: + logger.debug(f"Exception type: {type(e)}") + raise AlreadyExistsError(f"Persona '{persona.name}' already exists.") + + async def delete_persona(self, persona_id: str) -> None: + """ + Delete an existing Persona from the DB. + """ + sql = text("DELETE FROM personas WHERE id = :id") + conditions = {"id": persona_id} + await self._execute_with_no_return(sql, conditions) + class DbReader(DbCodeGate): def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs): @@ -588,7 +614,10 @@ async def _dump_result_to_pydantic_model( return None async def _execute_select_pydantic_model( - self, model_type: Type[BaseModel], sql_command: TextClause + self, + model_type: Type[BaseModel], + sql_command: TextClause, + should_raise: bool = False, ) -> Optional[List[BaseModel]]: async with self._async_db_engine.begin() as conn: try: @@ -596,6 +625,9 @@ async def _execute_select_pydantic_model( return await self._dump_result_to_pydantic_model(model_type, result) except Exception as e: logger.error(f"Failed to select model: {model_type}.", error=str(e)) + # Exposes errors to the caller + if should_raise: + raise e return None async def _exec_select_conditions_to_pydantic( @@ -1005,7 +1037,7 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]: return personas[0] if personas else None async def get_distance_to_existing_personas( - self, query_embedding: np.ndarray + self, query_embedding: np.ndarray, exclude_id: Optional[str] ) -> List[PersonaDistance]: """ Get the distance between a persona and a query embedding. @@ -1019,6 +1051,13 @@ async def get_distance_to_existing_personas( FROM personas """ conditions = {"query_embedding": query_embedding} + + # Exclude this persona from the SQL query. Used when checking the descriptions + # for updating the persona. Exclude the persona to update itself from the query. + if exclude_id: + sql += " WHERE id != :exclude_id" + conditions["exclude_id"] = exclude_id + persona_distances = await self._exec_vec_db_query_to_pydantic( sql, conditions, PersonaDistance ) @@ -1045,6 +1084,20 @@ async def get_distance_to_persona( ) return persona_distance[0] + async def get_all_personas(self) -> List[Persona]: + """ + Get all the personas. + """ + sql = text( + """ + SELECT + id, name, description + FROM personas + """ + ) + personas = await self._execute_select_pydantic_model(Persona, sql, should_raise=True) + return personas + class DbTransaction: def __init__(self): diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index f71e3c62..6f146b34 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -252,7 +252,7 @@ def nd_array_custom_before_validator(x): def nd_array_custom_serializer(x): # custome serialization logic - return str(x) + return x # Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type. diff --git a/src/codegate/muxing/semantic_router.py b/src/codegate/muxing/persona.py similarity index 74% rename from src/codegate/muxing/semantic_router.py rename to src/codegate/muxing/persona.py index 27a25754..615b3256 100644 --- a/src/codegate/muxing/semantic_router.py +++ b/src/codegate/muxing/persona.py @@ -1,5 +1,6 @@ import unicodedata import uuid +from typing import Optional import numpy as np import regex as re @@ -32,11 +33,12 @@ class PersonaSimilarDescriptionError(Exception): pass -class SemanticRouter: +class PersonaManager: def __init__(self): - self._inference_engine = LlamaCppInferenceEngine() + Config.load() conf = Config.get_config() + self._inference_engine = LlamaCppInferenceEngine() self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}" self._n_gpu = conf.chat_model_n_gpu_layers self._persona_threshold = conf.persona_threshold @@ -110,13 +112,15 @@ async def _embed_text(self, text: str) -> np.ndarray: logger.debug("Text embedded in semantic routing", text=cleaned_text[:50]) return np.array(embed_list[0], dtype=np.float32) - async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bool: + async def _is_persona_description_diff( + self, emb_persona_desc: np.ndarray, exclude_id: Optional[str] + ) -> bool: """ Check if the persona description is different enough from existing personas. """ # The distance calculation is done in the database persona_distances = await self._db_reader.get_distance_to_existing_personas( - emb_persona_desc + emb_persona_desc, exclude_id ) if not persona_distances: return True @@ -131,16 +135,26 @@ async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bo return False return True - async def add_persona(self, persona_name: str, persona_desc: str) -> None: + async def _validate_persona_description( + self, persona_desc: str, exclude_id: str = None + ) -> np.ndarray: """ - Add a new persona to the database. The persona description is embedded - and stored in the database. + Validate the persona description by embedding the text and checking if it is + different enough from existing personas. """ emb_persona_desc = await self._embed_text(persona_desc) - if not await self._is_persona_description_diff(emb_persona_desc): + if not await self._is_persona_description_diff(emb_persona_desc, exclude_id): raise PersonaSimilarDescriptionError( "The persona description is too similar to existing personas." ) + return emb_persona_desc + + async def add_persona(self, persona_name: str, persona_desc: str) -> None: + """ + Add a new persona to the database. The persona description is embedded + and stored in the database. + """ + emb_persona_desc = await self._validate_persona_description(persona_desc) new_persona = db_models.PersonaEmbedding( id=str(uuid.uuid4()), @@ -151,6 +165,43 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None: await self._db_recorder.add_persona(new_persona) logger.info(f"Added persona {persona_name} to the database.") + async def update_persona( + self, persona_name: str, new_persona_name: str, new_persona_desc: str + ) -> None: + """ + Update an existing persona in the database. The name and description are + updated in the database, but the ID remains the same. + """ + # First we check if the persona exists, if not we raise an error + found_persona = await self._db_reader.get_persona_by_name(persona_name) + if not found_persona: + raise PersonaDoesNotExistError(f"Person {persona_name} does not exist.") + + emb_persona_desc = await self._validate_persona_description( + new_persona_desc, exclude_id=found_persona.id + ) + + # Then we update the attributes in the database + updated_persona = db_models.PersonaEmbedding( + id=found_persona.id, + name=new_persona_name, + description=new_persona_desc, + description_embedding=emb_persona_desc, + ) + await self._db_recorder.update_persona(updated_persona) + logger.info(f"Updated persona {persona_name} in the database.") + + async def delete_persona(self, persona_name: str) -> None: + """ + Delete a persona from the database. + """ + persona = await self._db_reader.get_persona_by_name(persona_name) + if not persona: + raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") + + await self._db_recorder.delete_persona(persona.id) + logger.info(f"Deleted persona {persona_name} from the database.") + async def check_persona_match(self, persona_name: str, query: str) -> bool: """ Check if the query matches the persona description. A vector similarity diff --git a/tests/muxing/test_semantic_router.py b/tests/muxing/test_persona.py similarity index 81% rename from tests/muxing/test_semantic_router.py rename to tests/muxing/test_persona.py index 57687567..4e221d8a 100644 --- a/tests/muxing/test_semantic_router.py +++ b/tests/muxing/test_persona.py @@ -6,10 +6,10 @@ from pydantic import BaseModel from codegate.db import connection -from codegate.muxing.semantic_router import ( +from codegate.muxing.persona import ( PersonaDoesNotExistError, + PersonaManager, PersonaSimilarDescriptionError, - SemanticRouter, ) @@ -40,16 +40,16 @@ def db_reader(db_path) -> connection.DbReader: @pytest.fixture() def semantic_router_mocked_db( db_recorder: connection.DbRecorder, db_reader: connection.DbReader -) -> SemanticRouter: +) -> PersonaManager: """Creates a SemanticRouter instance with mocked database.""" - semantic_router = SemanticRouter() + semantic_router = PersonaManager() semantic_router._db_reader = db_reader semantic_router._db_recorder = db_recorder return semantic_router @pytest.mark.asyncio -async def test_add_persona(semantic_router_mocked_db: SemanticRouter): +async def test_add_persona(semantic_router_mocked_db: PersonaManager): """Test adding a persona to the database.""" persona_name = "test_persona" persona_desc = "test_persona_desc" @@ -60,7 +60,7 @@ async def test_add_persona(semantic_router_mocked_db: SemanticRouter): @pytest.mark.asyncio -async def test_add_duplicate_persona(semantic_router_mocked_db: SemanticRouter): +async def test_add_duplicate_persona(semantic_router_mocked_db: PersonaManager): """Test adding a persona to the database.""" persona_name = "test_persona" persona_desc = "test_persona_desc" @@ -73,7 +73,7 @@ async def test_add_duplicate_persona(semantic_router_mocked_db: SemanticRouter): @pytest.mark.asyncio -async def test_persona_not_exist_match(semantic_router_mocked_db: SemanticRouter): +async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager): """Test checking persona match when persona does not exist""" persona_name = "test_persona" query = "test_query" @@ -311,7 +311,7 @@ class PersonaMatchTest(BaseModel): ], ) async def test_check_persona_pass_match( - semantic_router_mocked_db: SemanticRouter, persona_match_test: PersonaMatchTest + semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest ): """Test checking persona match.""" await semantic_router_mocked_db.add_persona( @@ -337,7 +337,7 @@ async def test_check_persona_pass_match( ], ) async def test_check_persona_fail_match( - semantic_router_mocked_db: SemanticRouter, persona_match_test: PersonaMatchTest + semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest ): """Test checking persona match.""" await semantic_router_mocked_db.add_persona( @@ -364,7 +364,7 @@ async def test_check_persona_fail_match( ], ) async def test_persona_diff_description( - semantic_router_mocked_db: SemanticRouter, + semantic_router_mocked_db: PersonaManager, personas: List[PersonaMatchTest], ): # First, add all existing personas @@ -376,3 +376,78 @@ async def test_persona_diff_description( await semantic_router_mocked_db.add_persona( "repeated persona", last_added_persona.persona_desc ) + + +@pytest.mark.asyncio +async def test_update_persona(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database different name and description.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + updated_description = "foo and bar description" + await semantic_router_mocked_db.update_persona( + persona_name, new_persona_name="new test persona", new_persona_desc=updated_description + ) + + +@pytest.mark.asyncio +async def test_update_persona_same_desc(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database with same description.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + await semantic_router_mocked_db.update_persona( + persona_name, new_persona_name="new test persona", new_persona_desc=persona_desc + ) + + +@pytest.mark.asyncio +async def test_update_persona_not_exists(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.update_persona( + persona_name, new_persona_name="new test persona", new_persona_desc=persona_desc + ) + + +@pytest.mark.asyncio +async def test_update_persona_same_name(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + persona_name_2 = "test_persona_2" + persona_desc_2 = "foo and bar" + await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2) + + with pytest.raises(connection.AlreadyExistsError): + await semantic_router_mocked_db.update_persona( + persona_name_2, new_persona_name=persona_name, new_persona_desc=persona_desc_2 + ) + + +@pytest.mark.asyncio +async def test_delete_persona(semantic_router_mocked_db: PersonaManager): + """Test deleting a persona from the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + await semantic_router_mocked_db.delete_persona(persona_name) + + persona_found = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name) + assert persona_found is None + + +@pytest.mark.asyncio +async def test_delete_persona_not_exists(semantic_router_mocked_db: PersonaManager): + persona_name = "test_persona" + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.delete_persona(persona_name)