From 4587b44eed6c2ee5bac5507ffc800d109ebec4be Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 8 Dec 2023 06:53:22 +0000 Subject: [PATCH 1/3] Allow latest inference framework tag --- .../model_engine_server/api/llms_v1.py | 2 + .../model_engine_server/core/docker/ecr.py | 12 +++++ .../domain/repositories/docker_repository.py | 11 +++++ .../use_cases/llm_model_endpoint_use_cases.py | 44 +++++++++++++------ .../repositories/ecr_docker_repository.py | 4 ++ .../repositories/fake_docker_repository.py | 3 ++ model-engine/tests/unit/conftest.py | 3 ++ model-engine/tests/unit/domain/conftest.py | 3 +- .../tests/unit/domain/test_llm_use_cases.py | 12 ++++- 9 files changed, 77 insertions(+), 17 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 64bbfc5f..8de1551d 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -164,6 +164,7 @@ async def create_model_endpoint( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, + docker_repository=external_interfaces.docker_repository, ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: @@ -265,6 +266,7 @@ async def update_model_endpoint( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + docker_repository=external_interfaces.docker_repository, ) return await use_case.execute( user=auth, model_endpoint_name=model_endpoint_name, request=request diff --git a/model-engine/model_engine_server/core/docker/ecr.py b/model-engine/model_engine_server/core/docker/ecr.py index aaf9ef6f..fcd324b9 100644 --- a/model-engine/model_engine_server/core/docker/ecr.py +++ b/model-engine/model_engine_server/core/docker/ecr.py @@ -97,3 +97,15 @@ def ecr_exists_for_repo(repo_name: str, image_tag: Optional[str] = None): return True except ecr.exceptions.ImageNotFoundException: return False + + +def get_latest_image_tag(repository_name: str): + ecr = boto3.client("ecr", region_name=infra_config().default_region) + images = ecr.describe_images( + registryId=infra_config().ml_account_id, + repositoryName=repository_name, + filter=DEFAULT_FILTER, + maxResults=1000, + )["imageDetails"] + latest_image = max(images, key=lambda image: image["imagePushedAt"]) + return latest_image["imageTags"][0] diff --git a/model-engine/model_engine_server/domain/repositories/docker_repository.py b/model-engine/model_engine_server/domain/repositories/docker_repository.py index b2d410a1..f8ba774c 100644 --- a/model-engine/model_engine_server/domain/repositories/docker_repository.py +++ b/model-engine/model_engine_server/domain/repositories/docker_repository.py @@ -49,6 +49,17 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: """ pass + @abstractmethod + def get_latest_image_tag(self, repository_name: str) -> str: + """ + Returns the Docker image tag of the most recently pushed image in the given repository + + Args: + repository_name: the name of the repository containing the image. + + Returns: the tag of the latest Docker image. + """ + def is_repo_name(self, repo_name: str): # We assume repository names must start with a letter and can only contain lowercase letters, numbers, hyphens, underscores, and forward slashes. # Based-off ECR naming standards diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 359c525b..21622bcc 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -85,6 +85,14 @@ logger = make_logger(logger_name()) +INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { + LLMInferenceFramework.DEEPSPEED: "instant-llm", + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: hmi_config.tgi_repository, + LLMInferenceFramework.VLLM: hmi_config.vllm_repository, + LLMInferenceFramework.LIGHTLLM: hmi_config.lightllm_repository, + LLMInferenceFramework.TENSORRT_LLM: hmi_config.tensorrt_llm_repository, +} + _SUPPORTED_MODELS_BY_FRAMEWORK = { LLMInferenceFramework.DEEPSPEED: set( [ @@ -328,8 +336,10 @@ async def execute( checkpoint_path: Optional[str], ) -> ModelBundle: if source == LLMSource.HUGGING_FACE: + self.check_docker_image_exists_for_image_tag( + framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework] + ) if framework == LLMInferenceFramework.DEEPSPEED: - self.check_docker_image_exists_for_image_tag(framework_image_tag, "instant-llm") bundle_id = await self.create_deepspeed_bundle( user, model_name, @@ -338,9 +348,6 @@ async def execute( endpoint_name, ) elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - self.check_docker_image_exists_for_image_tag( - framework_image_tag, hmi_config.tgi_repository - ) bundle_id = await self.create_text_generation_inference_bundle( user, model_name, @@ -351,9 +358,6 @@ async def execute( checkpoint_path, ) elif framework == LLMInferenceFramework.VLLM: - self.check_docker_image_exists_for_image_tag( - framework_image_tag, hmi_config.vllm_repository - ) bundle_id = await self.create_vllm_bundle( user, model_name, @@ -364,9 +368,6 @@ async def execute( checkpoint_path, ) elif framework == LLMInferenceFramework.LIGHTLLM: - self.check_docker_image_exists_for_image_tag( - framework_image_tag, hmi_config.lightllm_repository - ) bundle_id = await self.create_lightllm_bundle( user, model_name, @@ -858,10 +859,12 @@ def __init__( self, create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, model_endpoint_service: ModelEndpointService, + docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service + self.docker_repository = docker_repository async def execute( self, user: User, request: CreateLLMModelEndpointV1Request @@ -891,6 +894,11 @@ async def execute( f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM." ) + if request.inference_framework_image_tag == "latest": + request.inference_framework_image_tag = self.docker_repository.get_latest_image_tag( + INFERENCE_FRAMEWORK_REPOSITORY[request.inference_framework] + ) + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, @@ -1055,11 +1063,13 @@ def __init__( create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, model_endpoint_service: ModelEndpointService, llm_model_endpoint_service: LLMModelEndpointService, + docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service self.llm_model_endpoint_service = llm_model_endpoint_service + self.docker_repository = docker_repository async def execute( self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request @@ -1102,12 +1112,18 @@ async def execute( llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {}) inference_framework = llm_metadata["inference_framework"] + if request.inference_framework_image_tag == "latest": + inference_framework_image_tag = self.docker_repository.get_latest_image_tag( + INFERENCE_FRAMEWORK_REPOSITORY[inference_framework] + ) + else: + inference_framework_image_tag = ( + request.inference_framework_image_tag + or llm_metadata["inference_framework_image_tag"] + ) + model_name = request.model_name or llm_metadata["model_name"] source = request.source or llm_metadata["source"] - inference_framework_image_tag = ( - request.inference_framework_image_tag - or llm_metadata["inference_framework_image_tag"] - ) num_shards = request.num_shards or llm_metadata["num_shards"] quantize = request.quantize or llm_metadata.get("quantize") checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index 16c6b742..d283c4c4 100644 --- a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -3,6 +3,7 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.core.config import infra_config +from model_engine_server.core.docker.ecr import get_latest_image_tag from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists from model_engine_server.core.docker.remote_build import build_remote_block from model_engine_server.core.loggers import logger_name, make_logger @@ -52,3 +53,6 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: return BuildImageResponse( status=build_result.status, logs=build_result.logs, job_name=build_result.job_name ) + + def get_latest_image_tag(self, repository_name: str) -> str: + return get_latest_image_tag(repository_name=repository_name) diff --git a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py index b7fa39a6..142e2f68 100644 --- a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py @@ -19,3 +19,6 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str: def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: raise NotImplementedError("FakeDockerRepository build_image() not implemented") + + def get_latest_image_tag(self, repository_name: str) -> str: + return "fake_docker_repository_latest_image_tag" diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 3528d558..03ae16b8 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -672,6 +672,9 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: raise Exception("I hope you're handling this!") return BuildImageResponse(status=True, logs="", job_name="test-job-name") + def get_latest_image_tag(self, repository_name: str) -> str: + return "fake_docker_repository_latest_image_tag" + class FakeModelEndpointCacheRepository(ModelEndpointCacheRepository): def __init__(self): diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index f433071c..798af362 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -203,7 +203,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request model_name="mpt-7b", source="hugging_face", inference_framework="deepspeed", - inference_framework_image_tag="test_tag", + inference_framework_image_tag="latest", num_shards=2, endpoint_type=ModelEndpointType.ASYNC, metadata={}, @@ -252,6 +252,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req @pytest.fixture def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: return UpdateLLMModelEndpointV1Request( + inference_framework_image_tag="latest", checkpoint_path="s3://test_checkpoint_path", memory="4G", min_workers=0, diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c4fbb31f..9933190a 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -78,6 +78,7 @@ async def test_create_model_endpoint_use_case_success( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) @@ -97,7 +98,7 @@ async def test_create_model_endpoint_use_case_success( "model_name": create_llm_model_endpoint_request_async.model_name, "source": create_llm_model_endpoint_request_async.source, "inference_framework": create_llm_model_endpoint_request_async.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_async.inference_framework_image_tag, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", "num_shards": create_llm_model_endpoint_request_async.num_shards, "quantize": None, "checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path, @@ -201,6 +202,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() @@ -241,6 +243,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -302,6 +305,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -362,6 +366,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -395,6 +400,7 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -463,11 +469,13 @@ async def test_update_model_endpoint_use_case_success( create_use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) update_use_case = UpdateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) @@ -501,7 +509,7 @@ async def test_update_model_endpoint_use_case_success( "model_name": create_llm_model_endpoint_request_streaming.model_name, "source": create_llm_model_endpoint_request_streaming.source, "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, From 541af18162ca7c6728b93c1b9a54e49b37d884ff Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 8 Dec 2023 07:13:58 +0000 Subject: [PATCH 2/3] documentation, fix --- clients/python/llmengine/model.py | 2 +- .../infra/repositories/fake_docker_repository.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 3bd88944..fa84d1e3 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -67,7 +67,7 @@ def create( Name of the base model inference_framework_image_tag (`str`): - Image tag for the inference framework + Image tag for the inference framework. Use "latest" for the most recent image source (`LLMSource`): Source of the LLM. Currently only HuggingFace is supported diff --git a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py index 142e2f68..2d12de6e 100644 --- a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py @@ -21,4 +21,4 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: raise NotImplementedError("FakeDockerRepository build_image() not implemented") def get_latest_image_tag(self, repository_name: str) -> str: - return "fake_docker_repository_latest_image_tag" + raise NotImplementedError("FakeDockerRepository get_latest_image_tag() not implemented") From dca0d95b8ec5bc35eaf6d7d40e41a81206aea6b4 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 8 Dec 2023 22:31:10 +0000 Subject: [PATCH 3/3] omit from code coverage --- model-engine/setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index a5f56d8a..053cae1e 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -5,6 +5,7 @@ test=pytest omit = model_engine_server/entrypoints/* model_engine_server/api/app.py + model_engine_server/core/docker/ecr.py # TODO: Fix pylint errors # [pylint]