From ffa892d42affdc4510f15531e44aba4b5697b669 Mon Sep 17 00:00:00 2001 From: minalkharat-cmd Date: Sat, 31 Jan 2026 13:33:18 +0530 Subject: [PATCH 1/2] Add missing acos_ in-place function and improve documentation - Add missing in-place version acos_() function - Improve documentation with detailed docstrings - Add usage examples in docstrings - Maintain compatibility with existing acos() function Part of completing issue #883 Signed-off-by: minalkharat-cmd --- src/flag_gems/ops/acos.py | 56 +++++++++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/src/flag_gems/ops/acos.py b/src/flag_gems/ops/acos.py index c34d3a420..1ec96cf38 100644 --- a/src/flag_gems/ops/acos.py +++ b/src/flag_gems/ops/acos.py @@ -12,11 +12,57 @@ @pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit() def acos_kernel(x): - # TODO: use flag_gems.utils.tl_extra_shim help apis - return _acos(x.to(tl.float32)) + """ + Compute arccos(x) for input x. + + The arccos function returns values in the range [0, π] for input values in [-1, 1]. + + Args: + x: Input tensor (will be converted to float32 for computation) + + Returns: + Tensor with arccos(x) computed element-wise + """ + return _acos(x.to(tl.float32)) def acos(x): - logger.debug("GEMS ACOS FORWARD") - y = acos_kernel(x) - return y + """ + Computes the inverse cosine (arccos) of each element in input. + + Args: + x (Tensor): Input tensor with values in [-1, 1] + + Returns: + Tensor: Output tensor with values in [0, π] + + Example: + >>> x = torch.tensor([0.0, 0.5, 1.0]) + >>> torch.acos(x) + tensor([1.5708, 1.0472, 0.0000]) + """ + logger.debug("GEMS ACOS FORWARD") + y = acos_kernel(x) + return y + + +def acos_(x): + """ + In-place version of acos. + + Computes the inverse cosine of each element in input, modifying the tensor in-place. + + Args: + x (Tensor): Input tensor with values in [-1, 1] (modified in-place) + + Returns: + Tensor: The modified input tensor + + Example: + >>> x = torch.tensor([0.0, 0.5, 1.0]) + >>> torch.acos_(x) + tensor([1.5708, 1.0472, 0.0000]) + """ + logger.debug("GEMS ACOS_ INPLACE") + acos_kernel(x, out0=x) + return x From a501535dcd57869446b1694cb81d5f2ff8c54dce Mon Sep 17 00:00:00 2001 From: minalkharat-cmd Date: Sat, 31 Jan 2026 14:08:16 +0530 Subject: [PATCH 2/2] Add tests, benchmarks and export acos_ operator - Completes #883 --- benchmark/test_unary_pointwise_perf.py | 1 + src/flag_gems/ops/__init__.py | 3 ++- tests/test_unary_pointwise_ops.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/benchmark/test_unary_pointwise_perf.py b/benchmark/test_unary_pointwise_perf.py index 8038cf680..b5d74eb3b 100644 --- a/benchmark/test_unary_pointwise_perf.py +++ b/benchmark/test_unary_pointwise_perf.py @@ -124,6 +124,7 @@ def test_general_unary_pointwise_perf(op_name, torch_op, dtypes): ("tan_", torch.tan_, FLOAT_DTYPES), ("tanh_", torch.tanh_, FLOAT_DTYPES), ("atan_", torch.atan_, FLOAT_DTYPES), + ("acos_", torch.acos_, FLOAT_DTYPES), # Bitwise operations ("bitwise_not_", lambda a: a.bitwise_not_(), INT_DTYPES), ] diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 3cd2c8b1f..64aaa3a9a 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -1,5 +1,5 @@ from flag_gems.ops.abs import abs, abs_ -from flag_gems.ops.acos import acos +from flag_gems.ops.acos import acos, acos_ from flag_gems.ops.add import add, add_ from flag_gems.ops.addcdiv import addcdiv from flag_gems.ops.addcmul import addcmul @@ -243,6 +243,7 @@ "abs", "abs_", "acos", + "acos_", "add", "add_", "addcdiv", diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index 0ff8c874d..61fe8a5e7 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -72,6 +72,21 @@ def test_accuracy_acos(shape, dtype): gems_assert_close(res_out, ref_out, dtype, True) +@pytest.mark.inplace +@pytest.mark.acos_ +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_acos_(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + ref_inp = to_reference(inp.clone(), True) + + ref_out = torch.acos_(ref_inp) + with flag_gems.use_gems(): + res_out = torch.acos_(inp) + + gems_assert_close(res_out, ref_out, dtype, True) + + @pytest.mark.angle @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize(