Skip to content

A precision mismatch issue occurred when integrating this library with automatic mixed precision. #63

@Sylence8

Description

@Sylence8

After enabling AMP, I used the DWTForward and DWTInverse classes provided by the library.
This led to a situation in dwt/lowlevel.py where, for example, at line 356:lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2)
the tensor low has float16 precision, while the other inputs remain float32.
This causes a type mismatch during backward, resulting in the following error:

 y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
[rank0]: RuntimeError: expected scalar type Half but found Float

Is there a recommended way to resolve this issue, or is there any plan to update the library to address it?🤔

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions