From 18ec3317e500a6fee72fc8c24552c21808437bef Mon Sep 17 00:00:00 2001 From: Damon Date: Mon, 30 Dec 2024 12:04:36 -0800 Subject: [PATCH] Fix missing worker_id from interceptor (#33453) * Fix missing worker_id from interceptor * Add worker_id attribute * reorder and default parameters for GrpcStateHandlerFactory --- sdks/python/apache_beam/runners/worker/sdk_worker.py | 11 +++++++---- .../apache_beam/runners/worker/worker_status.py | 4 +++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index b091220a06b5..3cb1a26b77f1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -201,7 +201,9 @@ def __init__( self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( - self._state_cache, credentials) + state_cache=self._state_cache, + credentials=credentials, + worker_id=self._worker_id) self._profiler_factory = profiler_factory self.data_sampler = data_sampler self.runner_capabilities = runner_capabilities @@ -893,13 +895,14 @@ class GrpcStateHandlerFactory(StateHandlerFactory): Caches the created channels by ``state descriptor url``. """ - def __init__(self, state_cache, credentials=None): - # type: (StateCache, Optional[grpc.ChannelCredentials]) -> None + def __init__(self, state_cache, credentials=None, worker_id=None): + # type: (StateCache, Optional[grpc.ChannelCredentials], Optional[str]) -> None self._state_handler_cache = {} # type: Dict[str, CachingStateHandler] self._lock = threading.Lock() self._throwing_state_handler = ThrowingStateHandler() self._credentials = credentials self._state_cache = state_cache + self._worker_id = worker_id def create_state_handler(self, api_service_descriptor): # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler @@ -926,7 +929,7 @@ def create_state_handler(self, api_service_descriptor): _LOGGER.info('State channel established.') # Add workerId to the grpc channel grpc_channel = grpc.intercept_channel( - grpc_channel, WorkerIdInterceptor()) + grpc_channel, WorkerIdInterceptor(self._worker_id)) self._state_handler_cache[url] = GlobalCachingStateHandler( self._state_cache, GrpcStateHandler( diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py index 2271b4495d79..ecd4dc4e02c0 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status.py +++ b/sdks/python/apache_beam/runners/worker/worker_status.py @@ -151,6 +151,7 @@ def __init__( bundle_process_cache=None, state_cache=None, enable_heap_dump=False, + worker_id=None, log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS): """Initialize FnApiWorkerStatusHandler. @@ -164,7 +165,8 @@ def __init__( self._state_cache = state_cache ch = GRPCChannelFactory.insecure_channel(status_address) grpc.channel_ready_future(ch).result(timeout=60) - self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor()) + self._status_channel = grpc.intercept_channel( + ch, WorkerIdInterceptor(worker_id)) self._status_stub = beam_fn_api_pb2_grpc.BeamFnWorkerStatusStub( self._status_channel) self._responses = queue.Queue()