Skip to content

Commit 47d5489

Browse files
author
shengtsui
authored
VTEN-19-Add-broadcast-to-tensor-assignment (#27)
1 parent c171251 commit 47d5489

File tree

5 files changed

+61
-9
lines changed

5 files changed

+61
-9
lines changed

docs/source/api/core/proxy_slice.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ vt::TensorSliceProxy
1010
.. doxygenstruct:: vt::NewAxisT
1111
:members:
1212

13+
.. doxygenenum:: vt::SliceType
14+
:members:
15+
1316
.. doxygenclass:: vt::Slice
1417
:members:
1518

docs/source/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ Comparison between NumPy and VTensor
5555
- vt::expand_dims_lhs<T, N, 2>(tensor)
5656
* - arr[..., None, None]
5757
- vt::expand_dims_rhs<T, N, 2>(tensor)
58+
* - arr[:, :]
59+
- tensor(vt::Slice::all(), vt::Slice::all())
5860

5961
.. list-table:: Broadcasting
6062
:header-rows: 1

lib/core/slice.hpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
#pragma once
22

33
#include "lib/core/assertions.hpp"
4+
#include "lib/core/broadcast.hpp"
45
#include "lib/core/tensor.hpp"
56

67
namespace vt {
78

9+
/**
10+
* @brief Slice type.
11+
*/
12+
enum class SliceType { Normal, All };
13+
814
/**
915
* @brief Slice class to represent a slice of a tensor.
1016
* Usage:
1117
* Slice(0) -> Slice from 0 to 1 with step 1
1218
* Slice(0, 10) -> Slice from 0 to 10 with step 1
1319
* Slice(0, 10, 2) -> Slice from 0 to 10 with step 2
20+
* Slice::all() -> Slice to represent the full range
1421
*/
1522
class Slice {
1623
public:
1724
size_t start;
1825
size_t end;
1926
size_t step;
27+
SliceType type = SliceType::Normal;
2028

2129
// Default constructor
2230
Slice() = default;
@@ -44,6 +52,20 @@ class Slice {
4452
* @param step: The step of the slice.
4553
*/
4654
Slice(size_t start, size_t end, size_t step) : start(start), end(end), step(step) {}
55+
56+
/**
57+
* @brief Construct a new Slice object
58+
*
59+
* @param type: The type of the slice.
60+
*/
61+
Slice(SliceType type) : type(type) {}
62+
63+
/**
64+
* @brief Return a slice object to represent the full range.
65+
*
66+
* @return Slice: The slice object.
67+
*/
68+
static Slice all() { return Slice(SliceType::All); }
4769
};
4870

4971
/**
@@ -99,8 +121,8 @@ class TensorSliceProxy : public Tensor<T, N> {
99121
*/
100122
TensorSliceProxy& operator=(const Tensor<T, N>& other) {
101123
assert_same_order_between_two_tensors(this->order(), other.order());
102-
assert(this->shape() == other.shape());
103-
thrust::copy(other.begin(), other.end(), this->begin());
124+
auto [lhs, rhs] = broadcast(*this, other);
125+
thrust::copy(rhs.begin(), rhs.end(), lhs.begin());
104126
return *this;
105127
}
106128

lib/core/tensor.hpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,16 @@ class Tensor {
417417
Shape<N> new_shape{};
418418
size_t offset = 0;
419419
for (size_t i = 0; i < N; i++) {
420-
assert(slices[i].end <= _shape[i]); // Check if the end of the slice is within the tensor dimension.
421-
assert(slices[i].start < slices[i].end); // Check if the start of the slice is less than the end.
422-
new_strides[i] = slices[i].step * _strides[i];
423-
new_shape[i] = (slices[i].end - slices[i].start + slices[i].step - 1) / slices[i].step;
424-
offset += slices[i].start * _strides[i];
420+
if (slices[i].type == SliceType::All) {
421+
new_strides[i] = _strides[i];
422+
new_shape[i] = _shape[i];
423+
} else {
424+
assert(slices[i].end <= _shape[i]); // Check if the end of the slice is within the tensor dimension.
425+
assert(slices[i].start < slices[i].end); // Check if the start of the slice is less than the end.
426+
new_strides[i] = slices[i].step * _strides[i];
427+
new_shape[i] = (slices[i].end - slices[i].start + slices[i].step - 1) / slices[i].step;
428+
offset += slices[i].start * _strides[i];
429+
}
425430
}
426431
size_t new_start = _start + offset;
427432
return Tensor<T, N>(_data, new_shape, new_strides, new_start, _order, false);
@@ -443,7 +448,13 @@ class Tensor {
443448
*
444449
* @return T*: The raw pointer of the tensor.
445450
*/
446-
T* raw_ptr() const { return thrust::raw_pointer_cast(_data->data()); }
451+
T* raw_ptr() const {
452+
if (_data == nullptr) {
453+
return nullptr;
454+
} else {
455+
return thrust::raw_pointer_cast(_data->data());
456+
}
457+
}
447458

448459
/**
449460
* @brief Return the contiguous flag of the tensor.

lib/core/tests/test_slice.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,18 @@ TEST(TensorCondProxyFromConstantF, BasicAssertions) {
135135
auto tensor = vt::arange(24, vt::Order::F)({0, 24, 2}).reshape(3, 4);
136136
tensor[tensor > 12.0f] = 1.0f;
137137
EXPECT_EQ(vt::asvector(tensor), (std::vector<float>{0, 2, 4, 6, 8, 10, 12, 1, 1, 1, 1, 1}));
138-
}
138+
}
139+
140+
TEST(BroadcastSliceAssignmentC, BasicAssertions) {
141+
auto tensor = vt::arange(12).reshape(4, 3);
142+
tensor(vt::Slice::all(), vt::Slice(0, 2, 1)) = tensor(vt::Slice::all(), vt::Slice(2, 3, 1));
143+
EXPECT_EQ(vt::asvector(tensor), (std::vector<float>{2, 2, 2, 5, 5, 5, 8, 8, 8, 11, 11, 11}));
144+
EXPECT_EQ(tensor.contiguous(), true);
145+
}
146+
147+
TEST(BroadcastSliceAssignmentF, BasicAssertions) {
148+
auto tensor = vt::arange(12, vt::Order::F).reshape(4, 3);
149+
tensor(vt::Slice::all(), vt::Slice(0, 2, 1)) = tensor(vt::Slice::all(), vt::Slice(2, 3, 1));
150+
EXPECT_EQ(vt::asvector(tensor), (std::vector<float>{8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11}));
151+
EXPECT_EQ(tensor.contiguous(), true);
152+
}

0 commit comments

Comments
 (0)