Skip to content

Commit 7b9cce9

Browse files
committed
Prevent JSON file corruption with atomic writes - updated
1 parent 2aa811e commit 7b9cce9

File tree

2 files changed

+35
-61
lines changed

2 files changed

+35
-61
lines changed

botocore/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3583,17 +3583,15 @@ def __setitem__(self, cache_key, value):
35833583
temp_fd, temp_path = tempfile.mkstemp(
35843584
dir=self._working_dir, suffix='.tmp'
35853585
)
3586-
if hasattr(os, 'fchmod'):
3587-
os.fchmod(temp_fd, 0o600)
35883586
with os.fdopen(temp_fd, 'w') as f:
35893587
temp_fd = None
35903588
f.write(file_content)
35913589
f.flush()
35923590
os.fsync(f.fileno())
3593-
3591+
35943592
os.replace(temp_path, full_key)
35953593
temp_path = None
3596-
3594+
35973595
except Exception:
35983596
if temp_fd is not None:
35993597
try:

tests/unit/test_utils.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import shutil
1919
import tempfile
20+
import threading
2021
from contextlib import contextmanager
2122
from sys import getrefcount
2223

@@ -56,7 +57,7 @@
5657
from botocore.utils import (
5758
ArgumentGenerator,
5859
ArnParser,
59-
CachedProperty,
60+
CachedProperty,
6061
ContainerMetadataFetcher,
6162
IMDSRegionProvider,
6263
InstanceMetadataFetcher,
@@ -3684,6 +3685,7 @@ def test_get_token_from_environment_returns_none(
36843685
monkeypatch.delenv(env_var, raising=False)
36853686
assert get_token_from_environment(signing_name) is None
36863687

3688+
36873689
class TestJSONFileCacheAtomicWrites(unittest.TestCase):
36883690
"""Test atomic write operations in JSONFileCache."""
36893691

@@ -3696,94 +3698,68 @@ def tearDown(self):
36963698

36973699
@mock.patch('os.replace')
36983700
def test_uses_tempfile_and_replace_for_atomic_write(self, mock_replace):
3699-
37003701
self.cache['test_key'] = {'data': 'test_value'}
37013702
mock_replace.assert_called_once()
37023703

3703-
temp_path, final_path = mock_replace.call_args[0]
3704-
3705-
self.assertIn('.tmp', temp_path)
3706-
self.assertTrue(final_path.endswith('test_key.json'))
3704+
call_args = mock_replace.call_args[0]
3705+
temp_path = call_args[0]
37073706

3708-
def test_concurrent_writes_same_key(self):
3707+
assert '.tmp' in temp_path
3708+
3709+
def test_concurrent_writes_to_multiple_temp_files(self):
37093710
"""Test concurrent writes to same key don't cause corruption."""
3710-
import threading
3711-
3712-
key = 'concurrent_test'
37133711
errors = []
3714-
temp_files_used = []
3715-
original_mkstemp = tempfile.mkstemp
3716-
3717-
def track_temp_files(*args, **kwargs):
3718-
fd, path = original_mkstemp(*args, **kwargs)
3719-
temp_files_used.append(path)
3720-
return fd, path
37213712

37223713
def write_worker(thread_id):
37233714
try:
3715+
key = f'concurrent_test_{thread_id}'
37243716
for i in range(3):
37253717
self.cache[key] = {'thread': thread_id, 'iteration': i}
3726-
if os.name == 'nt':
3727-
time.sleep(0.01)
37283718
except Exception as e:
37293719
errors.append(f'Thread {thread_id}: {e}')
37303720

3731-
with mock.patch('tempfile.mkstemp', side_effect=track_temp_files):
3732-
threads = [
3733-
threading.Thread(target=write_worker, args=(i,))
3734-
for i in range(3)
3735-
]
3736-
3737-
for thread in threads:
3738-
thread.start()
3739-
for thread in threads:
3740-
thread.join()
3741-
3742-
# On Windows, file locking can cause expected write errors
3743-
# so we allow errors but ensure the key exists in cache.
3744-
if errors and os.name == 'nt':
3745-
print(f"Windows file locking warnings: {errors}")
3746-
self.assertIn(key, self.cache)
3747-
else:
3748-
self.assertEqual(len(errors), 0, f'Concurrent write errors: {errors}')
3749-
3750-
# Verify each write used a separate temporary file
3751-
self.assertEqual(len(temp_files_used), 9)
3752-
self.assertEqual(
3753-
len(set(temp_files_used)),
3754-
9,
3755-
'Concurrent writes should use separate temp files',
3756-
)
3757-
3758-
# Verify final data is valid
3759-
final_data = self.cache[key]
3760-
self.assertIsInstance(final_data, dict)
3761-
self.assertIn('thread', final_data)
3762-
self.assertIn('iteration', final_data)
3721+
threads = [
3722+
threading.Thread(target=write_worker, args=(i,)) for i in range(3)
3723+
]
3724+
3725+
for thread in threads:
3726+
thread.start()
3727+
for thread in threads:
3728+
thread.join()
3729+
3730+
self.assertEqual(len(errors), 0, f'Concurrent write errors: {errors}')
3731+
3732+
for thread_id in range(3):
3733+
key = f'concurrent_test_{thread_id}'
3734+
final_data = self.cache[key]
3735+
self.assertIsInstance(final_data, dict)
3736+
self.assertEqual(final_data['thread'], thread_id)
3737+
self.assertIn('thread', final_data)
3738+
self.assertIn('iteration', final_data)
37633739

37643740
def test_atomic_write_preserves_data_on_failure(self):
37653741
"""Test write failures don't corrupt existing data."""
37663742
key = 'atomic_test'
37673743
original_data = {'status': 'original'}
3768-
3744+
37693745
self.cache[key] = original_data
3770-
3746+
37713747
# Mock write failure
37723748
original_dumps = self.cache._dumps
37733749
self.cache._dumps = mock.Mock(side_effect=ValueError('Write failed'))
3774-
3750+
37753751
with self.assertRaises(ValueError):
37763752
self.cache[key] = {'status': 'should_fail'}
3777-
3753+
37783754
self.cache._dumps = original_dumps
3779-
3755+
37803756
# Verify original data intact
37813757
self.assertEqual(self.cache[key], original_data)
37823758

37833759
def test_no_temp_files_after_write(self):
37843760
"""Test temporary files cleaned up after writes."""
37853761
self.cache['test'] = {'data': 'value'}
3786-
3762+
37873763
temp_files = [
37883764
f for f in os.listdir(self.temp_dir) if f.endswith('.tmp')
37893765
]

0 commit comments

Comments
 (0)