Skip to content

Commit df8ceaa

Browse files
committed
Fixes for Krita Style & Prompt custom workflow node
* Remove auto-clear logic from _forward_validation_error that was incorrectly clearing all errors when validation passed. * Add missing style metadata (checkpoint, sampler, steps, guidance, loras) to custom workflow jobs for consistent "Info to Clipboard" output * Use job_params.set_style() for consistency with regular generation * Rename method to _prepare_style_and_prompt, simplify parameters
1 parent 2316bd3 commit df8ceaa

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

ai_diffusion/model.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,7 @@
2626
from .image import Extent, Image, Mask, Bounds, DummyImage
2727
from .client import Client, ClientMessage, ClientEvent, ClientOutput
2828
from .client import is_style_supported, filter_supported_styles, resolve_arch
29-
from .custom_workflow import (
30-
CustomWorkspace,
31-
WorkflowCollection,
32-
CustomGenerationMode,
33-
ComfyWorkflow,
34-
)
29+
from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode
3530
from .document import Document, KritaDocument, SelectionModifiers
3631
from .layer import Layer, LayerType, RestoreActiveLayer
3732
from .pose import Pose
@@ -73,6 +68,7 @@ class ErrorKind(Enum):
7368
insufficient_funds = 201
7469
warning = 300
7570
incompatible_lora = 301
71+
validation_warning = 302
7672

7773
@property
7874
def is_warning(self):
@@ -173,8 +169,8 @@ def _forward_error(self, error: str):
173169

174170
def _forward_validation_error(self, error: str):
175171
if error:
176-
self.report_error(Error(ErrorKind.warning, error))
177-
else:
172+
self.report_error(Error(ErrorKind.validation_warning, error))
173+
elif self.error.kind is ErrorKind.validation_warning:
178174
self.clear_error()
179175

180176
def generate(self):
@@ -498,15 +494,13 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
498494
img_input.hires_mask = mask.to_image(bounds.extent) if mask else None
499495

500496
params = self.custom.collect_parameters(self.layers, canvas_bounds, is_anim)
501-
502-
has_synced_style_and_prompt = (
503-
next(wf.find(type="ETN_KritaStyleAndPrompt"), None) is not None
504-
)
505497
custom_input = CustomWorkflowInput(wf.root, params)
506-
prompt_meta = {}
507-
if has_synced_style_and_prompt:
508-
custom_input, prompt_meta = self._prepare_synced_style_and_prompt(
509-
params, seed, custom_input, wf
498+
499+
style_and_prompt_node = next(wf.find(type="ETN_KritaStyleAndPrompt"), None)
500+
prompt_meta: dict[str, Any] = {}
501+
if style_and_prompt_node is not None:
502+
custom_input, prompt_meta = self._prepare_style_and_prompt(
503+
style_and_prompt_node, seed, custom_input
510504
)
511505

512506
input = WorkflowInput(
@@ -519,8 +513,13 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
519513

520514
metadata: dict[str, Any] = dict(self.custom.params)
521515
metadata.update(prompt_meta)
522-
523516
job_params = JobParams(bounds, self.custom.job_name, metadata=metadata)
517+
if style_and_prompt_node is not None:
518+
models = ensure(custom_input.models)
519+
job_params.set_style(self.style, models.checkpoint)
520+
job_params.metadata["loras"] = [
521+
dict(name=l.name, weight=l.strength) for l in models.loras
522+
]
524523
job_kind = {
525524
CustomGenerationMode.regular: JobKind.diffusion,
526525
CustomGenerationMode.live: JobKind.live_preview,
@@ -538,28 +537,18 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
538537
self.report_error(util.log_error(e))
539538
return False
540539

541-
def _prepare_synced_style_and_prompt(
542-
self,
543-
params: dict[str, Any],
544-
seed: int,
545-
custom_input: CustomWorkflowInput,
546-
wf: ComfyWorkflow,
540+
def _prepare_style_and_prompt(
541+
self, style_and_prompt_node: Any, seed: int, custom_input: CustomWorkflowInput
547542
) -> tuple[CustomWorkflowInput, dict[str, Any]]:
548543
"""Prepare prompts and models for ETN_KritaStyleAndPrompt node.
549544
Returns updated CustomWorkflowInput with evaluated prompts, models, sampling, and metadata for job history.
550545
"""
551546
style = self.style
552-
553-
style_node = next(wf.find(type="ETN_KritaStyleAndPrompt"), None)
554-
is_live = style_node.input("sampler_preset", "auto") == "live" if style_node else False
555-
547+
is_live = style_and_prompt_node.input("sampler_preset", "auto") == "live"
556548
checkpoint_input = style.get_models(self._connection.client.models.checkpoints)
557549
sampling = workflow._sampling_from_style(style, 1.0, is_live)
558550

559-
positive = self.regions.positive
560-
negative = self.regions.negative
561-
562-
cond = ConditioningInput(positive, negative)
551+
cond = ConditioningInput(self.regions.positive, self.regions.negative)
563552
arch = resolve_arch(style, self._connection.client_if_connected)
564553
prepared = workflow.prepare_prompts(cond, style, seed, arch, FileLibrary.instance())
565554

@@ -573,10 +562,7 @@ def _prepare_synced_style_and_prompt(
573562
models=checkpoint_input,
574563
sampling=sampling,
575564
)
576-
577-
meta = dict(prepared.metadata)
578-
meta["style"] = style.filename
579-
return custom_input, meta
565+
return custom_input, prepared.metadata
580566

581567
def _get_current_image(self, bounds: Bounds):
582568
exclude = []

0 commit comments

Comments
 (0)