Skip to content

Commit aeeaf68

Browse files
Implement auto_new attribute for #[pyclass] (#5421)
* implement auto_new on pyclasses * fix double import * add changelog fragment * fix import formatting * auto_new: unrestrict set_all and restrict extends * update ui test * cargo fmt * formatting * update new impl macro * update ui test * ensure generated `__new__` is treated as `#[new]` * support tuple structs * add ui test * fix ui test? * update UI test * avoid compile error &str case * fixup ui test * make trybuild output stable * bump to stable trybuild * add UI test for conflicting `#[new]` implementations --------- Co-authored-by: David Hewitt <mail@davidhewitt.dev>
1 parent 5578141 commit aeeaf68

File tree

9 files changed

+379
-7
lines changed

9 files changed

+379
-7
lines changed

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ portable-atomic = "1.0"
6969
assert_approx_eq = "1.1.0"
7070
chrono = "0.4.25"
7171
chrono-tz = ">= 0.10, < 0.11"
72-
# Required for "and $N others" normalization
73-
trybuild = ">=1.0.70"
72+
trybuild = ">=1.0.115"
7473
proptest = { version = "1.0", default-features = false, features = ["std"] }
7574
send_wrapper = "0.6"
7675
serde = { version = "1.0", features = ["derive"] }

