Skip to content

Commit 7e0f86a

Browse files
committed
Update code for StatsModels v0.6.0
1 parent 163b717 commit 7e0f86a

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ docs/site
88
*.jl.mem
99
deps/deps.jl
1010

11-
.ipynb_checkpoints/
11+
.ipynb_checkpoints/
12+
.vscode/settings.json

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1616
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1717
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819

1920
[compat]
2021
Distributions = ">=0.16.0"
2122
GLM = ">=1.0.0"
2223
StatsBase = ">=0.24.0"
23-
StatsModels = ">=0.2.0"
24+
StatsModels = ">=0.6.0"
2425

2526
[extras]
2627
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"

src/ordmnfit.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Fit ordered multinomial model by maximum likelihood estimation.
158158
"""
159159
polr(X::AbstractMatrix, y::AbstractVector, args...; kwargs...) =
160160
fit(AbstractOrdinalMultinomialModel, X, y, args...; kwargs...)
161-
polr(f::Formula, df, args...; kwargs...) =
161+
polr(f::FormulaTerm, df, args...; kwargs...) =
162162
fit(AbstractOrdinalMultinomialModel, f, df, args...; kwargs...)
163163

164164
"""
@@ -178,20 +178,29 @@ Fit ordered multinomial model by maximum likelihood estimation.
178178
function fit(
179179
::Type{M},
180180
X::AbstractMatrix,
181-
y::AbstractVector,
181+
y::AbstractVecOrMat,
182182
link::GLM.Link = LogitLink(),
183183
solver = NLoptSolver(algorithm=:LD_SLSQP, maxeval=4000);
184184
wts::AbstractVector = similar(X, 0)
185185
) where M <: AbstractOrdinalMultinomialModel
186+
ydata = Vector{Int}(undef, size(y, 1))
186187
# set up optimization
187-
ydata = denserank(y) # dense ranking of y, http://juliastats.github.io/StatsBase.jl/stable/ranking.html#StatsBase.denserank
188+
if size(y, 2) == 1
189+
ydata = denserank(y)
190+
else #y is encoded via dummy-encoding
191+
for i in 1:size(y, 1)
192+
idx = findfirst(view(y, i, :) .== 1)
193+
ydata[i] = idx == nothing ? 1 : idx + 1
194+
end
195+
end
196+
#ydata = denserank(y) # dense ranking of y, http://juliastats.github.io/StatsBase.jl/stable/ranking.html#StatsBase.denserank
188197
dd = OrdinalMultinomialModel(X, ydata, convert(Vector{eltype(X)}, wts), link)
189198
m = MathProgBase.NonlinearModel(solver)
190199
lb = fill(-Inf, dd.npar)
191200
ub = fill( Inf, dd.npar)
192201
MathProgBase.loadproblem!(m, dd.npar, 0, lb, ub, Float64[], Float64[], :Max, dd)
193202
# initialize from LS solution
194-
β0 = [ones(length(y)) X] \ ydata
203+
β0 = [ones(length(ydata)) X] \ ydata
195204
par0 = [β0[1] - dd.J / 2 + 1; zeros(dd.J - 2); β0[2:end]]
196205
MathProgBase.setwarmstart!(m, par0)
197206
MathProgBase.optimize!(m)
@@ -236,7 +245,7 @@ function MathProgBase.eval_grad_f(m::OrdinalMultinomialModel, grad::Vector, par:
236245
copyto!(grad, m.J, m.∇, m.J, m.p)
237246
end
238247

239-
function StatsModels.coeftable(mod::StatsModels.DataFrameRegressionModel{T, S}
248+
function StatsModels.coeftable(mod::StatsModels.TableRegressionModel{T, S}
240249
where {T <: OrdinalMultinomialModel, S <: Matrix} )
241250
ct = coeftable(mod.model)
242251
cfnames = [["intercept$i|$(i+1)" for i in 1:(mod.model.J - 1)]; coefnames(mod)]

src/ordmntest.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function polrtest(nm::OrdinalMultinomialModel, Z::AbstractVecOrMat; test=:score)
88
end
99
end
1010

11-
polrtest(nm::StatsModels.DataFrameRegressionModel{<:OrdinalMultinomialModel}, Z::AbstractVecOrMat; kwargs...) =
11+
polrtest(nm::StatsModels.TableRegressionModel{<:OrdinalMultinomialModel}, Z::AbstractVecOrMat; kwargs...) =
1212
polrtest(nm.model, Z; kwargs...)
1313

1414
###########################

0 commit comments

Comments
 (0)