@@ -185,8 +185,8 @@ ordered confusion matrices, see [`$CM.confmat`](@ref).
185185
186186"""
187187struct ConfusionMatrix{N,O,L}
188- mat:: Matrix{Int }
189- index_given_level:: LittleDict{L, Int , NTuple{N,L}, NTuple{N,Int}}
188+ mat:: Matrix{<:Integer }
189+ index_given_level:: LittleDict{L, I , NTuple{N,L}, NTuple{N,I}} where I <: Integer
190190end
191191
192192"""
@@ -204,14 +204,13 @@ See also [`$CM.confmat`](@ref).
204204"""
205205function ConfusionMatrix (
206206 m,
207- dic:: AbstractDict {L,I } ;
207+ dic:: LittleDict {L, I, NTuple{N,L}, NTuple{N,I} } ;
208208 checks= true ,
209209 ordered= false ,
210- ) where {L,I<: Integer }
211- s = size (m)
212- N = s[1 ]
210+ ) where {L,I<: Integer ,N}
213211 if checks
214- N == s[2 ] || throw (ArgumentError (" Expected a square matrix." ))
212+ s = size (m)
213+ N == s[1 ] == s[2 ] || throw (ArgumentError (" Expected a square matrix." ))
215214 N > 1 || throw (ArgumentError (" Expected a matrix of size ≥ 2x2." ))
216215 length (unique (keys (dic))) == N || throw (ArgumentError (
217216 " Expected dictionary with $N unique keys (levels) as " *
@@ -222,8 +221,15 @@ function ConfusionMatrix(
222221 " to be integers from 1 to $N . "
223222 ))
224223 end
225- index_given_level = freeze (dic)
226- ConfusionMatrix {N,ordered,L} (m, index_given_level)
224+ ConfusionMatrix {N,ordered,L} (m, dic)
225+ end
226+ function ConfusionMatrix (
227+ m,
228+ dic:: AbstractDict ;
229+ checks= true ,
230+ ordered= false ,
231+ )
232+ ConfusionMatrix (m, freeze (dic); checks, ordered)
227233end
228234
229235"""
@@ -251,7 +257,7 @@ function ConfusionMatrix(m, levels::AbstractVector{L}; ordered=false, checks=tru
251257 ))
252258 end
253259 index_given_level =
254- LittleDict {L, Int, Vector{L}, Vector{Int}} (levels, eachindex (levels)) |> freeze
260+ LittleDict {L, Int, Vector{L}, Vector{Int}} (levels, eachindex (levels))
255261 ConfusionMatrix (m, index_given_level; ordered, checks= false )
256262end
257263
@@ -489,8 +495,8 @@ function confmat(ŷ, y, _levels, _perm, rev)
489495 perm = permutation (_perm, rev, levels)
490496
491497 levels = apply (perm, levels)
492- indexer = LittleDict (levels[i] => i for i in eachindex (levels)) |> freeze
493-
498+ L = eltype (levels)
499+ indexer = LittleDict {L, Int, Vector{L}, Vector{Int}} (levels, eachindex (levels)) |> freeze
494500 _confmat (ŷ, y, indexer, levels, ordered)
495501end
496502
@@ -533,18 +539,26 @@ end
533539
534540
535541# ## Final method to do the computation
536-
537- function _confmat (ŷ, y, indexer, levels, ordered)
542+ function _confmat (ŷ, y, indexer:: F , levels, ordered) where F
538543 nc = length (levels)
539544 cmat = zeros (Int, nc, nc)
540545 @inbounds for i in eachindex (y)
541546 (ismissing (y[i]) || ismissing (ŷ[i])) && continue
542547 cmat[get (indexer, ŷ[i]), get (indexer, y[i])] += 1
543548 end
544- index_given_level = LittleDict (c => get (indexer, c) for c in levels) |> freeze
545549 return ConfusionMatrix (cmat, levels; ordered, checks= false )
546550end
547551
552+ function _confmat (ŷ, y, indexer:: AbstractDict{L,I} , levels, ordered) where {L,I<: Integer }
553+ nc = length (levels)
554+ cmat = zeros (Int, nc, nc)
555+ @inbounds for i in eachindex (y)
556+ (ismissing (y[i]) || ismissing (ŷ[i])) && continue
557+ cmat[get (indexer, ŷ[i]), get (indexer, y[i])] += 1
558+ end
559+ return ConfusionMatrix (cmat, indexer; ordered, checks= false )
560+ end
561+
548562
549563# DISPLAY
550564
0 commit comments