-
Notifications
You must be signed in to change notification settings - Fork 187
Handle NVIDIA PyTorch updates #1525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8f76581
d9c89cf
290d8d4
4dbd3b9
c3290ee
273942b
6f59baf
ecbe8fd
a4b4925
29c7dbd
3c7b54f
6e96626
30388eb
a64ef0e
8393245
477fb95
14c9f7e
2b26142
c38e2b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,17 @@ import { promisify } from 'node:util'; | |
|
|
||
| import { strictIpcMain as ipcMain } from '@/infrastructure/ipcChannels'; | ||
|
|
||
| import { IPC_CHANNELS, InstallStage, ProgressStatus } from '../constants'; | ||
| import { useComfySettings } from '../config/comfySettings'; | ||
| import { | ||
| IPC_CHANNELS, | ||
| InstallStage, | ||
| NVIDIA_DRIVER_MIN_VERSION, | ||
| NVIDIA_TORCHVISION_VERSION, | ||
| NVIDIA_TORCH_RECOMMENDED_VERSION, | ||
| NVIDIA_TORCH_VERSION, | ||
| ProgressStatus, | ||
| TorchMirrorUrl, | ||
| } from '../constants'; | ||
| import { PythonImportVerificationError } from '../infrastructure/pythonImportVerificationError'; | ||
| import { useAppState } from '../main-process/appState'; | ||
| import type { AppWindow } from '../main-process/appWindow'; | ||
|
|
@@ -23,7 +33,8 @@ import { InstallWizard } from './installWizard'; | |
| import { Troubleshooting } from './troubleshooting'; | ||
|
|
||
| const execAsync = promisify(exec); | ||
| const NVIDIA_DRIVER_MIN_VERSION = '580'; | ||
| const TORCH_MIRROR_CUDA_PATH = new URL(TorchMirrorUrl.Cuda).pathname; | ||
| const TORCH_MIRROR_NIGHTLY_CUDA_PATH = new URL(TorchMirrorUrl.NightlyCuda).pathname; | ||
|
|
||
| /** | ||
| * Extracts the NVIDIA driver version from `nvidia-smi` output. | ||
|
|
@@ -91,7 +102,7 @@ export class InstallationManager implements HasTelemetry { | |
| // Convert from old format | ||
| if (state === 'upgraded') installation.upgradeConfig(); | ||
|
|
||
| // Install updated manager requirements | ||
| // Install updated requirements | ||
| if (installation.needsRequirementsUpdate) await this.updatePackages(installation); | ||
|
|
||
| // Resolve issues and re-run validation | ||
|
|
@@ -382,14 +393,180 @@ export class InstallationManager implements HasTelemetry { | |
| await installation.virtualEnvironment.installComfyUIRequirements(callbacks); | ||
| await installation.virtualEnvironment.installComfyUIManagerRequirements(callbacks); | ||
| await this.warnIfNvidiaDriverTooOld(installation); | ||
| await installation.virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks); | ||
| await this.maybeUpdateNvidiaTorch(installation, callbacks); | ||
| await installation.validate(); | ||
| } catch (error) { | ||
| log.error('Error auto-updating packages:', error); | ||
| await this.appWindow.loadPage('server-start'); | ||
| } | ||
| } | ||
|
|
||
| private async maybeUpdateNvidiaTorch(installation: ComfyInstallation, callbacks: ProcessCallbacks): Promise<void> { | ||
| const virtualEnvironment = installation.virtualEnvironment; | ||
| if (virtualEnvironment.selectedDevice !== 'nvidia') return; | ||
|
|
||
| const config = useDesktopConfig(); | ||
| const updatePolicy = config.get('torchUpdatePolicy'); | ||
| const recommendedVersion = NVIDIA_TORCH_RECOMMENDED_VERSION; | ||
| const lastPromptedVersion = config.get('torchLastPromptedVersion'); | ||
| if (updatePolicy === 'pinned' && lastPromptedVersion === recommendedVersion) { | ||
| log.info('Skipping NVIDIA PyTorch update because updates are pinned for this version.'); | ||
| return; | ||
| } | ||
|
|
||
| const installedVersions = await virtualEnvironment.getInstalledTorchPackageVersions(); | ||
| if (!installedVersions) { | ||
| log.warn('Skipping NVIDIA PyTorch update because installed versions could not be read.'); | ||
| return; | ||
| } | ||
|
|
||
| const isOutOfDate = await virtualEnvironment.isNvidiaTorchOutOfDate(installedVersions); | ||
| if (!isOutOfDate) return; | ||
|
|
||
| if (config.get('torchOutOfDateRecommendedVersion') !== recommendedVersion) { | ||
| config.set('torchOutOfDateRecommendedVersion', recommendedVersion); | ||
| config.set('torchOutOfDatePackages', installedVersions); | ||
| } | ||
|
|
||
| if (updatePolicy === 'defer' && lastPromptedVersion === recommendedVersion) { | ||
| log.info('Skipping NVIDIA PyTorch update because updates are deferred for this version.'); | ||
| return; | ||
| } | ||
|
|
||
| const updateApproved = updatePolicy === 'auto' && lastPromptedVersion === recommendedVersion; | ||
| let shouldAttemptUpdate = updateApproved; | ||
|
|
||
| if (!updateApproved) { | ||
| const currentTorch = installedVersions.torch ?? 'unknown'; | ||
| const currentTorchaudio = installedVersions.torchaudio ?? 'unknown'; | ||
| const currentTorchvision = installedVersions.torchvision ?? 'unknown'; | ||
|
|
||
| const { response } = await this.appWindow.showMessageBox({ | ||
| type: 'question', | ||
| title: 'Update PyTorch?', | ||
| message: | ||
| 'Your NVIDIA PyTorch build is out of date. We can update it to the recommended build for improved performance. This update may affect memory usage and compatibility with some custom nodes.', | ||
| detail: [ | ||
| `Current: torch ${currentTorch}, torchaudio ${currentTorchaudio}, torchvision ${currentTorchvision}`, | ||
| `Recommended: torch ${NVIDIA_TORCH_VERSION}, torchaudio ${NVIDIA_TORCH_VERSION}, torchvision ${NVIDIA_TORCHVISION_VERSION}`, | ||
| ].join('\n'), | ||
| buttons: ['Update PyTorch', 'Ask again later', 'Silence until next version', 'Silence forever'], | ||
| defaultId: 0, | ||
| cancelId: 1, | ||
| }); | ||
|
|
||
| switch (response) { | ||
| case 1: | ||
| log.info('Deferring NVIDIA PyTorch update prompt.'); | ||
| return; | ||
| case 2: | ||
| config.set('torchLastPromptedVersion', recommendedVersion); | ||
| config.set('torchUpdatePolicy', 'defer'); | ||
| config.delete('torchPinnedPackages'); | ||
| virtualEnvironment.updateTorchUpdatePolicy('defer', undefined, recommendedVersion); | ||
| return; | ||
| case 3: | ||
| config.set('torchLastPromptedVersion', recommendedVersion); | ||
| config.set('torchUpdatePolicy', 'pinned'); | ||
| config.set('torchPinnedPackages', installedVersions); | ||
| virtualEnvironment.updateTorchUpdatePolicy('pinned', installedVersions, recommendedVersion); | ||
| return; | ||
| default: | ||
| config.set('torchLastPromptedVersion', recommendedVersion); | ||
| config.set('torchUpdatePolicy', 'auto'); | ||
| config.delete('torchPinnedPackages'); | ||
| config.delete('torchUpdateFailureSilencedVersion'); | ||
| virtualEnvironment.updateTorchUpdatePolicy('auto', undefined, recommendedVersion); | ||
| shouldAttemptUpdate = true; | ||
| } | ||
| } else { | ||
| virtualEnvironment.updateTorchUpdatePolicy('auto', undefined, recommendedVersion); | ||
| } | ||
|
|
||
| if (!shouldAttemptUpdate) return; | ||
|
|
||
| const torchMirrorOverride = await this.updateTorchMirrorForRecommendedVersion(); | ||
|
|
||
| try { | ||
| await virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks, torchMirrorOverride); | ||
| config.delete('torchUpdateFailureSilencedVersion'); | ||
| } catch (error) { | ||
| log.error('Error updating NVIDIA PyTorch packages:', error); | ||
| if (config.get('torchUpdateFailureSilencedVersion') === recommendedVersion) return; | ||
|
|
||
| const { response } = await this.appWindow.showMessageBox({ | ||
| type: 'warning', | ||
| title: 'PyTorch update failed', | ||
| message: | ||
| 'We could not install the recommended NVIDIA PyTorch build. This may be because your configured torch mirror does not provide it.', | ||
| detail: 'We will retry the update on each startup.', | ||
| buttons: ['OK', "Don't show again"], | ||
| defaultId: 0, | ||
| cancelId: 0, | ||
| }); | ||
|
|
||
| if (response === 1) { | ||
| config.set('torchUpdateFailureSilencedVersion', recommendedVersion); | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
404
to
512
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: rg -n "isUsingCustomTorchMirror|ensureRecommendedNvidiaTorch|torchMirror" src/ -t tsRepository: Comfy-Org/desktop Length of output: 2404 🏁 Script executed: sed -n '694,740p' src/virtualEnvironment.tsRepository: Comfy-Org/desktop Length of output: 2001 🏁 Script executed: grep -n "isUsingCustomTorchMirror" src/install/installationManager.tsRepository: Comfy-Org/desktop Length of output: 43 🏁 Script executed: grep -rn "isUsingCustomTorchMirror" src/ -t tsRepository: Comfy-Org/desktop Length of output: 171 🏁 Script executed: rg -n "isUsingCustomTorchMirror" src/Repository: Comfy-Org/desktop Length of output: 130 Add early return for custom torch mirrors. The method Add the check after the device check: 🤖 Prompt for AI Agents |
||
|
|
||
| private async updateTorchMirrorForRecommendedVersion(): Promise<string | undefined> { | ||
| let settings; | ||
| try { | ||
| settings = useComfySettings(); | ||
| } catch (error) { | ||
| log.warn('Unable to access Comfy settings to update torch mirror.', error); | ||
| return undefined; | ||
| } | ||
|
|
||
| const currentMirror = settings.get('Comfy-Desktop.UV.TorchInstallMirror'); | ||
| const updatedMirror = this.getRecommendedTorchMirror(currentMirror); | ||
| if (!updatedMirror || updatedMirror === currentMirror) return updatedMirror ?? currentMirror; | ||
|
|
||
| settings.set('Comfy-Desktop.UV.TorchInstallMirror', updatedMirror); | ||
| try { | ||
| await settings.saveSettings(); | ||
| } catch (error) { | ||
| log.warn('Failed to persist torch mirror update.', error); | ||
| } | ||
|
|
||
| return updatedMirror; | ||
| } | ||
|
|
||
| private getRecommendedTorchMirror(mirror: string | undefined): string | undefined { | ||
| const defaultTorchMirror = String(TorchMirrorUrl.Default); | ||
| if (!mirror?.trim() || mirror === defaultTorchMirror) return TorchMirrorUrl.Cuda; | ||
|
|
||
| let parsedMirror: URL; | ||
| try { | ||
| parsedMirror = new URL(mirror); | ||
| } catch (error) { | ||
| log.warn('Unable to parse torch mirror URL for normalization.', error); | ||
| return mirror; | ||
| } | ||
|
|
||
| const path = parsedMirror.pathname; | ||
| if (!path.includes('/whl/')) return mirror; | ||
|
|
||
| let updatedPath = path; | ||
| const nightlyCudaPattern = /\/whl\/nightly\/cu\d+/i; | ||
| const cudaPattern = /\/whl\/cu\d+/i; | ||
|
|
||
| if (nightlyCudaPattern.test(updatedPath)) { | ||
| updatedPath = updatedPath.replace(nightlyCudaPattern, TORCH_MIRROR_NIGHTLY_CUDA_PATH); | ||
| } else if (cudaPattern.test(updatedPath)) { | ||
| updatedPath = updatedPath.replace(cudaPattern, TORCH_MIRROR_CUDA_PATH); | ||
| } else { | ||
| return mirror; | ||
| } | ||
|
|
||
| if (updatedPath === path) return mirror; | ||
| parsedMirror.pathname = updatedPath; | ||
|
|
||
| return parsedMirror.toString(); | ||
| } | ||
|
|
||
| /** | ||
| * Warns the user if their NVIDIA driver is too old for the required CUDA build. | ||
| * @param installation The current installation. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Add brief JSDoc for new torch policy types.
These are exported public types; a short description helps consumers and aligns with the repo’s documentation expectations. As per coding guidelines, please add concise JSDoc.
✍️ Suggested doc additions
📝 Committable suggestion
🤖 Prompt for AI Agents