6666# #### map
6767# ####
6868
69- # `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70- # will be useful for the gradient of `map` etc.
71-
72-
7369"""
74- unzip_map(f, args...)
70+ unzip_map(f, args...)
7571
7672For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
7773but performed using `StructArrays` for efficiency.
@@ -86,40 +82,36 @@ function unzip_map(f::F, args...) where {F}
8682end
8783
8884unzip_map(f:: F , args:: Tuple... ) where {F} = unzip(map(f, args... ))
85+ # unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...))
8986
9087unzip_map(f:: F , args:: AbstractGPUArray... ) where {F} = unzip(map(f, args... ))
9188
89+ """
90+ unzip_map_reversed(f, args...)
91+
92+ For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`.
93+ But the order of evaluation is should be the reverse.
94+ Does NOT handle `zip`-like behaviour.
95+ """
9296function unzip_map_reversed(f:: F , args... ) where {F}
9397 T = Broadcast. combine_eltypes(f, args)
9498 if isconcretetype(T)
9599 T <: Tuple || throw(ArgumentError(""" unzip_map_reversed(f, args) only works on functions returning a tuple,
96100 but f = $(sprint(show, f)) returns type T = $T """ ))
97101 end
98102 len1 = length(first(args))
99- if all(a -> length(a)== len1, args)
100- rev_args = map(Iterators. reverse, args)
101- outs = StructArrays. components(StructArray(Iterators. map(f, rev_args... )))
102- else
103- len = minimum(length, args)
104- rev_args = map(a -> Iterators. reverse(@view a[begin : begin + len- 1 ]), args)
105- outs = StructArrays. components(StructArray(Iterators. map(f, rev_args... )))
106- end
107- return map(reverse!!, outs)
103+ all(a -> length(a)== len1, args) || error(" unzip_map_reversed does not handle zip-like behaviour." )
104+ return map(reverse!!, unzip_map(f, map(_safereverse, args). .. ))
108105end
109106
107+ # This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6
108+ _safereverse(x) = VERSION > v" 1.7" ? Iterators. reverse(x) : reverse(x)
109+
110110function unzip_map_reversed(f:: F , args:: Tuple... ) where {F}
111- len = minimum(length, args)
112- rev_args = map(a -> reverse(a[1 : len]), args)
113- # vlen = Val(len)
114- # rev_args = map(args) do a
115- # reverse(ntuple(i -> a[i], vlen)) # does not infer better
116- # end
117- return map(reverse, unzip(map(f, rev_args... )))
111+ len1 = length(first(args))
112+ all(a -> length(a)== len1, args) || error(" unzip_map_reversed does not handle zip-like behaviour." )
113+ return map(reverse, unzip(map(f, map(reverse, args). .. )))
118114end
119- # function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
120- # rev_args = map(reverse, args)
121- # return map(reverse, unzip(map(f, rev_args...)))
122- # end
123115
124116"""
125117 reverse!!(x)
@@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray)
135127 end
136128end
137129reverse!!(x:: AbstractArray{<:AbstractZero} ) = x
130+ reverse!!(x) = reverse(x)
138131
139- frule((_, xdot), :: typeof (reverse!!), x:: AbstractArray ) = reverse!!(x), reverse!!(xdot)
132+ frule((_, xdot), :: typeof (reverse!!), x) = reverse!!(x), reverse!!(xdot)
140133
141- function rrule(:: typeof (reverse!!), x:: AbstractArray )
134+ function rrule(:: typeof (reverse!!), x)
142135 reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
143136 return reverse!!(x), reverse!!_back
144137end
@@ -181,10 +174,16 @@ end
181174 Expr(:tuple, each... )
182175end
183176
184- unzip(xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret(T, xs),) # best case, no copy
177+ function unzip(xs:: AbstractArray{Tuple{T}} ) where {T}
178+ if isbitstype(T)
179+ (reinterpret(T, xs),) # best case, no copy
180+ else
181+ (map(only, xs),)
182+ end
183+ end
185184
186185@generated function unzip(xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
187- each = if count(! Base. issingletontype, Ts. parameters) < 2
186+ each = if count(! Base. issingletontype, Ts. parameters) < 2 && all(isbitstype, Ts . parameters)
188187 # good case, no copy of data, some trivial arrays
189188 [Base. issingletontype(T) ? :(similar(xs, $ T)) : :(reinterpret($ T, xs)) for T in Ts. parameters]
190189 else
0 commit comments