Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ export enum TorchMirrorUrl {
NightlyCpu = 'https://download.pytorch.org/whl/nightly/cpu',
}

export type TorchUpdatePolicy = 'auto' | 'defer' | 'pinned';

export type TorchPinnedPackages = {
torch?: string;
torchaudio?: string;
torchvision?: string;
};
Comment on lines +184 to +190
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Add brief JSDoc for new torch policy types.

These are exported public types; a short description helps consumers and aligns with the repo’s documentation expectations. As per coding guidelines, please add concise JSDoc.

✍️ Suggested doc additions
+/** How NVIDIA PyTorch updates should be handled. */
 export type TorchUpdatePolicy = 'auto' | 'defer' | 'pinned';

+/** Pinned NVIDIA torch package versions when updates are disabled. */
 export type TorchPinnedPackages = {
   torch?: string;
   torchaudio?: string;
   torchvision?: string;
 };
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
export type TorchUpdatePolicy = 'auto' | 'defer' | 'pinned';
export type TorchPinnedPackages = {
torch?: string;
torchaudio?: string;
torchvision?: string;
};
/** How NVIDIA PyTorch updates should be handled. */
export type TorchUpdatePolicy = 'auto' | 'defer' | 'pinned';
/** Pinned NVIDIA torch package versions when updates are disabled. */
export type TorchPinnedPackages = {
torch?: string;
torchaudio?: string;
torchvision?: string;
};
🤖 Prompt for AI Agents
In `@src/constants.ts` around lines 184 - 190, Add concise JSDoc comments above
the exported types TorchUpdatePolicy and TorchPinnedPackages describing their
purpose and valid values/fields; for TorchUpdatePolicy document the allowed
string options ('auto', 'defer', 'pinned') and their behavior, and for
TorchPinnedPackages describe that it maps optional package names (torch,
torchaudio, torchvision) to version strings to pin specific versions. Keep
comments short, one-to-two sentences each and follow the existing JSDoc style in
the file.


/** Legacy NVIDIA torch mirror used by older installs (CUDA 12.9). */
export const LEGACY_NVIDIA_TORCH_MIRROR = 'https://download.pytorch.org/whl/cu129';

Expand Down Expand Up @@ -216,6 +224,10 @@ export const NVIDIA_TORCH_PACKAGES: string[] = [
`torchaudio==${NVIDIA_TORCH_VERSION}`,
`torchvision==${NVIDIA_TORCHVISION_VERSION}`,
];
/** Minimum NVIDIA driver version recommended for the current CUDA torch build. */
export const NVIDIA_DRIVER_MIN_VERSION = '580';
/** Recommended NVIDIA torch package set key (torch/torchaudio/torchvision). */
export const NVIDIA_TORCH_RECOMMENDED_VERSION = `${NVIDIA_TORCH_VERSION}|${NVIDIA_TORCHVISION_VERSION}`;

