1212from saffier .types import DictAny
1313
1414if typing .TYPE_CHECKING : # pragma: no cover
15+ from saffier .db .connection import Database
1516 from saffier .models import Model
1617
1718
@@ -25,7 +26,7 @@ class QuerySetProps:
2526 """
2627
2728 @property
28- def database (self ):
29+ def database (self ) -> "Database" :
2930 return self .model_class ._meta .registry .database
3031
3132 @property
@@ -117,6 +118,7 @@ def _build_select(self):
117118 if self .distinct_on :
118119 expression = self ._build_select_distinct (self .distinct_on , expression = expression )
119120
121+ setattr (self , "_expression" , expression )
120122 return expression
121123
122124 def _filter_query (self , exclude : bool = False , ** kwargs ):
@@ -232,6 +234,7 @@ def _clone(self) -> "QuerySet[SaffierModel]":
232234 queryset ._order_by = copy .copy (self ._order_by )
233235 queryset ._group_by = copy .copy (self ._group_by )
234236 queryset .distinct_on = copy .copy (self .distinct_on )
237+ queryset ._expression = self ._expression
235238 return queryset
236239
237240
@@ -262,14 +265,30 @@ def __init__(
262265 self ._order_by = [] if order_by is None else order_by
263266 self ._group_by = [] if group_by is None else group_by
264267 self .distinct_on = [] if distinct_on is None else distinct_on
268+ self ._expression = None
265269
266270 def __get__ (self , instance , owner ):
267271 return self .__class__ (model_class = owner )
268272
273+ @property
274+ def sql (self ):
275+ return str (self ._expression )
276+
277+ @sql .setter
278+ def sql (self , value ):
279+ setattr (self , "_expression" , value )
280+
269281 async def __aiter__ (self ) -> typing .AsyncIterator [SaffierModel ]:
270282 for value in await self :
271283 yield value
272284
285+ def _set_query_expression (self , expression : typing .Any ) -> None :
286+ """
287+ Sets the value of the sql property to the expression used.
288+ """
289+ self .sql = expression
290+ self .model_class .raw_query = self .sql
291+
273292 def _filter_or_exclude (
274293 self ,
275294 clause : typing .Optional [sqlalchemy .sql .expression .BinaryExpression ] = None ,
@@ -389,6 +408,7 @@ async def exists(self) -> bool:
389408 """
390409 expression = self ._build_select ()
391410 expression = sqlalchemy .exists (expression ).select ()
411+ self ._set_query_expression (expression )
392412 return await self .database .fetch_val (expression )
393413
394414 async def count (self ) -> int :
@@ -397,6 +417,7 @@ async def count(self) -> int:
397417 """
398418 expression = self ._build_select ().alias ("subquery_for_count" )
399419 expression = sqlalchemy .func .count ().select ().select_from (expression )
420+ self ._set_query_expression (expression )
400421 return await self .database .fetch_val (expression )
401422
402423 async def get_or_none (self , ** kwargs ):
@@ -405,6 +426,7 @@ async def get_or_none(self, **kwargs):
405426 """
406427 queryset = self .filter (** kwargs )
407428 expression = queryset ._build_select ().limit (2 )
429+ self ._set_query_expression (expression )
408430 rows = await self .database .fetch_all (expression )
409431
410432 if not rows :
@@ -422,7 +444,12 @@ async def all(self, **kwargs):
422444 return await queryset .filter (** kwargs ).all ()
423445
424446 expression = queryset ._build_select ()
447+ self ._set_query_expression (expression )
448+
425449 rows = await queryset .database .fetch_all (expression )
450+
451+ # Attach the raw query to the object
452+ queryset .model_class .raw_query = self .sql
426453 return [
427454 queryset .model_class ._from_row (row , select_related = self ._select_related )
428455 for row in rows
@@ -437,6 +464,7 @@ async def get(self, **kwargs):
437464
438465 expression = self ._build_select ().limit (2 )
439466 rows = await self .database .fetch_all (expression )
467+ self ._set_query_expression (expression )
440468
441469 if not rows :
442470 raise DoesNotFound ()
@@ -475,6 +503,7 @@ async def create(self, **kwargs):
475503 kwargs = self ._validate_kwargs (** kwargs )
476504 instance = self .model_class (** kwargs )
477505 expression = self .table .insert ().values (** kwargs )
506+ self ._set_query_expression (expression )
478507
479508 if self .pkname not in kwargs :
480509 instance .pk = await self .database .execute (expression )
@@ -490,13 +519,57 @@ async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
490519 new_objs = [self ._validate_kwargs (** obj ) for obj in objs ]
491520
492521 expression = self .table .insert ().values (new_objs )
522+ self ._set_query_expression (expression )
493523 await self .database .execute (expression )
494524
525+ async def bulk_update (self , objs : typing .List [SaffierModel ], fields : typing .List [str ]) -> None :
526+ """
527+ Bulk updates records in a table.
528+
529+ A similar solution was suggested here: https://github.com/encode/orm/pull/148
530+
531+ It is thought to be a clean approach to a simple problem so it was added here and
532+ refactored to be compatible with Saffier.
533+ """
534+ new_fields = {}
535+ for key , field in self .model_class .fields .items ():
536+ if key in fields :
537+ new_fields [key ] = field .validator
538+
539+ validator = Schema (fields = new_fields )
540+
541+ new_objs = []
542+ for obj in objs :
543+ new_obj = {}
544+ for key , value in obj .__dict__ .items ():
545+ if key in fields :
546+ new_obj [key ] = self ._resolve_value (value )
547+ new_objs .append (new_obj )
548+
549+ new_objs = [
550+ self ._update_auto_now_fields (validator .validate (obj ), self .model_class .fields )
551+ for obj in new_objs
552+ ]
553+
554+ pk = getattr (self .table .c , self .pkname )
555+ expression = self .table .update ().where (pk == sqlalchemy .bindparam (self .pkname ))
556+ kwargs = {field : sqlalchemy .bindparam (field ) for obj in new_objs for field in obj .keys ()}
557+ pks = [{self .pkname : getattr (obj , self .pkname )} for obj in objs ]
558+
559+ query_list = []
560+ for pk , value in zip (pks , new_objs ):
561+ query_list .append ({** pk , ** value })
562+
563+ expression = expression .values (kwargs )
564+ self ._set_query_expression (expression )
565+ await self .database .execute_many (str (expression ), query_list )
566+
495567 async def delete (self ) -> None :
496568 expression = self .table .delete ()
497569 for filter_clause in self .filter_clauses :
498570 expression = expression .where (filter_clause )
499571
572+ self ._set_query_expression (expression )
500573 await self .database .execute (expression )
501574
502575 async def update (self , ** kwargs ) -> None :
@@ -509,12 +582,13 @@ async def update(self, **kwargs) -> None:
509582
510583 validator = Schema (fields = fields )
511584 kwargs = self ._update_auto_now_fields (validator .validate (kwargs ), self .model_class .fields )
512- expr = self .table .update ().values (** kwargs )
585+ expression = self .table .update ().values (** kwargs )
513586
514587 for filter_clause in self .filter_clauses :
515- expr = expr .where (filter_clause )
588+ expression = expression .where (filter_clause )
516589
517- await self .database .execute (expr )
590+ self ._set_query_expression (expression )
591+ await self .database .execute (expression )
518592
519593 async def get_or_create (
520594 self , defaults : typing .Dict [str , typing .Any ], ** kwargs
0 commit comments