Skip to content

Commit 4ad7186

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
Support core axis index in the device_id dict for async copy and semaphore.
PiperOrigin-RevId: 787329175
1 parent 413c976 commit 4ad7186

File tree

4 files changed

+211
-17
lines changed

4 files changed

+211
-17
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3607,42 +3607,54 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr, collecti
36073607
def _device_id_dict_to_mesh(ctx: LoweringRuleContext, device_id_dict):
36083608
mesh_context = ctx.lowering_context.mesh_context
36093609
assert mesh_context is not None
3610-
zipped_metadata = zip(mesh_context.axis_names, mesh_context.mesh_shape)
3610+
mesh_axis_sizes = dict(zip(mesh_context.axis_names, mesh_context.mesh_shape))
3611+
core_axis_name, grid_names = None, ctx.lowering_context.grid_names
3612+
if grid_names:
3613+
if len(grid_names) > 1:
3614+
raise NotImplementedError(
3615+
"Unable to determine core axis name if grid_names is more than 1."
3616+
)
3617+
mesh_axis_sizes.update(
3618+
dict(zip(grid_names, ctx.lowering_context.grid_sizes))
3619+
)
3620+
core_axis_name = grid_names[0]
36113621
physical_axis_dict = {}
36123622
# Handle joint axes (i.e., one logical axis over >1 physical axes)
36133623
for axis, idx in device_id_dict.items():
36143624
if isinstance(axis, tuple):
3615-
axis_names, mesh_shape = unzip2(
3616-
(name, shape) for name, shape in zipped_metadata if name in axis
3617-
)
3618-
for axis_index, axis_name in enumerate(axis_names):
3619-
axis_size = ir_constant(mesh_shape[axis_index])
3625+
axes_dimensions = [mesh_axis_sizes[name] for name in axis]
3626+
for axis_index, axis_name in enumerate(axis):
3627+
axis_size = ir_constant(mesh_axis_sizes[axis_name])
36203628
minor_divisor = ir_constant(
3621-
np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32)
3629+
np.prod(axes_dimensions[axis_index + 1 :], dtype=np.int32)
36223630
)
36233631
device_idx = arith.remsi(arith.divsi(idx, minor_divisor), axis_size)
36243632
physical_axis_dict[axis_name] = device_idx
36253633
else:
36263634
physical_axis_dict[axis] = idx
3635+
core_index = None
3636+
if core_axis_name and core_axis_name in physical_axis_dict:
3637+
core_index = physical_axis_dict.pop(grid_names[0])
36273638
device_id = []
36283639
for axis in mesh_context.axis_names:
36293640
if axis in physical_axis_dict:
36303641
device_id.append(physical_axis_dict[axis])
36313642
else:
36323643
device_id.append(_axis_index_rule(ctx, axis_name=axis))
3633-
return tuple(device_id)
3644+
return tuple(device_id), core_index
36343645

36353646

36363647
def _device_id_to_logical(
36373648
ctx: LoweringRuleContext, device_id,
36383649
device_id_type: primitives.DeviceIdType):
3650+
core_index = None
36393651
if isinstance(device_id, dict):
36403652
if device_id_type is not primitives.DeviceIdType.MESH:
36413653
raise ValueError(
36423654
"`device_id_type` must be MESH if `device_id` is a dict,"
36433655
f" got: {device_id_type = }."
36443656
)
3645-
device_id = _device_id_dict_to_mesh(ctx, device_id)
3657+
device_id, core_index = _device_id_dict_to_mesh(ctx, device_id)
36463658
if device_id_type is primitives.DeviceIdType.MESH:
36473659
assert (mesh_context := ctx.lowering_context.mesh_context)
36483660
# Mesh means we are passed the mesh coordinates for the device
@@ -3651,16 +3663,16 @@ def _device_id_to_logical(
36513663

36523664
i32 = ir.IntegerType.get_signless(32)
36533665
if len(device_ids) == 0:
3654-
return arith.constant(i32, 0)
3666+
return arith.constant(i32, 0), core_index
36553667
return functools.reduce(
36563668
arith.addi,
36573669
(
36583670
arith.muli(a, arith.constant(i32, b))
36593671
for a, b in zip(device_ids, mesh_strides)
36603672
),
3661-
)
3673+
), core_index
36623674
elif device_id_type is primitives.DeviceIdType.LOGICAL:
3663-
return device_id
3675+
return device_id, None
36643676
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
36653677

36663678

@@ -3704,7 +3716,13 @@ def _semaphore_signal_lowering_rule(
37043716
)
37053717
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
37063718
if device_id is not None:
3707-
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
3719+
device_id, core_id = _device_id_to_logical(ctx, device_id, device_id_type)
3720+
if core_id is not None:
3721+
if core_index is not None:
3722+
raise ValueError(
3723+
"Cannot specify both `core_index` and the core axis in `device_id`."
3724+
)
3725+
core_index = core_id
37083726
tpu.sem_signal(sem, value, device_id=device_id, core_id=core_index)
37093727
return []
37103728

