Skip to content

Commit 03031a5

Browse files
authored
Add rms_norm layer (#623)
1 parent d846c97 commit 03031a5

File tree

3 files changed

+250
-5
lines changed

3 files changed

+250
-5
lines changed

lib/axon.ex

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,6 +2070,74 @@ defmodule Axon do
20702070
)
20712071
end
20722072

2073+
@doc ~S"""
2074+
Adds an RMS normalization layer to the network.
2075+
2076+
RMS normalization normalizes the input tensor using only the root
2077+
mean square, without centering by the mean. This is computationally
2078+
simpler than layer normalization while achieving similar results.
2079+
2080+
See `Axon.Layers.rms_norm/3` for more details.
2081+
2082+
$$y = \frac{x}{\sqrt{E[x^2] + \epsilon}} * (\text{shift} + \gamma)$$
2083+
2084+
## Options
2085+
2086+
* `:name` - layer name.
2087+
2088+
* `:gamma_initializer` - gamma parameter initializer. Defaults
2089+
to `:ones`.
2090+
2091+
* `:channel_index` - input feature index used for calculating
2092+
the root mean square. Defaults to `-1`.
2093+
2094+
* `:epsilon` - numerical stability term. Defaults to `1.0e-6`.
2095+
2096+
* `:shift` - numeric shift added to gamma before scaling.
2097+
Defaults to `0.0`.
2098+
2099+
* `:upcast` - adds explicit type casting to make sure the norm
2100+
is computed in high numerical precision. Either of:
2101+
2102+
* `:normalization` (default) - upcasts only the input normalization
2103+
part
2104+
2105+
* `:all` - upcasts both input normalization and the scaling
2106+
expression
2107+
2108+
## References
2109+
2110+
* [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)
2111+
2112+
"""
2113+
@doc type: :normalization
2114+
def rms_norm(%Axon{} = x, opts \\ []) do
2115+
opts =
2116+
Keyword.validate!(opts, [
2117+
:name,
2118+
:meta,
2119+
gamma_initializer: :ones,
2120+
channel_index: -1,
2121+
epsilon: 1.0e-6,
2122+
shift: 0.0,
2123+
upcast: :normalization
2124+
])
2125+
2126+
channel_index = opts[:channel_index]
2127+
gamma_shape = &Axon.Shape.norm_param(&1, channel_index)
2128+
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
2129+
2130+
layer(:rms_norm, [x, gamma],
2131+
name: opts[:name],
2132+
meta: opts[:meta],
2133+
epsilon: opts[:epsilon],
2134+
channel_index: channel_index,
2135+
shift: opts[:shift],
2136+
upcast: opts[:upcast],
2137+
op_name: :rms_norm
2138+
)
2139+
end
2140+
20732141
@doc """
20742142
Applies the given `Nx` expression to the input.
20752143

lib/axon/layers.ex

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,101 @@ defmodule Axon.Layers do
13011301
Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input))
13021302
end
13031303

1304+
@doc ~S"""
1305+
Functional implementation of RMS normalization.
1306+
1307+
Normalizes the input by calculating the root mean square of the
1308+
input tensor along the given feature dimension `:channel_index`.
1309+
Unlike layer normalization, RMS normalization does not center the
1310+
input by subtracting the mean.
1311+
1312+
$$y = \frac{x}{\sqrt{E[x^2] + \epsilon}} * (\text{shift} + \gamma)$$
1313+
1314+
`gamma` is often a trainable parameter. This method does not maintain
1315+
an EMA of variance.
1316+
1317+
## Options
1318+
1319+
* `:epsilon` - numerical stability term. $\epsilon$ in the above
1320+
formulation. Defaults to `1.0e-6`.
1321+
1322+
* `:channel_index` - channel index used to determine reduction
1323+
axes for RMS calculation. Defaults to `-1`.
1324+
1325+
* `:shift` - numeric shift added to gamma before scaling.
1326+
Defaults to `0.0`.
1327+
1328+
* `:upcast` - controls type casting for numerical precision.
1329+
Either `:normalization` (default) to upcast only the normalization
1330+
part, or `:all` to upcast the entire computation.
1331+
1332+
## References
1333+
1334+
* [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)
1335+
"""
1336+
@doc type: :normalization
1337+
defn rms_norm(input, gamma, opts \\ []) do
1338+
opts =
1339+
keyword!(opts,
1340+
epsilon: 1.0e-6,
1341+
channel_index: -1,
1342+
shift: 0.0,
1343+
upcast: :normalization,
1344+
mode: :inference
1345+
)
1346+
1347+
rms_norm_impl(input, gamma, opts)
1348+
end
1349+
1350+
deftransformp rms_norm_impl(input, gamma, opts) do
1351+
case opts[:upcast] do
1352+
:normalization ->
1353+
rms_norm_upcast_normalization(input, gamma, opts)
1354+
1355+
:all ->
1356+
rms_norm_upcast_all(input, gamma, opts)
1357+
1358+
other ->
1359+
raise ArgumentError,
1360+
"expected :upcast to be either :all or :normalization, got: #{inspect(other)}"
1361+
end
1362+
end
1363+
1364+
defnp rms_norm_upcast_normalization(input, gamma, opts) do
1365+
num_channels = Nx.axis_size(input, opts[:channel_index])
1366+
parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index])
1367+
gamma = Nx.reshape(gamma, parameter_shape)
1368+
1369+
normalized_input =
1370+
input
1371+
|> Nx.as_type(:f32)
1372+
|> rms_normalize(opts)
1373+
|> Nx.as_type(Nx.type(input))
1374+
1375+
normalized_input * (opts[:shift] + gamma)
1376+
end
1377+
1378+
defnp rms_norm_upcast_all(input, gamma, opts) do
1379+
num_channels = Nx.axis_size(input, opts[:channel_index])
1380+
parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index])
1381+
gamma = Nx.reshape(gamma, parameter_shape)
1382+
1383+
input = Nx.as_type(input, :f32)
1384+
gamma = Nx.as_type(gamma, :f32)
1385+
1386+
normalized_input = rms_normalize(input, opts)
1387+
normalized_input * (opts[:shift] + gamma)
1388+
end
1389+
1390+
defnp rms_normalize(input, opts) do
1391+
variance =
1392+
input
1393+
|> Nx.pow(2)
1394+
|> Nx.mean(axes: [opts[:channel_index]], keep_axes: true)
1395+
1396+
input * Nx.rsqrt(variance + opts[:epsilon])
1397+
end
1398+
13041399
@doc ~S"""
13051400
Functional implementation of instance normalization.
13061401

