Skip to content

Commit 30cb09c

Browse files
authored
Merge pull request #152 from cnellington/dev
Release v0.2.4
2 parents 9a9549f + b9a5a0f commit 30cb09c

File tree

5 files changed

+21
-9
lines changed

5 files changed

+21
-9
lines changed

contextualized/easy/ContextualizedClassifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from contextualized.functions import LINK_FUNCTIONS
88
from contextualized.easy import ContextualizedRegressor
9+
from contextualized.regression import LOSSES
910

1011

1112
class ContextualizedClassifier(ContextualizedRegressor):
@@ -15,6 +16,7 @@ class ContextualizedClassifier(ContextualizedRegressor):
1516

1617
def __init__(self, **kwargs):
1718
kwargs["link_fn"] = LINK_FUNCTIONS["logistic"]
19+
kwargs["loss_fn"] = LOSSES["bceloss"]
1820
super().__init__(**kwargs)
1921

2022
def predict(self, C, X, individual_preds=False, **kwargs):

contextualized/easy/ContextualizedNetworks.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def measure_mses(self, C, X, individual_preds=False):
8080
mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
8181
for i in range(X.shape[-1]):
8282
for j in range(X.shape[-1]):
83-
residuals = np.tile(X[:, i], len(betas)) - (
84-
betas[:, :, i, j] * np.tile(X[:, j], len(betas)) + mus[:, :, i, j]
85-
)
83+
tiled_xi = np.array([X[:, i] for _ in range(len(betas))])
84+
tiled_xj = np.array([X[:, j] for _ in range(len(betas))])
85+
residuals = tiled_xi - betas[:, :, i, j]*tiled_xj + mus[:, :, i ,j]
8686
mses += residuals**2 / (X.shape[-1] ** 2)
8787
if not individual_preds:
8888
mses = np.mean(mses, axis=0)
@@ -164,7 +164,7 @@ def __init__(self, **kwargs):
164164
**kwargs,
165165
)
166166

167-
def predict_params(self, C, individual_preds=False, **kwargs):
167+
def predict_params(self, C, **kwargs):
168168
"""
169169
170170
:param C:
@@ -173,8 +173,12 @@ def predict_params(self, C, individual_preds=False, **kwargs):
173173
"""
174174
# Returns betas
175175
# TODO: No mus for NOTMAD at present.
176-
return super().predict_params(C, individual_preds,
177-
model_includes_mus=False, **kwargs)
176+
return super().predict_params(
177+
C,
178+
individual_preds=kwargs.get("individual_preds", False),
179+
model_includes_mus=False,
180+
uses_y=False,
181+
project_to_dag=kwargs.get("project_to_dag", True))
178182

179183
def predict_networks(self, C, with_offsets=False, project_to_dag=True, **kwargs):
180184
"""

contextualized/regression/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
MultitaskMultivariateDataset,
88
MultitaskUnivariateDataset,
99
)
10-
from contextualized.regression.losses import MSE
10+
from contextualized.regression.losses import MSE, BCELoss
1111
from contextualized.regression.regularizers import REGULARIZERS
1212
from contextualized.regression.lightning_modules import (
1313
NaiveContextualizedRegression,
@@ -25,7 +25,7 @@
2525
"multitask_multivariate": MultitaskMultivariateDataset,
2626
"multitask_univariate": MultitaskUnivariateDataset,
2727
}
28-
LOSSES = {"mse": MSE}
28+
LOSSES = {"mse": MSE, "bceloss": BCELoss}
2929
MODELS = ["multivariate", "univariate"]
3030
METAMODELS = ["simple", "subtype", "multitask", "tasksplit"]
3131
TRAINERS = {"regression_trainer": RegressionTrainer}

contextualized/regression/losses.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Losses used in regression.
33
"""
44

5+
import torch
56

67
def MSE(Y_true, Y_pred):
78
"""
@@ -20,3 +21,8 @@ def MSE(Y_true, Y_pred):
2021
"""
2122
residual = Y_true - Y_pred
2223
return residual.pow(2).mean()
24+
25+
26+
def BCELoss(Y_true, Y_pred):
27+
loss = -(Y_true * torch.log(Y_pred + 1e-8) + (1 - Y_true) * torch.log(1 - Y_pred + 1e-8))
28+
return loss.mean()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from setuptools import find_packages, setup
66

77
DESCRIPTION = "An ML toolbox for estimating context-specific parameters."
8-
VERSION = '0.2.3'
8+
VERSION = '0.2.4'
99

1010
setup(
1111
name='contextualized',

0 commit comments

Comments
 (0)