Skip to content

Commit a0d7cd1

Browse files
authored
Make telemetry batch size configurable and add time-based flush (#622)
configurable telemetry batch size, time based flush Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 59d28b0 commit a0d7cd1

File tree

4 files changed

+52
-4
lines changed

4 files changed

+52
-4
lines changed

src/databricks/sql/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def read(self) -> Optional[OAuthToken]:
254254
self.telemetry_enabled = (
255255
self.client_telemetry_enabled and self.server_telemetry_enabled
256256
)
257+
self.telemetry_batch_size = kwargs.get(
258+
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
259+
)
257260

258261
try:
259262
self.session = Session(
@@ -290,6 +293,7 @@ def read(self) -> Optional[OAuthToken]:
290293
session_id_hex=self.get_session_id_hex(),
291294
auth_provider=self.session.auth_provider,
292295
host_url=self.session.host,
296+
batch_size=self.telemetry_batch_size,
293297
)
294298

295299
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,18 @@ class TelemetryClient(BaseTelemetryClient):
138138
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
139139
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
140140

141-
DEFAULT_BATCH_SIZE = 100
142-
143141
def __init__(
144142
self,
145143
telemetry_enabled,
146144
session_id_hex,
147145
auth_provider,
148146
host_url,
149147
executor,
148+
batch_size,
150149
):
151150
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
152151
self._telemetry_enabled = telemetry_enabled
153-
self._batch_size = self.DEFAULT_BATCH_SIZE
152+
self._batch_size = batch_size
154153
self._session_id_hex = session_id_hex
155154
self._auth_provider = auth_provider
156155
self._user_agent = None
@@ -318,7 +317,7 @@ def close(self):
318317
class TelemetryClientFactory:
319318
"""
320319
Static factory class for creating and managing telemetry clients.
321-
It uses a thread pool to handle asynchronous operations.
320+
It uses a thread pool to handle asynchronous operations and a single flush thread for all clients.
322321
"""
323322

324323
_clients: Dict[
@@ -331,6 +330,13 @@ class TelemetryClientFactory:
331330
_original_excepthook = None
332331
_excepthook_installed = False
333332

333+
# Shared flush thread for all clients
334+
_flush_thread = None
335+
_flush_event = threading.Event()
336+
_flush_interval_seconds = 90
337+
338+
DEFAULT_BATCH_SIZE = 100
339+
334340
@classmethod
335341
def _initialize(cls):
336342
"""Initialize the factory if not already initialized"""
@@ -341,11 +347,39 @@ def _initialize(cls):
341347
max_workers=10
342348
) # Thread pool for async operations
343349
cls._install_exception_hook()
350+
cls._start_flush_thread()
344351
cls._initialized = True
345352
logger.debug(
346353
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
347354
)
348355

356+
@classmethod
357+
def _start_flush_thread(cls):
358+
"""Start the shared background thread for periodic flushing of all clients"""
359+
cls._flush_event.clear()
360+
cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True)
361+
cls._flush_thread.start()
362+
363+
@classmethod
364+
def _flush_worker(cls):
365+
"""Background worker thread for periodic flushing of all clients"""
366+
while not cls._flush_event.wait(cls._flush_interval_seconds):
367+
logger.debug("Performing periodic flush for all telemetry clients")
368+
369+
with cls._lock:
370+
clients_to_flush = list(cls._clients.values())
371+
372+
for client in clients_to_flush:
373+
client._flush()
374+
375+
@classmethod
376+
def _stop_flush_thread(cls):
377+
"""Stop the shared background flush thread"""
378+
if cls._flush_thread is not None:
379+
cls._flush_event.set()
380+
cls._flush_thread.join(timeout=1.0)
381+
cls._flush_thread = None
382+
349383
@classmethod
350384
def _install_exception_hook(cls):
351385
"""Install global exception handler for unhandled exceptions"""
@@ -374,6 +408,7 @@ def initialize_telemetry_client(
374408
session_id_hex,
375409
auth_provider,
376410
host_url,
411+
batch_size,
377412
):
378413
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
379414
try:
@@ -395,6 +430,7 @@ def initialize_telemetry_client(
395430
auth_provider=auth_provider,
396431
host_url=host_url,
397432
executor=TelemetryClientFactory._executor,
433+
batch_size=batch_size,
398434
)
399435
else:
400436
TelemetryClientFactory._clients[
@@ -433,6 +469,7 @@ def close(session_id_hex):
433469
"No more telemetry clients, shutting down thread pool executor"
434470
)
435471
try:
472+
TelemetryClientFactory._stop_flush_thread()
436473
TelemetryClientFactory._executor.shutdown(wait=True)
437474
TelemetryHttpClient.close()
438475
except Exception as e:
@@ -458,6 +495,7 @@ def connection_failure_log(
458495
session_id_hex=UNAUTH_DUMMY_SESSION_ID,
459496
auth_provider=None,
460497
host_url=host_url,
498+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
461499
)
462500

463501
telemetry_client = TelemetryClientFactory.get_telemetry_client(

tests/unit/test_telemetry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def mock_telemetry_client():
3030
auth_provider=auth_provider,
3131
host_url="test-host.com",
3232
executor=executor,
33+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
3334
)
3435

3536

@@ -214,6 +215,7 @@ def test_client_lifecycle_flow(self):
214215
session_id_hex=session_id_hex,
215216
auth_provider=auth_provider,
216217
host_url="test-host.com",
218+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
217219
)
218220

219221
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -238,6 +240,7 @@ def test_disabled_telemetry_flow(self):
238240
session_id_hex=session_id_hex,
239241
auth_provider=None,
240242
host_url="test-host.com",
243+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
241244
)
242245

243246
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -257,6 +260,7 @@ def test_factory_error_handling(self):
257260
session_id_hex=session_id,
258261
auth_provider=AccessTokenAuthProvider("token"),
259262
host_url="test-host.com",
263+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
260264
)
261265

262266
# Should fall back to NoopTelemetryClient
@@ -275,6 +279,7 @@ def test_factory_shutdown_flow(self):
275279
session_id_hex=session,
276280
auth_provider=AccessTokenAuthProvider("token"),
277281
host_url="test-host.com",
282+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
278283
)
279284

280285
# Factory should be initialized

tests/unit/test_telemetry_retry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_client(self, session_id, num_retries=3):
4747
session_id_hex=session_id,
4848
auth_provider=None,
4949
host_url="test.databricks.com",
50+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
5051
)
5152
client = TelemetryClientFactory.get_telemetry_client(session_id)
5253

0 commit comments

Comments
 (0)