Skip to content

Commit 1fcfff2

Browse files
authored
Add support for Krita Style & Prompt node (#2276)
Features: - Prompt and Style widgets sync across Generate/Live/Animation and now Graph workspaces - Full prompt evaluation: wildcards, comments, layer extraction, style merging, LoRA extraction - LoRAs applied to model output - Error shown if workflow contains multiple KritaPromptStyle nodes (only one allowed since prompts sync to shared fields)
1 parent 7c896ca commit 1fcfff2

File tree

5 files changed

+126
-4
lines changed

5 files changed

+126
-4
lines changed

ai_diffusion/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ class UpscaleInput:
169169
class CustomWorkflowInput:
170170
workflow: dict
171171
params: dict[str, Any]
172+
positive_evaluated: str = ""
173+
negative_evaluated: str = ""
174+
models: CheckpointInput | None = None
175+
sampling: SamplingInput | None = None
172176

173177

174178
@dataclass

ai_diffusion/custom_workflow.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .api import WorkflowInput, InpaintContext
1414
from .client import OutputBatchMode, TextOutput, ClientOutput, JobInfoOutput
1515
from .comfy_workflow import ComfyWorkflow, ComfyNode
16+
from .localization import translate as _
1617
from .connection import Connection, ConnectionState
1718
from .image import Bounds, Image, Mask
1819
from .jobs import Job, JobParams, JobQueue, JobKind
@@ -373,6 +374,7 @@ class CustomWorkspace(QObject, ObservableProperties):
373374
has_result = Property(False)
374375
outputs = Property({})
375376
params_ui_height = Property(100, persist=True)
377+
validation_error = Property("")
376378

377379
workflow_id_changed = pyqtSignal(str)
378380
graph_changed = pyqtSignal()
@@ -383,6 +385,7 @@ class CustomWorkspace(QObject, ObservableProperties):
383385
has_result_changed = pyqtSignal(bool)
384386
outputs_changed = pyqtSignal(dict)
385387
params_ui_height_changed = pyqtSignal(int)
388+
validation_error_changed = pyqtSignal(str)
386389
modified = pyqtSignal(QObject, str)
387390

388391
_live_poll_rate = 0.1
@@ -419,10 +422,20 @@ def _update_workflow(self, idx: QModelIndex, _: QModelIndex):
419422
if wf.id == self._workflow_id:
420423
self._workflow = wf
421424
self._graph = self._workflow.workflow
425+
self._validate_workflow(self._graph)
422426
self._metadata = list(workflow_parameters(self._graph))
423427
self.params = _coerce(self.params, self._metadata)
424428
self.graph_changed.emit()
425429

430+
def _validate_workflow(self, wf: ComfyWorkflow):
431+
style_and_prompt_node_count = sum(1 for _ in wf.find(type="ETN_KritaStyleAndPrompt"))
432+
if style_and_prompt_node_count > 1:
433+
self.validation_error = _(
434+
"Workflow contains multiple `Krita Style & Prompt` nodes. Only one is allowed since prompts sync across workspaces."
435+
)
436+
else:
437+
self.validation_error = ""
438+
426439
def _set_workflow_id(self, id: str):
427440
if self._workflow_id == id:
428441
return

ai_diffusion/model.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@
1919
from .api import FillMode, ImageInput, CustomWorkflowInput, UpscaleInput
2020
from .api import InpaintMode, InpaintContext, InpaintParams
2121
from .localization import translate as _
22-
from .util import clamp, ensure, trim_text, client_logger as log
22+
from .util import clamp, ensure, unique, trim_text, client_logger as log
2323
from .settings import ApplyBehavior, ApplyRegionBehavior, GenerationFinishedAction, ImageFileFormat
2424
from .settings import settings
2525
from .network import NetworkError
2626
from .image import Extent, Image, Mask, Bounds, DummyImage
2727
from .client import Client, ClientMessage, ClientEvent, ClientOutput
2828
from .client import is_style_supported, filter_supported_styles, resolve_arch
29-
from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode
29+
from .custom_workflow import (
30+
CustomWorkspace,
31+
WorkflowCollection,
32+
CustomGenerationMode,
33+
ComfyWorkflow,
34+
)
3035
from .document import Document, KritaDocument, SelectionModifiers
3136
from .layer import Layer, LayerType, RestoreActiveLayer
3237
from .pose import Pose
@@ -149,6 +154,7 @@ def __init__(self, document: Document, connection: Connection, workflows: Workfl
149154
self.jobs.selection_changed.connect(self.update_preview)
150155
connection.state_changed.connect(self._init_on_connect)
151156
connection.error_changed.connect(self._forward_error)
157+
self.custom.validation_error_changed.connect(self._forward_validation_error)
152158
Styles.list().changed.connect(self._init_on_connect)
153159
self._init_on_connect()
154160

@@ -165,6 +171,12 @@ def _init_on_connect(self):
165171
def _forward_error(self, error: str):
166172
self.report_error(error if error else no_error)
167173

174+
def _forward_validation_error(self, error: str):
175+
if error:
176+
self.report_error(Error(ErrorKind.warning, error))
177+
else:
178+
self.clear_error()
179+
168180
def generate(self):
169181
"""Enqueue image generation for the current setup."""
170182
self._generate(self.queue_mode)
@@ -488,14 +500,29 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
488500
img_input.hires_mask = mask.to_image(bounds.extent) if mask else None
489501

490502
params = self.custom.collect_parameters(self.layers, canvas_bounds, is_anim)
503+
504+
has_synced_style_and_prompt = (
505+
next(wf.find(type="ETN_KritaStyleAndPrompt"), None) is not None
506+
)
507+
custom_input = CustomWorkflowInput(wf.root, params)
508+
prompt_meta = {}
509+
if has_synced_style_and_prompt:
510+
custom_input, prompt_meta = self._prepare_synced_style_and_prompt(
511+
params, seed, custom_input, wf
512+
)
513+
491514
input = WorkflowInput(
492515
WorkflowKind.custom,
493516
img_input,
494517
sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed),
495518
inpaint=InpaintParams(InpaintMode.fill, bounds),
496-
custom_workflow=CustomWorkflowInput(wf.root, params),
519+
custom_workflow=custom_input,
497520
)
498-
job_params = JobParams(bounds, self.custom.job_name, metadata=self.custom.params)
521+
522+
metadata: dict[str, Any] = dict(self.custom.params)
523+
metadata.update(prompt_meta)
524+
525+
job_params = JobParams(bounds, self.custom.job_name, metadata=metadata)
499526
job_kind = {
500527
CustomGenerationMode.regular: JobKind.diffusion,
501528
CustomGenerationMode.live: JobKind.live_preview,
@@ -513,6 +540,46 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
513540
self.report_error(util.log_error(e))
514541
return False
515542

543+
def _prepare_synced_style_and_prompt(
544+
self,
545+
params: dict[str, Any],
546+
seed: int,
547+
custom_input: CustomWorkflowInput,
548+
wf: ComfyWorkflow,
549+
) -> tuple[CustomWorkflowInput, dict[str, Any]]:
550+
"""Prepare prompts and models for ETN_KritaStyleAndPrompt node.
551+
Returns updated CustomWorkflowInput with evaluated prompts, models, sampling, and metadata for job history.
552+
"""
553+
style = self.style
554+
555+
style_node = next(wf.find(type="ETN_KritaStyleAndPrompt"), None)
556+
is_live = style_node.input("sampler_preset", "auto") == "live" if style_node else False
557+
558+
checkpoint_input = style.get_models(self._connection.client.models.checkpoints)
559+
sampling = workflow._sampling_from_style(style, 1.0, is_live)
560+
561+
positive = self.regions.positive
562+
negative = self.regions.negative
563+
564+
cond = ConditioningInput(positive, negative)
565+
arch = resolve_arch(style, self._connection.client_if_connected)
566+
prepared = workflow.prepare_prompts(cond, style, seed, arch, FileLibrary.instance())
567+
568+
merged_loras = unique(checkpoint_input.loras + prepared.loras, key=lambda l: l.name)
569+
checkpoint_input.loras = merged_loras
570+
571+
custom_input = replace(
572+
custom_input,
573+
positive_evaluated=prepared.metadata["prompt_final"],
574+
negative_evaluated=prepared.metadata["negative_prompt_final"],
575+
models=checkpoint_input,
576+
sampling=sampling,
577+
)
578+
579+
meta = dict(prepared.metadata)
580+
meta["style"] = style.filename
581+
return custom_input, meta
582+
516583
def _get_current_image(self, bounds: Bounds):
517584
exclude = []
518585
if self.workspace is not Workspace.live:

ai_diffusion/ui/custom_workflow.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .live import LivePreviewArea
2626
from .switch import SwitchWidget
2727
from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget, ErrorBox
28+
from .region import ActiveRegionWidget, PromptHeader
2829
from .settings_widgets import ExpanderButton
2930
from . import theme
3031
from .theme import SignalBlocker
@@ -700,6 +701,16 @@ def __init__(self):
700701
self._params_scroll.setFrameShape(QFrame.Shape.NoFrame)
701702
self._params_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
702703

704+
self._style_widget = StyleSelectWidget(self)
705+
self._style_widget.setVisible(False) # Hidden until workflow has synced style
706+
707+
self._prompt_widget = ActiveRegionWidget(
708+
self._model.regions, self, header=PromptHeader.none
709+
)
710+
self._prompt_widget.setVisible(False) # Hidden until workflow has synced prompts
711+
self._prompt_widget.positive.activated.connect(self._generate)
712+
self._prompt_widget.negative.activated.connect(self._generate)
713+
703714
self._bottom = QWidget(self)
704715

705716
self._generate_button = GenerateButton(JobKind.diffusion, self._bottom)
@@ -771,6 +782,8 @@ def __init__(self):
771782
header_layout.addWidget(self._workflow_select_widgets)
772783
header_layout.addWidget(self._workflow_edit_widgets)
773784
layout.addLayout(header_layout)
785+
layout.addWidget(self._style_widget)
786+
layout.addWidget(self._prompt_widget)
774787
layout.addWidget(self._splitter)
775788
actions_layout = QHBoxLayout()
776789
actions_layout.setSpacing(0)
@@ -808,6 +821,7 @@ def model(self, model: Model):
808821
self._model = model
809822
self._model_bindings = [
810823
bind(model, "workspace", self._workspace_select, "value", Bind.one_way),
824+
bind(model, "style", self._style_widget, "value"),
811825
bind(model, "error", self._error_box, "error", Bind.one_way),
812826
bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way),
813827
bind(model.custom, "outputs", self._outputs, "value", Bind.one_way),
@@ -822,6 +836,7 @@ def model(self, model: Model):
822836
self._queue_button.model = model
823837
self._progress_bar.model = model
824838
self._history.model_ = model
839+
self._prompt_widget.region = model.regions
825840
self._update_current_workflow()
826841
self._update_ui()
827842
self._set_params_height(model.custom.params_ui_height)
@@ -874,12 +889,21 @@ def _update_current_workflow(self):
874889
if not self.model.custom.workflow:
875890
self._save_workflow_button.setEnabled(False)
876891
self._delete_workflow_button.setEnabled(False)
892+
self._style_widget.setVisible(False)
893+
self._prompt_widget.setVisible(False)
877894
return
878895
self._save_workflow_button.setEnabled(True)
879896
self._delete_workflow_button.setEnabled(
880897
self.model.custom.workflow.source is WorkflowSource.local
881898
)
882899

900+
graph = self.model.custom.graph
901+
has_synced_style_and_prompt = (
902+
graph is not None and next(graph.find(type="ETN_KritaStyleAndPrompt"), None) is not None
903+
)
904+
self._style_widget.setVisible(has_synced_style_and_prompt)
905+
self._prompt_widget.setVisible(has_synced_style_and_prompt)
906+
883907
if self._params_widget:
884908
self._params_scroll.setWidget(None)
885909
self._params_widget.deleteLater()

ai_diffusion/workflow.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,20 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None =
14201420
outputs[node.output(6)] = sampling.scheduler
14211421
outputs[node.output(7)] = sampling.total_steps
14221422
outputs[node.output(8)] = sampling.cfg_scale
1423+
1424+
case "ETN_KritaStyleAndPrompt":
1425+
checkpoint_input = ensure(input.models)
1426+
sampling = ensure(input.sampling)
1427+
model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models)
1428+
outputs[node.output(0)] = model
1429+
outputs[node.output(1)] = clip.model
1430+
outputs[node.output(2)] = vae
1431+
outputs[node.output(3)] = input.positive_evaluated
1432+
outputs[node.output(4)] = input.negative_evaluated
1433+
outputs[node.output(5)] = sampling.sampler
1434+
outputs[node.output(6)] = sampling.scheduler
1435+
outputs[node.output(7)] = sampling.total_steps
1436+
outputs[node.output(8)] = sampling.cfg_scale
14231437
case _:
14241438
mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()}
14251439
mapped = ComfyNode(node.id, node.type, mapped_inputs)

0 commit comments

Comments
 (0)