guide/pyclass-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
| `rename_all = "renaming_rule"` | Applies renaming rules to every getters and setters of a struct, or every variants of an enum. Possible values are: "camelCase", "kebab-case", "lowercase", "PascalCase", "SCREAMING-KEBAB-CASE", "SCREAMING_SNAKE_CASE", "snake_case", "UPPERCASE". |
2323
| `sequence` | Inform PyO3 that this class is a [`Sequence`][params-sequence], and so leave its C-API mapping length slot empty. |
2424
| `set_all` | Generates setters for all fields of the pyclass. |
25+
| `new = "from_fields"` | Generates a default `__new__` constructor with all fields as parameters in the `new()` method. |
2526
| `skip_from_py_object` | Prevents this PyClass from participating in the `FromPyObject: PyClass + Clone` blanket implementation. This allows a custom `FromPyObject` impl, even if `self` is `Clone`. |
2627
| `str` | Implements `__str__` using the `Display` implementation of the underlying Rust datatype or by passing an optional format string `str="<format string>"`. *Note: The optional format string is only allowed for structs. `name` and `rename_all` are incompatible with the optional format string. Additional details can be found in the discussion on this [PR](https://github.com/PyO3/pyo3/pull/4233).* |
2728
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |

newsfragments/5421.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement `new = "from_fields"` attribute for `#[pyclass]`

pyo3-macros-backend/src/attributes.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub mod kw {
4040
syn::custom_keyword!(sequence);
4141
syn::custom_keyword!(set);
4242
syn::custom_keyword!(set_all);
43+
syn::custom_keyword!(new);
4344
syn::custom_keyword!(signature);
4445
syn::custom_keyword!(str);
4546
syn::custom_keyword!(subclass);
@@ -311,13 +312,41 @@ impl ToTokens for TextSignatureAttributeValue {
311312
}
312313
}
313314

315+
#[derive(Clone, Debug, PartialEq, Eq)]
316+
pub enum NewImplTypeAttributeValue {
317+
FromFields,
318+
// Future variant for 'default' should go here
319+
}
320+
321+
impl Parse for NewImplTypeAttributeValue {
322+
fn parse(input: ParseStream<'_>) -> Result<Self> {
323+
let string_literal: LitStr = input.parse()?;
324+
if string_literal.value().as_str() == "from_fields" {
325+
Ok(NewImplTypeAttributeValue::FromFields)
326+
} else {
327+
bail_spanned!(string_literal.span() => "expected \"from_fields\"")
328+
}
329+
}
330+
}
331+
332+
impl ToTokens for NewImplTypeAttributeValue {
333+
fn to_tokens(&self, tokens: &mut TokenStream) {
334+
match self {
335+
NewImplTypeAttributeValue::FromFields => {
336+
tokens.extend(quote! { "from_fields" });
337+
}
338+
}
339+
}
340+
}
341+
314342
pub type ExtendsAttribute = KeywordAttribute<kw::extends, Path>;
315343
pub type FreelistAttribute = KeywordAttribute<kw::freelist, Box<Expr>>;
316344
pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;
317345
pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>;
318346
pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>;
319347
pub type StrFormatterAttribute = OptionalKeywordAttribute<kw::str, StringFormatter>;
320348
pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>;
349+
pub type NewImplTypeAttribute = KeywordAttribute<kw::new, NewImplTypeAttributeValue>;
321350
pub type SubmoduleAttribute = kw::submodule;
322351
pub type GILUsedAttribute = KeywordAttribute<kw::gil_used, LitBool>;
323352

pyo3-macros-backend/src/pyclass.rs

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ use syn::{parse_quote, parse_quote_spanned, spanned::Spanned, ImplItemFn, Result
1111
use crate::attributes::kw::frozen;
1212
use crate::attributes::{
1313
self, kw, take_pyo3_options, CrateAttribute, ExtendsAttribute, FreelistAttribute,
14-
ModuleAttribute, NameAttribute, NameLitStr, RenameAllAttribute, StrFormatterAttribute,
14+
ModuleAttribute, NameAttribute, NameLitStr, NewImplTypeAttribute, NewImplTypeAttributeValue,
15+
RenameAllAttribute, StrFormatterAttribute,
1516
};
1617
use crate::combine_errors::CombineErrors;
1718
#[cfg(feature = "experimental-inspect")]
@@ -88,6 +89,7 @@ pub struct PyClassPyO3Options {
8889
pub rename_all: Option<RenameAllAttribute>,
8990
pub sequence: Option<kw::sequence>,
9091
pub set_all: Option<kw::set_all>,
92+
pub new: Option<NewImplTypeAttribute>,
9193
pub str: Option<StrFormatterAttribute>,
9294
pub subclass: Option<kw::subclass>,
9395
pub unsendable: Option<kw::unsendable>,
@@ -115,6 +117,7 @@ pub enum PyClassPyO3Option {
115117
RenameAll(RenameAllAttribute),
116118
Sequence(kw::sequence),
117119
SetAll(kw::set_all),
120+
New(NewImplTypeAttribute),
118121
Str(StrFormatterAttribute),
119122
Subclass(kw::subclass),
120123
Unsendable(kw::unsendable),
@@ -161,6 +164,8 @@ impl Parse for PyClassPyO3Option {
161164
input.parse().map(PyClassPyO3Option::Sequence)
162165
} else if lookahead.peek(attributes::kw::set_all) {
163166
input.parse().map(PyClassPyO3Option::SetAll)
167+
} else if lookahead.peek(attributes::kw::new) {
168+
input.parse().map(PyClassPyO3Option::New)
164169
} else if lookahead.peek(attributes::kw::str) {
165170
input.parse().map(PyClassPyO3Option::Str)
166171
} else if lookahead.peek(attributes::kw::subclass) {
@@ -243,6 +248,7 @@ impl PyClassPyO3Options {
243248
PyClassPyO3Option::RenameAll(rename_all) => set_option!(rename_all),
244249
PyClassPyO3Option::Sequence(sequence) => set_option!(sequence),
245250
PyClassPyO3Option::SetAll(set_all) => set_option!(set_all),
251+
PyClassPyO3Option::New(new) => set_option!(new),
246252
PyClassPyO3Option::Str(str) => set_option!(str),
247253
PyClassPyO3Option::Subclass(subclass) => set_option!(subclass),
248254
PyClassPyO3Option::Unsendable(unsendable) => set_option!(unsendable),
@@ -488,6 +494,13 @@ fn impl_class(
488494
}
489495
}
490496

497+
let (default_new, default_new_slot) = pyclass_new_impl(
498+
&args.options,
499+
&syn::parse_quote!(#cls),
500+
field_options.iter().map(|(f, _)| f),
501+
ctx,
502+
)?;
503+
491504
let mut default_methods = descriptors_to_items(
492505
cls,
493506
args.options.rename_all.as_ref(),
@@ -516,6 +529,7 @@ fn impl_class(
516529
slots.extend(default_richcmp_slot);
517530
slots.extend(default_hash_slot);
518531
slots.extend(default_str_slot);
532+
slots.extend(default_new_slot);
519533

520534
let impl_builder =
521535
PyClassImplsBuilder::new(cls, cls, args, methods_type, default_methods, slots).doc(doc);
@@ -543,6 +557,7 @@ fn impl_class(
543557
#default_richcmp
544558
#default_hash
545559
#default_str
560+
#default_new
546561
#default_class_getitem
547562
}
548563
})
@@ -1711,11 +1726,11 @@ fn generate_protocol_slot(
17111726
) -> syn::Result<MethodAndSlotDef> {
17121727
let spec = FnSpec::parse(
17131728
&mut method.sig,
1714-
&mut Vec::new(),
1729+
&mut method.attrs,
17151730
PyFunctionOptions::default(),
17161731
)?;
17171732
#[cfg_attr(not(feature = "experimental-inspect"), allow(unused_mut))]
1718-
let mut def = slot.generate_type_slot(&syn::parse_quote!(#cls), &spec, name, ctx)?;
1733+
let mut def = slot.generate_type_slot(cls, &spec, name, ctx)?;
17191734
#[cfg(feature = "experimental-inspect")]
17201735
def.add_introspection(introspection_data.generate(ctx, cls));
17211736
Ok(def)
@@ -2431,6 +2446,94 @@ fn pyclass_hash(
24312446
}
24322447
}
24332448

2449+
fn pyclass_new_impl<'a>(
2450+
options: &PyClassPyO3Options,
2451+
ty: &syn::Type,
2452+
fields: impl Iterator<Item = &'a &'a syn::Field>,
2453+
ctx: &Ctx,
2454+
) -> Result<(Option<ImplItemFn>, Option<MethodAndSlotDef>)> {
2455+
if options
2456+
.new
2457+
.as_ref()
2458+
.is_some_and(|o| matches!(o.value, NewImplTypeAttributeValue::FromFields))
2459+
{
2460+
ensure_spanned!(
2461+
options.extends.is_none(), options.new.span() => "The `new=\"from_fields\"` option cannot be used with `extends`.";
2462+
);
2463+
}
2464+
2465+
let mut tuple_struct: bool = false;
2466+
2467+
match &options.new {
2468+
Some(opt) => {
2469+
let mut field_idents = vec![];
2470+
let mut field_types = vec![];
2471+
for (idx, field) in fields.enumerate() {
2472+
tuple_struct = field.ident.is_none();
2473+
2474+
field_idents.push(
2475+
field
2476+
.ident
2477+
.clone()
2478+
.unwrap_or_else(|| format_ident!("_{}", idx)),
2479+
);
2480+
field_types.push(&field.ty);
2481+
}
2482+
2483+
let mut new_impl = if tuple_struct {
2484+
parse_quote_spanned! { opt.span() =>
2485+
#[new]
2486+
fn __pyo3_generated____new__( #( #field_idents : #field_types ),* ) -> Self {
2487+
Self (
2488+
#( #field_idents, )*
2489+
)
2490+
}
2491+
}
2492+
} else {
2493+
parse_quote_spanned! { opt.span() =>
2494+
#[new]
2495+
fn __pyo3_generated____new__( #( #field_idents : #field_types ),* ) -> Self {
2496+
Self {
2497+
#( #field_idents, )*
2498+
}
2499+
}
2500+
}
2501+
};
2502+
2503+
let new_slot = generate_protocol_slot(
2504+
ty,
2505+
&mut new_impl,
2506+
&__NEW__,
2507+
"__new__",
2508+
#[cfg(feature = "experimental-inspect")]
2509+
FunctionIntrospectionData {
2510+
names: &["__new__"],
2511+
arguments: field_idents
2512+
.iter()
2513+
.zip(field_types.iter())
2514+
.map(|(ident, ty)| {
2515+
FnArg::Regular(RegularArg {
2516+
name: Cow::Owned(ident.clone()),
2517+
ty,
2518+
from_py_with: None,
2519+
default_value: None,
2520+
option_wrapped_type: None,
2521+
annotation: None,
2522+
})
2523+
})
2524+
.collect(),
2525+
returns: ty.clone(),
2526+
},
2527+
ctx,
2528+
)
2529+
.unwrap();
2530+
2531+
Ok((Some(new_impl), Some(new_slot)))
2532+
}
2533+
None => Ok((None, None)),
2534+
}
2535+
}
2536+
24342537
fn pyclass_class_getitem(
24352538
options: &PyClassPyO3Options,
24362539
cls: &syn::Type,

tests/test_class_attributes.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,42 @@ fn test_renaming_all_struct_fields() {
235235
});
236236
}
237237

238+
#[pyclass(get_all, set_all, new = "from_fields")]
239+
struct AutoNewCls {
240+
a: i32,
241+
b: String,
242+
c: Option<f64>,
243+
}
244+
245+
#[test]
246+
fn new_impl() {
247+
Python::attach(|py| {
248+
// python should be able to do AutoNewCls(1, "two", 3.0)
249+
let cls = py.get_type::<AutoNewCls>();
250+
pyo3::py_run!(
251+
py,
252+
cls,
253+
"inst = cls(1, 'two', 3.0); assert inst.a == 1; assert inst.b == 'two'; assert inst.c == 3.0"
254+
);
255+
});
256+
}
257+
258+
#[pyclass(new = "from_fields", get_all)]
259+
struct Point2d(#[pyo3(name = "first")] f64, #[pyo3(name = "second")] f64);
260+
261+
#[test]
262+
fn new_impl_tuple_struct() {
263+
Python::attach(|py| {
264+
// python should be able to do AutoNewCls(1, "two", 3.0)
265+
let cls = py.get_type::<Point2d>();
266+
pyo3::py_run!(
267+
py,
268+
cls,
269+
"inst = cls(0.2, 0.3); assert inst.first == 0.2; assert inst.second == 0.3"
270+
);
271+
});
272+
}
273+
238274
macro_rules! test_case {
239275
($struct_name: ident, $rule: literal, $field_name: ident, $renamed_field_name: literal, $test_name: ident) => {
240276
#[pyclass(get_all, set_all, rename_all = $rule)]

tests/test_compile_error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ fn test_compile_errors() {
99
#[cfg(not(feature = "experimental-inspect"))]
1010
t.compile_fail("tests/ui/invalid_property_args.rs");
1111
t.compile_fail("tests/ui/invalid_proto_pymethods.rs");
12+
#[cfg(not(all(Py_LIMITED_API, not(Py_3_10))))] // to avoid PyFunctionArgument for &str
1213
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
1314
t.compile_fail("tests/ui/invalid_pyclass_doc.rs");
1415
t.compile_fail("tests/ui/invalid_pyclass_enum.rs");
@@ -19,6 +20,7 @@ fn test_compile_errors() {
1920
#[cfg(Py_3_9)]
2021
t.compile_fail("tests/ui/pyclass_generic_enum.rs");
2122
#[cfg(not(feature = "experimental-inspect"))]
23+
#[cfg(not(all(Py_LIMITED_API, not(Py_3_10))))] // to avoid PyFunctionArgument for &str
2224
t.compile_fail("tests/ui/invalid_pyfunction_argument.rs");
2325
t.compile_fail("tests/ui/invalid_pyfunction_definition.rs");
2426
t.compile_fail("tests/ui/invalid_pyfunction_signatures.rs");

tests/ui/invalid_pyclass_args.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,22 @@ struct StructImplicitFromPyObjectDeprecated {
200200
b: String,
201201
}
202202

203+
#[pyclass(new = "from_fields")]
204+
struct NonPythonField {
205+
field: Box<dyn std::error::Error + Send + Sync>,
206+
}
207+
208+
#[pyclass(new = "from_fields")]
209+
struct NewFromFieldsWithManualNew {
210+
field: i32,
211+
}
212+
213+
#[pymethods]
214+
impl NewFromFieldsWithManualNew {
215+
#[new]
216+
fn new(field: i32) -> Self {
217+
Self { field }
218+
}
219+
}
220+
203221
fn main() {}

0 commit comments

Comments
 (0)