Skip to content

Commit e7c2ed7

Browse files
authored
Merge pull request #4 from JasonTam/eval-use-precompute
Eval use precomputed representations
2 parents 88f6bee + 48073bd commit e7c2ed7

File tree

7 files changed

+1418
-653
lines changed

7 files changed

+1418
-653
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ test_precision = precision_at_k(model, data['test'], k=5).mean()
4141
1. [Learning to Rank Sketchfab Models with LightFM](http://blog.ethanrosenthal.com/2016/11/07/implicit-mf-part-2/)
4242
2. [Metadata Embeddings for User and Item Cold-start Recommendations](http://building-babylon.net/2016/01/26/metadata-embeddings-for-user-and-item-cold-start-recommendations/)
4343
3. [Recommendation Systems - Learn Python for Data Science](https://www.youtube.com/watch?v=9gBC9R-msAk)
44-
4. [Using LightFM to Recommend Projects to Consultants] (https://medium.com/product-at-catalant-technologies/using-lightfm-to-recommend-projects-to-consultants-44084df7321c#.gu887ky51)
44+
4. [Using LightFM to Recommend Projects to Consultants](https://medium.com/product-at-catalant-technologies/using-lightfm-to-recommend-projects-to-consultants-44084df7321c#.gu887ky51)
4545

4646
## How to cite
4747
Please cite LightFM if it helps your research. You can use the following BibTeX entry:

lightfm/_lightfm_fast.pyx.template

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,12 +1233,15 @@ def predict_lightfm(CSRMatrix item_features,
12331233
int[::1] item_ids,
12341234
double[::1] predictions,
12351235
FastLightFM lightfm,
1236-
int num_threads):
1236+
int num_threads,
1237+
bint use_precomputed,
1238+
flt[:, ::1] user_reprs,
1239+
flt[:, ::1] item_reprs):
12371240
"""
12381241
Generate predictions.
12391242
"""
12401243

1241-
cdef int i, no_examples
1244+
cdef int i, no_examples, factor_dim
12421245
cdef flt *user_repr
12431246
cdef flt *it_repr
12441247

@@ -1250,21 +1253,25 @@ def predict_lightfm(CSRMatrix item_features,
12501253
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
12511254

12521255
for i in {range_block}(no_examples):
1253-
1254-
compute_representation(user_features,
1255-
lightfm.user_features,
1256-
lightfm.user_biases,
1257-
lightfm,
1258-
user_ids[i],
1259-
lightfm.user_scale,
1260-
user_repr)
1261-
compute_representation(item_features,
1262-
lightfm.item_features,
1263-
lightfm.item_biases,
1264-
lightfm,
1265-
item_ids[i],
1266-
lightfm.item_scale,
1267-
it_repr)
1256+
if use_precomputed:
1257+
for factor_dim in {range_block}(lightfm.no_components + 1):
1258+
user_repr[factor_dim] = user_reprs[user_ids[i], factor_dim]
1259+
it_repr[factor_dim] = item_reprs[item_ids[i], factor_dim]
1260+
else:
1261+
compute_representation(user_features,
1262+
lightfm.user_features,
1263+
lightfm.user_biases,
1264+
lightfm,
1265+
user_ids[i],
1266+
lightfm.user_scale,
1267+
user_repr)
1268+
compute_representation(item_features,
1269+
lightfm.item_features,
1270+
lightfm.item_biases,
1271+
lightfm,
1272+
item_ids[i],
1273+
lightfm.item_scale,
1274+
it_repr)
12681275

12691276
predictions[i] = compute_prediction_from_repr(user_repr,
12701277
it_repr,
@@ -1280,11 +1287,14 @@ def predict_ranks(CSRMatrix item_features,
12801287
CSRMatrix train_interactions,
12811288
flt[::1] ranks,
12821289
FastLightFM lightfm,
1283-
int num_threads):
1290+
int num_threads,
1291+
bint use_precomputed,
1292+
flt[:, ::1] user_reprs,
1293+
flt[:, ::1] item_reprs):
12841294
"""
12851295
"""
12861296

1287-
cdef int i, j, user_id, item_id, predictions_size, row_start, row_stop
1297+
cdef int i, j, user_id, item_id, predictions_size, row_start, row_stop, factor_dim
12881298
cdef flt *user_repr
12891299
cdef flt *it_repr
12901300
cdef flt *predictions
@@ -1315,27 +1325,35 @@ def predict_ranks(CSRMatrix item_features,
13151325
# No test interactions for this user
13161326
continue
13171327

1318-
compute_representation(user_features,
1319-
lightfm.user_features,
1320-
lightfm.user_biases,
1321-
lightfm,
1322-
user_id,
1323-
lightfm.user_scale,
1324-
user_repr)
1328+
if use_precomputed:
1329+
for factor_dim in {range_block}(lightfm.no_components + 1):
1330+
user_repr[factor_dim] = user_reprs[user_id, factor_dim]
1331+
else:
1332+
compute_representation(user_features,
1333+
lightfm.user_features,
1334+
lightfm.user_biases,
1335+
lightfm,
1336+
user_id,
1337+
lightfm.user_scale,
1338+
user_repr)
13251339

13261340
# Compute predictions for the items whose
13271341
# ranks we want to know
13281342
for i in range(row_stop - row_start):
13291343

13301344
item_id = test_interactions.indices[row_start + i]
13311345

1332-
compute_representation(item_features,
1333-
lightfm.item_features,
1334-
lightfm.item_biases,
1335-
lightfm,
1336-
item_id,
1337-
lightfm.item_scale,
1338-
it_repr)
1346+
if use_precomputed:
1347+
for factor_dim in range(lightfm.no_components + 1):
1348+
it_repr[factor_dim] = item_reprs[item_id, factor_dim]
1349+
else:
1350+
compute_representation(item_features,
1351+
lightfm.item_features,
1352+
lightfm.item_biases,
1353+
lightfm,
1354+
item_id,
1355+
lightfm.item_scale,
1356+
it_repr)
13391357

13401358
item_ids[i] = item_id
13411359
predictions[i] = compute_prediction_from_repr(user_repr,
@@ -1350,13 +1368,17 @@ def predict_ranks(CSRMatrix item_features,
13501368
# This depends on how we want to evaluate (so it's not necessary)
13511369
continue
13521370

1353-
compute_representation(item_features,
1354-
lightfm.item_features,
1355-
lightfm.item_biases,
1356-
lightfm,
1357-
item_id,
1358-
lightfm.item_scale,
1359-
it_repr)
1371+
if use_precomputed:
1372+
for factor_dim in range(lightfm.no_components + 1):
1373+
it_repr[factor_dim] = item_reprs[item_id, factor_dim]
1374+
else:
1375+
compute_representation(item_features,
1376+
lightfm.item_features,
1377+
lightfm.item_biases,
1378+
lightfm,
1379+
item_id,
1380+
lightfm.item_scale,
1381+
it_repr)
13601382
prediction = compute_prediction_from_repr(user_repr,
13611383
it_repr,
13621384
lightfm.no_components)

0 commit comments

Comments
 (0)