@@ -3757,8 +3775,9 @@ def _dma_start_lowering_rule(
37573775
dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms
37583776
)
37593777
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
3778+
core_id = None
37603779
if device_id is not None:
3761-
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
3780+
device_id, core_id = _device_id_to_logical(ctx, device_id, device_id_type)
37623781
priority_kwarg = {"priority": priority}
37633782
if jaxlib_version < (0, 5, 4):
37643783
priority_kwarg = {}
@@ -3768,6 +3787,7 @@ def _dma_start_lowering_rule(
37683787
sem,
37693788
source_semaphore=src_sem,
37703789
device_id=device_id,
3790+
core_id=core_id,
37713791
**priority_kwarg,
37723792
)
37733793
return []
@@ -3796,13 +3816,14 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
37963816
dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms)
37973817
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
37983818

3819+
core_id = None
37993820
if device_id is not None:
3800-
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
3821+
device_id, core_id = _device_id_to_logical(ctx, device_id, device_id_type)
38013822

38023823
if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 7, 27):
3803-
tpu.wait_dma2(sem, src, dst)
3824+
tpu.wait_dma2(sem, src, dst, core_id=core_id)
38043825
else:
3805-
tpu.wait_dma2(sem, src, dst, device_id=device_id)
3826+
tpu.wait_dma2(sem, src, dst, device_id=device_id, core_id=core_id)
38063827
return []
38073828

38083829