test/axon/layers_test.exs

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ defmodule Axon.LayersTest do
988988
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
989989
]
990990
]),
991-
atol: 1.0e-4
991+
atol: 1.0e-3
992992
)
993993

994994
assert_all_close(
@@ -1000,7 +1000,7 @@ defmodule Axon.LayersTest do
10001000
[[6.3724, 7.3724, 8.3724], [8.2449, 9.2449, 10.2449], [10.1173, 11.1173, 12.1173]]
10011001
]
10021002
]),
1003-
atol: 1.0e-4
1003+
atol: 1.0e-3
10041004
)
10051005

10061006
assert_all_close(
@@ -1012,7 +1012,7 @@ defmodule Axon.LayersTest do
10121012
[[6.4508, 7.4508, 8.4508], [8.4016, 9.4016, 10.4016], [10.3525, 11.3525, 12.3525]]
10131013
]
10141014
]),
1015-
atol: 1.0e-4
1015+
atol: 1.0e-3
10161016
)
10171017
end
10181018

@@ -1036,7 +1036,7 @@ defmodule Axon.LayersTest do
10361036
]
10371037
]
10381038
]),
1039-
atol: 1.0e-4
1039+
atol: 1.0e-3
10401040
)
10411041

10421042
# Downscaling (no effect)
@@ -1052,7 +1052,7 @@ defmodule Axon.LayersTest do
10521052
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
10531053
]
10541054
]),
1055-
atol: 1.0e-4
1055+
atol: 1.0e-3
10561056
)
10571057
end
10581058
end
@@ -1723,6 +1723,88 @@ defmodule Axon.LayersTest do
17231723
end
17241724
end
17251725

