Skip to content

Allow latest inference framework tag #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/python/llmengine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create(
Name of the base model

inference_framework_image_tag (`str`):
Image tag for the inference framework
Image tag for the inference framework. Use "latest" for the most recent image

source (`LLMSource`):
Source of the LLM. Currently only HuggingFace is supported
Expand Down
2 changes: 2 additions & 0 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def create_model_endpoint(
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
model_endpoint_service=external_interfaces.model_endpoint_service,
docker_repository=external_interfaces.docker_repository,
)
return await use_case.execute(user=auth, request=request)
except ObjectAlreadyExistsException as exc:
Expand Down Expand Up @@ -265,6 +266,7 @@ async def update_model_endpoint(
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
docker_repository=external_interfaces.docker_repository,
)
return await use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
Expand Down
12 changes: 12 additions & 0 deletions model-engine/model_engine_server/core/docker/ecr.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,15 @@ def ecr_exists_for_repo(repo_name: str, image_tag: Optional[str] = None):
return True
except ecr.exceptions.ImageNotFoundException:
return False


def get_latest_image_tag(repository_name: str):
ecr = boto3.client("ecr", region_name=infra_config().default_region)
images = ecr.describe_images(
registryId=infra_config().ml_account_id,
repositoryName=repository_name,
filter=DEFAULT_FILTER,
maxResults=1000,
)["imageDetails"]
latest_image = max(images, key=lambda image: image["imagePushedAt"])
return latest_image["imageTags"][0]
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
"""
pass

@abstractmethod
def get_latest_image_tag(self, repository_name: str) -> str:
"""
Returns the Docker image tag of the most recently pushed image in the given repository

Args:
repository_name: the name of the repository containing the image.

