Skip to content

Commit d79eeff

Browse files
No public description
PiperOrigin-RevId: 862264383
1 parent cdd01c2 commit d79eeff

File tree

4 files changed

+26
-10
lines changed

4 files changed

+26
-10
lines changed

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/requirements.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ sudo apt-get install -y python3-venv python3-pip
3232
python3.10 -m venv myenv
3333
source myenv/bin/activate
3434

35+
echo "Activated python environment, installing dependencies."
36+
3537
pip install --no-cache-dir natsort absl-py opencv-python pandas pandas-gbq \
3638
google-cloud-bigquery google-auth trackpy google-cloud-storage \
3739
scikit-image scikit-learn webcolors==1.13 ffmpeg-python tritonclient[all] \

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/triton_server_inference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,18 @@ def _resize_mask_batch(
8181
) -> np.ndarray:
8282
"""Resizes a batch of masks to the target dimensions."""
8383
target_w, target_h = target_dims
84-
return cv2.resize(
85-
masks, (target_w, target_h), interpolation=cv2.INTER_NEAREST
84+
masks_transposed = np.transpose(masks, (1, 2, 0))
85+
86+
resized_batch = cv2.resize(
87+
masks_transposed, (target_w, target_h), interpolation=cv2.INTER_NEAREST
8688
)
8789

90+
# If N=1, cv2.resize might drop the last dim, so we ensure 3D
91+
if resized_batch.ndim == 2:
92+
return resized_batch[np.newaxis, ...]
93+
94+
return np.transpose(resized_batch, (2, 0, 1))
95+
8896
def _scale_bbox_and_masks(
8997
self, results: Dict[str, Any], target_dims: Tuple[int, int]
9098
) -> Dict[str, Any]:

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/triton_server_inference_test.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,12 @@ def test_box_cxcywh_to_xyxyn(self):
9797
expected = np.array([[0.4, 0.4, 0.6, 0.6], [0.0, 0.0, 1.0, 1.0]])
9898
np.testing.assert_allclose(result, expected)
9999

100-
@mock.patch(f"{MODULE_PATH}.Image.fromarray")
101-
def test_scale_bbox_and_masks(self, mock_fromarray):
100+
@mock.patch(f"{MODULE_PATH}.cv2.resize")
101+
def test_scale_bbox_and_masks(self, mock_cv2_resize):
102102
# Arrange
103-
mock_img_instance = mock.Mock()
104-
mock_fromarray.return_value = mock_img_instance
105-
mock_img_instance.resize.return_value = np.ones(
103+
mock_cv2_resize.return_value = np.ones(
106104
(20, 10)
107-
) # w=10, h=20 -> PIL resize returns (width, height) numpy array
105+
) # h=20, w=10. cv2 resize returns h, w
108106

109107
results = {
110108
"xyxy": np.array([[0.1, 0.1, 0.5, 0.5]]),
@@ -118,8 +116,14 @@ def test_scale_bbox_and_masks(self, mock_fromarray):
118116
# Assert
119117
np.testing.assert_allclose(scaled_results["xyxy"], [[1.0, 2.0, 5.0, 10.0]])
120118
self.assertEqual(scaled_results["masks"].shape, (1, 20, 10))
121-
mock_fromarray.assert_called_once_with(results["masks"][0])
122-
mock_img_instance.resize.assert_called_once_with((10, 20))
119+
self.assertEqual(scaled_results["masks"].dtype, bool)
120+
mock_cv2_resize.assert_called_once()
121+
args, kwargs = mock_cv2_resize.call_args
122+
np.testing.assert_array_equal(args[0], np.ones((5, 5, 1)))
123+
self.assertEqual(args[1], (10, 20))
124+
self.assertEqual(
125+
kwargs["interpolation"], triton_server_inference.cv2.INTER_NEAREST
126+
)
123127

124128
@mock.patch(f"{MODULE_PATH}.cv2.imread")
125129
@mock.patch(f"{MODULE_PATH}.cv2.cvtColor")

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/server/triton_inference_server.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ command -v screen >/dev/null 2>&1 || { \
3333
sudo apt update && sudo apt install -y screen; \
3434
}
3535

36+
echo "Starting Triton server in a screen session."
37+
3638
# Start Triton server
3739
screen -dmS server bash -c '
3840
sudo docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \

0 commit comments

Comments
 (0)