1726+
describe "rms_norm" do
1727+
test "matches pytorch 2D input" do
1728+
input =
1729+
Nx.tensor([
1730+
[1.9269, 1.4873, 0.9007, -2.1055, 0.6784, -1.2345, -0.0431, -1.6047],
1731+
[-0.7521, 1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688, 0.7624]
1732+
])
1733+
1734+
gamma =
1735+
Nx.tensor([0.4617, 0.2674, 0.5349, 0.8094, 1.1103, -1.6898, -0.9890, 0.9580])
1736+
1737+
expected =
1738+
Nx.tensor([
1739+
[0.6344, 0.2836, 0.3436, -1.2153, 0.5372, 1.4877, 0.0304, -1.0962],
1740+
[-0.3605, 0.4576, -0.2179, -1.1793, -0.8390, 0.9814, 0.7893, 0.7582]
1741+
])
1742+
1743+
actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1)
1744+
assert_all_close(expected, actual, atol: 1.0e-3)
1745+
end
1746+
1747+
test "matches pytorch 3D input" do
1748+
input =
1749+
Nx.tensor([
1750+
[
1751+
[-1.3847, -0.8712, -0.2234, 1.7174, 0.3189, -0.4245],
1752+
[0.3057, -0.7746, -1.5576, 0.9956, -0.8798, -0.6011],
1753+
[-1.2742, 2.1228, -1.2347, -0.4879, -0.9138, -0.6581],
1754+
[0.0780, 0.5258, -0.4880, 1.1914, -0.8140, -0.7360]
1755+
],
1756+
[
1757+
[-1.4032, 0.0360, -0.0635, 0.6756, -0.0978, 1.8446],
1758+
[-1.1845, 1.3835, 1.4451, 0.8564, 2.2181, 0.5232],
1759+
[0.3466, -0.1973, -1.0546, 1.2780, -0.1722, 0.5238],
1760+
[0.0566, 0.4263, 0.5750, -0.6417, -2.2064, -0.7508]
1761+
]
1762+
])
1763+
1764+
gamma =
1765+
Nx.tensor([0.4679, -0.2049, -0.7409, 0.3618, 1.9199, -0.2254])
1766+
1767+
expected =
1768+
Nx.tensor([
1769+
[
1770+
[-0.6502, 0.1792, 0.1661, 0.6236, 0.6144, 0.0960],
1771+
[0.1530, 0.1698, 1.2341, 0.3853, -1.8064, 0.1449],
1772+
[-0.4825, -0.3521, 0.7403, -0.1429, -1.4199, 0.1201],
1773+
[0.0504, -0.1488, 0.4994, 0.5955, -2.1588, 0.2291]
1774+
],
1775+
[
1776+
[-0.6653, -0.0075, 0.0477, 0.2477, -0.1903, -0.4213],
1777+
[-0.4033, -0.2063, -0.7791, 0.2255, 3.0986, -0.0858],
1778+
[0.2218, 0.0553, 1.0685, 0.6324, -0.4521, -0.1614],
1779+
[0.0257, -0.0849, -0.4138, -0.2255, -4.1147, 0.1644]
1780+
]
1781+
])
1782+
1783+
actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1)
1784+
assert_all_close(expected, actual, atol: 1.0e-3)
1785+
end
1786+
1787+
test "matches pytorch with ones weight" do
1788+
input =
1789+
Nx.tensor([
1790+
[0.6127, -1.1754, -0.7646, -0.6666],
1791+
[0.7444, -0.6453, -1.3890, -0.2730]
1792+
])
1793+
1794+
gamma =
1795+
Nx.tensor([1.0000, 1.0000, 1.0000, 1.0000])
1796+
1797+
expected =
1798+
Nx.tensor([
1799+
[0.7342, -1.4084, -0.9163, -0.7987],
1800+
[0.8632, -0.7483, -1.6108, -0.3165]
1801+
])
1802+
1803+
actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1)
1804+
assert_all_close(expected, actual, atol: 1.0e-3)
1805+
end
1806+
end
1807+
17261808
describe "batch_norm" do
17271809
test "matches pytorch when variance < epsilon" do
17281810
input_val = -0.002805

0 commit comments

Comments
 (0)