File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed
Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -3717,8 +3717,24 @@ bool AdjointGenerator::handleKnownCallDerivatives(
37173717 IRBuilder<> BuilderZ (&call);
37183718 getForwardBuilder (BuilderZ);
37193719
3720- bool forceErase = Mode == DerivativeMode::ReverseModeGradient ||
3721- Mode == DerivativeMode::ForwardModeSplit;
3720+ bool backwardsShadow = false ;
3721+ bool forwardsShadow = true ;
3722+ for (auto pair : gutils->backwardsOnlyShadows ) {
3723+ if (pair.second .stores .count (&call)) {
3724+ backwardsShadow = true ;
3725+ forwardsShadow = pair.second .primalInitialize ;
3726+ if (auto inst = dyn_cast<Instruction>(pair.first ))
3727+ if (!forwardsShadow && pair.second .LI &&
3728+ pair.second .LI ->contains (inst->getParent ()))
3729+ backwardsShadow = false ;
3730+ }
3731+ }
3732+
3733+ bool forceErase =
3734+ !((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
3735+ (Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) ||
3736+ (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
3737+ (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow));
37223738
37233739 if (forceErase)
37243740 eraseIfUnused (call, /* erase*/ true , /* check*/ false );
You can’t perform that action at this time.
0 commit comments