diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 2a160b9e..d4695ffd 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,21 +1,24 @@ from typing import List, Optional +from uuid import UUID import requests import structlog -from fastapi import APIRouter, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader +from codegate.providers import crud as provendcrud from codegate.workspaces import crud logger = structlog.get_logger("codegate") v1 = APIRouter() wscrud = crud.WorkspaceCrud() +pcrud = provendcrud.ProviderCrud() # This is a singleton object dbreader = DbReader() @@ -25,38 +28,78 @@ def uniq_name(route: APIRoute): return f"v1_{route.name}" +class FilterByNameParams(BaseModel): + name: Optional[str] = None + + @v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name) -async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]: +async def list_provider_endpoints( + filter_query: FilterByNameParams = Depends(), +) -> List[v1_models.ProviderEndpoint]: """List all provider endpoints.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the provider endpoints from the database. - return [ - v1_models.ProviderEndpoint( - id=1, - name="dummy", - description="Dummy provider endpoint", - endpoint="http://example.com", - provider_type=v1_models.ProviderType.openai, - auth_type=v1_models.ProviderAuthType.none, - ) - ] + if filter_query.name is None: + try: + return await pcrud.list_endpoints() + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + try: + provend = await pcrud.get_endpoint_by_name(filter_query.name) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + if provend is None: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + return [provend] + + +# This needs to be above /provider-endpoints/{provider_id} to avoid conflict +@v1.get( + "/provider-endpoints/models", + tags=["Providers"], + generate_unique_id_function=uniq_name, +) +async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]: + """List all models for all providers.""" + try: + return await pcrud.get_all_models() + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.get( + "/provider-endpoints/{provider_id}/models", + tags=["Providers"], + generate_unique_id_function=uniq_name, +) +async def list_models_by_provider( + provider_id: UUID, +) -> List[v1_models.ModelByProvider]: + """List models by provider.""" + + try: + return await pcrud.models_by_provider(provider_id) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider not found") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @v1.get( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) -async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint: +async def get_provider_endpoint( + provider_id: UUID, +) -> v1_models.ProviderEndpoint: """Get a provider endpoint by ID.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the provider endpoint from the database. - return v1_models.ProviderEndpoint( - id=provider_id, - name="dummy", - description="Dummy provider endpoint", - endpoint="http://example.com", - provider_type=v1_models.ProviderType.openai, - auth_type=v1_models.ProviderAuthType.none, - ) + try: + provend = await pcrud.get_endpoint_by_id(provider_id) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + if provend is None: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + return provend @v1.post( @@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint: generate_unique_id_function=uniq_name, status_code=201, ) -async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint: +async def add_provider_endpoint( + request: v1_models.ProviderEndpoint, +) -> v1_models.ProviderEndpoint: """Add a provider endpoint.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that adds the provider endpoint to the database. - return request + try: + provend = await pcrud.add_endpoint(request) + except AlreadyExistsError: + raise HTTPException(status_code=409, detail="Provider endpoint already exists") + except ValidationError as e: + # TODO: This should be more specific + raise HTTPException( + status_code=400, + detail=str(e), + ) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + return provend @v1.put( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def update_provider_endpoint( - provider_id: int, request: v1_models.ProviderEndpoint + provider_id: UUID, + request: v1_models.ProviderEndpoint, ) -> v1_models.ProviderEndpoint: """Update a provider endpoint by ID.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that updates the provider endpoint in the database. - return request + try: + request.id = provider_id + provend = await pcrud.update_endpoint(request) + except ValidationError as e: + # TODO: This should be more specific + raise HTTPException( + status_code=400, + detail=str(e), + ) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + return provend @v1.delete( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) -async def delete_provider_endpoint(provider_id: int): +async def delete_provider_endpoint( + provider_id: UUID, +): """Delete a provider endpoint by id.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that deletes the provider endpoint from the database. + try: + await pcrud.delete_endpoint(provider_id) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) -@v1.get( - "/provider-endpoints/{provider_name}/models", - tags=["Providers"], - generate_unique_id_function=uniq_name, -) -async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]: - """List models by provider.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches the models by provider from the database. - return [v1_models.ModelByProvider(name="dummy", provider="dummy")] - - -@v1.get( - "/provider-endpoints/models", - tags=["Providers"], - generate_unique_id_function=uniq_name, -) -async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]: - """List all models for all providers.""" - # NOTE: This is a dummy implementation. In the future, we should have a proper - # implementation that fetches all the models for all providers from the database. - return [v1_models.ModelByProvider(name="dummy", provider="dummy")] - - @v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name) async def list_workspaces() -> v1_models.ListWorkspacesResponse: """List all workspaces.""" @@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str): tags=["Workspaces", "Muxes"], generate_unique_id_function=uniq_name, ) -async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]: +async def get_workspace_muxes( + workspace_name: str, +) -> List[v1_models.MuxRule]: """Get the mux rules of a workspace. The list is ordered in order of priority. That is, the first rule in the list @@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]: generate_unique_id_function=uniq_name, status_code=204, ) -async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]): +async def set_workspace_muxes( + workspace_name: str, + request: List[v1_models.MuxRule], +): """Set the mux rules of a workspace.""" # TODO: This is a dummy implementation. In the future, we should have a proper # implementation that sets the mux rules in the database. diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index fb4e90d3..3f1f37a6 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -6,6 +6,8 @@ from codegate.db import models as db_models from codegate.pipeline.base import CodeSnippet +from codegate.providers.base import BaseProvider +from codegate.providers.registry import ProviderRegistry class Workspace(pydantic.BaseModel): @@ -122,6 +124,8 @@ class ProviderType(str, Enum): openai = "openai" anthropic = "anthropic" vllm = "vllm" + ollama = "ollama" + lm_studio = "lm_studio" class TokenUsageByModel(pydantic.BaseModel): @@ -191,13 +195,38 @@ class ProviderEndpoint(pydantic.BaseModel): so we can use this for muxing messages. """ - id: int + # This will be set on creation + id: Optional[str] = "" name: str description: str = "" provider_type: ProviderType endpoint: str auth_type: ProviderAuthType + @staticmethod + def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint": + return ProviderEndpoint( + id=db_model.id, + name=db_model.name, + description=db_model.description, + provider_type=db_model.provider_type, + endpoint=db_model.endpoint, + auth_type=db_model.auth_type, + ) + + def to_db_model(self) -> db_models.ProviderEndpoint: + return db_models.ProviderEndpoint( + id=self.id, + name=self.name, + description=self.description, + provider_type=self.provider_type, + endpoint=self.endpoint, + auth_type=self.auth_type, + ) + + def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider]: + return registry.get_provider(self.provider_type) + class ModelByProvider(pydantic.BaseModel): """ @@ -207,10 +236,11 @@ class ModelByProvider(pydantic.BaseModel): """ name: str - provider: str + provider_id: str + provider_name: str def __str__(self): - return f"{self.provider}/{self.name}" + return f"{self.provider_name} / {self.name}" class MuxMatcherType(str, Enum): diff --git a/src/codegate/cli.py b/src/codegate/cli.py index dc05ed25..ba3016eb 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -17,6 +17,7 @@ from codegate.db.connection import init_db_sync, init_session_if_not_exists from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.secrets.manager import SecretsManager +from codegate.providers import crud as provendcrud from codegate.providers.copilot.provider import CopilotProvider from codegate.server import init_app from codegate.storage.utils import restore_storage_backup @@ -338,6 +339,9 @@ def serve( # noqa: C901 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + registry = app.provider_registry + loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry)) + # Run the server try: loop.run_until_complete(run_servers(cfg, app)) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 15305790..10c1c81f 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -21,6 +21,9 @@ GetWorkspaceByNameConditions, Output, Prompt, + ProviderAuthMaterial, + ProviderEndpoint, + ProviderModel, Session, WorkspaceRow, WorkspaceWithSessionInfo, @@ -368,6 +371,72 @@ async def recover_workspace(self, workspace: WorkspaceRow) -> Optional[Workspace ) return recovered_workspace + async def add_provider_endpoint(self, provider: ProviderEndpoint) -> ProviderEndpoint: + sql = text( + """ + INSERT INTO provider_endpoints ( + id, name, description, provider_type, endpoint, auth_type, auth_blob + ) + VALUES (:id, :name, :description, :provider_type, :endpoint, :auth_type, "") + RETURNING * + """ + ) + added_provider = await self._execute_update_pydantic_model(provider, sql, should_raise=True) + return added_provider + + async def update_provider_endpoint(self, provider: ProviderEndpoint) -> ProviderEndpoint: + sql = text( + """ + UPDATE provider_endpoints + SET name = :name, description = :description, provider_type = :provider_type, + endpoint = :endpoint, auth_type = :auth_type + WHERE id = :id + RETURNING * + """ + ) + updated_provider = await self._execute_update_pydantic_model( + provider, sql, should_raise=True + ) + return updated_provider + + async def delete_provider_endpoint( + self, + provider: ProviderEndpoint, + ) -> Optional[ProviderEndpoint]: + sql = text( + """ + DELETE FROM provider_endpoints + WHERE id = :id + RETURNING * + """ + ) + deleted_provider = await self._execute_update_pydantic_model( + provider, sql, should_raise=True + ) + return deleted_provider + + async def push_provider_auth_material(self, auth_material: ProviderAuthMaterial): + sql = text( + """ + UPDATE provider_endpoints + SET auth_type = :auth_type, auth_blob = :auth_blob + WHERE id = :provider_endpoint_id + """ + ) + _ = await self._execute_update_pydantic_model(auth_material, sql, should_raise=True) + return + + async def add_provider_model(self, model: ProviderModel) -> ProviderModel: + sql = text( + """ + INSERT INTO provider_models (provider_endpoint_id, name) + VALUES (:provider_endpoint_id, :name) + RETURNING * + """ + ) + added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True) + return added_model + class DbReader(DbCodeGate): @@ -537,6 +606,69 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]: active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql) return active_workspace[0] if active_workspace else None + async def get_provider_endpoint_by_name(self, provider_name: str) -> Optional[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + WHERE name = :name + """ + ) + conditions = {"name": provider_name} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + return provider[0] if provider else None + + async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + WHERE id = :id + """ + ) + conditions = {"id": provider_id} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + return provider[0] if provider else None + + async def get_provider_endpoints(self) -> List[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + """ + ) + providers = await self._execute_select_pydantic_model(ProviderEndpoint, sql) + return providers + + async def get_provider_models_by_provider_id(self, provider_id: str) -> List[ProviderModel]: + sql = text( + """ + SELECT provider_endpoint_id, name + FROM provider_models + WHERE provider_endpoint_id = :provider_endpoint_id + """ + ) + conditions = {"provider_endpoint_id": provider_id} + models = await self._exec_select_conditions_to_pydantic( + ProviderModel, sql, conditions, should_raise=True + ) + return models + + async def get_all_provider_models(self) -> List[ProviderModel]: + sql = text( + """ + SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name + FROM provider_models pm + INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id + """ + ) + models = await self._execute_select_pydantic_model(ProviderModel, sql) + return models + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 23cbea5d..2a6434ef 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -99,3 +99,24 @@ class ActiveWorkspace(BaseModel): custom_instructions: Optional[str] session_id: str last_update: datetime.datetime + + +class ProviderEndpoint(BaseModel): + id: str + name: str + description: str + provider_type: str + endpoint: str + auth_type: str + + +class ProviderAuthMaterial(BaseModel): + provider_endpoint_id: str + auth_type: str + auth_blob: str + + +class ProviderModel(BaseModel): + provider_endpoint_id: str + provider_endpoint_name: Optional[str] = None + name: str diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 10215c9e..48821de0 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -1,5 +1,7 @@ import json +from typing import List +import httpx import structlog from fastapi import Header, HTTPException, Request @@ -27,6 +29,21 @@ def __init__( def provider_route_name(self) -> str: return "anthropic" + def models(self) -> List[str]: + # TODO: This won't work since we need an API Key being set. + resp = httpx.get("https://api.anthropic.com/models") + # If Anthropic returned 404, it means it's not accepting our + # requests. We should throw an error. + if resp.status_code == 404: + raise HTTPException( + status_code=404, + detail="The Anthropic API is not accepting requests. Please check your API key.", + ) + + respjson = resp.json() + + return [model["id"] for model in respjson.get("data", [])] + def _setup_routes(self): """ Sets up the /messages route for the provider as expected by the Anthropic diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 515be531..1ab055ea 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Callable, Dict, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union import structlog from fastapi import APIRouter, Request @@ -54,6 +54,10 @@ def __init__( def _setup_routes(self) -> None: pass + @abstractmethod + def models(self) -> List[str]: + pass + @property @abstractmethod def provider_route_name(self) -> str: diff --git a/src/codegate/providers/crud/__init__.py b/src/codegate/providers/crud/__init__.py new file mode 100644 index 00000000..58adb943 --- /dev/null +++ b/src/codegate/providers/crud/__init__.py @@ -0,0 +1,3 @@ +from .crud import ProviderCrud, ProviderNotFoundError, initialize_provider_endpoints + +__all__ = ["ProviderCrud", "initialize_provider_endpoints", "ProviderNotFoundError"] diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py new file mode 100644 index 00000000..637375e8 --- /dev/null +++ b/src/codegate/providers/crud/crud.py @@ -0,0 +1,229 @@ +import asyncio +from typing import List, Optional +from urllib.parse import urlparse +from uuid import UUID, uuid4 + +import structlog +from pydantic import ValidationError + +from codegate.api import v1_models as apimodelsv1 +from codegate.config import Config +from codegate.db import models as dbmodels +from codegate.db.connection import DbReader, DbRecorder +from codegate.providers.base import BaseProvider +from codegate.providers.registry import ProviderRegistry + +logger = structlog.get_logger("codegate") + + +class ProviderNotFoundError(Exception): + pass + + +class ProviderCrud: + """The CRUD operations for the provider endpoint references within + Codegate. + + This is meant to handle all the transformations in between the + database and the API, as well as other sources of information. All + operations should result in the API models being returned. + """ + + def __init__(self): + self._db_reader = DbReader() + self._db_writer = DbRecorder() + + async def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]: + """List all the endpoints.""" + + outendpoints = [] + dbendpoints = await self._db_reader.get_provider_endpoints() + for dbendpoint in dbendpoints: + outendpoints.append(apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)) + + return outendpoints + + async def get_endpoint_by_id(self, id: UUID) -> Optional[apimodelsv1.ProviderEndpoint]: + """Get an endpoint by ID.""" + + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(id)) + if dbendpoint is None: + return None + + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.ProviderEndpoint]: + """Get an endpoint by name.""" + + dbendpoint = await self._db_reader.get_provider_endpoint_by_name(name) + if dbendpoint is None: + return None + + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def add_endpoint( + self, endpoint: apimodelsv1.ProviderEndpoint + ) -> apimodelsv1.ProviderEndpoint: + """Add an endpoint.""" + dbend = endpoint.to_db_model() + + # We override the ID here, as we want to generate it. + dbend.id = str(uuid4()) + + dbendpoint = await self._db_writer.add_provider_endpoint() + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def update_endpoint( + self, endpoint: apimodelsv1.ProviderEndpoint + ) -> apimodelsv1.ProviderEndpoint: + """Update an endpoint.""" + + dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def delete_endpoint(self, provider_id: UUID): + """Delete an endpoint.""" + + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id)) + if dbendpoint is None: + raise ProviderNotFoundError("Provider not found") + + await self._db_writer.delete_provider_endpoint(dbendpoint) + + async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelByProvider]: + """Get the models by provider.""" + + # First we try to get the provider + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id)) + if dbendpoint is None: + raise ProviderNotFoundError("Provider not found") + + outmodels = [] + dbmodels = await self._db_reader.get_provider_models_by_provider_id(str(provider_id)) + for dbmodel in dbmodels: + outmodels.append( + apimodelsv1.ModelByProvider( + name=dbmodel.name, + provider_id=dbmodel.provider_endpoint_id, + provider_name=dbendpoint.name, + ) + ) + + return outmodels + + async def get_all_models(self) -> List[apimodelsv1.ModelByProvider]: + """Get all the models.""" + + outmodels = [] + dbmodels = await self._db_reader.get_all_provider_models() + for dbmodel in dbmodels: + ename = dbmodel.provider_endpoint_name if dbmodel.provider_endpoint_name else "" + outmodels.append( + apimodelsv1.ModelByProvider( + name=dbmodel.name, + provider_id=dbmodel.provider_endpoint_id, + provider_name=ename, + ) + ) + + return outmodels + + +async def initialize_provider_endpoints(preg: ProviderRegistry): + db_writer = DbRecorder() + db_reader = DbReader() + config = Config.get_config() + if config is None: + provided_urls = {} + else: + provided_urls = config.provider_urls + + for provider_name, provider_url in provided_urls.items(): + provend = __provider_endpoint_from_cfg(provider_name, provider_url) + if provend is None: + continue + + # Check if the provider is already in the db + dbprovend = await db_reader.get_provider_endpoint_by_name(provend.name) + if dbprovend is not None: + logger.debug( + "Provider already in DB. Not re-adding.", + provider=provend.name, + endpoint=provend.endpoint, + ) + continue + + pimpl = provend.get_from_registry(preg) + await try_initialize_provider_endpoints(provend, pimpl, db_writer) + + +async def try_initialize_provider_endpoints( + provend: apimodelsv1.ProviderEndpoint, + pimpl: BaseProvider, + db_writer: DbRecorder, +): + try: + models = pimpl.models() + except Exception as err: + logger.debug( + "Unable to get models from provider", + provider=provend.name, + err=str(err), + ) + return + + logger.info( + "initializing provider to DB", + provider=provend.name, + endpoint=provend.endpoint, + models=models, + ) + # We only try to add the provider if we have models + await db_writer.add_provider_endpoint(provend.to_db_model()) + + tasks = set() + for model in models: + tasks.add( + db_writer.add_provider_model( + dbmodels.ProviderModel( + provider_endpoint_id=provend.id, + name=model, + ) + ) + ) + + await asyncio.gather(*tasks) + + +def __provider_endpoint_from_cfg( + provider_name: str, provider_url: str +) -> Optional[apimodelsv1.ProviderEndpoint]: + """Create a provider endpoint from the config entry.""" + + try: + _ = urlparse(provider_url) + except Exception: + logger.warning( + "Invalid provider URL", provider_name=provider_name, provider_url=provider_url + ) + return None + + try: + return apimodelsv1.ProviderEndpoint( + id=str(uuid4()), + name=provider_name, + endpoint=provider_url, + description=("Endpoint for the {} provided via the CodeGate configuration.").format( + provider_name + ), + provider_type=provider_name, + auth_type=apimodelsv1.ProviderAuthType.passthrough, + ) + except ValidationError as err: + logger.warning( + "Invalid provider name", + provider_name=provider_name, + provider_url=provider_url, + err=str(err), + ) + return None diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 7f90619e..4478d137 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,5 +1,6 @@ import json +import httpx import structlog from fastapi import HTTPException, Request @@ -26,6 +27,13 @@ def __init__( def provider_route_name(self) -> str: return "llamacpp" + def models(self): + # HACK: This is using OpenAI's /v1/models endpoint to get the list of models + resp = httpx.get(f"{self.base_url}/v1/models") + jsonresp = resp.json() + + return [model["id"] for model in jsonresp.get("data", [])] + def _setup_routes(self): """ Sets up the /completions and /chat/completions routes for the diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index ac8013b9..b8e0477b 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -34,6 +34,12 @@ def __init__( def provider_route_name(self) -> str: return "ollama" + def models(self): + resp = httpx.get(f"{self.base_url}/api/tags") + jsonresp = resp.json() + + return [model["name"] for model in jsonresp.get("models", [])] + def _setup_routes(self): """ Sets up Ollama API routes. diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 8a00c68c..87588265 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,5 +1,7 @@ import json +from typing import List +import httpx import structlog from fastapi import Header, HTTPException, Request from fastapi.responses import JSONResponse @@ -33,6 +35,13 @@ def __init__( def provider_route_name(self) -> str: return "openai" + def models(self) -> List[str]: + # NOTE: This won't work since we need an API Key being set. + resp = httpx.get(f"{self.lm_studio_url}/v1/models") + jsonresp = resp.json() + + return [model["id"] for model in jsonresp.get("data", [])] + def _setup_routes(self): """ Sets up the /chat/completions route for the provider as expected by the diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index f39ed8d6..303b907b 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -31,6 +31,12 @@ def __init__( def provider_route_name(self) -> str: return "vllm" + def models(self): + resp = httpx.get(f"{self.base_url}/v1/models") + jsonresp = resp.json() + + return [model["id"] for model in jsonresp.get("data", [])] + def _setup_routes(self): """ Sets up the /chat/completions route for the provider as expected by the diff --git a/src/codegate/server.py b/src/codegate/server.py index 857fb064..ece60c0c 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -30,9 +30,16 @@ async def custom_error_handler(request, exc: Exception): return JSONResponse({"error": str(exc)}, status_code=500) -def init_app(pipeline_factory: PipelineFactory) -> FastAPI: +class CodeGateServer(FastAPI): + provider_registry: ProviderRegistry = None + + def set_provider_registry(self, registry: ProviderRegistry): + self.provider_registry = registry + + +def init_app(pipeline_factory: PipelineFactory) -> CodeGateServer: """Create the FastAPI application.""" - app = FastAPI( + app = CodeGateServer( title="CodeGate", description=__description__, version=__version__, @@ -58,6 +65,7 @@ async def log_user_agent(request: Request, call_next): # Create provider registry registry = ProviderRegistry(app) + app.set_provider_registry(registry) # Register all known providers registry.add_provider( diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index d7c97da9..4922a5ef 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -93,6 +93,9 @@ def __init__( def provider_route_name(self) -> str: return "mock_provider" + def models(self): + return [] + def _setup_routes(self) -> None: @self.router.get(f"/{self.provider_route_name}/test") def test_route(): diff --git a/tests/test_provider.py b/tests/test_provider.py index 95361c97..3539b942 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -19,6 +19,9 @@ def __init__(self): mocked_factory, ) + def models(self): + return [] + def _setup_routes(self) -> None: pass