/** The log files used by the desktop process. */
export enum LogFile {
Expand Down
185 changes: 181 additions & 4 deletions src/install/installationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@ import { promisify } from 'node:util';

import { strictIpcMain as ipcMain } from '@/infrastructure/ipcChannels';

import { IPC_CHANNELS, InstallStage, ProgressStatus } from '../constants';
import { useComfySettings } from '../config/comfySettings';
import {
IPC_CHANNELS,
InstallStage,
NVIDIA_DRIVER_MIN_VERSION,
NVIDIA_TORCHVISION_VERSION,
NVIDIA_TORCH_RECOMMENDED_VERSION,
NVIDIA_TORCH_VERSION,
ProgressStatus,
TorchMirrorUrl,
} from '../constants';
import { PythonImportVerificationError } from '../infrastructure/pythonImportVerificationError';
import { useAppState } from '../main-process/appState';
import type { AppWindow } from '../main-process/appWindow';
Expand All @@ -23,7 +33,8 @@ import { InstallWizard } from './installWizard';
import { Troubleshooting } from './troubleshooting';

const execAsync = promisify(exec);
const NVIDIA_DRIVER_MIN_VERSION = '580';
const TORCH_MIRROR_CUDA_PATH = new URL(TorchMirrorUrl.Cuda).pathname;
const TORCH_MIRROR_NIGHTLY_CUDA_PATH = new URL(TorchMirrorUrl.NightlyCuda).pathname;

/**
* Extracts the NVIDIA driver version from `nvidia-smi` output.
Expand Down Expand Up @@ -91,7 +102,7 @@ export class InstallationManager implements HasTelemetry {
// Convert from old format
if (state === 'upgraded') installation.upgradeConfig();

// Install updated manager requirements
// Install updated requirements
if (installation.needsRequirementsUpdate) await this.updatePackages(installation);

// Resolve issues and re-run validation
Expand Down Expand Up @@ -382,14 +393,180 @@ export class InstallationManager implements HasTelemetry {
await installation.virtualEnvironment.installComfyUIRequirements(callbacks);
await installation.virtualEnvironment.installComfyUIManagerRequirements(callbacks);
await this.warnIfNvidiaDriverTooOld(installation);
await installation.virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks);
await this.maybeUpdateNvidiaTorch(installation, callbacks);
await installation.validate();
} catch (error) {
log.error('Error auto-updating packages:', error);
await this.appWindow.loadPage('server-start');
}
}

private async maybeUpdateNvidiaTorch(installation: ComfyInstallation, callbacks: ProcessCallbacks): Promise<void> {
const virtualEnvironment = installation.virtualEnvironment;
if (virtualEnvironment.selectedDevice !== 'nvidia') return;

const config = useDesktopConfig();
const updatePolicy = config.get('torchUpdatePolicy');
const recommendedVersion = NVIDIA_TORCH_RECOMMENDED_VERSION;
const lastPromptedVersion = config.get('torchLastPromptedVersion');
if (updatePolicy === 'pinned' && lastPromptedVersion === recommendedVersion) {
log.info('Skipping NVIDIA PyTorch update because updates are pinned for this version.');
return;
}

const installedVersions = await virtualEnvironment.getInstalledTorchPackageVersions();
if (!installedVersions) {
log.warn('Skipping NVIDIA PyTorch update because installed versions could not be read.');
return;
}

const isOutOfDate = await virtualEnvironment.isNvidiaTorchOutOfDate(installedVersions);
if (!isOutOfDate) return;

if (config.get('torchOutOfDateRecommendedVersion') !== recommendedVersion) {
config.set('torchOutOfDateRecommendedVersion', recommendedVersion);
config.set('torchOutOfDatePackages', installedVersions);
}

if (updatePolicy === 'defer' && lastPromptedVersion === recommendedVersion) {
log.info('Skipping NVIDIA PyTorch update because updates are deferred for this version.');
return;
}

const updateApproved = updatePolicy === 'auto' && lastPromptedVersion === recommendedVersion;
let shouldAttemptUpdate = updateApproved;

if (!updateApproved) {
const currentTorch = installedVersions.torch ?? 'unknown';
const currentTorchaudio = installedVersions.torchaudio ?? 'unknown';
const currentTorchvision = installedVersions.torchvision ?? 'unknown';

const { response } = await this.appWindow.showMessageBox({
type: 'question',
title: 'Update PyTorch?',
message:
'Your NVIDIA PyTorch build is out of date. We can update it to the recommended build for improved performance. This update may affect memory usage and compatibility with some custom nodes.',
detail: [
`Current: torch ${currentTorch}, torchaudio ${currentTorchaudio}, torchvision ${currentTorchvision}`,
`Recommended: torch ${NVIDIA_TORCH_VERSION}, torchaudio ${NVIDIA_TORCH_VERSION}, torchvision ${NVIDIA_TORCHVISION_VERSION}`,
].join('\n'),
buttons: ['Update PyTorch', 'Ask again later', 'Silence until next version', 'Silence forever'],
defaultId: 0,
cancelId: 1,
});

switch (response) {
case 1:
log.info('Deferring NVIDIA PyTorch update prompt.');
return;
case 2:
config.set('torchLastPromptedVersion', recommendedVersion);
config.set('torchUpdatePolicy', 'defer');
config.delete('torchPinnedPackages');
virtualEnvironment.updateTorchUpdatePolicy('defer', undefined, recommendedVersion);
return;
case 3:
config.set('torchLastPromptedVersion', recommendedVersion);
config.set('torchUpdatePolicy', 'pinned');
config.set('torchPinnedPackages', installedVersions);
virtualEnvironment.updateTorchUpdatePolicy('pinned', installedVersions, recommendedVersion);
return;
default:
config.set('torchLastPromptedVersion', recommendedVersion);
config.set('torchUpdatePolicy', 'auto');
config.delete('torchPinnedPackages');
config.delete('torchUpdateFailureSilencedVersion');
virtualEnvironment.updateTorchUpdatePolicy('auto', undefined, recommendedVersion);
shouldAttemptUpdate = true;
}
} else {
virtualEnvironment.updateTorchUpdatePolicy('auto', undefined, recommendedVersion);
}

if (!shouldAttemptUpdate) return;

const torchMirrorOverride = await this.updateTorchMirrorForRecommendedVersion();

try {
await virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks, torchMirrorOverride);
config.delete('torchUpdateFailureSilencedVersion');
} catch (error) {
log.error('Error updating NVIDIA PyTorch packages:', error);
if (config.get('torchUpdateFailureSilencedVersion') === recommendedVersion) return;

const { response } = await this.appWindow.showMessageBox({
type: 'warning',
title: 'PyTorch update failed',
message:
'We could not install the recommended NVIDIA PyTorch build. This may be because your configured torch mirror does not provide it.',
detail: 'We will retry the update on each startup.',
buttons: ['OK', "Don't show again"],
defaultId: 0,
cancelId: 0,
});

if (response === 1) {
config.set('torchUpdateFailureSilencedVersion', recommendedVersion);
}
}
}
Comment on lines 404 to 512
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

rg -n "isUsingCustomTorchMirror|ensureRecommendedNvidiaTorch|torchMirror" src/ -t ts

Repository: Comfy-Org/desktop

Length of output: 2404


🏁 Script executed:

sed -n '694,740p' src/virtualEnvironment.ts

Repository: Comfy-Org/desktop

Length of output: 2001


🏁 Script executed:

grep -n "isUsingCustomTorchMirror" src/install/installationManager.ts

Repository: Comfy-Org/desktop

Length of output: 43


🏁 Script executed:

grep -rn "isUsingCustomTorchMirror" src/ -t ts

Repository: Comfy-Org/desktop

Length of output: 171


🏁 Script executed:

rg -n "isUsingCustomTorchMirror" src/

Repository: Comfy-Org/desktop

Length of output: 130


Add early return for custom torch mirrors.

The method isUsingCustomTorchMirror() exists but is unused in the torch update flow. Without a guard at the start of maybeUpdateNvidiaTorch(), users with custom mirrors will still see update prompts and attempts, contradicting the PR objective to skip updates when a custom mirror is configured.

Add the check after the device check:

if (virtualEnvironment.isUsingCustomTorchMirror()) {
  log.info('Skipping NVIDIA PyTorch update because a custom torch mirror is configured.');
  return;
}
🤖 Prompt for AI Agents
In `@src/install/installationManager.ts` around lines 400 - 504, The
maybeUpdateNvidiaTorch function lacks a guard for custom torch mirrors; after
the existing device check (virtualEnvironment.selectedDevice !== 'nvidia') add a
check using virtualEnvironment.isUsingCustomTorchMirror() and if true log a
message like "Skipping NVIDIA PyTorch update because a custom torch mirror is
configured." and return; update references: maybeUpdateNvidiaTorch,
virtualEnvironment.isUsingCustomTorchMirror(), and keep the rest of the flow
unchanged so prompts and installs are skipped when a custom mirror is set.


private async updateTorchMirrorForRecommendedVersion(): Promise<string | undefined> {
let settings;
try {
settings = useComfySettings();
} catch (error) {
log.warn('Unable to access Comfy settings to update torch mirror.', error);
return undefined;
}

const currentMirror = settings.get('Comfy-Desktop.UV.TorchInstallMirror');
const updatedMirror = this.getRecommendedTorchMirror(currentMirror);
if (!updatedMirror || updatedMirror === currentMirror) return updatedMirror ?? currentMirror;

settings.set('Comfy-Desktop.UV.TorchInstallMirror', updatedMirror);
try {
await settings.saveSettings();
} catch (error) {
log.warn('Failed to persist torch mirror update.', error);
}

return updatedMirror;
}

private getRecommendedTorchMirror(mirror: string | undefined): string | undefined {
const defaultTorchMirror = String(TorchMirrorUrl.Default);
if (!mirror?.trim() || mirror === defaultTorchMirror) return TorchMirrorUrl.Cuda;

let parsedMirror: URL;
try {
parsedMirror = new URL(mirror);
} catch (error) {
log.warn('Unable to parse torch mirror URL for normalization.', error);
return mirror;
}

const path = parsedMirror.pathname;
if (!path.includes('/whl/')) return mirror;

let updatedPath = path;
const nightlyCudaPattern = /\/whl\/nightly\/cu\d+/i;
const cudaPattern = /\/whl\/cu\d+/i;

if (nightlyCudaPattern.test(updatedPath)) {
updatedPath = updatedPath.replace(nightlyCudaPattern, TORCH_MIRROR_NIGHTLY_CUDA_PATH);
} else if (cudaPattern.test(updatedPath)) {
updatedPath = updatedPath.replace(cudaPattern, TORCH_MIRROR_CUDA_PATH);
} else {
return mirror;
}

if (updatedPath === path) return mirror;
parsedMirror.pathname = updatedPath;

return parsedMirror.toString();
}

/**
* Warns the user if their NVIDIA driver is too old for the required CUDA build.
* @param installation The current installation.
Expand Down
3 changes: 3 additions & 0 deletions src/main-process/comfyInstallation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ export class ComfyInstallation {
pythonMirror: useComfySettings().get('Comfy-Desktop.UV.PythonInstallMirror'),
pypiMirror: useComfySettings().get('Comfy-Desktop.UV.PypiInstallMirror'),
torchMirror: useComfySettings().get('Comfy-Desktop.UV.TorchInstallMirror'),
torchUpdatePolicy: useDesktopConfig().get('torchUpdatePolicy'),
torchPinnedPackages: useDesktopConfig().get('torchPinnedPackages'),
torchUpdateDecisionVersion: useDesktopConfig().get('torchLastPromptedVersion'),
});
}

Expand Down
13 changes: 13 additions & 0 deletions src/store/desktopSettings.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { TorchPinnedPackages, TorchUpdatePolicy } from '../constants';
import type { GpuType, TorchDeviceType } from '../preload';

export type DesktopInstallState = 'started' | 'installed' | 'upgraded';
Expand Down Expand Up @@ -35,4 +36,16 @@ export type DesktopSettings = {
versionConsentedMetrics?: string;
/** Whether the user has generated an image successfully. */
hasGeneratedSuccessfully?: boolean;
/** How to handle NVIDIA PyTorch updates. */
torchUpdatePolicy?: TorchUpdatePolicy;
/** The pinned NVIDIA torch package versions when updates are disabled. */
torchPinnedPackages?: TorchPinnedPackages;
/** The recommended NVIDIA torch version tied to the current update decision. */
torchLastPromptedVersion?: string;
/** The recommended NVIDIA torch version whose update failure prompt is suppressed. */
torchUpdateFailureSilencedVersion?: string;
/** The recommended NVIDIA torch version recorded when we first detected an out-of-date torch install. */
torchOutOfDateRecommendedVersion?: string;
/** The torch package versions recorded when we first detected an out-of-date torch install. */
torchOutOfDatePackages?: TorchPinnedPackages;
};
Loading
Loading