Skip to content
17 changes: 12 additions & 5 deletions choice_learn/basket_models/base_basket_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,6 @@ def evaluate(
sparse=True,
from_logits=False,
epsilon=epsilon_eval,
average_on_batch=True,
name="basketwise-nll",
)
)
Expand All @@ -897,7 +896,7 @@ def evaluate(
metric.reset_state()

# for trip in trip_dataset.trips:
for data_batch, identifier_batch in trip_dataset.iter_batch_evaluate(
for data_batch, weights_batch in trip_dataset.iter_batch_evaluate(
trip_batch_size=trip_batch_size
):
# Sum of the log-likelihoods of all the baskets in the batch
Expand All @@ -913,9 +912,17 @@ def evaluate(

for metric in exec_metrics:
# Use update_state, not append(metric(...))
metric.update_state(
y_true=data_batch[0], y_pred=predicted_probabilities, batch=identifier_batch
)
if "basketwise" in metric.name:
metric.update_state(
y_true=data_batch[0],
y_pred=predicted_probabilities,
sample_weight=weights_batch,
)
else:
metric.update_state(
y_true=data_batch[0],
y_pred=predicted_probabilities,
)

# After the loops, get the final results
return {metric.name: metric.result() for metric in exec_metrics}
Expand Down
21 changes: 12 additions & 9 deletions choice_learn/basket_models/data/basket_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,29 +783,31 @@ def iter_batch_evaluate(
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items), dtype=int), # Prices
np.empty((0, self.n_items), dtype=int), # Available items
np.empty(0, dtype=int), # Users
)

if trip_batch_size == -1:
# Get the whole dataset in one batch
identifiers = []
weights = []
for trip_index in trip_indexes:
additional_trip_data = self.get_one_vs_all_augmented_data_from_trip_index(
trip_index
)
buffer = tuple(
np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer))
)
identifiers.extend([trip_index] * len(additional_trip_data[0]))
weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0]))

# Yield the whole dataset
yield buffer, np.array(identifiers)
yield buffer, np.array(weights).astype("float32")

else:
# Yield batches of size batch_size while going through all the trips
index = 0
outer_break = False
while index < num_trips:
trip_identifier = []
weights = []
trip_count = 0
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
Expand All @@ -816,11 +818,11 @@ def iter_batch_evaluate(
np.empty((0, self.n_items), dtype=int), # Available items
np.empty(0, dtype=int), # Users
)
while np.max(trip_identifier, initial=-1) + 1 < trip_batch_size:
while trip_count + 1 < trip_batch_size:
if index >= num_trips:
# Then the buffer is not full but there are no more trips to consider
# Yield the batch partially filled
yield buffer, np.array(trip_identifier)
yield buffer, np.array(weights).astype("float32")

# Exit the TWO while loops when all trips have been considered
outer_break = True
Expand All @@ -832,18 +834,19 @@ def iter_batch_evaluate(
trip_indexes[index]
)
index += 1
trip_count += 1

# Fill the buffer with the new trip
buffer = tuple(
np.concatenate((buffer[i], additional_trip_data[i]))
for i in range(len(buffer))
)
trip_identifier.extend(
[np.max(trip_identifier, initial=-1) + 1] * len(additional_trip_data[0])
weights.extend(
[1 / len(additional_trip_data[0])] * len(additional_trip_data[0])
)

if outer_break:
break

# Yield the batch
yield buffer, np.array(trip_identifier)
yield buffer, np.array(weights).astype("float32")
69 changes: 53 additions & 16 deletions choice_learn/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tqdm

import choice_learn.tf_ops as tf_ops
from choice_learn.data import ChoiceDataset


class ChoiceModel:
Expand Down Expand Up @@ -254,6 +255,7 @@ def fit(
choice_dataset,
sample_weight=None,
val_dataset=None,
validation_freq=1,
verbose=0,
):
"""Train the model with a ChoiceDataset.
Expand All @@ -264,14 +266,18 @@ def fit(
Input data in the form of a ChoiceDataset
sample_weight : np.ndarray, optional
Sample weight to apply, by default None
val_dataset : ChoiceDataset, optional
val_dataset : ChoiceDataset or (ChoiceDataset, samples_weight), optional
Test ChoiceDataset to evaluate performances on test at each epoch, by default None
verbose : int, optional
print level, for debugging, by default 0
epochs : int, optional
Number of epochs, default is None, meaning we use self.epochs
batch_size : int, optional
Batch size, default is None, meaning we use self.batch_size
validation_freq: int, optional
Only relevant if validation data is provided. Specifies how many training epochs
to run before a new validation run is performed, e.g. validation_freq=2 runs validation
every 2 epochs.

Returns
-------
Expand Down Expand Up @@ -405,24 +411,55 @@ def fit(
)

# Test on val_dataset if provided
if val_dataset is not None:
if val_dataset is not None and ((epoch_nb + 1) % validation_freq) == 0:
test_losses = []
for batch_nb, (
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)):

val_samples_weight = None
if isinstance(val_dataset, tuple):
if not len(val_dataset) == 2:
raise ValueError(
"""if argument val_dataset is a tuple, it should be
in the form (ChoiceDataset, weights)"""
)
validation_dataset, val_samples_weight = val_dataset
elif isinstance(val_dataset, ChoiceDataset):
validation_dataset = val_dataset
else:
raise ValueError(
"""val_dataset should be a ChoiceDataset or
a tuple of (ChoiceDataset, weights)."""
)

val_iterator = validation_dataset.iter_batch(
shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size
)

for batch_nb, batch_data in enumerate(val_iterator):
weight_batch = None
if val_samples_weight is not None:
batch_features, weight_batch = batch_data
else:
batch_features = batch_data

(
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
) = batch_features

self.callbacks.on_batch_begin(batch_nb)
self.callbacks.on_test_batch_begin(batch_nb)
test_losses.append(
self.batch_predict(
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
)[0]["optimized_loss"]
)

loss = self.batch_predict(
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
sample_weight=weight_batch,
)[0]["optimized_loss"]
test_losses.append(loss)

val_logs["val_loss"].append(test_losses[-1])
temps_logs = {k: tf.reduce_mean(v) for k, v in val_logs.items()}
self.callbacks.on_test_batch_end(batch_nb, logs=temps_logs)
Expand Down
54 changes: 23 additions & 31 deletions choice_learn/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
self,
from_logits=False,
sparse=False,
average_on_batch=False,
epsilon=1e-10,
name="negative_log_likelihood",
axis=-1,
Expand All @@ -40,8 +39,6 @@ def __init__(
Whether y_true is given as an index or a one-hot, by default False
epsilon : float, optional
Lower bound for log(.), by default 1e-10
average_on_batch: bool, optional
Whether the metric should be averaged over each batch. Typically used to
get metrics averaged by Trip, by default False
name : str, optional
Name of operation, by default "negative_log_likelihood"
Expand All @@ -53,11 +50,10 @@ def __init__(
self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals")
self.from_logits = from_logits
self.sparse = sparse
self.average_on_batch = average_on_batch
self.epsilon = epsilon
self.axis = axis

def update_state(self, y_true, y_pred, batch=None, sample_weight=None):
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulate statistics for the metric.

Parameters
Expand Down Expand Up @@ -91,16 +87,11 @@ def update_state(self, y_true, y_pred, batch=None, sample_weight=None):
axis=self.axis,
)

if batch is not None and self.average_on_batch:
for _, idx in zip(*tf.unique(batch)):
self.nll.assign(self.nll + tf.reduce_mean(nll_value[idx]))
self.n_evals.assign(self.n_evals + 1)
self.nll.assign(self.nll + tf.reduce_sum(nll_value))
if sample_weight is None:
self.n_evals.assign(self.n_evals + tf.shape(y_true)[0])
else:
self.nll.assign(self.nll + tf.reduce_sum(nll_value))
if sample_weight is None:
self.n_evals.assign(self.n_evals + tf.shape(y_true)[0])
else:
self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight))
self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight))

def result(self):
"""Compute the current metric value.
Expand All @@ -118,22 +109,20 @@ class MRR(tf.keras.metrics.Metric):

def __init__(
self,
average_on_batch=False,
name="mean_reciprocal_rank",
axis=-1,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.mrr = self.add_variable(shape=(), initializer="zeros", name="mrr")
self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals")
self.average_on_batch = average_on_batch
self.axis = axis

def update_state(
self,
y_true,
y_pred,
batch=None,
sample_weight=None,
):
"""Accumulate statistics for the metric.

Expand All @@ -156,15 +145,17 @@ def update_state(
[tf.range(len(y_true)), y_true], axis=1
) # Shape: (batch_size, 2)
item_ranks = tf.gather_nd(ranks, item_batch_indices) # Shape: (batch_size,)
mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis)

if batch is not None and self.average_on_batch:
self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank))
self.n_evals.assign(self.n_evals + 1)
if sample_weight is not None:
mean_rank = tf.reduce_sum(
tf.cast(1 / item_ranks, dtype=tf.float32) * sample_weight, axis=self.axis
)
self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight))
else:
self.mrr.assign(self.mrr + tf.reduce_sum(mean_rank))
mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis)
self.n_evals.assign(self.n_evals + tf.shape(y_true)[0])

self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank))

def result(self):
"""Compute the current metric value.

Expand All @@ -181,7 +172,6 @@ class HitRate(tf.keras.metrics.Metric):

def __init__(
self,
average_on_batch=False,
top_k: int = 10,
name=None,
axis=-1,
Expand All @@ -195,10 +185,9 @@ def __init__(
shape=(), initializer="zeros", name=f"hit_rate_at_{self.top_k}"
)
self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals")
self.average_on_batch = average_on_batch
self.axis = axis

def update_state(self, y_true, y_pred, batch=None):
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulate statistics for the metric.

Parameters
Expand All @@ -223,14 +212,17 @@ def update_state(self, y_true, y_pred, batch=None):
),
axis=1,
)
hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis)
if batch is not None and self.average_on_batch:
self.hit_rate.assign(self.hit_rate + tf.reduce_mean(hits))
self.n_evals.assign(self.n_evals + 1)
if sample_weight is not None:
hits = tf.reduce_sum(
tf.cast(hits_per_batch, tf.float32) * sample_weight, axis=self.axis
)
self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight))
else:
self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits))
hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis)
self.n_evals.assign(self.n_evals + tf.shape(y_true)[0])

self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits))

def result(self):
"""Compute the current metric value.

Expand Down
Loading