Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lib/axon/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,7 @@ defmodule Axon.Shared do

defn normalize(input, mean, variance, gamma, bias, opts \\ []) do
[epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6)

# The select is so that we improve numerical stability by clipping
# both insignificant values of variance and NaNs to epsilon.
scale =
gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon))

scale = gamma * Nx.rsqrt(variance + epsilon)
scale * (input - mean) + bias
end

Expand Down
26 changes: 26 additions & 0 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1722,4 +1722,30 @@ defmodule Axon.LayersTest do
assert_all_close(expected, actual, atol: 1.0e-3)
end
end

describe "batch_norm" do
test "matches pytorch when variance < epsilon" do
input_val = -0.002805
mean = -0.008561
variance = 0.000412
weight = 1.0
bias = -0.144881
epsilon = 0.001

expected = Nx.tensor([0.0083])

actual =
Axon.Layers.batch_norm(
Nx.tensor([[[[input_val]]]]),
Nx.tensor([weight]),
Nx.tensor([bias]),
Nx.tensor([mean]),
Nx.tensor([variance]),
mode: :inference,
epsilon: epsilon
)

assert_all_close(expected, actual, atol: 1.0e-3)
end
end
end
Loading