@@ -4474,17 +4474,38 @@ def make_nd_idx(dim_length: int, indices: int, ndim: int):
44744474
44754475
44764476def setitem_sample_generator (op , device , dtype , requires_grad , ** kwargs ):
4477- for sample in getitem_sample_generator (op , device , dtype , requires_grad , ** kwargs ):
4478- tensor , key = sample .args
4477+ make = partial (make_tensor , device = device , dtype = dtype , requires_grad = requires_grad )
44794478
4480- indexed_tensor = tensor [key ]
4481- # getitem already has lots of cases, and doubling it is too time-consuming
4482- # value = make_tensor(indexed_tensor.shape, device=device, dtype=dtype, requires_grad=requires_grad)
4479+ def _make_setitem_sample (tensor , key ):
4480+ indexed_shape = tensor [key ].shape
4481+
4482+ # Tests for getitem are already slow, and doubling them is too time-consuming
4483+ # value = make_tensor(indexed_shape, device=device, dtype=dtype, requires_grad=requires_grad)
44834484 # yield SampleInput(tensor, key, value)
44844485
4485- pre_broadcast_shape = tuple (random .choice ((s , 1 )) for s in indexed_tensor .shape )
4486- value = make_tensor (pre_broadcast_shape , device = device , dtype = dtype , requires_grad = requires_grad )
4487- yield SampleInput (tensor , key , value )
4486+ pre_broadcast_shape = tuple (random .choice ((s , 1 )) for s in indexed_shape )
4487+ pre_broadcast_value = make_tensor (pre_broadcast_shape , device = device , dtype = dtype , requires_grad = requires_grad )
4488+ return SampleInput (tensor , key , pre_broadcast_value )
4489+
4490+ for sample in getitem_sample_generator (op , device , dtype , requires_grad , ** kwargs ):
4491+ tensor , key = sample .args
4492+ yield _make_setitem_sample (tensor , key )
4493+
4494+ # Boolean mask indexing
4495+ boolean_mask_cases = [
4496+ ((6 ,), (torch .tensor ([True , False , True , False , True , False ]),)),
4497+ ((2 , 3 ), (torch .tensor ([[True , False , True ], [False , True , False ]]),)),
4498+ ((2 , 3 , 4 ), ([False , True ], [False , True , False ], slice (None ))),
4499+ ((2 , 3 , 4 ), (torch .tensor ([True , False ]), [1 , 1 ], slice (None ))),
4500+ ((2 , 3 , 4 ), (torch .tensor ([False , False ]), [1 , 1 ], slice (None ))),
4501+ ((2 , 3 , 4 ), (1 , torch .tensor ([True , False , True ]), slice (None ))),
4502+ ((2 , 3 ), (torch .tensor ([True , False ]), None , [0 , 2 ])),
4503+ ((4 , 2 , 3 ), (Ellipsis , [False , True , False ])),
4504+ ]
4505+
4506+ for shape , key in boolean_mask_cases :
4507+ tensor = make (shape )
4508+ yield _make_setitem_sample (tensor , key )
44884509
44894510
44904511setitem_opinfo = OpInfo (
0 commit comments