Skip to content

Commit 48b0a40

Browse files
authored
Merge pull request #188 from TorchTrade/fix/187-fractional-rebalancing
Fix fractional position sizing rebalancing on repeated actions
2 parents a99530c + c0ab16a commit 48b0a40

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

tests/envs/offline/test_sequential.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,106 @@ def test_check_env_specs_passes(self, unified_env):
803803
from torchrl.envs.utils import check_env_specs
804804
check_env_specs(unified_env)
805805

806+
@pytest.mark.parametrize("action_levels,leverage,repeat_action_idx", [
807+
([0, 1], 1, 1), # Spot: repeat buy (long)
808+
([-1, 0, 1], 5, 2), # Futures: repeat long
809+
([-1, 0, 1], 5, 0), # Futures: repeat short
810+
], ids=["spot-long", "futures-long", "futures-short"])
811+
def test_repeated_action_does_not_rebalance(
812+
self, sample_ohlcv_df, action_levels, leverage, repeat_action_idx
813+
):
814+
"""Repeating the same action should hold, not rebalance.
815+
816+
Regression test for #187: fractional position sizing recalculated
817+
target from drifting portfolio_value, causing constant-leverage
818+
rebalancing (close_partial / increase) when the agent repeated
819+
the same action.
820+
"""
821+
config = SequentialTradingEnvConfig(
822+
action_levels=action_levels,
823+
leverage=leverage,
824+
initial_cash=1000,
825+
transaction_fee=0.0,
826+
slippage=0.0,
827+
time_frames=[TimeFrame(1, TimeFrameUnit.Minute)],
828+
window_sizes=[10],
829+
execute_on=TimeFrame(1, TimeFrameUnit.Minute),
830+
)
831+
env = SequentialTradingEnv(sample_ohlcv_df, config, simple_feature_fn)
832+
td = env.reset()
833+
834+
# Step 1: open position
835+
action_td = td.clone()
836+
action_td["action"] = torch.tensor(repeat_action_idx)
837+
td = env.step(action_td)
838+
839+
position_after_open = env.position.position_size
840+
assert position_after_open != 0, "Position should have opened"
841+
842+
# Steps 2-50: repeat same action — position size must not change
843+
trades_executed = 0
844+
for _ in range(49):
845+
action_td = td["next"].clone()
846+
action_td["action"] = torch.tensor(repeat_action_idx)
847+
td = env.step(action_td)
848+
if td["next"]["done"].item():
849+
break
850+
if env.position.position_size != position_after_open:
851+
trades_executed += 1
852+
853+
assert trades_executed == 0, (
854+
f"Repeating the same action should hold, not rebalance. "
855+
f"Position changed {trades_executed} times (issue #187)"
856+
)
857+
env.close()
858+
859+
@pytest.mark.parametrize("action_levels,leverage,open_idx,close_idx", [
860+
([0, 1], 1, 1, 0), # Spot: long then sell
861+
([-1, 0, 1], 5, 2, 1), # Futures: long then close
862+
([-1, 0, 1], 5, 0, 1), # Futures: short then close
863+
], ids=["spot-close", "futures-close-long", "futures-close-short"])
864+
def test_action_change_after_repeated_holds_still_executes(
865+
self, sample_ohlcv_df, action_levels, leverage, open_idx, close_idx
866+
):
867+
"""Changing action after repeated holds must still execute.
868+
869+
Regression test for #187: ensures the _prev_action_value guard
870+
does not accidentally lock agents into positions they cannot exit.
871+
"""
872+
config = SequentialTradingEnvConfig(
873+
action_levels=action_levels,
874+
leverage=leverage,
875+
initial_cash=1000,
876+
transaction_fee=0.0,
877+
slippage=0.0,
878+
time_frames=[TimeFrame(1, TimeFrameUnit.Minute)],
879+
window_sizes=[10],
880+
execute_on=TimeFrame(1, TimeFrameUnit.Minute),
881+
)
882+
env = SequentialTradingEnv(sample_ohlcv_df, config, simple_feature_fn)
883+
td = env.reset()
884+
885+
# Open position and repeat for 10 steps
886+
action_td = td.clone()
887+
action_td["action"] = torch.tensor(open_idx)
888+
td = env.step(action_td)
889+
assert env.position.position_size != 0, "Position should have opened"
890+
891+
for _ in range(10):
892+
action_td = td["next"].clone()
893+
action_td["action"] = torch.tensor(open_idx)
894+
td = env.step(action_td)
895+
896+
# Now close — must actually execute
897+
action_td = td["next"].clone()
898+
action_td["action"] = torch.tensor(close_idx)
899+
td = env.step(action_td)
900+
901+
assert env.position.position_size == 0, (
902+
"Position should have closed after action change (issue #187)"
903+
)
904+
env.close()
905+
806906

807907
# ============================================================================
808908
# PER-TIMEFRAME FEATURE PROCESSING TESTS (Issue #177)

torchtrade/envs/offline/sequential.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def _reset_position_state(self):
458458
self.unrealized_pnl = 0.0
459459
self.unrealized_pnl_pct = 0.0
460460
self.liquidation_price = 0.0
461+
self._prev_action_value = None
461462

462463
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
463464
"""Execute one environment step."""
@@ -600,6 +601,12 @@ def _execute_fractional_action(self, action_value: float, execution_price: float
600601
Returns:
601602
trade_info: Dict with execution details
602603
"""
604+
# If action hasn't changed and we already have a position, just hold
605+
if action_value == self._prev_action_value and self.position.position_size != 0:
606+
self.position.hold_counter += 1
607+
return {"executed": False, "side": None, "fee_paid": 0.0, "liquidated": False}
608+
self._prev_action_value = action_value
609+
603610
# Calculate target position from action value
604611
target_position_size, target_notional, target_side = (
605612
self._calculate_fractional_position(action_value, execution_price)

0 commit comments

Comments
 (0)