Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/test_unary_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_general_unary_pointwise_perf(op_name, torch_op, dtypes):
("tan_", torch.tan_, FLOAT_DTYPES),
("tanh_", torch.tanh_, FLOAT_DTYPES),
("atan_", torch.atan_, FLOAT_DTYPES),
("acos_", torch.acos_, FLOAT_DTYPES),
# Bitwise operations
("bitwise_not_", lambda a: a.bitwise_not_(), INT_DTYPES),
]
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flag_gems.ops.abs import abs, abs_
from flag_gems.ops.acos import acos
from flag_gems.ops.acos import acos, acos_
from flag_gems.ops.add import add, add_
from flag_gems.ops.addcdiv import addcdiv
from flag_gems.ops.addcmul import addcmul
Expand Down Expand Up @@ -243,6 +243,7 @@
"abs",
"abs_",
"acos",
"acos_",
"add",
"add_",
"addcdiv",
Expand Down
56 changes: 51 additions & 5 deletions src/flag_gems/ops/acos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,57 @@
@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
@triton.jit()
def acos_kernel(x):
# TODO: use flag_gems.utils.tl_extra_shim help apis
return _acos(x.to(tl.float32))
"""
Compute arccos(x) for input x.

The arccos function returns values in the range [0, π] for input values in [-1, 1].

Args:
x: Input tensor (will be converted to float32 for computation)

Returns:
Tensor with arccos(x) computed element-wise
"""
return _acos(x.to(tl.float32))


def acos(x):
logger.debug("GEMS ACOS FORWARD")
y = acos_kernel(x)
return y
"""
Computes the inverse cosine (arccos) of each element in input.

Args:
x (Tensor): Input tensor with values in [-1, 1]

Returns:
Tensor: Output tensor with values in [0, π]

Example:
>>> x = torch.tensor([0.0, 0.5, 1.0])
>>> torch.acos(x)
tensor([1.5708, 1.0472, 0.0000])
"""
logger.debug("GEMS ACOS FORWARD")
y = acos_kernel(x)
return y


def acos_(x):
"""
In-place version of acos.

Computes the inverse cosine of each element in input, modifying the tensor in-place.

Args:
x (Tensor): Input tensor with values in [-1, 1] (modified in-place)

Returns:
Tensor: The modified input tensor

Example:
>>> x = torch.tensor([0.0, 0.5, 1.0])
>>> torch.acos_(x)
tensor([1.5708, 1.0472, 0.0000])
"""
logger.debug("GEMS ACOS_ INPLACE")
acos_kernel(x, out0=x)
return x
15 changes: 15 additions & 0 deletions tests/test_unary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def test_accuracy_acos(shape, dtype):
gems_assert_close(res_out, ref_out, dtype, True)


@pytest.mark.inplace
@pytest.mark.acos_
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_acos_(shape, dtype):
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp = to_reference(inp.clone(), True)

ref_out = torch.acos_(ref_inp)
with flag_gems.use_gems():
res_out = torch.acos_(inp)

gems_assert_close(res_out, ref_out, dtype, True)


@pytest.mark.angle
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize(
Expand Down