tests/pallas/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ jax_multiplatform_test(
523523
enable_configs = [
524524
"tpu_v5e_x8",
525525
"tpu_v5p",
526+
"tpu_v5p_x4",
526527
],
527528
deps = [
528529
"//jax:pallas_tpu",

tests/pallas/tpu_pallas_async_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,84 @@ def _():
454454
np.testing.assert_array_equal(pallas_out[:xlocal],
455455
pallas_out[xlocal:(2*xlocal)])
456456

457+
def test_axis_dict_with_core_single_device(self):
458+
if jax.device_count() > 2 or (jax.devices()[0].num_cores) != 2:
459+
self.skipTest('Testing single device two cores')
460+
mesh = jax.make_mesh((jax.device_count(),), ('device',))
461+
ddim = jax.device_count()
462+
tcmesh = pltpu.create_tensorcore_mesh('core')
463+
pspec = P('device', None)
464+
sharding = jax.sharding.NamedSharding(mesh, pspec)
465+
466+
# Array is fully sharded.
467+
xlocal, ylocal = 8, 256
468+
input_arr = jnp.arange(xlocal * ddim * ylocal, dtype=jnp.int32).reshape(
469+
(xlocal * ddim, ylocal)
470+
)
471+
input_arr = jax.device_put(input_arr, sharding)
472+
473+
def core_copy(refs):
474+
in_ref, out_ref = refs
475+
476+
@pl.core_map(tcmesh, compiler_params=pltpu.CompilerParams(collective_id=7))
477+
def _():
478+
num_cores = jax.lax.axis_size('core')
479+
slc_size = ylocal // num_cores
480+
vmem_shape = (xlocal, slc_size)
481+
482+
# This runs on every core, for every vmem iterations
483+
def alloc(out_vmem_ref, sem, send_sem, recv_sem):
484+
core_index = jax.lax.axis_index('core')
485+
slc = pl.ds(core_index * slc_size, slc_size)
486+
487+
# Make sure all cores have entered run_scoped.
488+
sem0 = pltpu.get_barrier_semaphore()
489+
for i in range(ddim):
490+
for j in range(num_cores):
491+
pltpu.semaphore_signal(
492+
sem0, 1, device_id={'device': i, 'core': j},
493+
device_id_type=pltpu.DeviceIdType.MESH)
494+
pltpu.semaphore_wait(sem0, ddim * num_cores)
495+
496+
# Identity function by default
497+
pltpu.async_copy(in_ref.at[:, slc], out_ref.at[:, slc], sem).wait()
498+
499+
copy_c0_to_c1 = pltpu.make_async_remote_copy(
500+
src_ref=in_ref.at[:, slc],
501+
dst_ref=out_vmem_ref,
502+
send_sem=send_sem,
503+
recv_sem=recv_sem,
504+
device_id={'core': 1},
505+
device_id_type=pltpu.DeviceIdType.MESH,
506+
)
507+
508+
@pl.when(core_index == 0)
509+
def _():
510+
copy_c0_to_c1.start()
511+
copy_c0_to_c1.wait_send()
512+
513+
@pl.when(core_index == 1)
514+
def _():
515+
copy_c0_to_c1.wait_recv()
516+
pltpu.async_copy(out_vmem_ref, out_ref.at[:, slc], sem).wait()
517+
518+
pl.run_scoped(
519+
alloc,
520+
pltpu.VMEM(vmem_shape, out_ref.dtype),
521+
*([pltpu.SemaphoreType.DMA] * 3),
522+
)
523+
524+
@partial(jax.shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=False)
525+
def run_core_kernel(input):
526+
output = jnp.zeros_like(input)
527+
_, output = pl.run_state(core_copy)((input, output))
528+
return output
529+
pallas_out = jax.jit(run_core_kernel)(input_arr)
530+
531+
# The device=0 core=1 slice was flushed with device=0 core=0 contents
532+
np.testing.assert_array_equal(pallas_out[:, 128:], input_arr[:, :128])
533+
np.testing.assert_array_equal(pallas_out[:, :128], input_arr[:, :128])
534+
457535

458536
def make_async_remote_copy(axis_name: str, direction: str = 'right',
459537
target_memory_space=None):

tests/pallas/tpu_pallas_distributed_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,100 @@ def body(x):
326326
)(x)
327327
np.testing.assert_allclose(y, x)
328328

329+
@parameterized.product(joint_axis=[True, False])
330+
def test_axis_dict_with_core_multi_device(self, joint_axis):
331+
if jax.device_count() < 2:
332+
self.skipTest('Requires at least 2 devices for DMAs.')
333+
if (cdim := jax.devices()[0].num_cores) < 2:
334+
self.skipTest('Requires a TPU with at least 2 cores.')
335+
mesh = jax.make_mesh((jax.device_count(),), ('device',))
336+
ddim = jax.device_count()
337+
tcmesh = pltpu.create_tensorcore_mesh('core')
338+
pspec = P('device', None)
339+
sharding = jax.sharding.NamedSharding(mesh, pspec)
340+
341+
# Array is fully sharded.
342+
xlocal, ylocal = 8, 256
343+
input_arr = jnp.arange(xlocal * ddim * ylocal, dtype=jnp.int32).reshape(
344+
(xlocal * ddim, ylocal)
345+
)
346+
input_arr = jax.device_put(input_arr, sharding)
347+
348+
def core_copy(refs):
349+
in_ref, out_ref = refs
350+
351+
@pl.core_map(tcmesh, compiler_params=pltpu.CompilerParams(collective_id=7))
352+
def _():
353+
num_cores = jax.lax.axis_size('core')
354+
slc_size = ylocal // num_cores
355+
vmem_shape = (xlocal, slc_size)
356+
357+
# This runs on every core, for every vmem iterations
358+
def alloc(out_vmem_ref, sem, send_sem, recv_sem):
359+
core_index = jax.lax.axis_index('core')
360+
device_index = jax.lax.axis_index('device')
361+
slc = pl.ds(core_index * slc_size, slc_size)
362+
363+
# Make sure all cores have entered run_scoped.
364+
sem0 = pltpu.get_barrier_semaphore()
365+
for i in range(ddim):
366+
for j in range(num_cores):
367+
pltpu.semaphore_signal(
368+
sem0, 1, device_id={'device': i, 'core': j},
369+
device_id_type=pltpu.DeviceIdType.MESH)
370+
pltpu.semaphore_wait(sem0, ddim * num_cores)
371+
372+
# Identity function by default
373+
pltpu.async_copy(in_ref.at[:, slc], out_ref.at[:, slc], sem).wait()
374+
375+
if joint_axis:
376+
device_id = {('device', 'core'): cdim + 1}
377+
else:
378+
device_id = {'device': 1, 'core': 1}
379+
copy_d0c0_to_d1c1 = pltpu.make_async_remote_copy(
380+
src_ref=in_ref.at[:, slc],
381+
dst_ref=out_vmem_ref,
382+
send_sem=send_sem,
383+
recv_sem=recv_sem,
384+
device_id=device_id,
385+
device_id_type=pltpu.DeviceIdType.MESH,
386+
)
387+
388+
@pl.when(device_index == 0)
389+
def _():
390+
@pl.when(core_index == 0)
391+
def _():
392+
copy_d0c0_to_d1c1.start()
393+
copy_d0c0_to_d1c1.wait_send()
394+
395+
@pl.when(device_index == 1)
396+
def _():
397+
@pl.when(core_index == 1)
398+
def _():
399+
copy_d0c0_to_d1c1.wait_recv()
400+
pltpu.async_copy(out_vmem_ref, out_ref.at[:, slc], sem).wait()
401+
402+
pl.run_scoped(
403+
alloc,
404+
pltpu.VMEM(vmem_shape, out_ref.dtype),
405+
*([pltpu.SemaphoreType.DMA] * 3),
406+
)
407+
408+
@partial(jax.shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=False)
409+
def run_core_kernel(input):
410+
output = jnp.zeros_like(input)
411+
_, output = pl.run_state(core_copy)((input, output))
412+
return output
413+
pallas_out = jax.jit(run_core_kernel)(input_arr)
414+
415+
# The device=1 core=1 slice was flushed with device=0 core=0 contents
416+
np.testing.assert_array_equal(pallas_out[8:16, 128:], input_arr[:8, :128])
417+
# Mask that slice out and all should be the same.
418+
mask = jnp.zeros((8, 128), jnp.int32)
419+
masked_in = jax.lax.dynamic_update_slice(input_arr, mask, (8, 128))
420+
masked_out = jax.lax.dynamic_update_slice(pallas_out, mask, (8, 128))
421+
np.testing.assert_array_equal(masked_in, masked_out)
422+
329423

330424
class PallasCallRemoteDMAInterpretTest(parameterized.TestCase):
331425

0 commit comments

Comments
 (0)