From 8d185c5d42a56e91df812fada9f8cce8bc2f19e3 Mon Sep 17 00:00:00 2001 From: Warmybloodwolf <156565187+Warmybloodwolf@users.noreply.github.com> Date: Fri, 13 Jun 2025 18:56:55 +0800 Subject: [PATCH] Optimize the batching logic of the embeddings Used to fix the error "beyond max_tokens_per_request" Estimate the total tokens consumed and process them in batches when they exceed max_tokens_per_request --- backend/llm_model/embeddings.py | 39 +++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/backend/llm_model/embeddings.py b/backend/llm_model/embeddings.py index e0bbd1a..4fb17cb 100644 --- a/backend/llm_model/embeddings.py +++ b/backend/llm_model/embeddings.py @@ -8,6 +8,19 @@ from config.db_config import db, db_session_manager from config.logging_config import logger +# OpenAI / Azure-OpenAI allows up to 300 000 tokens per embedding request. +# Leave some headroom to reduce the chance of repeated retries at the limit. +_MAX_TOKENS_PER_REQ = 290_000 + + +def _estimate_tokens(text: str) -> int: + """ + Roughly estimate how many tokens `text` consumes. + For English models, on average 1 token ≈ 4 characters; for Chinese, 1 token ≈ 1.3–2 characters. + We use a compromise value of 3.5 characters per token to ensure a safer upper-bound estimate. + """ + return max(1, int(len(text) / 3.5)) + class EmbeddingManager: """Embedding Manager""" @@ -144,9 +157,31 @@ async def _get_embeddings_with_context(text: Union[str, List[str]], model_name: if isinstance(text, str): embedding = await embedding_model.aembed_query(text[:8192]) else: - text = [t[:8192] for t in text] - embedding = await embedding_model.aembed_documents(text) + # First, trim each text to 8 192 characters + texts = [t[:8192] for t in text] + + # —— Batching logic —— # + batches, cur_batch, cur_tokens = [], [], 0 + for t in texts: + tok = _estimate_tokens(t) + # If adding `t` would exceed the per-request token limit, finalize the current batch + if cur_batch and cur_tokens + tok > _MAX_TOKENS_PER_REQ: + batches.append(cur_batch) + cur_batch, cur_tokens = [], 0 + cur_batch.append(t) + cur_tokens += tok + if cur_batch: # Process the last batch + batches.append(cur_batch) + + # Send requests sequentially to preserve output order + embedding = [] + for bt in batches: + bt_emb = await embedding_model.aembed_documents(bt) + embedding.extend(bt_emb) + return np.array(embedding) + except Exception as e: logger.error(f"Failed to generate Embedding: {str(e)}") raise +