@@ -988,7 +988,7 @@ defmodule Axon.LayersTest do
988988 [ [ 6.1974 , 7.1974 , 8.1974 ] , [ 7.8947 , 8.8947 , 9.8947 ] , [ 9.5921 , 10.5921 , 11.5921 ] ]
989989 ]
990990 ] ) ,
991- atol: 1.0e-4
991+ atol: 1.0e-3
992992 )
993993
994994 assert_all_close (
@@ -1000,7 +1000,7 @@ defmodule Axon.LayersTest do
10001000 [ [ 6.3724 , 7.3724 , 8.3724 ] , [ 8.2449 , 9.2449 , 10.2449 ] , [ 10.1173 , 11.1173 , 12.1173 ] ]
10011001 ]
10021002 ] ) ,
1003- atol: 1.0e-4
1003+ atol: 1.0e-3
10041004 )
10051005
10061006 assert_all_close (
@@ -1012,7 +1012,7 @@ defmodule Axon.LayersTest do
10121012 [ [ 6.4508 , 7.4508 , 8.4508 ] , [ 8.4016 , 9.4016 , 10.4016 ] , [ 10.3525 , 11.3525 , 12.3525 ] ]
10131013 ]
10141014 ] ) ,
1015- atol: 1.0e-4
1015+ atol: 1.0e-3
10161016 )
10171017 end
10181018
@@ -1036,7 +1036,7 @@ defmodule Axon.LayersTest do
10361036 ]
10371037 ]
10381038 ] ) ,
1039- atol: 1.0e-4
1039+ atol: 1.0e-3
10401040 )
10411041
10421042 # Downscaling (no effect)
@@ -1052,7 +1052,7 @@ defmodule Axon.LayersTest do
10521052 [ [ 6.1974 , 7.1974 , 8.1974 ] , [ 7.8947 , 8.8947 , 9.8947 ] , [ 9.5921 , 10.5921 , 11.5921 ] ]
10531053 ]
10541054 ] ) ,
1055- atol: 1.0e-4
1055+ atol: 1.0e-3
10561056 )
10571057 end
10581058 end
@@ -1723,6 +1723,88 @@ defmodule Axon.LayersTest do
17231723 end
17241724 end
17251725
1726+ describe "rms_norm" do
1727+ test "matches pytorch 2D input" do
1728+ input =
1729+ Nx . tensor ( [
1730+ [ 1.9269 , 1.4873 , 0.9007 , - 2.1055 , 0.6784 , - 1.2345 , - 0.0431 , - 1.6047 ] ,
1731+ [ - 0.7521 , 1.6487 , - 0.3925 , - 1.4036 , - 0.7279 , - 0.5594 , - 0.7688 , 0.7624 ]
1732+ ] )
1733+
1734+ gamma =
1735+ Nx . tensor ( [ 0.4617 , 0.2674 , 0.5349 , 0.8094 , 1.1103 , - 1.6898 , - 0.9890 , 0.9580 ] )
1736+
1737+ expected =
1738+ Nx . tensor ( [
1739+ [ 0.6344 , 0.2836 , 0.3436 , - 1.2153 , 0.5372 , 1.4877 , 0.0304 , - 1.0962 ] ,
1740+ [ - 0.3605 , 0.4576 , - 0.2179 , - 1.1793 , - 0.8390 , 0.9814 , 0.7893 , 0.7582 ]
1741+ ] )
1742+
1743+ actual = Axon.Layers . rms_norm ( input , gamma , epsilon: 1.0e-6 , channel_index: - 1 )
1744+ assert_all_close ( expected , actual , atol: 1.0e-3 )
1745+ end
1746+
1747+ test "matches pytorch 3D input" do
1748+ input =
1749+ Nx . tensor ( [
1750+ [
1751+ [ - 1.3847 , - 0.8712 , - 0.2234 , 1.7174 , 0.3189 , - 0.4245 ] ,
1752+ [ 0.3057 , - 0.7746 , - 1.5576 , 0.9956 , - 0.8798 , - 0.6011 ] ,
1753+ [ - 1.2742 , 2.1228 , - 1.2347 , - 0.4879 , - 0.9138 , - 0.6581 ] ,
1754+ [ 0.0780 , 0.5258 , - 0.4880 , 1.1914 , - 0.8140 , - 0.7360 ]
1755+ ] ,
1756+ [
1757+ [ - 1.4032 , 0.0360 , - 0.0635 , 0.6756 , - 0.0978 , 1.8446 ] ,
1758+ [ - 1.1845 , 1.3835 , 1.4451 , 0.8564 , 2.2181 , 0.5232 ] ,
1759+ [ 0.3466 , - 0.1973 , - 1.0546 , 1.2780 , - 0.1722 , 0.5238 ] ,
1760+ [ 0.0566 , 0.4263 , 0.5750 , - 0.6417 , - 2.2064 , - 0.7508 ]
1761+ ]
1762+ ] )
1763+
1764+ gamma =
1765+ Nx . tensor ( [ 0.4679 , - 0.2049 , - 0.7409 , 0.3618 , 1.9199 , - 0.2254 ] )
1766+
1767+ expected =
1768+ Nx . tensor ( [
1769+ [
1770+ [ - 0.6502 , 0.1792 , 0.1661 , 0.6236 , 0.6144 , 0.0960 ] ,
1771+ [ 0.1530 , 0.1698 , 1.2341 , 0.3853 , - 1.8064 , 0.1449 ] ,
1772+ [ - 0.4825 , - 0.3521 , 0.7403 , - 0.1429 , - 1.4199 , 0.1201 ] ,
1773+ [ 0.0504 , - 0.1488 , 0.4994 , 0.5955 , - 2.1588 , 0.2291 ]
1774+ ] ,
1775+ [
1776+ [ - 0.6653 , - 0.0075 , 0.0477 , 0.2477 , - 0.1903 , - 0.4213 ] ,
1777+ [ - 0.4033 , - 0.2063 , - 0.7791 , 0.2255 , 3.0986 , - 0.0858 ] ,
1778+ [ 0.2218 , 0.0553 , 1.0685 , 0.6324 , - 0.4521 , - 0.1614 ] ,
1779+ [ 0.0257 , - 0.0849 , - 0.4138 , - 0.2255 , - 4.1147 , 0.1644 ]
1780+ ]
1781+ ] )
1782+
1783+ actual = Axon.Layers . rms_norm ( input , gamma , epsilon: 1.0e-6 , channel_index: - 1 )
1784+ assert_all_close ( expected , actual , atol: 1.0e-3 )
1785+ end
1786+
1787+ test "matches pytorch with ones weight" do
1788+ input =
1789+ Nx . tensor ( [
1790+ [ 0.6127 , - 1.1754 , - 0.7646 , - 0.6666 ] ,
1791+ [ 0.7444 , - 0.6453 , - 1.3890 , - 0.2730 ]
1792+ ] )
1793+
1794+ gamma =
1795+ Nx . tensor ( [ 1.0000 , 1.0000 , 1.0000 , 1.0000 ] )
1796+
1797+ expected =
1798+ Nx . tensor ( [
1799+ [ 0.7342 , - 1.4084 , - 0.9163 , - 0.7987 ] ,
1800+ [ 0.8632 , - 0.7483 , - 1.6108 , - 0.3165 ]
1801+ ] )
1802+
1803+ actual = Axon.Layers . rms_norm ( input , gamma , epsilon: 1.0e-6 , channel_index: - 1 )
1804+ assert_all_close ( expected , actual , atol: 1.0e-3 )
1805+ end
1806+ end
1807+
17261808 describe "batch_norm" do
17271809 test "matches pytorch when variance < epsilon" do
17281810 input_val = - 0.002805
0 commit comments