Skip to content

Commit 817b4d6

Browse files
authored
Fix custom zero segfault (#2672)
1 parent 9dda0a1 commit 817b4d6

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

enzyme/Enzyme/CallDerivatives.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,8 +3717,24 @@ bool AdjointGenerator::handleKnownCallDerivatives(
37173717
IRBuilder<> BuilderZ(&call);
37183718
getForwardBuilder(BuilderZ);
37193719

3720-
bool forceErase = Mode == DerivativeMode::ReverseModeGradient ||
3721-
Mode == DerivativeMode::ForwardModeSplit;
3720+
bool backwardsShadow = false;
3721+
bool forwardsShadow = true;
3722+
for (auto pair : gutils->backwardsOnlyShadows) {
3723+
if (pair.second.stores.count(&call)) {
3724+
backwardsShadow = true;
3725+
forwardsShadow = pair.second.primalInitialize;
3726+
if (auto inst = dyn_cast<Instruction>(pair.first))
3727+
if (!forwardsShadow && pair.second.LI &&
3728+
pair.second.LI->contains(inst->getParent()))
3729+
backwardsShadow = false;
3730+
}
3731+
}
3732+
3733+
bool forceErase =
3734+
!((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
3735+
(Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) ||
3736+
(Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
3737+
(Mode == DerivativeMode::ForwardModeSplit && backwardsShadow));
37223738

37233739
if (forceErase)
37243740
eraseIfUnused(call, /*erase*/ true, /*check*/ false);

0 commit comments

Comments
 (0)