Returns: the tag of the latest Docker image.
"""

def is_repo_name(self, repo_name: str):
# We assume repository names must start with a letter and can only contain lowercase letters, numbers, hyphens, underscores, and forward slashes.
# Based-off ECR naming standards
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@
logger = make_logger(logger_name())


INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = {
LLMInferenceFramework.DEEPSPEED: "instant-llm",
LLMInferenceFramework.TEXT_GENERATION_INFERENCE: hmi_config.tgi_repository,
LLMInferenceFramework.VLLM: hmi_config.vllm_repository,
LLMInferenceFramework.LIGHTLLM: hmi_config.lightllm_repository,
LLMInferenceFramework.TENSORRT_LLM: hmi_config.tensorrt_llm_repository,
}

_SUPPORTED_MODELS_BY_FRAMEWORK = {
LLMInferenceFramework.DEEPSPEED: set(
[
Expand Down Expand Up @@ -332,8 +340,10 @@ async def execute(
checkpoint_path: Optional[str],
) -> ModelBundle:
if source == LLMSource.HUGGING_FACE:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework]
)
if framework == LLMInferenceFramework.DEEPSPEED:
self.check_docker_image_exists_for_image_tag(framework_image_tag, "instant-llm")
bundle_id = await self.create_deepspeed_bundle(
user,
model_name,
Expand All @@ -342,9 +352,6 @@ async def execute(
endpoint_name,
)
elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, hmi_config.tgi_repository
)
bundle_id = await self.create_text_generation_inference_bundle(
user,
model_name,
Expand All @@ -355,9 +362,6 @@ async def execute(
checkpoint_path,
)
elif framework == LLMInferenceFramework.VLLM:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, hmi_config.vllm_repository
)
bundle_id = await self.create_vllm_bundle(
user,
model_name,
Expand All @@ -368,9 +372,6 @@ async def execute(
checkpoint_path,
)
elif framework == LLMInferenceFramework.LIGHTLLM:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, hmi_config.lightllm_repository
)
bundle_id = await self.create_lightllm_bundle(
user,
model_name,
Expand Down Expand Up @@ -862,10 +863,12 @@ def __init__(
self,
create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase,
model_endpoint_service: ModelEndpointService,
docker_repository: DockerRepository,
):
self.authz_module = LiveAuthorizationModule()
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
self.model_endpoint_service = model_endpoint_service
self.docker_repository = docker_repository

async def execute(
self, user: User, request: CreateLLMModelEndpointV1Request
Expand Down Expand Up @@ -895,6 +898,11 @@ async def execute(
f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM."
)

if request.inference_framework_image_tag == "latest":
request.inference_framework_image_tag = self.docker_repository.get_latest_image_tag(
INFERENCE_FRAMEWORK_REPOSITORY[request.inference_framework]
)

bundle = await self.create_llm_model_bundle_use_case.execute(
user,
endpoint_name=request.name,
Expand Down Expand Up @@ -1059,11 +1067,13 @@ def __init__(
create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase,
model_endpoint_service: ModelEndpointService,
llm_model_endpoint_service: LLMModelEndpointService,
docker_repository: DockerRepository,
):
self.authz_module = LiveAuthorizationModule()
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
self.model_endpoint_service = model_endpoint_service
self.llm_model_endpoint_service = llm_model_endpoint_service
self.docker_repository = docker_repository

async def execute(
self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request
Expand Down Expand Up @@ -1106,12 +1116,18 @@ async def execute(
llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {})
inference_framework = llm_metadata["inference_framework"]

if request.inference_framework_image_tag == "latest":
inference_framework_image_tag = self.docker_repository.get_latest_image_tag(
INFERENCE_FRAMEWORK_REPOSITORY[inference_framework]
)
else:
inference_framework_image_tag = (
request.inference_framework_image_tag
or llm_metadata["inference_framework_image_tag"]
)

model_name = request.model_name or llm_metadata["model_name"]
source = request.source or llm_metadata["source"]
inference_framework_image_tag = (
request.inference_framework_image_tag
or llm_metadata["inference_framework_image_tag"]
)
num_shards = request.num_shards or llm_metadata["num_shards"]
quantize = request.quantize or llm_metadata.get("quantize")
checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse
from model_engine_server.core.config import infra_config
from model_engine_server.core.docker.ecr import get_latest_image_tag
from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists
from model_engine_server.core.docker.remote_build import build_remote_block
from model_engine_server.core.loggers import logger_name, make_logger
Expand Down Expand Up @@ -52,3 +53,6 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
return BuildImageResponse(
status=build_result.status, logs=build_result.logs, job_name=build_result.job_name
)

def get_latest_image_tag(self, repository_name: str) -> str:
return get_latest_image_tag(repository_name=repository_name)
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str:

def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
raise NotImplementedError("FakeDockerRepository build_image() not implemented")

def get_latest_image_tag(self, repository_name: str) -> str:
raise NotImplementedError("FakeDockerRepository get_latest_image_tag() not implemented")
1 change: 1 addition & 0 deletions model-engine/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test=pytest
omit =
model_engine_server/entrypoints/*
model_engine_server/api/app.py
model_engine_server/core/docker/ecr.py

# TODO: Fix pylint errors
# [pylint]
Expand Down
3 changes: 3 additions & 0 deletions model-engine/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,9 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
raise Exception("I hope you're handling this!")
return BuildImageResponse(status=True, logs="", job_name="test-job-name")

def get_latest_image_tag(self, repository_name: str) -> str:
return "fake_docker_repository_latest_image_tag"


class FakeModelEndpointCacheRepository(ModelEndpointCacheRepository):
def __init__(self):
Expand Down
3 changes: 2 additions & 1 deletion model-engine/tests/unit/domain/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request
model_name="mpt-7b",
source="hugging_face",
inference_framework="deepspeed",
inference_framework_image_tag="test_tag",
inference_framework_image_tag="latest",
num_shards=2,
endpoint_type=ModelEndpointType.ASYNC,
metadata={},
Expand Down Expand Up @@ -252,6 +252,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req
@pytest.fixture
def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request:
return UpdateLLMModelEndpointV1Request(
inference_framework_image_tag="latest",
checkpoint_path="s3://test_checkpoint_path",
memory="4G",
min_workers=0,
Expand Down
12 changes: 10 additions & 2 deletions model-engine/tests/unit/domain/test_llm_use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ async def test_create_model_endpoint_use_case_success(
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)

user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
Expand All @@ -97,7 +98,7 @@ async def test_create_model_endpoint_use_case_success(
"model_name": create_llm_model_endpoint_request_async.model_name,
"source": create_llm_model_endpoint_request_async.source,
"inference_framework": create_llm_model_endpoint_request_async.inference_framework,
"inference_framework_image_tag": create_llm_model_endpoint_request_async.inference_framework_image_tag,
"inference_framework_image_tag": "fake_docker_repository_latest_image_tag",
"num_shards": create_llm_model_endpoint_request_async.num_shards,
"quantize": None,
"checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path,
Expand Down Expand Up @@ -201,6 +202,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation(
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)

request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy()
Expand Down Expand Up @@ -241,6 +243,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
response_1 = await use_case.execute(
Expand Down Expand Up @@ -302,6 +305,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success(
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
response_1 = await use_case.execute(
Expand Down Expand Up @@ -362,6 +366,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
with pytest.raises(ObjectHasInvalidValueException):
Expand Down Expand Up @@ -395,6 +400,7 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception(
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
with pytest.raises(ObjectHasInvalidValueException):
Expand Down Expand Up @@ -463,11 +469,13 @@ async def test_update_model_endpoint_use_case_success(
create_use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)
update_use_case = UpdateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=llm_bundle_use_case,
model_endpoint_service=fake_model_endpoint_service,
llm_model_endpoint_service=fake_llm_model_endpoint_service,
docker_repository=fake_docker_repository_image_always_exists,
)

user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
Expand Down Expand Up @@ -501,7 +509,7 @@ async def test_update_model_endpoint_use_case_success(
"model_name": create_llm_model_endpoint_request_streaming.model_name,
"source": create_llm_model_endpoint_request_streaming.source,
"inference_framework": create_llm_model_endpoint_request_streaming.inference_framework,
"inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag,
"inference_framework_image_tag": "fake_docker_repository_latest_image_tag",
"num_shards": create_llm_model_endpoint_request_streaming.num_shards,
"quantize": None,
"checkpoint_path": update_llm_model_endpoint_request.checkpoint_path,
Expand Down