From 4dff960d31bebd65add3710c68ec95565d409879 Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 00:35:26 +1100 Subject: [PATCH 001/154] fixes gorgonia/gorgonia#189 by creating using an alloc-ful (as opposed to inplace) transpose. The old transpose has been placed behind a build tag --- .travis.yml | 2 +- defaultengine_matop_transpose.go | 396 ++++++++--------------- defaultengine_matop_transpose_inplace.go | 261 +++++++++++++++ dense.go | 2 + dense_format.go | 5 + example_extension_matop_test.go | 2 +- 6 files changed, 407 insertions(+), 261 deletions(-) create mode 100644 defaultengine_matop_transpose_inplace.go diff --git a/.travis.yml b/.travis.yml index 1f1604d..24da83a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ env: - GOARCH=amd64 - BLAS_LIB=OpenBLAS - TRAVISTEST=true - - CUDA=8.0.61-1 + - CUDA=9.1.85-1 before_install: - go get github.com/mattn/goveralls diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index d2beba6..e66c4a6 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -1,259 +1,137 @@ -package tensor - -import ( - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" -) - -func (e StdEng) Transpose(a Tensor, expStrides []int) error { - if !a.IsNativelyAccessible() { - return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") - } - if dt, ok := a.(DenseTensor); ok { - e.denseTranspose(dt, expStrides) - return nil - } - return errors.Errorf("Tranpose for tensor of %T not supported", a) -} - -func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { - if a.rtype() == String.Type { - e.denseTransposeString(a, expStrides) - return - } - - switch a.rtype().Size() { - case 1: - e.denseTranspose1(a, expStrides) - case 2: - e.denseTranspose2(a, expStrides) - case 4: - e.denseTranspose4(a, expStrides) - case 8: - e.denseTranspose8(a, expStrides) - default: - e.denseTransposeArbitrary(a, expStrides) - } -} - -func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp byte - var i int - - data := a.hdr().Uint8s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint16 - var i int - - data := a.hdr().Uint16s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint32 - var i int - - data := a.hdr().Uint32s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint64 - var i int - - data := a.hdr().Uint64s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp string - var i int - - data := a.hdr().Strings() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = "" - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - rtype := a.rtype() - typeSize := int(rtype.Size()) - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - saved := make([]byte, typeSize, typeSize) - tmp := make([]byte, typeSize, typeSize) - var i int - - data := storage.AsByteSlice(a.hdr(), rtype) - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - start := typeSize * i - - if track.IsSet(i) && track.IsSet(dest) { - copy(data[start:start+typeSize], saved) - for i := range saved { - saved[i] = 0 - } - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - copy(tmp, data[start:start+typeSize]) - copy(data[start:start+typeSize], saved) - saved = tmp - - i = dest - } -} +// +build !inplacetranspose + +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e StdEng) Transpose(a Tensor, expStrides []int) error { + if !a.IsNativelyAccessible() { + return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") + } + if dt, ok := a.(DenseTensor); ok { + e.denseTranspose(dt, expStrides) + return nil + } + return errors.Errorf("Tranpose for tensor of %T not supported", a) +} + +func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { + if a.rtype() == String.Type { + e.denseTransposeString(a, expStrides) + return + } + + switch a.rtype().Size() { + case 1: + e.denseTranspose1(a, expStrides) + case 2: + e.denseTranspose2(a, expStrides) + case 4: + e.denseTranspose4(a, expStrides) + case 8: + e.denseTranspose8(a, expStrides) + default: + e.denseTransposeArbitrary(a, expStrides) + } +} + +func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u8s := tmpArr.Uint8s() + + orig := a.hdr().Uint8s() + it := NewFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u8s[j] = orig[i] + j++ + } + copy(orig, u8s) +} + +func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u16s := tmpArr.Uint16s() + + orig := a.hdr().Uint16s() + it := NewFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u16s[j] = orig[i] + j++ + } + copy(orig, u16s) +} + +func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u32s := tmpArr.Uint32s() + + orig := a.hdr().Uint32s() + it := NewFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u32s[j] = orig[i] + j++ + } + copy(orig, u32s) +} + +func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + u64s := tmpArr.Uint64s() + + orig := a.hdr().Uint64s() + it := NewFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + u64s[j] = orig[i] + j++ + } + copy(orig, u64s) +} + +func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + strs := tmpArr.Strings() + + orig := a.hdr().Strings() + it := NewFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + strs[j] = orig[i] + j++ + } + copy(orig, strs) +} + +func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { + rtype := a.rtype() + typeSize := int(rtype.Size()) + var tmpArr array + e.makeArray(&tmpArr, a.Dtype(), a.Size()) + // arbs := storage.AsByteSlice(tmpArr.hdr(), rtype) + arbs := tmpArr.byteSlice() + + orig := storage.AsByteSlice(a.hdr(), rtype) + it := NewFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + srcStart := i * typeSize + srcEnd := srcStart + typeSize + dstStart := j * typeSize + dstEnd := dstStart + typeSize + + copy(arbs[dstStart:dstEnd], orig[srcStart:srcEnd]) + j++ + } + copy(orig, arbs) +} diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go new file mode 100644 index 0000000..09f49b7 --- /dev/null +++ b/defaultengine_matop_transpose_inplace.go @@ -0,0 +1,261 @@ +// +build inplacetranspose + +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e StdEng) Transpose(a Tensor, expStrides []int) error { + if !a.IsNativelyAccessible() { + return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") + } + if dt, ok := a.(DenseTensor); ok { + e.denseTranspose(dt, expStrides) + return nil + } + return errors.Errorf("Tranpose for tensor of %T not supported", a) +} + +func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { + if a.rtype() == String.Type { + e.denseTransposeString(a, expStrides) + return + } + + switch a.rtype().Size() { + case 1: + e.denseTranspose1(a, expStrides) + case 2: + e.denseTranspose2(a, expStrides) + case 4: + e.denseTranspose4(a, expStrides) + case 8: + e.denseTranspose8(a, expStrides) + default: + e.denseTransposeArbitrary(a, expStrides) + } +} + +func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp byte + var i int + + data := a.hdr().Uint8s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint16 + var i int + + data := a.hdr().Uint16s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint32 + var i int + + data := a.hdr().Uint32s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint64 + var i int + + data := a.hdr().Uint64s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp string + var i int + + data := a.hdr().Strings() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = "" + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + rtype := a.rtype() + typeSize := int(rtype.Size()) + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + saved := make([]byte, typeSize, typeSize) + tmp := make([]byte, typeSize, typeSize) + var i int + + data := storage.AsByteSlice(a.hdr(), rtype) + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + start := typeSize * i + + if track.IsSet(i) && track.IsSet(dest) { + copy(data[start:start+typeSize], saved) + for i := range saved { + saved[i] = 0 + } + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + copy(tmp, data[start:start+typeSize]) + copy(data[start:start+typeSize], saved) + saved = tmp + + i = dest + } +} diff --git a/dense.go b/dense.go index 3152912..6a026b6 100644 --- a/dense.go +++ b/dense.go @@ -2,6 +2,7 @@ package tensor import ( "fmt" + "log" "unsafe" "github.com/pkg/errors" @@ -128,6 +129,7 @@ func (t *Dense) Reshape(dims ...int) error { } if t.old != nil { + log.Println("Transposing") t.Transpose() } diff --git a/dense_format.go b/dense_format.go index ab4cfef..9d31994 100644 --- a/dense_format.go +++ b/dense_format.go @@ -248,6 +248,11 @@ func (f *fmtState) writeVElision() { // // Special care also needs be taken for the verb 's' - it prints a super compressed version of the tensor, only printing 4 cols and 4 rows. func (t *Dense) Format(s fmt.State, c rune) { + if c == 'i' { + fmt.Fprintf(s, "INFO:\n\tAP: %v\n\tOLD: %v\n\tTRANS %v\n\t", t.AP, t.old, t.transposeWith) + return + } + f := newFmtState(s, c) if t.IsScalar() { o := f.originalFmt() diff --git a/example_extension_matop_test.go b/example_extension_matop_test.go index 1d9e856..0855b77 100644 --- a/example_extension_matop_test.go +++ b/example_extension_matop_test.go @@ -50,7 +50,7 @@ func Example_TransposeExtension() { // // After: // ⎡{a: 0, b: 0, c: 0, d: 0, e: 0} {a: 2, b: 2, c: 2, d: 2, e: 2}⎤ - // ⎣{a: 2, b: 2, c: 2, d: 2, e: 2} {a: 3, b: 3, c: 3, d: 3, e: 3}⎦ + // ⎣{a: 1, b: 1, c: 1, d: 1, e: 1} {a: 3, b: 3, c: 3, d: 3, e: 3}⎦ } func Example_stackExtension() { From 1ec1189bd68e17f483f9ab3a62c1dde17611ab6c Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 00:40:44 +1100 Subject: [PATCH 002/154] Removed debug related traces --- dense.go | 1 - 1 file changed, 1 deletion(-) diff --git a/dense.go b/dense.go index 6a026b6..18705b9 100644 --- a/dense.go +++ b/dense.go @@ -129,7 +129,6 @@ func (t *Dense) Reshape(dims ...int) error { } if t.old != nil { - log.Println("Transposing") t.Transpose() } From 0e4c58492405e5568ff975d03894223f3129cb58 Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 00:41:35 +1100 Subject: [PATCH 003/154] wtf gofmt stopped working --- dense.go | 1 - 1 file changed, 1 deletion(-) diff --git a/dense.go b/dense.go index 18705b9..3152912 100644 --- a/dense.go +++ b/dense.go @@ -2,7 +2,6 @@ package tensor import ( "fmt" - "log" "unsafe" "github.com/pkg/errors" From 51c3d81d1d4cc8ea6c653e288877a6a939e31e06 Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 09:30:16 +1100 Subject: [PATCH 004/154] added some minor optimizations based on profiling --- array.go | 11 +---------- dense.go | 1 + iterator.go | 14 ++++++++++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/array.go b/array.go index 62fa8a2..4162280 100644 --- a/array.go +++ b/array.go @@ -321,16 +321,7 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { panic("Cannot copy Dense arrays of different types") } - // do not use requiresIterator because requiresIterator has particular optimizations for operations (like add, sub etc) - var dstOK, srcOK bool - if dstView, ok := dst.(View); ok && dstView.IsMaterializable() { - srcOK = true - } - if srcView, ok := src.(View); ok && srcView.IsMaterializable() { - dstOK = true - } - - if !dstOK && !srcOK { + if !dst.RequiresIterator() && !src.RequiresIterator() { return copyDense(dst, src), nil } diff --git a/dense.go b/dense.go index 3152912..95e3f53 100644 --- a/dense.go +++ b/dense.go @@ -584,6 +584,7 @@ func (t *Dense) slice(start, end int) { t.array = t.array.slice(start, end) } +// RequiresIterator indicates if an iterator is required to read the data in *Dense in the correct fashion func (t *Dense) RequiresIterator() bool { if t.len() == 1 { return false diff --git a/iterator.go b/iterator.go index 0801748..c1605e4 100644 --- a/iterator.go +++ b/iterator.go @@ -126,6 +126,9 @@ type FlatIterator struct { track []int done bool reverse bool // if true, iterator starts at end of array and runs backwards + + isScalar bool + isVector bool } // NewFlatIterator creates a new FlatIterator. @@ -142,6 +145,9 @@ func NewFlatIterator(ap *AP) *FlatIterator { track: make([]int, len(ap.shape)), size: ap.shape.TotalSize(), strides0: strides0, + + isScalar: ap.IsScalar(), + isVector: ap.IsVector(), } } @@ -182,10 +188,10 @@ func (it *FlatIterator) Next() (int, error) { } switch { - case it.IsScalar(): + case it.isScalar: it.done = true return 0, nil - case it.IsVector(): + case it.isVector: if it.reverse { return it.singlePrevious() } @@ -212,10 +218,10 @@ func (it *FlatIterator) NextValid() (int, int, error) { return -1, 1, noopError{} } switch { - case it.IsScalar(): + case it.isScalar: it.done = true return 0, 0, nil - case it.IsVector(): + case it.isVector: if it.reverse { a, err := it.singlePrevious() return a, -1, err From 787b5248475879bfbc8f141fcd834ae9634d752f Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 11:09:19 +1100 Subject: [PATCH 005/154] Resolved subtle bug in inplace transpose Added tests to test for all the things --- .travis.yml | 1 + .travis/test.sh | 16 + defaultengine_matop_transpose_inplace.go | 521 +++++++++++------------ 3 files changed, 277 insertions(+), 261 deletions(-) create mode 100644 .travis/test.sh diff --git a/.travis.yml b/.travis.yml index 24da83a..7c0fa86 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,6 +24,7 @@ before_install: go_import_path: gorgonia.org/tensor script: + source ${TRAVIS_BUILD_DIR}/.travis/test.sh - $HOME/gopath/bin/goveralls -service=travis-ci -package=gorgonia.org/tensor -covermode=atomic matrix: diff --git a/.travis/test.sh b/.travis/test.sh new file mode 100644 index 0000000..2e00d07 --- /dev/null +++ b/.travis/test.sh @@ -0,0 +1,16 @@ +set -ex + +go env + +go test -v -a -covermode=atomic -coverprofile=test.cover . +go test -tags='avx' -a -covermode=atomic -coverprofile=avx.cover . +go test -tags='sse' -a -covermode=atomic -coverprofile=sse.cover . +go test -tags='inplacetranspose' -a -covermode=atomic -coverprofile=inplacetranspose.cover . + +# because coveralls only accepts one coverage file at one time... we combine them into one gigantic one +covers=(./test.cover ./avx.cover ./sse.cover ./inplacetranspose.cover) +echo "mode: set" > ./final.cover +tail -q -n +2 "${covers[@]}" >> ./final.cover +goveralls -coverprofile=./final.cover -service=travis-ci + +set +ex \ No newline at end of file diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index 09f49b7..6725aea 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -1,261 +1,260 @@ -// +build inplacetranspose - -package tensor - -import ( - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" -) - -func (e StdEng) Transpose(a Tensor, expStrides []int) error { - if !a.IsNativelyAccessible() { - return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") - } - if dt, ok := a.(DenseTensor); ok { - e.denseTranspose(dt, expStrides) - return nil - } - return errors.Errorf("Tranpose for tensor of %T not supported", a) -} - -func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { - if a.rtype() == String.Type { - e.denseTransposeString(a, expStrides) - return - } - - switch a.rtype().Size() { - case 1: - e.denseTranspose1(a, expStrides) - case 2: - e.denseTranspose2(a, expStrides) - case 4: - e.denseTranspose4(a, expStrides) - case 8: - e.denseTranspose8(a, expStrides) - default: - e.denseTransposeArbitrary(a, expStrides) - } -} - -func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp byte - var i int - - data := a.hdr().Uint8s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint16 - var i int - - data := a.hdr().Uint16s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint32 - var i int - - data := a.hdr().Uint32s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp uint64 - var i int - - data := a.hdr().Uint64s() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = 0 - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - var saved, tmp string - var i int - - data := a.hdr().Strings() - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - - if track.IsSet(i) && track.IsSet(dest) { - data[i] = saved - saved = "" - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - tmp = data[i] - data[i] = saved - saved = tmp - - i = dest - } -} - -func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { - axes := a.transposeAxes() - size := a.len() - rtype := a.rtype() - typeSize := int(rtype.Size()) - - // first we'll create a bit-map to track which elements have been moved to their correct places - track := NewBitMap(size) - track.Set(0) - track.Set(size - 1) // first and last element of a transposedon't change - - saved := make([]byte, typeSize, typeSize) - tmp := make([]byte, typeSize, typeSize) - var i int - - data := storage.AsByteSlice(a.hdr(), rtype) - for i = 1; ; { - dest := a.transposeIndex(i, axes, expStrides) - start := typeSize * i - - if track.IsSet(i) && track.IsSet(dest) { - copy(data[start:start+typeSize], saved) - for i := range saved { - saved[i] = 0 - } - for i < size && track.IsSet(i) { - i++ - } - if i >= size { - break - } - continue - } - track.Set(i) - copy(tmp, data[start:start+typeSize]) - copy(data[start:start+typeSize], saved) - saved = tmp - - i = dest - } -} +// +build inplacetranspose + +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e StdEng) Transpose(a Tensor, expStrides []int) error { + if !a.IsNativelyAccessible() { + return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") + } + if dt, ok := a.(DenseTensor); ok { + e.denseTranspose(dt, expStrides) + return nil + } + return errors.Errorf("Tranpose for tensor of %T not supported", a) +} + +func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { + if a.rtype() == String.Type { + e.denseTransposeString(a, expStrides) + return + } + + switch a.rtype().Size() { + case 1: + e.denseTranspose1(a, expStrides) + case 2: + e.denseTranspose2(a, expStrides) + case 4: + e.denseTranspose4(a, expStrides) + case 8: + e.denseTranspose8(a, expStrides) + default: + e.denseTransposeArbitrary(a, expStrides) + } +} + +func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp byte + var i int + + data := a.hdr().Uint8s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint16 + var i int + + data := a.hdr().Uint16s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint32 + var i int + + data := a.hdr().Uint32s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp uint64 + var i int + + data := a.hdr().Uint64s() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = 0 + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + // log.Printf("i: %d start %d, end %d | tmp %v saved %v", i, start, end, tmp, saved) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + var saved, tmp string + var i int + + data := a.hdr().Strings() + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + + if track.IsSet(i) && track.IsSet(dest) { + data[i] = saved + saved = "" + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + tmp = data[i] + data[i] = saved + saved = tmp + + i = dest + } +} + +func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { + axes := a.transposeAxes() + size := a.len() + rtype := a.rtype() + typeSize := int(rtype.Size()) + + // first we'll create a bit-map to track which elements have been moved to their correct places + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) // first and last element of a transposedon't change + + saved := make([]byte, typeSize, typeSize) + tmp := make([]byte, typeSize, typeSize) + var i int + data := storage.AsByteSlice(a.hdr(), rtype) + for i = 1; ; { + dest := a.transposeIndex(i, axes, expStrides) + start := typeSize * i + end := start + typeSize + + if track.IsSet(i) && track.IsSet(dest) { + copy(data[start:end], saved) + for i := range saved { + saved[i] = 0 + } + for i < size && track.IsSet(i) { + i++ + } + if i >= size { + break + } + continue + } + track.Set(i) + copy(tmp, data[start:end]) + copy(data[start:end], saved) + copy(saved, tmp) + i = dest + } +} From ff9e67b0c7201ce91671a70376e09f863f00471e Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 12:18:31 +1100 Subject: [PATCH 006/154] added @stuartcarnie's optimization --- CONTRIBUTORS.md | 1 + iterator.go | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9e473c5..57adb1d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -3,6 +3,7 @@ * Xuanyi Chew (@chewxy) - initial package * Naseer Dari (@ndari) - errors and error handling * Joe Kabaka (@kabaka0) - masked array functionality +* Stuart Carnie (@stuartcarnie) - performance optimization for iterators # Contributors diff --git a/iterator.go b/iterator.go index c1605e4..d4cc3e3 100644 --- a/iterator.go +++ b/iterator.go @@ -298,20 +298,36 @@ func (it *FlatIterator) singlePrevious() (int, error) { } func (it *FlatIterator) ndNext() (int, error) { - it.lastIndex = it.nextIndex - for i := len(it.shape) - 1; i >= 0; i-- { - it.track[i]++ - if it.track[i] == it.shape[i] { + // the reason for this weird looking bits of code is because the SSA compiler doesn't + // know how to optimize for this bit of code, not keeping things in registers correctly + // @stuartcarnie optimized this iout to great effect + + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + // the following 3 lines causes the compiler to perform bounds check here, + // instead of being done in the loop + coord := it.shape[:v+1] + track := it.track[:v+1] + strides := it.strides[:v+1] + for i := v; i >= 0; i-- { + track[i]++ + shapeI := coord[i] + strideI := strides[i] + + if track[i] == shapeI { if i == 0 { it.done = true } - it.track[i] = 0 - it.nextIndex -= (it.shape[i] - 1) * it.strides[i] + track[i] = 0 + nextIndex -= (shapeI - 1) * strideI continue } - it.nextIndex += it.strides[i] + nextIndex += strideI break } + it.nextIndex = nextIndex return it.lastIndex, nil } From 31339cb797ce5ef9e3485eee4711ace3dd902436 Mon Sep 17 00:00:00 2001 From: chewxy Date: Sun, 21 Jan 2018 12:23:21 +1100 Subject: [PATCH 007/154] Added benchmark for iterator --- benchmark_dense_matop_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/benchmark_dense_matop_test.go b/benchmark_dense_matop_test.go index 15a77ab..2c5b8a7 100644 --- a/benchmark_dense_matop_test.go +++ b/benchmark_dense_matop_test.go @@ -69,3 +69,24 @@ func BenchmarkGetWithIterator(b *testing.B) { } _ = f } + +func BenchmarkComplicatedGet(b *testing.B) { + T := New(WithShape(101, 1, 36, 5), Of(Float64)) + T.T(0, 2, 1, 3) + data := T.Data().([]float64) + var f float64 + b.ResetTimer() + for i := 0; i < b.N; i++ { + it := IteratorFromDense(T) + var next int + + var err error + for next, err = it.Start(); err == nil; next, err = it.Next() { + f = data[next] + } + if _, ok := err.(NoOpError); !ok { + b.Error("Error: %v", err) + } + } + _ = f +} From 1ca1fbc9472bc1287683ad1a368ee9471a5a7a2c Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 28 Feb 2018 11:52:52 +1100 Subject: [PATCH 008/154] Pulled iterator_native from the upcoming v0.9.0 update. --- iterator_native.go | 1151 +++++++++++++++++++++++++++++++++++++++ iterator_native_test.go | 585 ++++++++++++++++++++ 2 files changed, 1736 insertions(+) create mode 100644 iterator_native.go create mode 100644 iterator_native_test.go diff --git a/iterator_native.go b/iterator_native.go new file mode 100644 index 0000000..9801e63 --- /dev/null +++ b/iterator_native.go @@ -0,0 +1,1151 @@ +// Code generated by genlib2. DO NOT EDIT. + +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" +) + +func checkNativeIterable(t *Dense, dims int, dt Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.shape.Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.DataOrder().isColMajor() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} + +/* Native Iterables for bool */ + +// NativeVectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// NativeMatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]bool)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorB converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorB(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]bool)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for int */ + +// NativeVectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// NativeMatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]int)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorI converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorI(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]int)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for int8 */ + +// NativeVectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// NativeMatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]int8)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorI8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorI8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]int8)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for int16 */ + +// NativeVectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// NativeMatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]int16)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorI16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorI16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]int16)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for int32 */ + +// NativeVectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// NativeMatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]int32)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorI32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorI32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]int32)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for int64 */ + +// NativeVectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// NativeMatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]int64)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorI64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorI64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]int64)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for uint */ + +// NativeVectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// NativeMatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]uint)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorU converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorU(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]uint)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for uint8 */ + +// NativeVectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// NativeMatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]uint8)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorU8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorU8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]uint8)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for uint16 */ + +// NativeVectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// NativeMatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]uint16)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorU16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorU16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]uint16)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for uint32 */ + +// NativeVectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// NativeMatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]uint32)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorU32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorU32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]uint32)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for uint64 */ + +// NativeVectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// NativeMatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]uint64)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorU64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorU64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]uint64)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for float32 */ + +// NativeVectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// NativeMatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]float32)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorF32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorF32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]float32)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for float64 */ + +// NativeVectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// NativeMatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]float64)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorF64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorF64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]float64)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for complex64 */ + +// NativeVectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// NativeMatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]complex64)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorC64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorC64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]complex64)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for complex128 */ + +// NativeVectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// NativeMatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]complex128)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorC128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorC128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]complex128)(unsafe.Pointer(hdr)) + } + } + return +} + +/* Native Iterables for string */ + +// NativeVectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func NativeVectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// NativeMatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func NativeMatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]string)(unsafe.Pointer(hdr)) + } + return +} + +// Native3TensorStr converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Native3TensorStr(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]string)(unsafe.Pointer(hdr)) + } + } + return +} diff --git a/iterator_native_test.go b/iterator_native_test.go new file mode 100644 index 0000000..d2c1724 --- /dev/null +++ b/iterator_native_test.go @@ -0,0 +1,585 @@ +// Code generated by genlib2. DO NOT EDIT. + +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NativeVectorB(t *testing.T) { + assert := assert.New(t) + T := New(Of(Bool), WithShape(6)) + it, err := NativeVectorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixB(t *testing.T) { + assert := assert.New(t) + T := New(Of(Bool), WithShape(2, 3)) + it, err := NativeMatrixB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorB(t *testing.T) { + assert := assert.New(t) + T := New(Of(Bool), WithShape(2, 3, 4)) + it, err := Native3TensorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorI(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int, 0, 6)), WithShape(6)) + it, err := NativeVectorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixI(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorI(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorI8(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) + it, err := NativeVectorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixI8(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorI8(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorI16(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) + it, err := NativeVectorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixI16(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorI16(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorI32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) + it, err := NativeVectorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixI32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorI32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorI64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) + it, err := NativeVectorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixI64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorI64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorU(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) + it, err := NativeVectorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixU(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorU(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorU8(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) + it, err := NativeVectorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixU8(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorU8(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorU16(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) + it, err := NativeVectorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixU16(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorU16(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorU32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) + it, err := NativeVectorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixU32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorU32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorU64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) + it, err := NativeVectorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixU64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorU64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorF32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) + it, err := NativeVectorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixF32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorF32(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorF64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) + it, err := NativeVectorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixF64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorF64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorC64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) + it, err := NativeVectorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixC64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorC64(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorC128(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) + it, err := NativeVectorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixC128(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) + it, err := NativeMatrixC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorC128(t *testing.T) { + assert := assert.New(t) + T := New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) + it, err := Native3TensorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_NativeVectorStr(t *testing.T) { + assert := assert.New(t) + T := New(Of(String), WithShape(6)) + it, err := NativeVectorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_NativeMatrixStr(t *testing.T) { + assert := assert.New(t) + T := New(Of(String), WithShape(2, 3)) + it, err := NativeMatrixStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func TestNative3TensorStr(t *testing.T) { + assert := assert.New(t) + T := New(Of(String), WithShape(2, 3, 4)) + it, err := Native3TensorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} From a7ab578330499437eb69d30474b10744b906a87e Mon Sep 17 00:00:00 2001 From: Jim Walker Date: Mon, 5 Mar 2018 15:54:18 -0700 Subject: [PATCH 009/154] update godoc tag --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 69d8b06..62b3fbd 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Package `tensor` [![GoDoc](https://godoc.org/github.com/chewxy/gorgonia/tensor?status.svg)](https://godoc.org/github.com/chewxy/gorgonia/tensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) # +# Package `tensor` [![GoDoc](https://godoc.org/github.com/gorgonia/tensor?status.svg)](https://godoc.org/github.com/gorgonia/tensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) # Package `tensor` is a package that provides efficient, generic (by some definitions of generic) n-dimensional arrays in Go. Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations. The main purpose of this package is to support the operations required by [Gorgonia](https://github.com/chewxy/gorgonia). From 73439fee57e93b272f4315f5a2cde306419efe3c Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 20 Mar 2018 09:49:22 +1100 Subject: [PATCH 010/154] Native iterators go into their own package now Couldn't cherry pick commits from v0.9.0 to move over, so decided to checkout individual files instead --- genlib2/agg2_body.go | 8 +- genlib2/cmp_tests.go | 8 +- genlib2/dense_compat.go | 10 +- genlib2/dense_cons.go | 2 +- genlib2/dense_getset.go | 4 +- genlib2/dense_io.go | 443 ++++++--- genlib2/main.go | 37 +- genlib2/native_iterator.go | 180 ++++ genlib2/native_select.go | 142 +++ genlib2/package.go | 5 + iterator_native_test.go | 585 ------------ native/doc.go | 8 + native/example_test.go | 80 ++ .../iterator_native.go | 199 +++-- native/iterator_native2.go | 635 +++++++++++++ native/iterator_native2_test.go | 842 ++++++++++++++++++ native/iterator_native_test.go | 634 +++++++++++++ 17 files changed, 2999 insertions(+), 823 deletions(-) create mode 100644 genlib2/native_iterator.go create mode 100644 genlib2/native_select.go delete mode 100644 iterator_native_test.go create mode 100644 native/doc.go create mode 100644 native/example_test.go rename iterator_native.go => native/iterator_native.go (80%) create mode 100644 native/iterator_native2.go create mode 100644 native/iterator_native2_test.go create mode 100644 native/iterator_native_test.go diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 81141fc..6c1716c 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -5,7 +5,7 @@ import "text/template" // level 2 aggregation (tensor.StdEng) templates const cmpPrepRaw = `var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), false, opts...); err != nil{ + if reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(),false, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -14,7 +14,7 @@ const cmpPrepRaw = `var safe, same bool ` const arithPrepRaw = `var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), true, opts...); err != nil{ + if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } ` @@ -69,7 +69,7 @@ const prepUnaryRaw = `if err = unaryCheck(a, {{.TypeClassCheck | lower}}Types); } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -235,7 +235,7 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created {{if not .VV -}} // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ,dataReuse,dataA) diff --git a/genlib2/cmp_tests.go b/genlib2/cmp_tests.go index e0189e7..8d3d8f6 100644 --- a/genlib2/cmp_tests.go +++ b/genlib2/cmp_tests.go @@ -100,7 +100,7 @@ const transitivityBodyRaw = `transFn := func(q *Dense) bool { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for {{.Name}} failed: %v", err) + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) } ` @@ -146,7 +146,7 @@ const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for {{.Name}} failed: %v", err) + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) } ` @@ -182,7 +182,7 @@ const symmetryBodyRaw = `symFn := func(q *Dense) bool { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for {{.Name}} failed: %v", err) + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) } ` @@ -216,7 +216,7 @@ const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for {{.Name}} failed: %v", err) + t.Errorf("Symmetry test for {{.Name}} failed: %v", err) } ` diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 2c96721..fb353d0 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -194,14 +194,12 @@ func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { // checks: if !t.IsNativelyAccessible() { - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") } if !t.IsMatrix() { // error - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) } fo := ParseFuncOpts(opts...) @@ -220,7 +218,7 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { case !t.IsMaterializable(): data = convToFloat64s(t) default: - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) var next int for next, err = it.Next(); err == nil; next, err = it.Next() { if err = handleNoOp(err); err != nil { @@ -235,6 +233,8 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { retVal = mat.NewDense(r, c, data) return } + + ` var ( diff --git a/genlib2/dense_cons.go b/genlib2/dense_cons.go index 031e73a..fee0df5 100644 --- a/genlib2/dense_cons.go +++ b/genlib2/dense_cons.go @@ -68,7 +68,7 @@ func I(dt Dtype, r, c, k int) *Dense{ panic(err) } var nexts []int - iter := NewFlatIterator(s.AP) + iter := newFlatIterator(&s.AP) nexts, err = iter.Slice(rs{i, s.Size(), c + 1}) switch s.t.Kind() { diff --git a/genlib2/dense_getset.go b/genlib2/dense_getset.go index 4d2b415..2f3a38b 100644 --- a/genlib2/dense_getset.go +++ b/genlib2/dense_getset.go @@ -47,10 +47,10 @@ const copyIterRaw = `func copyDenseIter(dest, src *Dense, diter, siter *FlatIter } if diter == nil { - diter = NewFlatIterator(dest.AP) + diter = newFlatIterator(&dest.AP) } if siter == nil { - siter = NewFlatIterator(src.AP) + siter = newFlatIterator(&src.AP) } isMasked:= src.IsMasked() diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index e6e98fe..2cad2d2 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -9,21 +9,45 @@ import ( const writeNpyRaw = ` type binaryWriter struct { io.Writer - error + err error seq int } -func (w binaryWriter) w(x interface{}) { - if w.error != nil { +func (w *binaryWriter) w(x interface{}) { + if w.err != nil { return } - binary.Write(w, binary.LittleEndian, x) + w.err = binary.Write(w, binary.LittleEndian, x) w.seq++ } -func (w binaryWriter) Error() string { - return fmt.Sprintf("Error at sequence %d : %v", w.seq, w.error.Error()) +func (w *binaryWriter) Err() error { + if w.err == nil { + return nil + } + return errors.Wrapf(w.err, "Sequence %d", w.seq) +} + +type binaryReader struct { + io.Reader + err error + seq int +} + +func (r *binaryReader) Read(data interface{}) { + if r.err != nil { + return + } + r.err = binary.Read(r.Reader, binary.LittleEndian, data) + r.seq++ +} + +func (r *binaryReader) Err() error { + if r.err == nil { + return nil + } + return errors.Wrapf(r.err, "Sequence %d", r.seq) } // WriteNpy writes the *Tensor as a numpy compatible serialized file. @@ -54,8 +78,8 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { bw.w(byte(1)) // major version bw.w(byte(0)) // minor version bw.w(uint16(len(header))) // 4 bytes to denote header length - if bw.error != nil { - return bw + if err = bw.Err() ; err != nil { + return err } bw.Write([]byte(header)) @@ -76,10 +100,7 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { } } - if bw.error != nil { - return bw - } - return nil + return bw.Err() } ` @@ -203,7 +224,7 @@ func (t *Dense) GobDecode(p []byte) (err error){ } } - t.AP = NewAP(shape, strides) + t.AP.Init(shape, strides) t.AP.o = o t.AP.Δ = tr @@ -216,84 +237,69 @@ func (t *Dense) GobDecode(p []byte) (err error){ if err = decoder.Decode(&data); err != nil { return } + t.fromSlice(data) t.addMask(mask) t.fix() + if t.e == nil { + t.e = StdEng{} + } return t.sanity() } ` +const npyDescRE = `var npyDescRE = regexp.MustCompile(` + "`" + `'descr':` + `\` + `s*'([^']*)'` + "`" + ")" +const rowOrderRE = `var rowOrderRE = regexp.MustCompile(` + "`" + `'fortran_order':\s*(False|True)` + "`)" +const shapeRE = `var shapeRE = regexp.MustCompile(` + "`" + `'shape':\s*\(([^\(]*)\)` + "`)" const readNpyRaw = `// ReadNpy reads NumPy formatted files into a *Dense func (t *Dense) ReadNpy(r io.Reader) (err error){ + br := binaryReader{Reader: r} var magic [6]byte - if _, err = r.Read(magic[:]); err != nil { - return - } - if string(magic[:]) != "\x93NUMPY" { - err = errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) - return + if br.Read(magic[:]); string(magic[:]) != "\x93NUMPY" { + return errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) } - var version byte - if err = binary.Read(r, binary.LittleEndian, &version); err != nil { - return - } - if version != 1 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + var version, minor byte + if br.Read(&version); version != 1 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } - var minor byte - if err = binary.Read(r, binary.LittleEndian, &minor); err != nil { - return - } - if minor != 0 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + if br.Read(&minor); minor != 0 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } var headerLen uint16 - if err = binary.Read(r, binary.LittleEndian, &headerLen); err != nil { - return - } - + br.Read(&headerLen) header := make([]byte, int(headerLen)) - if _, err = r.Read(header); err != nil { + br.Read(header) + if err = br.Err(); err != nil { return } - desc := regexp.MustCompile(` + "`'descr':" + `\s` + "*'([^']*)'`" + `) - match := desc.FindSubmatch(header) - if match == nil { - err = errors.New("No dtype information in npy file") - return + // extract stuff from header + var match [][]byte + if match = npyDescRE.FindSubmatch(header); match == nil { + return errors.New("No dtype information in npy file") } // TODO: check for endianness. For now we assume everything is little endian - var dt Dtype - if dt, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { return } - t.t = dt - rowOrder := regexp.MustCompile(` + "`'fortran_order':" + `\s` + "*(False|True)`" + `) - match = rowOrder.FindSubmatch(header) - if match == nil { - err = errors.New("No Row Order information found in the numpy file") - return + if match = rowOrderRE.FindSubmatch(header); match == nil { + return errors.New("No Row Order information found in the numpy file") } if string(match[1]) != "False" { - err = errors.New("Cannot yet read from Fortran Ordered Numpy files") - return + return errors.New("Cannot yet read from Fortran Ordered Numpy files") } - shpRe := regexp.MustCompile(` + "`'shape':" + `\s*\(([^\(]*)\)` + "`" + `) - match = shpRe.FindSubmatch(header) - if match == nil { - err = errors.New("No shape information found in npy file") - return + if match = shapeRE.FindSubmatch(header); match == nil { + return errors.New("No shape information found in npy file") } sizesStr := strings.Split(string(match[1]), ",") + + var shape Shape for _, s := range sizesStr { s = strings.Trim(s, " ") @@ -311,7 +317,6 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ if t.e == nil { t.e = StdEng{} } - t.makeArray(size) switch t.t.Kind() { @@ -319,21 +324,24 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ case reflect.{{reflectKind .}}: data := t.{{sliceOf .}} for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil{ - return - } + br.Read(&data[i]) } {{end -}} } - t.AP = BorrowAP(len(shape)) + if err = br.Err(); err != nil { + return err + } + + t.AP.zeroWithDims(len(shape)) t.setShape(shape...) t.fix() return t.sanity() } ` -const readCSVRaw = `// convFromStrs conversts a []string to a slice of the Dtype provided -func convFromStrs(to Dtype, record []string) (interface{}, error) { +const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. +// If into is nil, then a backing slice will be created. +func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { {{range .Kinds -}} @@ -341,6 +349,13 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { {{if isOrd . -}} case reflect.{{reflectKind .}}: retVal := make([]{{asType .}}, len(record)) + var backing []{{asType .}} + if into == nil { + backing = make([]{{asType .}}, 0, len(record)) + }else{ + backing = into.([]{{asType .}}) + } + for i, v := range record { {{if eq .String "float64" -}} if retVal[i], err = strconv.ParseFloat(v, 64); err != nil { @@ -366,10 +381,20 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { retVal[i] = {{asType .}}(u) {{end -}} } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil {{end -}} {{end -}} {{end -}} + case reflect.String: + var backing []string + if into == nil { + backing = make([]string, 0, len(record)) + }else{ + backing = into.([]string) + } + backing = append(backing, record...) + return backing, nil default: return nil,errors.Errorf(methodNYI, "convFromStrs", to) } @@ -388,62 +413,223 @@ func (t *Dense) ReadCSV(r io.Reader, opts ...FuncOpt) (err error) { cr := csv.NewReader(r) var record []string - var row interface{} var rows, cols int + var backing interface{} + for { + record, err = cr.Read() + if err == io.EOF{ + break + } else if err != nil { + return + } + if backing, err = convFromStrs(as, record, backing); err != nil { + return + } + cols = len(record) + rows++ + } + t.fromSlice(backing) + t.AP.zero() + t.AP.SetShape(rows, cols) + return nil + return errors.Errorf("not yet handled") +} +` - switch as.Kind() { - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - case reflect.{{reflectKind .}}: - var backing []{{asType .}} - for { - record, err = cr.Read() - if err == io.EOF{ - break - } - - if err != nil { - return - } +var fbEncodeDecodeRaw = `// FBEncode encodes to a byte slice using flatbuffers. +// +// Only natively accessible data can be encided +func (t *Dense) FBEncode() ([]byte, error) { + builder := flatbuffers.NewBuilder(1024) + + fb.DenseStartShapeVector(builder, len(t.shape)) + for i := len(t.shape) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.shape[i])) + } + shape := builder.EndVector(len(t.shape)) + + fb.DenseStartStridesVector(builder, len(t.strides)) + for i := len(t.strides) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.strides[i])) + } + strides := builder.EndVector(len(t.strides)) + + var o uint32 + switch { + case t.o.isRowMajor() && t.o.isContiguous(): + o = 0 + case t.o.isRowMajor() && !t.o.isContiguous(): + o = 1 + case t.o.isColMajor() && t.o.isContiguous(): + o = 2 + case t.o.isColMajor() && !t.o.isContiguous(): + o = 3 + } + + var triangle int32 + switch t.Δ { + case NotTriangle: + triangle = fb.TriangleNOT_TRIANGLE + case Upper: + triangle = fb.TriangleUPPER + case Lower: + triangle = fb.TriangleLOWER + case Symmetric: + triangle = fb.TriangleSYMMETRIC + } + + dt := builder.CreateString(t.Dtype().String()) + data := t.byteSlice() + + fb.DenseStartDataVector(builder, len(data)) + for i := len(data) - 1; i >= 0; i-- { + builder.PrependUint8(data[i]) + } + databyte := builder.EndVector(len(data)) + + fb.DenseStart(builder) + fb.DenseAddShape(builder, shape) + fb.DenseAddStrides(builder, strides) + fb.DenseAddO(builder, o) + fb.DenseAddT(builder, triangle) + fb.DenseAddType(builder, dt) + fb.DenseAddData(builder, databyte) + serialized := fb.DenseEnd(builder) + builder.Finish(serialized) + + return builder.FinishedBytes(), nil +} - if row, err = convFromStrs({{asType . | strip | title}}, record); err != nil { - return - } - backing = append(backing, row.([]{{asType .}})...) - cols = len(record) - rows++ +// FBDecode decodes a byteslice from a flatbuffer table into a *Dense +func (t *Dense) FBDecode(buf []byte) error { + serialized := fb.GetRootAsDense(buf, 0) + + o := serialized.O() + switch o { + case 0: + t.o = 0 + case 1: + t.o = MakeDataOrder(NonContiguous) + case 2: + t.o = MakeDataOrder(ColMajor) + case 3: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + + tri := serialized.T() + switch tri { + case fb.TriangleNOT_TRIANGLE: + t.Δ = NotTriangle + case fb.TriangleUPPER: + t.Δ = Upper + case fb.TriangleLOWER: + t.Δ = Lower + case fb.TriangleSYMMETRIC: + t.Δ = Symmetric + } + + t.shape = Shape(BorrowInts(serialized.ShapeLength())) + for i := 0; i < serialized.ShapeLength(); i++ { + t.shape[i] = int(int32(serialized.Shape(i))) + } + + t.strides = BorrowInts(serialized.StridesLength()) + for i := 0; i < serialized.ShapeLength(); i++ { + t.strides[i] = int(serialized.Strides(i)) + } + typ := string(serialized.Type()) + for _, dt := range allTypes.set { + if dt.String() == typ { + t.t = dt + break } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - {{end -}} - {{end -}} - {{end -}} - case reflect.String: - var backing []string - for { - record, err = cr.Read() - if err == io.EOF{ - break - } + } - if err != nil { - return - } - backing = append(backing, record...) - cols = len(record) - rows++ + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, serialized.DataBytes()) + t.forcefix() + return t.sanity() +} +` + +var pbEncodeDecodeRaw = `// PBEncode encodes the Dense into a protobuf byte slice. +func (t *Dense) PBEncode() ([]byte, error) { + var toSerialize pb.Dense + toSerialize.Shape = make([]int32, len(t.shape)) + for i, v := range t.shape { + toSerialize.Shape[i] = int32(v) + } + toSerialize.Strides = make([]int32, len(t.strides)) + for i, v := range t.strides { + toSerialize.Strides[i] = int32(v) + } + + switch { + case t.o.isRowMajor() && t.o.isContiguous(): + toSerialize.O = pb.RowMajorContiguous + case t.o.isRowMajor() && !t.o.isContiguous(): + toSerialize.O = pb.RowMajorNonContiguous + case t.o.isColMajor() && t.o.isContiguous(): + toSerialize.O = pb.ColMajorContiguous + case t.o.isColMajor() && !t.o.isContiguous(): + toSerialize.O = pb.ColMajorNonContiguous + } + toSerialize.T = pb.Triangle(t.Δ) + toSerialize.Type = t.t.String() + data := t.byteSlice() + toSerialize.Data = make([]byte, len(data)) + copy(toSerialize.Data, data) + return toSerialize.Marshal() +} + +// PBDecode unmarshalls a protobuf byteslice into a *Dense. +func (t *Dense) PBDecode(buf []byte) error { + var toSerialize pb.Dense + if err := toSerialize.Unmarshal(buf); err != nil { + return err + } + t.shape = make(Shape, len(toSerialize.Shape)) + for i, v := range toSerialize.Shape { + t.shape[i] = int(v) + } + t.strides = make([]int, len(toSerialize.Strides)) + for i, v := range toSerialize.Strides { + t.strides[i] = int(v) + } + + switch toSerialize.O { + case pb.RowMajorContiguous: + case pb.RowMajorNonContiguous: + t.o = MakeDataOrder(NonContiguous) + case pb.ColMajorContiguous: + t.o = MakeDataOrder(ColMajor) + case pb.ColMajorNonContiguous: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + t.Δ = Triangle(toSerialize.T) + typ := string(toSerialize.Type) + for _, dt := range allTypes.set { + if dt.String() == typ { + t.t = dt + break } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - default: - return errors.Errorf("%v not yet handled", as) } - return errors.Errorf("not yet handled") + + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, toSerialize.Data) + return t.sanity() } ` @@ -464,15 +650,30 @@ func init() { func generateDenseIO(f io.Writer, generic Kinds) { mk := Kinds{Kinds: filter(generic.Kinds, isNumber)} - // writes + fmt.Fprintln(f, "/* GOB SERIALIZATION */\n") + gobEncode.Execute(f, mk) + gobDecode.Execute(f, mk) + fmt.Fprint(f, "\n") + + fmt.Fprintln(f, "/* NPY SERIALIZATION */\n") + fmt.Fprintln(f, npyDescRE) + fmt.Fprintln(f, rowOrderRE) + fmt.Fprintln(f, shapeRE) fmt.Fprintln(f, writeNpyRaw) + readNpy.Execute(f, mk) fmt.Fprint(f, "\n") + + fmt.Fprintln(f, "/* CSV SERIALIZATION */\n") fmt.Fprintln(f, writeCSVRaw) + readCSV.Execute(f, mk) + fmt.Fprint(f, "\n") + + fmt.Fprintln(f, "/* FB SERIALIZATION */\n") + fmt.Fprintln(f, fbEncodeDecodeRaw) + fmt.Fprint(f, "\n") + + fmt.Fprintln(f, "/* PB SERIALIZATION */\n") + fmt.Fprintln(f, pbEncodeDecodeRaw) fmt.Fprint(f, "\n") - gobEncode.Execute(f, mk) - // reads - readNpy.Execute(f, mk) - gobDecode.Execute(f, mk) - readCSV.Execute(f, mk) } diff --git a/genlib2/main.go b/genlib2/main.go index 7207327..fafd74c 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -6,16 +6,18 @@ import ( "log" "os" "os/exec" + "os/user" "path" "path/filepath" "reflect" + "runtime" "strings" ) const genmsg = "Code generated by genlib2. DO NOT EDIT." var ( - gopath, tensorPkgLoc, execLoc, storageLoc string + gopath, tensorPkgLoc, nativePkgLoc, execLoc, storageLoc string ) type Kinds struct { @@ -24,7 +26,24 @@ type Kinds struct { func init() { gopath = os.Getenv("GOPATH") + + // now that go can have a default gopath, this checks that path + if gopath == "" { + usr, err := user.Current() + if err != nil { + log.Fatal(err) + } + gopath = path.Join(usr.HomeDir, "go") + stat, err := os.Stat(gopath) + if err != nil { + log.Fatal(err) + } + if !stat.IsDir() { + log.Fatal("You need to define a $GOPATH") + } + } tensorPkgLoc = path.Join(gopath, "src/gorgonia.org/tensor") + nativePkgLoc = path.Join(gopath, "src/gorgonia.org/tensor/native") execLoc = path.Join(gopath, "src/gorgonia.org/tensor/internal/execution") storageLoc = path.Join(gopath, "src/gorgonia.org/tensor/internal/storage") } @@ -93,6 +112,12 @@ func main() { pipeline(tensorPkgLoc, "api_unary_generated_test.go", Kinds{allKinds}, generateAPIUnaryTests) pipeline(tensorPkgLoc, "api_cmp_generated_test.go", Kinds{allKinds}, generateAPICmpTests, generateAPICmpMixedTests) pipeline(tensorPkgLoc, "dense_cmp_test.go", Kinds{allKinds}, generateDenseMethodCmpTests, generateDenseMethodCmpMixedTests) + + // native iterators + pipeline(nativePkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators) + pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests) + pipeline(nativePkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect) + pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests) } func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) { @@ -115,7 +140,12 @@ func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) log.Fatalf("Go imports failed with %v for %q", err, fullpath) } - cmd = exec.Command("sed", "-i", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) + // account for differences in the postix from the linux sed + if runtime.GOOS == "darwin" || strings.HasSuffix(runtime.GOOS, "bsd") { + cmd = exec.Command("sed", "-i", "", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) + } else { + cmd = exec.Command("sed", "-E", "-i", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) + } if err = cmd.Run(); err != nil { if err.Error() != "exit status 4" { // exit status 4 == not found log.Fatalf("sed failed with %v for %q", err.Error(), fullpath) @@ -136,6 +166,9 @@ func pregenerate() error { if err := cleanup(execLoc); err != nil { return err } + if err := cleanup(nativePkgLoc); err != nil { + return err + } return cleanup(tensorPkgLoc) } diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go new file mode 100644 index 0000000..5b586e9 --- /dev/null +++ b/genlib2/native_iterator.go @@ -0,0 +1,180 @@ +package main + +import ( + "fmt" + "io" + "text/template" +) + +const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} +` + +const nativeIterRaw = `// Vector{{short .}} converts a *Dense into a []{{asType .}} +// If the *Dense does not represent a vector of the wanted type, it will return an error. +func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) { + if err = checkNativeIterable(t, 1, {{reflectKind .}}); err != nil { + return nil, err + } + return t.{{sliceOf .}}, nil +} + +// Matrix{{short .}} converts a *Dense into a [][]{{asType .}} +// If the *Dense does not represent a matrix of the wanted type, it will return an error. +func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { + if err = checkNativeIterable(t, 2, {{reflectKind .}}); err != nil { + return nil, err + } + + data := t.{{sliceOf .}} + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]{{asType .}}, rows) + for i := range retVal { + start := i * rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i] = *(*[]{{asType .}})(unsafe.Pointer(hdr)) + } + return +} + +// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { + if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil { + return nil, err + } + + data := t.{{sliceOf .}} + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]{{asType .}}, layers) + for i := range retVal { + retVal[i] = make([][]{{asType .}}, rows) + for j := range retVal[i] { + start := i*layerStride + j*rowStride + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[start])), + Len: cols, + Cap: cols, + } + retVal[i][j] = *(*[]{{asType .}})(unsafe.Pointer(hdr)) + } + } + return +} +` + +const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + {{if isRangeable . -}} + T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(6)) + {{else -}} + T = New(Of({{reflectKind .}}), WithShape(6)) + {{end -}} + it, err := Vector{{short .}}(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_Matrix{{short .}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + {{if isRangeable . -}} + T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(2, 3)) + {{else -}} + T = New(Of({{reflectKind .}}), WithShape(2, 3)) + {{end -}} + it, err := Matrix{{short .}}(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3{{short .}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + {{if isRangeable . -}} + T = New(WithBacking(Range({{reflectKind .}}, 0, 24)), WithShape(2, 3, 4)) + {{else -}} + T = New(Of({{reflectKind .}}), WithShape(2, 3, 4)) + {{end -}} + it, err := Tensor3{{short .}}(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} +` + +var ( + NativeIter *template.Template + NativeIterTest *template.Template +) + +func init() { + NativeIter = template.Must(template.New("NativeIter").Funcs(funcs).Parse(nativeIterRaw)) + NativeIterTest = template.Must(template.New("NativeIterTest").Funcs(funcs).Parse(nativeIterTestRaw)) +} + +func generateNativeIterators(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + fmt.Fprintf(f, "%v\n", checkNativeiterable) + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) + NativeIter.Execute(f, k) + fmt.Fprint(f, "\n\n") + } +} + +func generateNativeIteratorTests(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeIterTest.Execute(f, k) + fmt.Fprint(f, "\n\n") + } +} diff --git a/genlib2/native_select.go b/genlib2/native_select.go new file mode 100644 index 0000000..a386eaa --- /dev/null +++ b/genlib2/native_select.go @@ -0,0 +1,142 @@ +package main + +import ( + "fmt" + "io" + "text/template" +) + +const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} +` +const nativeSelectRaw = `// Select{{short .}} creates a slice of flat data types. See Example of NativeSelectF64. +func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) { + if err := checkNativeSelectable(t, axis, {{reflectKind .}}); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]{{asType .}}, 1) + retVal[0] = t.{{sliceOf .}} + case 2: + if axis == 0 { + return Matrix{{short .}}(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.{{sliceOf .}} + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]{{asType .}}, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]{{asType .}})(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} +` +const nativeSelectTestRaw = `func TestSelect{{short .}}(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]{{asType .}} + T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) + if x, err = Select{{short .}}(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) + if x, err = Select{{short .}}(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) + if x, err = Select{{short .}}(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of({{reflectKind .}}), WithShape(2, 3), ) + if x, err = Select{{short .}}(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of({{reflectKind .}}), WithShape(2, 3), ) + if x, err = Select{{short .}}(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar({{if eq .String "bool" -}}false{{else if eq .String "string" -}}""{{else -}}{{asType .}}(0) {{end -}} )) + if x, err = Select{{short .}}(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = Select{{short .}}(T, 10); err == nil{ + t.Fatal("Expected errors") + } +} +` + +var ( + NativeSelect *template.Template + NativeSelectTest *template.Template +) + +func init() { + NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw)) + NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw)) +} + +func generateNativeSelect(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + fmt.Fprintf(f, "%v\n", checkNativeSelectable) + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) + NativeSelect.Execute(f, k) + fmt.Fprint(f, "\n\n") + } +} + +func generateNativeSelectTests(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeSelectTest.Execute(f, k) + fmt.Fprint(f, "\n\n") + } +} diff --git a/genlib2/package.go b/genlib2/package.go index e78e7c0..4380b6b 100644 --- a/genlib2/package.go +++ b/genlib2/package.go @@ -9,6 +9,8 @@ func writePkgName(f io.Writer, pkg string) { switch pkg { case tensorPkgLoc: fmt.Fprintf(f, "// %s\n\npackage tensor\n\n", genmsg) + case nativePkgLoc: + fmt.Fprintf(f, "// %s\n\npackage native\n\n", genmsg) case execLoc: fmt.Fprintf(f, "// %s\n\npackage execution\n\n", genmsg) case storageLoc: @@ -17,3 +19,6 @@ func writePkgName(f io.Writer, pkg string) { fmt.Fprintf(f, "// %s\n\npackage unknown\n\n", genmsg) } } + +const importUnqualifiedTensor = `import . "gorgonia.org/tensor" +` diff --git a/iterator_native_test.go b/iterator_native_test.go deleted file mode 100644 index d2c1724..0000000 --- a/iterator_native_test.go +++ /dev/null @@ -1,585 +0,0 @@ -// Code generated by genlib2. DO NOT EDIT. - -package tensor - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_NativeVectorB(t *testing.T) { - assert := assert.New(t) - T := New(Of(Bool), WithShape(6)) - it, err := NativeVectorB(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixB(t *testing.T) { - assert := assert.New(t) - T := New(Of(Bool), WithShape(2, 3)) - it, err := NativeMatrixB(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorB(t *testing.T) { - assert := assert.New(t) - T := New(Of(Bool), WithShape(2, 3, 4)) - it, err := Native3TensorB(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorI(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int, 0, 6)), WithShape(6)) - it, err := NativeVectorI(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixI(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixI(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorI(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorI(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorI8(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) - it, err := NativeVectorI8(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixI8(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixI8(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorI8(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorI8(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorI16(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) - it, err := NativeVectorI16(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixI16(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixI16(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorI16(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorI16(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorI32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) - it, err := NativeVectorI32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixI32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixI32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorI32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorI32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorI64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) - it, err := NativeVectorI64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixI64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixI64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorI64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorI64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorU(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) - it, err := NativeVectorU(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixU(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixU(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorU(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorU(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorU8(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) - it, err := NativeVectorU8(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixU8(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixU8(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorU8(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorU8(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorU16(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) - it, err := NativeVectorU16(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixU16(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixU16(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorU16(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorU16(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorU32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) - it, err := NativeVectorU32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixU32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixU32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorU32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorU32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorU64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) - it, err := NativeVectorU64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixU64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixU64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorU64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorU64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorF32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) - it, err := NativeVectorF32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixF32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixF32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorF32(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorF32(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorF64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) - it, err := NativeVectorF64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixF64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixF64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorF64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorF64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorC64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) - it, err := NativeVectorC64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixC64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixC64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorC64(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorC64(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorC128(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) - it, err := NativeVectorC128(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixC128(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) - it, err := NativeMatrixC128(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorC128(t *testing.T) { - assert := assert.New(t) - T := New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) - it, err := Native3TensorC128(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} - -func Test_NativeVectorStr(t *testing.T) { - assert := assert.New(t) - T := New(Of(String), WithShape(6)) - it, err := NativeVectorStr(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(6, len(it)) -} - -func Test_NativeMatrixStr(t *testing.T) { - assert := assert.New(t) - T := New(Of(String), WithShape(2, 3)) - it, err := NativeMatrixStr(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) -} - -func TestNative3TensorStr(t *testing.T) { - assert := assert.New(t) - T := New(Of(String), WithShape(2, 3, 4)) - it, err := Native3TensorStr(T) - if err != nil { - t.Fatal(err) - } - - assert.Equal(2, len(it)) - assert.Equal(3, len(it[0])) - assert.Equal(4, len(it[0][0])) -} diff --git a/native/doc.go b/native/doc.go new file mode 100644 index 0000000..516fbe2 --- /dev/null +++ b/native/doc.go @@ -0,0 +1,8 @@ +// package native is a utility package for gorgonia.org/tensor. +// +// Amongst other things, it provides iterators that use Go slice semantics, while keeping a reference to the underlying memory. +// This means you can update the slices and the changes will be reflected back into the original tensor. +// +// There is of course a cost of using the native iterators and selectors - allocation costs. +// For best performance, don't use these in a tight loop. +package native diff --git a/native/example_test.go b/native/example_test.go new file mode 100644 index 0000000..94b324a --- /dev/null +++ b/native/example_test.go @@ -0,0 +1,80 @@ +package native + +import ( + "fmt" + + . "gorgonia.org/tensor" +) + +// There are times where it is more effective to use native Go slice semantics to do work (for example, when performing batch work over kernels) +// NativeIterators are useful for this purpose. +func Example_iterator() { + var T *Dense + T = New(WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) + x, err := MatrixF64(T) + if err != nil { + fmt.Printf("ERR: %v", err) + } + + for _, row := range x { + fmt.Printf("%v\n", row) + } + + // Output: + // [0 1 2] + // [3 4 5] +} + +// The NativeSelect function squashes the dimensions, and returns an iterator in native Go slice semantics. +func Exampleselect() { + // Selection is a bit of an interesting use case. Sometimes you don't want to iterate through the layers. + // + // For example, in a number of use cases where you have a 4-Tensor, you'd typically reshape it to some + // 2D matrix which can then be plugged into BLAS algorithms directly. Sometimes you wouldn't need to reshape. + // All you have to do is squash the dimensions inwards. This function does that. + // + // The best way to explain the Select functions is through concrete examples. + // Imagine a tensor with (2,3,4,5) shape. Arbitrarily, we call them (NCHW) - Batch Size, Channel Count, Height, Width. + // If we want to select all the channels, across all batches, then `NativeSelectX(T, 1)` would yield all channels. The resulting matrix will be (6, 20) + // If we want to select all the heights, across all channels and batches, then `NativeSelectX(T, 2) will yield all heights. The resulting matrix will be (24, 5) + // + // If for some reason the format was in NHWC, then you would need to reshape. This wouldn't be useful. + + var T *Dense + T = New(WithShape(2, 3, 4, 5), WithBacking(Range(Float64, 0, 2*3*4*5))) + x, err := SelectF64(T, 1) + if err != nil { + fmt.Printf("ERR %v", err) + } + for _, row := range x { + fmt.Printf("%3.0f\n", row) + } + + // Output: + // [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] + // [ 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39] + // [ 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59] + // [ 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79] + // [ 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99] + // [100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119] +} + +// The iterators are iteratos in the truest sense. The data isn't copied, as this example shows +func Example_clobber() { + var T *Dense + T = New(WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) + fmt.Printf("Before :\n%v", T) + + xx, _ := MatrixF64(T) + xx[1][1] = 10000 + fmt.Printf("After :\n%v", T) + + // Output: + // Before : + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // After : + // ⎡ 0 1 2⎤ + // ⎣ 3 10000 5⎦ + +} diff --git a/iterator_native.go b/native/iterator_native.go similarity index 80% rename from iterator_native.go rename to native/iterator_native.go index 9801e63..a360aeb 100644 --- a/iterator_native.go +++ b/native/iterator_native.go @@ -1,12 +1,13 @@ // Code generated by genlib2. DO NOT EDIT. -package tensor +package native import ( "reflect" "unsafe" "github.com/pkg/errors" + . "gorgonia.org/tensor" ) func checkNativeIterable(t *Dense, dims int, dt Dtype) error { @@ -15,11 +16,11 @@ func checkNativeIterable(t *Dense, dims int, dt Dtype) error { return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") } - if t.shape.Dims() != dims { + if t.Shape().Dims() != dims { return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) } - if t.DataOrder().isColMajor() || t.RequiresIterator() { + if t.F() || t.RequiresIterator() { return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") } @@ -32,18 +33,18 @@ func checkNativeIterable(t *Dense, dims int, dt Dtype) error { /* Native Iterables for bool */ -// NativeVectorB converts a *Dense into a []bool +// VectorB converts a *Dense into a []bool // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorB(t *Dense) (retVal []bool, err error) { +func VectorB(t *Dense) (retVal []bool, err error) { if err = checkNativeIterable(t, 1, Bool); err != nil { return nil, err } return t.Bools(), nil } -// NativeMatrixB converts a *Dense into a [][]bool +// MatrixB converts a *Dense into a [][]bool // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixB(t *Dense) (retVal [][]bool, err error) { +func MatrixB(t *Dense) (retVal [][]bool, err error) { if err = checkNativeIterable(t, 2, Bool); err != nil { return nil, err } @@ -68,9 +69,9 @@ func NativeMatrixB(t *Dense) (retVal [][]bool, err error) { return } -// Native3TensorB converts a *Dense into a [][][]bool. +// Tensor3B converts a *Dense into a [][][]bool. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorB(t *Dense) (retVal [][][]bool, err error) { +func Tensor3B(t *Dense) (retVal [][][]bool, err error) { if err = checkNativeIterable(t, 3, Bool); err != nil { return nil, err } @@ -102,18 +103,18 @@ func Native3TensorB(t *Dense) (retVal [][][]bool, err error) { /* Native Iterables for int */ -// NativeVectorI converts a *Dense into a []int +// VectorI converts a *Dense into a []int // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorI(t *Dense) (retVal []int, err error) { +func VectorI(t *Dense) (retVal []int, err error) { if err = checkNativeIterable(t, 1, Int); err != nil { return nil, err } return t.Ints(), nil } -// NativeMatrixI converts a *Dense into a [][]int +// MatrixI converts a *Dense into a [][]int // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixI(t *Dense) (retVal [][]int, err error) { +func MatrixI(t *Dense) (retVal [][]int, err error) { if err = checkNativeIterable(t, 2, Int); err != nil { return nil, err } @@ -138,9 +139,9 @@ func NativeMatrixI(t *Dense) (retVal [][]int, err error) { return } -// Native3TensorI converts a *Dense into a [][][]int. +// Tensor3I converts a *Dense into a [][][]int. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorI(t *Dense) (retVal [][][]int, err error) { +func Tensor3I(t *Dense) (retVal [][][]int, err error) { if err = checkNativeIterable(t, 3, Int); err != nil { return nil, err } @@ -172,18 +173,18 @@ func Native3TensorI(t *Dense) (retVal [][][]int, err error) { /* Native Iterables for int8 */ -// NativeVectorI8 converts a *Dense into a []int8 +// VectorI8 converts a *Dense into a []int8 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorI8(t *Dense) (retVal []int8, err error) { +func VectorI8(t *Dense) (retVal []int8, err error) { if err = checkNativeIterable(t, 1, Int8); err != nil { return nil, err } return t.Int8s(), nil } -// NativeMatrixI8 converts a *Dense into a [][]int8 +// MatrixI8 converts a *Dense into a [][]int8 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixI8(t *Dense) (retVal [][]int8, err error) { +func MatrixI8(t *Dense) (retVal [][]int8, err error) { if err = checkNativeIterable(t, 2, Int8); err != nil { return nil, err } @@ -208,9 +209,9 @@ func NativeMatrixI8(t *Dense) (retVal [][]int8, err error) { return } -// Native3TensorI8 converts a *Dense into a [][][]int8. +// Tensor3I8 converts a *Dense into a [][][]int8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorI8(t *Dense) (retVal [][][]int8, err error) { +func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { if err = checkNativeIterable(t, 3, Int8); err != nil { return nil, err } @@ -242,18 +243,18 @@ func Native3TensorI8(t *Dense) (retVal [][][]int8, err error) { /* Native Iterables for int16 */ -// NativeVectorI16 converts a *Dense into a []int16 +// VectorI16 converts a *Dense into a []int16 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorI16(t *Dense) (retVal []int16, err error) { +func VectorI16(t *Dense) (retVal []int16, err error) { if err = checkNativeIterable(t, 1, Int16); err != nil { return nil, err } return t.Int16s(), nil } -// NativeMatrixI16 converts a *Dense into a [][]int16 +// MatrixI16 converts a *Dense into a [][]int16 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixI16(t *Dense) (retVal [][]int16, err error) { +func MatrixI16(t *Dense) (retVal [][]int16, err error) { if err = checkNativeIterable(t, 2, Int16); err != nil { return nil, err } @@ -278,9 +279,9 @@ func NativeMatrixI16(t *Dense) (retVal [][]int16, err error) { return } -// Native3TensorI16 converts a *Dense into a [][][]int16. +// Tensor3I16 converts a *Dense into a [][][]int16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorI16(t *Dense) (retVal [][][]int16, err error) { +func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { if err = checkNativeIterable(t, 3, Int16); err != nil { return nil, err } @@ -312,18 +313,18 @@ func Native3TensorI16(t *Dense) (retVal [][][]int16, err error) { /* Native Iterables for int32 */ -// NativeVectorI32 converts a *Dense into a []int32 +// VectorI32 converts a *Dense into a []int32 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorI32(t *Dense) (retVal []int32, err error) { +func VectorI32(t *Dense) (retVal []int32, err error) { if err = checkNativeIterable(t, 1, Int32); err != nil { return nil, err } return t.Int32s(), nil } -// NativeMatrixI32 converts a *Dense into a [][]int32 +// MatrixI32 converts a *Dense into a [][]int32 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixI32(t *Dense) (retVal [][]int32, err error) { +func MatrixI32(t *Dense) (retVal [][]int32, err error) { if err = checkNativeIterable(t, 2, Int32); err != nil { return nil, err } @@ -348,9 +349,9 @@ func NativeMatrixI32(t *Dense) (retVal [][]int32, err error) { return } -// Native3TensorI32 converts a *Dense into a [][][]int32. +// Tensor3I32 converts a *Dense into a [][][]int32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorI32(t *Dense) (retVal [][][]int32, err error) { +func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { if err = checkNativeIterable(t, 3, Int32); err != nil { return nil, err } @@ -382,18 +383,18 @@ func Native3TensorI32(t *Dense) (retVal [][][]int32, err error) { /* Native Iterables for int64 */ -// NativeVectorI64 converts a *Dense into a []int64 +// VectorI64 converts a *Dense into a []int64 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorI64(t *Dense) (retVal []int64, err error) { +func VectorI64(t *Dense) (retVal []int64, err error) { if err = checkNativeIterable(t, 1, Int64); err != nil { return nil, err } return t.Int64s(), nil } -// NativeMatrixI64 converts a *Dense into a [][]int64 +// MatrixI64 converts a *Dense into a [][]int64 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixI64(t *Dense) (retVal [][]int64, err error) { +func MatrixI64(t *Dense) (retVal [][]int64, err error) { if err = checkNativeIterable(t, 2, Int64); err != nil { return nil, err } @@ -418,9 +419,9 @@ func NativeMatrixI64(t *Dense) (retVal [][]int64, err error) { return } -// Native3TensorI64 converts a *Dense into a [][][]int64. +// Tensor3I64 converts a *Dense into a [][][]int64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorI64(t *Dense) (retVal [][][]int64, err error) { +func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { if err = checkNativeIterable(t, 3, Int64); err != nil { return nil, err } @@ -452,18 +453,18 @@ func Native3TensorI64(t *Dense) (retVal [][][]int64, err error) { /* Native Iterables for uint */ -// NativeVectorU converts a *Dense into a []uint +// VectorU converts a *Dense into a []uint // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorU(t *Dense) (retVal []uint, err error) { +func VectorU(t *Dense) (retVal []uint, err error) { if err = checkNativeIterable(t, 1, Uint); err != nil { return nil, err } return t.Uints(), nil } -// NativeMatrixU converts a *Dense into a [][]uint +// MatrixU converts a *Dense into a [][]uint // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixU(t *Dense) (retVal [][]uint, err error) { +func MatrixU(t *Dense) (retVal [][]uint, err error) { if err = checkNativeIterable(t, 2, Uint); err != nil { return nil, err } @@ -488,9 +489,9 @@ func NativeMatrixU(t *Dense) (retVal [][]uint, err error) { return } -// Native3TensorU converts a *Dense into a [][][]uint. +// Tensor3U converts a *Dense into a [][][]uint. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorU(t *Dense) (retVal [][][]uint, err error) { +func Tensor3U(t *Dense) (retVal [][][]uint, err error) { if err = checkNativeIterable(t, 3, Uint); err != nil { return nil, err } @@ -522,18 +523,18 @@ func Native3TensorU(t *Dense) (retVal [][][]uint, err error) { /* Native Iterables for uint8 */ -// NativeVectorU8 converts a *Dense into a []uint8 +// VectorU8 converts a *Dense into a []uint8 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorU8(t *Dense) (retVal []uint8, err error) { +func VectorU8(t *Dense) (retVal []uint8, err error) { if err = checkNativeIterable(t, 1, Uint8); err != nil { return nil, err } return t.Uint8s(), nil } -// NativeMatrixU8 converts a *Dense into a [][]uint8 +// MatrixU8 converts a *Dense into a [][]uint8 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixU8(t *Dense) (retVal [][]uint8, err error) { +func MatrixU8(t *Dense) (retVal [][]uint8, err error) { if err = checkNativeIterable(t, 2, Uint8); err != nil { return nil, err } @@ -558,9 +559,9 @@ func NativeMatrixU8(t *Dense) (retVal [][]uint8, err error) { return } -// Native3TensorU8 converts a *Dense into a [][][]uint8. +// Tensor3U8 converts a *Dense into a [][][]uint8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorU8(t *Dense) (retVal [][][]uint8, err error) { +func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { if err = checkNativeIterable(t, 3, Uint8); err != nil { return nil, err } @@ -592,18 +593,18 @@ func Native3TensorU8(t *Dense) (retVal [][][]uint8, err error) { /* Native Iterables for uint16 */ -// NativeVectorU16 converts a *Dense into a []uint16 +// VectorU16 converts a *Dense into a []uint16 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorU16(t *Dense) (retVal []uint16, err error) { +func VectorU16(t *Dense) (retVal []uint16, err error) { if err = checkNativeIterable(t, 1, Uint16); err != nil { return nil, err } return t.Uint16s(), nil } -// NativeMatrixU16 converts a *Dense into a [][]uint16 +// MatrixU16 converts a *Dense into a [][]uint16 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixU16(t *Dense) (retVal [][]uint16, err error) { +func MatrixU16(t *Dense) (retVal [][]uint16, err error) { if err = checkNativeIterable(t, 2, Uint16); err != nil { return nil, err } @@ -628,9 +629,9 @@ func NativeMatrixU16(t *Dense) (retVal [][]uint16, err error) { return } -// Native3TensorU16 converts a *Dense into a [][][]uint16. +// Tensor3U16 converts a *Dense into a [][][]uint16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorU16(t *Dense) (retVal [][][]uint16, err error) { +func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { if err = checkNativeIterable(t, 3, Uint16); err != nil { return nil, err } @@ -662,18 +663,18 @@ func Native3TensorU16(t *Dense) (retVal [][][]uint16, err error) { /* Native Iterables for uint32 */ -// NativeVectorU32 converts a *Dense into a []uint32 +// VectorU32 converts a *Dense into a []uint32 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorU32(t *Dense) (retVal []uint32, err error) { +func VectorU32(t *Dense) (retVal []uint32, err error) { if err = checkNativeIterable(t, 1, Uint32); err != nil { return nil, err } return t.Uint32s(), nil } -// NativeMatrixU32 converts a *Dense into a [][]uint32 +// MatrixU32 converts a *Dense into a [][]uint32 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixU32(t *Dense) (retVal [][]uint32, err error) { +func MatrixU32(t *Dense) (retVal [][]uint32, err error) { if err = checkNativeIterable(t, 2, Uint32); err != nil { return nil, err } @@ -698,9 +699,9 @@ func NativeMatrixU32(t *Dense) (retVal [][]uint32, err error) { return } -// Native3TensorU32 converts a *Dense into a [][][]uint32. +// Tensor3U32 converts a *Dense into a [][][]uint32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorU32(t *Dense) (retVal [][][]uint32, err error) { +func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { if err = checkNativeIterable(t, 3, Uint32); err != nil { return nil, err } @@ -732,18 +733,18 @@ func Native3TensorU32(t *Dense) (retVal [][][]uint32, err error) { /* Native Iterables for uint64 */ -// NativeVectorU64 converts a *Dense into a []uint64 +// VectorU64 converts a *Dense into a []uint64 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorU64(t *Dense) (retVal []uint64, err error) { +func VectorU64(t *Dense) (retVal []uint64, err error) { if err = checkNativeIterable(t, 1, Uint64); err != nil { return nil, err } return t.Uint64s(), nil } -// NativeMatrixU64 converts a *Dense into a [][]uint64 +// MatrixU64 converts a *Dense into a [][]uint64 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixU64(t *Dense) (retVal [][]uint64, err error) { +func MatrixU64(t *Dense) (retVal [][]uint64, err error) { if err = checkNativeIterable(t, 2, Uint64); err != nil { return nil, err } @@ -768,9 +769,9 @@ func NativeMatrixU64(t *Dense) (retVal [][]uint64, err error) { return } -// Native3TensorU64 converts a *Dense into a [][][]uint64. +// Tensor3U64 converts a *Dense into a [][][]uint64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorU64(t *Dense) (retVal [][][]uint64, err error) { +func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { if err = checkNativeIterable(t, 3, Uint64); err != nil { return nil, err } @@ -802,18 +803,18 @@ func Native3TensorU64(t *Dense) (retVal [][][]uint64, err error) { /* Native Iterables for float32 */ -// NativeVectorF32 converts a *Dense into a []float32 +// VectorF32 converts a *Dense into a []float32 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorF32(t *Dense) (retVal []float32, err error) { +func VectorF32(t *Dense) (retVal []float32, err error) { if err = checkNativeIterable(t, 1, Float32); err != nil { return nil, err } return t.Float32s(), nil } -// NativeMatrixF32 converts a *Dense into a [][]float32 +// MatrixF32 converts a *Dense into a [][]float32 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixF32(t *Dense) (retVal [][]float32, err error) { +func MatrixF32(t *Dense) (retVal [][]float32, err error) { if err = checkNativeIterable(t, 2, Float32); err != nil { return nil, err } @@ -838,9 +839,9 @@ func NativeMatrixF32(t *Dense) (retVal [][]float32, err error) { return } -// Native3TensorF32 converts a *Dense into a [][][]float32. +// Tensor3F32 converts a *Dense into a [][][]float32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorF32(t *Dense) (retVal [][][]float32, err error) { +func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { if err = checkNativeIterable(t, 3, Float32); err != nil { return nil, err } @@ -872,18 +873,18 @@ func Native3TensorF32(t *Dense) (retVal [][][]float32, err error) { /* Native Iterables for float64 */ -// NativeVectorF64 converts a *Dense into a []float64 +// VectorF64 converts a *Dense into a []float64 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorF64(t *Dense) (retVal []float64, err error) { +func VectorF64(t *Dense) (retVal []float64, err error) { if err = checkNativeIterable(t, 1, Float64); err != nil { return nil, err } return t.Float64s(), nil } -// NativeMatrixF64 converts a *Dense into a [][]float64 +// MatrixF64 converts a *Dense into a [][]float64 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixF64(t *Dense) (retVal [][]float64, err error) { +func MatrixF64(t *Dense) (retVal [][]float64, err error) { if err = checkNativeIterable(t, 2, Float64); err != nil { return nil, err } @@ -908,9 +909,9 @@ func NativeMatrixF64(t *Dense) (retVal [][]float64, err error) { return } -// Native3TensorF64 converts a *Dense into a [][][]float64. +// Tensor3F64 converts a *Dense into a [][][]float64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorF64(t *Dense) (retVal [][][]float64, err error) { +func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { if err = checkNativeIterable(t, 3, Float64); err != nil { return nil, err } @@ -942,18 +943,18 @@ func Native3TensorF64(t *Dense) (retVal [][][]float64, err error) { /* Native Iterables for complex64 */ -// NativeVectorC64 converts a *Dense into a []complex64 +// VectorC64 converts a *Dense into a []complex64 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorC64(t *Dense) (retVal []complex64, err error) { +func VectorC64(t *Dense) (retVal []complex64, err error) { if err = checkNativeIterable(t, 1, Complex64); err != nil { return nil, err } return t.Complex64s(), nil } -// NativeMatrixC64 converts a *Dense into a [][]complex64 +// MatrixC64 converts a *Dense into a [][]complex64 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixC64(t *Dense) (retVal [][]complex64, err error) { +func MatrixC64(t *Dense) (retVal [][]complex64, err error) { if err = checkNativeIterable(t, 2, Complex64); err != nil { return nil, err } @@ -978,9 +979,9 @@ func NativeMatrixC64(t *Dense) (retVal [][]complex64, err error) { return } -// Native3TensorC64 converts a *Dense into a [][][]complex64. +// Tensor3C64 converts a *Dense into a [][][]complex64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorC64(t *Dense) (retVal [][][]complex64, err error) { +func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { if err = checkNativeIterable(t, 3, Complex64); err != nil { return nil, err } @@ -1012,18 +1013,18 @@ func Native3TensorC64(t *Dense) (retVal [][][]complex64, err error) { /* Native Iterables for complex128 */ -// NativeVectorC128 converts a *Dense into a []complex128 +// VectorC128 converts a *Dense into a []complex128 // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorC128(t *Dense) (retVal []complex128, err error) { +func VectorC128(t *Dense) (retVal []complex128, err error) { if err = checkNativeIterable(t, 1, Complex128); err != nil { return nil, err } return t.Complex128s(), nil } -// NativeMatrixC128 converts a *Dense into a [][]complex128 +// MatrixC128 converts a *Dense into a [][]complex128 // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixC128(t *Dense) (retVal [][]complex128, err error) { +func MatrixC128(t *Dense) (retVal [][]complex128, err error) { if err = checkNativeIterable(t, 2, Complex128); err != nil { return nil, err } @@ -1048,9 +1049,9 @@ func NativeMatrixC128(t *Dense) (retVal [][]complex128, err error) { return } -// Native3TensorC128 converts a *Dense into a [][][]complex128. +// Tensor3C128 converts a *Dense into a [][][]complex128. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorC128(t *Dense) (retVal [][][]complex128, err error) { +func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { if err = checkNativeIterable(t, 3, Complex128); err != nil { return nil, err } @@ -1082,18 +1083,18 @@ func Native3TensorC128(t *Dense) (retVal [][][]complex128, err error) { /* Native Iterables for string */ -// NativeVectorStr converts a *Dense into a []string +// VectorStr converts a *Dense into a []string // If the *Dense does not represent a vector of the wanted type, it will return an error. -func NativeVectorStr(t *Dense) (retVal []string, err error) { +func VectorStr(t *Dense) (retVal []string, err error) { if err = checkNativeIterable(t, 1, String); err != nil { return nil, err } return t.Strings(), nil } -// NativeMatrixStr converts a *Dense into a [][]string +// MatrixStr converts a *Dense into a [][]string // If the *Dense does not represent a matrix of the wanted type, it will return an error. -func NativeMatrixStr(t *Dense) (retVal [][]string, err error) { +func MatrixStr(t *Dense) (retVal [][]string, err error) { if err = checkNativeIterable(t, 2, String); err != nil { return nil, err } @@ -1118,9 +1119,9 @@ func NativeMatrixStr(t *Dense) (retVal [][]string, err error) { return } -// Native3TensorStr converts a *Dense into a [][][]string. +// Tensor3Str converts a *Dense into a [][][]string. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Native3TensorStr(t *Dense) (retVal [][][]string, err error) { +func Tensor3Str(t *Dense) (retVal [][][]string, err error) { if err = checkNativeIterable(t, 3, String); err != nil { return nil, err } diff --git a/native/iterator_native2.go b/native/iterator_native2.go new file mode 100644 index 0000000..85045ce --- /dev/null +++ b/native/iterator_native2.go @@ -0,0 +1,635 @@ +// Code generated by genlib2. DO NOT EDIT. + +package native + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + . "gorgonia.org/tensor" +) + +func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} + +/* Native Select for bool */ + +// SelectB creates a slice of flat data types. See Example of NativeSelectF64. +func SelectB(t *Dense, axis int) (retVal [][]bool, err error) { + if err := checkNativeSelectable(t, axis, Bool); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]bool, 1) + retVal[0] = t.Bools() + case 2: + if axis == 0 { + return MatrixB(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Bools() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]bool, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]bool)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int */ + +// SelectI creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI(t *Dense, axis int) (retVal [][]int, err error) { + if err := checkNativeSelectable(t, axis, Int); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int, 1) + retVal[0] = t.Ints() + case 2: + if axis == 0 { + return MatrixI(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Ints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]int)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int8 */ + +// SelectI8 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI8(t *Dense, axis int) (retVal [][]int8, err error) { + if err := checkNativeSelectable(t, axis, Int8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int8, 1) + retVal[0] = t.Int8s() + case 2: + if axis == 0 { + return MatrixI8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]int8)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int16 */ + +// SelectI16 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI16(t *Dense, axis int) (retVal [][]int16, err error) { + if err := checkNativeSelectable(t, axis, Int16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int16, 1) + retVal[0] = t.Int16s() + case 2: + if axis == 0 { + return MatrixI16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]int16)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int32 */ + +// SelectI32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI32(t *Dense, axis int) (retVal [][]int32, err error) { + if err := checkNativeSelectable(t, axis, Int32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int32, 1) + retVal[0] = t.Int32s() + case 2: + if axis == 0 { + return MatrixI32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]int32)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int64 */ + +// SelectI64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI64(t *Dense, axis int) (retVal [][]int64, err error) { + if err := checkNativeSelectable(t, axis, Int64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int64, 1) + retVal[0] = t.Int64s() + case 2: + if axis == 0 { + return MatrixI64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]int64)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint */ + +// SelectU creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU(t *Dense, axis int) (retVal [][]uint, err error) { + if err := checkNativeSelectable(t, axis, Uint); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint, 1) + retVal[0] = t.Uints() + case 2: + if axis == 0 { + return MatrixU(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]uint)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint8 */ + +// SelectU8 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { + if err := checkNativeSelectable(t, axis, Uint8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint8, 1) + retVal[0] = t.Uint8s() + case 2: + if axis == 0 { + return MatrixU8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]uint8)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint16 */ + +// SelectU16 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { + if err := checkNativeSelectable(t, axis, Uint16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint16, 1) + retVal[0] = t.Uint16s() + case 2: + if axis == 0 { + return MatrixU16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]uint16)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint32 */ + +// SelectU32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { + if err := checkNativeSelectable(t, axis, Uint32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint32, 1) + retVal[0] = t.Uint32s() + case 2: + if axis == 0 { + return MatrixU32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]uint32)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint64 */ + +// SelectU64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { + if err := checkNativeSelectable(t, axis, Uint64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint64, 1) + retVal[0] = t.Uint64s() + case 2: + if axis == 0 { + return MatrixU64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]uint64)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float32 */ + +// SelectF32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectF32(t *Dense, axis int) (retVal [][]float32, err error) { + if err := checkNativeSelectable(t, axis, Float32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float32, 1) + retVal[0] = t.Float32s() + case 2: + if axis == 0 { + return MatrixF32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]float32)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float64 */ + +// SelectF64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectF64(t *Dense, axis int) (retVal [][]float64, err error) { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float64, 1) + retVal[0] = t.Float64s() + case 2: + if axis == 0 { + return MatrixF64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]float64)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex64 */ + +// SelectC64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { + if err := checkNativeSelectable(t, axis, Complex64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex64, 1) + retVal[0] = t.Complex64s() + case 2: + if axis == 0 { + return MatrixC64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]complex64)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex128 */ + +// SelectC128 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { + if err := checkNativeSelectable(t, axis, Complex128); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex128, 1) + retVal[0] = t.Complex128s() + case 2: + if axis == 0 { + return MatrixC128(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex128s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex128, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]complex128)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for string */ + +// SelectStr creates a slice of flat data types. See Example of NativeSelectF64. +func SelectStr(t *Dense, axis int) (retVal [][]string, err error) { + if err := checkNativeSelectable(t, axis, String); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]string, 1) + retVal[0] = t.Strings() + case 2: + if axis == 0 { + return MatrixStr(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Strings() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]string, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + hdr := &reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(&data[i])), + Len: stride, + Cap: stride, + } + retVal = append(retVal, *(*[]string)(unsafe.Pointer(hdr))) + r++ + } + return retVal, nil + + } + return +} diff --git a/native/iterator_native2_test.go b/native/iterator_native2_test.go new file mode 100644 index 0000000..df56b5e --- /dev/null +++ b/native/iterator_native2_test.go @@ -0,0 +1,842 @@ +// Code generated by genlib2. DO NOT EDIT. + +package native + +import ( + "testing" + + "github.com/stretchr/testify/assert" + . "gorgonia.org/tensor" +) + +func TestSelectB(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]bool + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = SelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = SelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = SelectB(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = SelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = SelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(false)) + if x, err = SelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectB(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = SelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = SelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = SelectI(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = SelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = SelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int(0))) + if x, err = SelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int8 + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = SelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = SelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = SelectI8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = SelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = SelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int8(0))) + if x, err = SelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int16 + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = SelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = SelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = SelectI16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = SelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = SelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int16(0))) + if x, err = SelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int32 + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = SelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = SelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = SelectI32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = SelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = SelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int32(0))) + if x, err = SelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int64 + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = SelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = SelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = SelectI64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = SelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = SelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int64(0))) + if x, err = SelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectI64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = SelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = SelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = SelectU(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = SelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = SelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint(0))) + if x, err = SelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint8 + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = SelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = SelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = SelectU8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = SelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = SelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint8(0))) + if x, err = SelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint16 + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = SelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = SelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = SelectU16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = SelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = SelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint16(0))) + if x, err = SelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint32 + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = SelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = SelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = SelectU32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = SelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = SelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint32(0))) + if x, err = SelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint64 + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = SelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = SelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = SelectU64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = SelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = SelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint64(0))) + if x, err = SelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectU64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float32 + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = SelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = SelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = SelectF32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = SelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = SelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float32(0))) + if x, err = SelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectF32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float64 + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = SelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = SelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = SelectF64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = SelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = SelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float64(0))) + if x, err = SelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectF64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex64 + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = SelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = SelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = SelectC64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = SelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = SelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex64(0))) + if x, err = SelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectC64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex128 + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = SelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = SelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = SelectC128(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = SelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = SelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex128(0))) + if x, err = SelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectC128(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestSelectStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]string + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = SelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = SelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = SelectStr(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = SelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = SelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar("")) + if x, err = SelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = SelectStr(T, 10); err == nil { + t.Fatal("Expected errors") + } +} diff --git a/native/iterator_native_test.go b/native/iterator_native_test.go new file mode 100644 index 0000000..09236a0 --- /dev/null +++ b/native/iterator_native_test.go @@ -0,0 +1,634 @@ +// Code generated by genlib2. DO NOT EDIT. + +package native + +import ( + "testing" + + "github.com/stretchr/testify/assert" + . "gorgonia.org/tensor" +) + +func Test_VectorB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(6)) + it, err := VectorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3)) + it, err := MatrixB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3B(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3, 4)) + it, err := Tensor3B(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(6)) + it, err := VectorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) + it, err := MatrixI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) + it, err := VectorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) + it, err := MatrixI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) + it, err := VectorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) + it, err := MatrixI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) + it, err := VectorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) + it, err := MatrixI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) + it, err := VectorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) + it, err := MatrixI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3I64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3I64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) + it, err := VectorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) + it, err := MatrixU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) + it, err := VectorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) + it, err := MatrixU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) + it, err := VectorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) + it, err := MatrixU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) + it, err := VectorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) + it, err := MatrixU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) + it, err := VectorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) + it, err := MatrixU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3U64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3U64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) + it, err := VectorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) + it, err := MatrixF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3F32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3F32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) + it, err := VectorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) + it, err := MatrixF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3F64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3F64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) + it, err := VectorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) + it, err := MatrixC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3C64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3C64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) + it, err := VectorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) + it, err := MatrixC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3C128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) + it, err := Tensor3C128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_VectorStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(6)) + it, err := VectorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_MatrixStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3)) + it, err := MatrixStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_Tensor3Str(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3, 4)) + it, err := Tensor3Str(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} From d319726786ce9ce11db7fc0d0ac0c2a1f02de71c Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sat, 26 May 2018 15:18:21 +1000 Subject: [PATCH 011/154] Update .travis.yml Remove Go <1.8 --- .travis.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 7c0fa86..8706540 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,8 +5,6 @@ branches: only: - master go: - - 1.6.x - - 1.7.x - 1.8.x - 1.9.x - tip @@ -29,4 +27,4 @@ script: matrix: allow_failures: - - go: tip \ No newline at end of file + - go: tip From 9cc480c10113d9b0cc40c62739bfe67cf8f8d521 Mon Sep 17 00:00:00 2001 From: Andrew Snodgrass Date: Fri, 17 Aug 2018 14:10:40 -0600 Subject: [PATCH 012/154] Improved multi arch support --- divmod.s => divmod_amd64.s | 0 mathutils.go | 2 +- mathutils_go.go | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename divmod.s => divmod_amd64.s (100%) diff --git a/divmod.s b/divmod_amd64.s similarity index 100% rename from divmod.s rename to divmod_amd64.s diff --git a/mathutils.go b/mathutils.go index 88ebbae..8060ad3 100644 --- a/mathutils.go +++ b/mathutils.go @@ -1,4 +1,4 @@ -// +build !noasm +// +build amd64,!noasm package tensor diff --git a/mathutils_go.go b/mathutils_go.go index 1a5f2c1..299bcbe 100644 --- a/mathutils_go.go +++ b/mathutils_go.go @@ -1,4 +1,4 @@ -// +build noasm +// +build !amd64 noasm package tensor From 72d4a63c5ebbf8c5dc2a589ee476cdde2b80785a Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sun, 19 Aug 2018 11:45:45 +1000 Subject: [PATCH 013/154] V0.9.0 (#22) # v0.9.0 # The changes made in this PR is aimed at better supporting v0.9.0 of Gorgonia itself. Along the way there are some new features and optimizations, as well as some bug fixes. The majority of the work in supporting v0.9.0 of Gorgonia is to shore up the underlying architecture to support CUDA related engines. This means moving more things to rely on `Engine` while keeping the engine interface overheads low. Additionally this also means better support for column major data layouts. * Heavier reliance on `Engine` for most functions. This allows for extensibility on the data structure. * Long standing bugbear - concepts of `RowVec` and `ColVec` has been removed (thanks to @matdodgson) - Touch points: `ap.go`, `iterator.go`, `iterator_mult.go`.`shape.go`, and the tests that were correct prior to this change have semantic meaning changes too. - **POTENTIAL TECH DEBT**: `iterator_mult.go` - the solution of filling with ones is a little too dodgy for my liking. The alternative would be to change `BroadcastStrides` which will change even more things (`Concat`, `Stack` etc) * **Optimization**: - `AP` has been depointerized in `*Dense` (thanks to @docmerlin). This reduces *some* amount of GC pointer chasing, but not all - allocation is slightly improved. (`(array).fromSliceOrArrayer`, `(array).fix()` and `(array).forcefix()` are part of the improvement around the logic of allocating data. * **Bug fixes**: - Fixes subtle errors in linear algebra functions. The result is a slightly longer function but easier to reason with. - Fixes some subtle bugs in `Concat` - see also gorgonia/gorgonia#218 - Fixed some small bugs with regards to `SampleIndex` that only show up when the slices have extreme lengths. This API should have been deprecated 2 years ago, but eh... it touched a lot of external projects. * **API changes**: - `Diag` is made available. Relies heavily on an `Engine`'s implementation - `NewFlatIterator` is unexported. - `NewAP` is unexported. - `MakeAP` is used instead. - `(Tensor).DataOrder()` is added to the definiiton of what a `Tensor` is. - `(Shape).IsScalarEquiv()` is a new method. This corresponds to the change of semantics of what a `Shape` should be. - `(Shape).CalcStrides()` is exported now. This enables users to correctly calculate strides that are consistent to what the package expects. - `(Shape).CalcStridesColMajor()` is exported as the method to calculate the strides of a Col-Major `*Dense`. * **New Interfaces**: - `NonStdEngine` is an `Engine that does not allocate using the default allocator. This allows for both embedding a `DefaultEngine` while overriding the allocation behaviour. - `Diager` - any engine that can return a tensor that only contains the diagonal values of the input - `NaNChecker` and `InfChecker` - engines that can check a tensor for NaN and Inf * **New Features**: * Added full support for colmajor tensors. (fixes #10) - TODO: colmajor iterator's prev() method (see #34) - Added serialization to Protobuf and Flatbuffers * TODO: Add example for serialization (see #35 and #36) - Added more support for sparse CS tensors. * **New Subpackages**: * `native` is a subpackage that essentially gives users a native, Go-based iterator. Basically the ability to go from a `*Dense` to a `[][]T` or `[][][]T` **without extra allocations** (for the data). This was pulled into `master` earlier, but as of v0.9.0, the generic version is available too. * **Semantic Changes**: - `Shape` has semantic changes regarding whether or not a shape is scalar. A scalar shape is defined to be `Shape{}` or `Shape{1}` only. Formerly, `Shape{1,1}` was also considered to be scalar. Now they're considered to be `ScalarEquivalent` (along with `Shape{1, 1, .... , 1}`) - A `Dtype` that is is orderable is also now comparable for equality. If `RegisterOrd` is called with a new `Dtype`, it is also automatically registered as `Eq`. * **Cosmetic Changes**: - README has been updated to point to correct doc pages --- .travis.yml | 1 + .travis/test.sh | 3 +- CONTRIBUTORS.md | 3 +- README.md | 15 +- ap.go | 157 ++- ap_test.go | 56 +- api_arith_test.go | 10 +- api_cmp_generated_test.go | 56 +- api_matop.go | 7 + api_utils.go | 12 +- array.go | 93 +- benchmark_dense_matop_test.go | 58 +- consopt.go | 76 +- defaultengine.go | 7 +- defaultengine_argmethods.go | 22 +- defaultengine_arith.go | 24 +- defaultengine_cmp.go | 36 +- defaultengine_linalg.go | 158 ++- defaultengine_mapreduce.go | 20 +- defaultengine_matop_misc.go | 130 +- defaultengine_matop_stack.go | 13 +- defaultengine_matop_transpose.go | 12 +- defaultengine_matop_transpose_inplace.go | 18 + defaultengine_misc.go | 2 +- defaultengine_prep.go | 29 +- defaultengine_unary.go | 28 +- defaultenginefloat32.go | 10 +- defaultenginefloat64.go | 10 +- dense.go | 58 +- dense_assign.go | 8 +- dense_cmp_test.go | 56 +- dense_colmajor_linalg_test.go | 483 +++++++ dense_compat.go | 8 +- dense_format.go | 4 +- dense_generated.go | 2 +- dense_io.go | 1070 +++++++-------- dense_io_test.go | 49 +- dense_linalg.go | 15 +- dense_linalg_test.go | 109 +- dense_matop.go | 23 +- dense_matop_memmove.go | 11 +- dense_matop_test.go | 120 +- dense_norms.go | 4 +- dense_svd_test.go | 29 +- engine.go | 24 + example_dense_linalg_test.go | 151 +++ example_dense_matop_test.go | 2 + example_iterator_test.go | 48 +- example_tensor_basics_test.go | 120 +- flags.go | 51 +- flags_test.go | 47 +- interfaces.go | 2 +- internal/IDLs/generated.fbs | 38 + internal/IDLs/generated.proto | 52 + internal/serialization/README.md | 33 + internal/serialization/doc.go | 2 + internal/serialization/fb/AP.go | 110 ++ internal/serialization/fb/Dense.go | 152 +++ internal/serialization/fb/MaskedDense.go | 198 +++ internal/serialization/fb/Triangle.go | 18 + internal/serialization/pb/dense.go | 45 + internal/serialization/pb/generated.pb.go | 1457 +++++++++++++++++++++ iterator.go | 80 +- iterator_mult.go | 20 +- iterator_test.go | 119 +- native/example_test.go | 7 +- native/generic.go | 72 + native/generic_test.go | 67 + perf.go | 91 +- shape.go | 47 +- shape_test.go | 22 +- sparse.go | 28 +- tensor.go | 2 +- testutils_test.go | 1 + types.go | 23 + utils.go | 9 +- 76 files changed, 5010 insertions(+), 1243 deletions(-) create mode 100644 dense_colmajor_linalg_test.go create mode 100644 example_dense_linalg_test.go create mode 100644 internal/IDLs/generated.fbs create mode 100755 internal/IDLs/generated.proto create mode 100644 internal/serialization/README.md create mode 100644 internal/serialization/doc.go create mode 100644 internal/serialization/fb/AP.go create mode 100644 internal/serialization/fb/Dense.go create mode 100644 internal/serialization/fb/MaskedDense.go create mode 100644 internal/serialization/fb/Triangle.go create mode 100644 internal/serialization/pb/dense.go create mode 100644 internal/serialization/pb/generated.pb.go create mode 100644 native/generic.go create mode 100644 native/generic_test.go diff --git a/.travis.yml b/.travis.yml index 8706540..9a3402d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ branches: go: - 1.8.x - 1.9.x + - 1.10.x - tip env: diff --git a/.travis/test.sh b/.travis/test.sh index 2e00d07..37fdd87 100644 --- a/.travis/test.sh +++ b/.travis/test.sh @@ -6,9 +6,10 @@ go test -v -a -covermode=atomic -coverprofile=test.cover . go test -tags='avx' -a -covermode=atomic -coverprofile=avx.cover . go test -tags='sse' -a -covermode=atomic -coverprofile=sse.cover . go test -tags='inplacetranspose' -a -covermode=atomic -coverprofile=inplacetranspose.cover . +go test -a -covermode=atomic -coverprofile=native.cover ./native/. # because coveralls only accepts one coverage file at one time... we combine them into one gigantic one -covers=(./test.cover ./avx.cover ./sse.cover ./inplacetranspose.cover) +covers=(./test.cover ./avx.cover ./sse.cover ./inplacetranspose.cover ./native.cover) echo "mode: set" > ./final.cover tail -q -n +2 "${covers[@]}" >> ./final.cover goveralls -coverprofile=./final.cover -service=travis-ci diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 57adb1d..d94f24c 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -4,6 +4,7 @@ * Naseer Dari (@ndari) - errors and error handling * Joe Kabaka (@kabaka0) - masked array functionality * Stuart Carnie (@stuartcarnie) - performance optimization for iterators +* Jorge Landivar (@docmerlin) - performance optimization for `*Dense` # Contributors @@ -13,8 +14,8 @@ * David Soller | @3ygun * Davor Kapsa | @dvrkps * James Michael DuPont | @h4ck3rm1k3 -* Jorge Landivar | @docmerlin * Yuanlin Lian | @alienchow +* Andrew SnodGrass | @pointlander diff --git a/README.md b/README.md index 62b3fbd..086cdc4 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ -# Package `tensor` [![GoDoc](https://godoc.org/github.com/gorgonia/tensor?status.svg)](https://godoc.org/github.com/gorgonia/tensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) # +# Package `tensor` [![GoDoc](https://godoc.org/gorgonia.org/tensor?status.svg)](https://godoc.org/gorgonia.org/tensor) [![GitHub version](https://badge.fury.io/gh/gorgonia%2Ftensor.svg)](https://badge.fury.io/gh/gorgonia%2Ftensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) [![Go Report Card](https://goreportcard.com/badge/gorgonia.org/tensor)](https://goreportcard.com/report/gorgonia.org/tensor) [![unstable](http://badges.github.io/stability-badges/dist/unstable.svg)](http://github.com/badges/stability-badges)# + Package `tensor` is a package that provides efficient, generic (by some definitions of generic) n-dimensional arrays in Go. Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations. -The main purpose of this package is to support the operations required by [Gorgonia](https://github.com/chewxy/gorgonia). +The main purpose of this package is to support the operations required by [Gorgonia](https://gorgonia.org/gorgonia). ## Introduction ## In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. @@ -50,15 +51,15 @@ The `*Dense` tensor is the primary tensor and is represented by a singular flat ### Compressed Sparse Column Matrix ### -Coming soon +Documentation Coming soon ### Compressed Sparse Row Matrix ### -Coming soon +Documentation Coming soon ## Usage ## -To install: `go get -u "github.com/chewxy/gorgonia/tensor"` +To install: `go get -u "gorgonia.org/tensor"` To create a matrix with package `tensor` is easy: @@ -129,7 +130,7 @@ b.SetAt(1000, 0, 1, 2) fmt.Printf("b:\n%v", b) ``` -There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/github.com/chewxy/gorgonia/tensor) page +There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/gorgonia.org/tensor) page @@ -198,7 +199,7 @@ The above call will use `myEngine` to allocate memory instead. This is useful in ### Other failed designs ### -The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https://github.com/chewxy/gorgonia/blob/master/tensor/ALTERNATIVEDESIGNS.md) +The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https://github.com/tensor/blob/master/ALTERNATIVEDESIGNS.md) ## Generic Features ## diff --git a/ap.go b/ap.go index b4b9176..83df9c5 100644 --- a/ap.go +++ b/ap.go @@ -26,13 +26,30 @@ type AP struct { Δ Triangle } -// NewAP creates a new AP, given the shape and strides -func NewAP(shape Shape, strides []int) *AP { - ap := borrowAP() +func makeAP(size int) AP { + return AP{ + shape: Shape(BorrowInts(size)), + strides: BorrowInts(size), + } +} + +// MakeAP creates an AP, given the shape and strides. +func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP { + return AP{ + shape: shape, + strides: strides, + o: o, + Δ: Δ, + fin: true, + } +} + +// Init initalizes an already created AP with a shape and stries. +// It will panic if AP is nil. +func (ap *AP) Init(shape Shape, strides []int) { ap.shape = shape ap.strides = strides ap.fin = true - return ap } // SetShape is for very specific times when modifying the AP is necessary, such as reshaping and doing I/O related stuff @@ -46,6 +63,9 @@ func (ap *AP) SetShape(s ...int) { if !ap.fin { // scalars are a special case, we don't want to remove it completely if len(s) == 0 { + if ap.shape == nil || ap.strides == nil { + ap.shape = Shape{} + } ap.shape = ap.shape[:0] ap.strides = ap.strides[:0] return @@ -102,9 +122,54 @@ func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() } // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 } -// Clone clones the *AP. Clearly. -func (ap *AP) Clone() (retVal *AP) { - retVal = BorrowAP(len(ap.shape)) +// IsZero tell us if the ap has zero size +func (ap *AP) IsZero() bool { + return len(ap.shape) == 0 && len(ap.strides) == 0 && !ap.fin && ap.o == 0 && ap.Δ == 0 +} + +// Zero zeros out an AP. +func (ap *AP) zero() { + // log.Printf("ZEROING. Called by %v", string(debug.Stack())) + + // Jorge's original implementation for zeroing a AP is as below + // but to cater for the (*Dense).fix() method of the *Dense + // a nil shape is used to signal unsetness + // so we cannot just truncate the shape even though it would be a lot more efficient + + // ap.shape = ap.shape[:0] + // ap.strides = ap.strides[:0] + ReturnInts([]int(ap.shape)) + ReturnInts(ap.strides) + ap.zeroOnly() +} + +// side effect free zeroing +func (ap *AP) zeroOnly() { + ap.shape = nil + ap.strides = nil + + ap.fin = false + ap.o = 0 + ap.Δ = 0 +} + +func (ap *AP) zeroWithDims(dims int) { + //ap.shape = BorrowInts(dims) + //ap.strides = BorrowInts(dims) + if cap(ap.shape) >= dims { + ap.shape = ap.shape[:dims] + } + ap.shape = BorrowInts(dims) + if cap(ap.strides) >= dims { + ap.strides = ap.strides[:dims] + } + ap.strides = BorrowInts(dims) +} + +// Clone clones the *AP. Clearly. It returns AP +func (ap *AP) Clone() (retVal AP) { + retVal = makeAP(cap(ap.shape)) + copy(retVal.shape, ap.shape) copy(retVal.strides, ap.strides) @@ -118,21 +183,25 @@ func (ap *AP) Clone() (retVal *AP) { return } +func (ap *AP) CloneTo(dest *AP) { + dest.shape = append(dest.shape[:0], ap.shape...) + dest.strides = append(dest.strides[:0], ap.strides...) + dest.fin = ap.fin + dest.o = ap.o + dest.Δ = ap.Δ +} + // DataOrder returns the data order of the AP. func (ap *AP) DataOrder() DataOrder { return ap.o } // C returns true if the access pattern is C-contiguous array -func (ap *AP) C() bool { - return ap.o.isRowMajor() && ap.o.isContiguous() -} +func (ap *AP) C() bool { return ap.o.IsRowMajor() && ap.o.IsContiguous() } // F returns true if the access pattern is Fortran contiguous array -func (ap *AP) F() bool { - return ap.o.isColMajor() && ap.o.isContiguous() -} +func (ap *AP) F() bool { return ap.o.IsColMajor() && ap.o.IsContiguous() } // S returns the metadata of the sliced tensor. -func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err error) { +func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err error) { if len(slices) > len(ap.shape) { // error err = errors.Errorf(dimMismatch, len(ap.shape), len(slices)) @@ -146,7 +215,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e var outerDim int order := ap.o - if ap.o.isRowMajor() || ap.IsVector() { + if ap.o.IsRowMajor() || ap.IsVector() { outerDim = 0 } else { outerDim = len(ap.shape) - 1 @@ -160,12 +229,13 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e size := ap.shape[i] var stride int - if ap.IsVector() { - // handles non-vanilla vectors - stride = ap.strides[0] - } else { - stride = ap.strides[i] - } + stride = ap.strides[i] + // if ap.IsVector() { + // // handles non-vanilla vectors + // stride = ap.strides[0] + // } else { + // stride = ap.strides[i] + // } var start, end, step int if start, end, step, err = SliceDetails(sl, size); err != nil { @@ -196,37 +266,29 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e if ndEnd-ndStart == 1 { // scalars are a special case - newAP = borrowAP() + newAP = AP{} newAP.SetShape() // make it a Scalar newAP.lock() } else { // drop any dimension with size 1, except the last dimension + offset := 0 for d := 0; d < dims; d++ { - if newShape[d] == 1 /*&& d != t.dims-1 && dims > 2*/ { + if newShape[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { newShape = append(newShape[:d], newShape[d+1:]...) newStrides = append(newStrides[:d], newStrides[d+1:]...) d-- dims-- + offset++ } } - - //fix up strides - if newShape.IsColVec() { - stride0 := newStrides[0] - ReturnInts(newStrides) - newStrides = BorrowInts(1) - newStrides[0] = stride0 - } - - newAP = NewAP(newShape, newStrides) - newAP.o = order + newAP = MakeAP(newShape, newStrides, order, ap.Δ) } return } // T returns the transposed metadata based on the given input -func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { +func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { // prep axes if len(axes) > 0 && len(axes) != ap.Dims() { err = errors.Errorf(dimMismatch, ap.Dims(), len(axes)) @@ -244,7 +306,7 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { // if axes is 0, 1, 2, 3... then no op if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 { - return ap, a, noopError{} + return ap.Clone(), a, noopError{} } currentShape := ap.shape @@ -270,12 +332,8 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { } } - retVal = borrowAP() - retVal.shape = shape - retVal.strides = strides - if ap.IsVector() { - retVal.strides = retVal.strides[:1] - } + o := MakeDataOrder(ap.o, Transposed) + retVal = MakeAP(shape, strides, o, ap.Δ) retVal.fin = true return } @@ -286,14 +344,21 @@ func (ap *AP) unlock() { ap.fin = false } func (ap *AP) calcStrides() []int { switch { - case ap.o.isRowMajor(): - return ap.shape.calcStrides() - case ap.o.isColMajor(): - return ap.shape.calcStridesColMajor() + case ap.o.IsRowMajor(): + return ap.shape.CalcStrides() + case ap.o.IsColMajor(): + return ap.shape.CalcStridesColMajor() } panic("unreachable") } +// setDataOrder is a method such that any tensor that embeds *AP will have the same method +func (ap *AP) setDataOrder(o DataOrder) { + if !o.HasSameOrder(ap.o) { + ap.o = ap.o.toggleColMajor() + } +} + // TransposeIndex returns the new index given the old index func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int { oldCoord, err := Itol(i, oldShape, oldStrides) diff --git a/ap_test.go b/ap_test.go index 37e0a28..091d6e6 100644 --- a/ap_test.go +++ b/ap_test.go @@ -32,46 +32,40 @@ func sli(start int, opt ...int) dummySlice { return dummySlice{start: start, end: end, step: step} } -func dummyScalar1() *AP { - return &AP{} -} +func dummyScalar1() AP { return AP{} } -func dummyScalar2() *AP { - return &AP{ - shape: Shape{1}, - } -} +func dummyScalar2() AP { return AP{shape: Shape{1}} } -func dummyColVec() *AP { - return &AP{ +func dummyColVec() AP { + return AP{ shape: Shape{5, 1}, strides: []int{1}, } } -func dummyRowVec() *AP { - return &AP{ +func dummyRowVec() AP { + return AP{ shape: Shape{1, 5}, strides: []int{1}, } } -func dummyVec() *AP { - return &AP{ +func dummyVec() AP { + return AP{ shape: Shape{5}, strides: []int{1}, } } -func twothree() *AP { - return &AP{ +func twothree() AP { + return AP{ shape: Shape{2, 3}, strides: []int{3, 1}, } } -func twothreefour() *AP { - return &AP{ +func twothreefour() AP { + return AP{ shape: Shape{2, 3, 4}, strides: []int{12, 4, 1}, } @@ -83,7 +77,7 @@ func TestAccessPatternBasics(t *testing.T) { ap.SetShape(1, 2) assert.Equal(Shape{1, 2}, ap.Shape()) - assert.Equal([]int{1}, ap.Strides()) + assert.Equal([]int{2, 1}, ap.Strides()) assert.Equal(2, ap.Dims()) assert.Equal(2, ap.Size()) @@ -100,21 +94,21 @@ func TestAccessPatternBasics(t *testing.T) { ap.unlock() ap.SetShape(1, 2) assert.Equal(Shape{1, 2}, ap.Shape()) - assert.Equal([]int{1}, ap.Strides()) + assert.Equal([]int{2, 1}, ap.Strides()) assert.Equal(2, ap.Dims()) assert.Equal(2, ap.Size()) - if ap.String() != "Shape: (1, 2), Stride: [1], Lock: false" { - t.Error("AP formatting error. Got %q", ap.String()) + if ap.String() != "Shape: (1, 2), Stride: [2 1], Lock: false" { + t.Errorf("AP formatting error. Got %q", ap.String()) } ap2 := ap.Clone() - assert.Equal(ap, ap2) + assert.Equal(*ap, ap2) } func TestAccessPatternIsX(t *testing.T) { assert := assert.New(t) - var ap *AP + var ap AP ap = dummyScalar1() assert.True(ap.IsScalar()) @@ -151,7 +145,7 @@ func TestAccessPatternIsX(t *testing.T) { func TestAccessPatternT(t *testing.T) { assert := assert.New(t) - var ap, apT *AP + var ap, apT AP var axes []int var err error @@ -216,16 +210,22 @@ var sliceTests = []struct { {"A[1:3]", Shape{4, 5}, []Slice{sli(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{sli(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, sli(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, + + // tensor + {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true}, + {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, 0, 2, Shape{1, 2}, []int{4, 1}, false}, + {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, + {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, } func TestAccessPatternS(t *testing.T) { assert := assert.New(t) - var ap, apS *AP + var ap, apS AP var ndStart, ndEnd int var err error for _, sts := range sliceTests { - ap = NewAP(sts.shape, sts.shape.calcStrides()) + ap = MakeAP(sts.shape, sts.shape.CalcStrides(), 0, 0) if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil { t.Errorf("%v errored: %v", sts.name, err) continue @@ -234,7 +234,7 @@ func TestAccessPatternS(t *testing.T) { assert.Equal(sts.correctEnd, ndEnd, "Wrong end: %v. Want %d Got %d", sts.name, sts.correctEnd, ndEnd) assert.True(sts.correctShape.Eq(apS.shape), "Wrong shape: %v. Want %v. Got %v", sts.name, sts.correctShape, apS.shape) assert.Equal(sts.correctStride, apS.strides, "Wrong strides: %v. Want %v. Got %v", sts.name, sts.correctStride, apS.strides) - assert.Equal(sts.contiguous, apS.DataOrder().isContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) + assert.Equal(sts.contiguous, apS.DataOrder().IsContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) } } diff --git a/api_arith_test.go b/api_arith_test.go index d7bd9a5..687e4b7 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -26,7 +26,7 @@ func TestMod(t *testing.T) { // scalar if res, err = Mod(a, 1.0); err != nil { - t.Fatal("Error: %v", err) + t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) } @@ -41,10 +41,10 @@ func TestFMA(t *testing.T) { y2 := y.Clone().(*Dense) we, willFailEq := willerr(a, numberTypes, nil) - // _, ok1 := q.Engine().(FMAer) - // _, ok2 := q.Engine().(Muler) - // _, ok3 := q.Engine().(Adder) - // we = we || (!ok1 && (!ok2 || !ok3)) + _, ok1 := q.Engine().(FMAer) + _, ok2 := q.Engine().(Muler) + _, ok3 := q.Engine().(Adder) + we = we || (!ok1 && (!ok2 || !ok3)) f, err := FMA(a, x, y) if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { diff --git a/api_cmp_generated_test.go b/api_cmp_generated_test.go index 163ae5c..002587b 100644 --- a/api_cmp_generated_test.go +++ b/api_cmp_generated_test.go @@ -62,7 +62,7 @@ func TestGt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -120,7 +120,7 @@ func TestGte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -178,7 +178,7 @@ func TestLt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -236,7 +236,7 @@ func TestLte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -294,7 +294,7 @@ func TestEq(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -328,7 +328,7 @@ func TestEq(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestNe(t *testing.T) { @@ -363,7 +363,7 @@ func TestNe(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestGt_assame(t *testing.T) { @@ -422,7 +422,7 @@ func TestGt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -482,7 +482,7 @@ func TestGte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -542,7 +542,7 @@ func TestLt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -602,7 +602,7 @@ func TestLte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -662,7 +662,7 @@ func TestEq_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -699,7 +699,7 @@ func TestEq_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestNe_assame(t *testing.T) { @@ -737,7 +737,7 @@ func TestNe_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestGtScalar(t *testing.T) { @@ -792,7 +792,7 @@ func TestGtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -848,7 +848,7 @@ func TestGteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -904,7 +904,7 @@ func TestLtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -960,7 +960,7 @@ func TestLteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1016,7 +1016,7 @@ func TestEqScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1048,7 +1048,7 @@ func TestEqScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestNeScalar(t *testing.T) { @@ -1081,7 +1081,7 @@ func TestNeScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } func TestGtScalar_assame(t *testing.T) { @@ -1138,7 +1138,7 @@ func TestGtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -1196,7 +1196,7 @@ func TestGteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -1254,7 +1254,7 @@ func TestLtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -1312,7 +1312,7 @@ func TestLteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1370,7 +1370,7 @@ func TestEqScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1405,7 +1405,7 @@ func TestEqScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestNeScalar_assame(t *testing.T) { @@ -1441,6 +1441,6 @@ func TestNeScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } diff --git a/api_matop.go b/api_matop.go index 0db34e7..8c687b2 100644 --- a/api_matop.go +++ b/api_matop.go @@ -108,3 +108,10 @@ func Materialize(t Tensor) Tensor { return t } } + +func Diag(t Tensor) (retVal Tensor, err error) { + if d, ok := t.Engine().(Diager); ok { + return d.Diag(t) + } + return nil, errors.Errorf("Unable to perform diagonalization of tensor ") +} diff --git a/api_utils.go b/api_utils.go index 12bea19..3cf55f0 100644 --- a/api_utils.go +++ b/api_utils.go @@ -53,11 +53,11 @@ func SortIndex(in interface{}) (out []int) { // SampleIndex samples a slice or a Tensor. // TODO: tidy this up. func SampleIndex(in interface{}) int { - var l int + // var l int switch list := in.(type) { case []int: var sum, i int - l = len(list) + // l = len(list) r := rand.Int() for { sum += list[i] @@ -69,7 +69,7 @@ func SampleIndex(in interface{}) int { case []float64: var sum float64 var i int - l = len(list) + // l = len(list) r := rand.Float64() for { sum += list[i] @@ -85,7 +85,7 @@ func SampleIndex(in interface{}) int { var sum float64 r := rand.Float64() data := list.Float64s() - l = len(data) + // l = len(data) for { datum := data[i] if math.IsNaN(datum) || math.IsInf(datum, 0) { @@ -102,7 +102,7 @@ func SampleIndex(in interface{}) int { var sum float32 r := rand.Float32() data := list.Float32s() - l = len(data) + // l = len(data) for { datum := data[i] if math32.IsNaN(datum) || math32.IsInf(datum, 0) { @@ -121,5 +121,5 @@ func SampleIndex(in interface{}) int { default: panic("Not yet implemented") } - return l - 1 + return -1 } diff --git a/array.go b/array.go index 4162280..321c9cf 100644 --- a/array.go +++ b/array.go @@ -18,10 +18,8 @@ type array struct { // makeHeader makes a array Header func makeHeader(t Dtype, length int) storage.Header { - size := int(calcMemSize(t, length)) - s := make([]byte, size) return storage.Header{ - Ptr: unsafe.Pointer(&s[0]), + Ptr: malloc(t, length), L: length, C: length, } @@ -75,6 +73,7 @@ func arrayFromSlice(x interface{}) array { } } +// fromSlice populates the value from a slice func (a *array) fromSlice(x interface{}) { xT := reflect.TypeOf(x) if xT.Kind() != reflect.Slice { @@ -91,20 +90,45 @@ func (a *array) fromSlice(x interface{}) { a.v = x } +// fromSliceOrTensor populates the value from a slice or anything that can form an array +func (a *array) fromSliceOrArrayer(x interface{}) { + if T, ok := x.(arrayer); ok { + xp := T.arrPtr() + + // if the underlying array hasn't been allocated, or not enough has been allocated + if a.Ptr == nil || a.L < xp.L || a.C < xp.C { + a.t = xp.t + a.L = xp.L + a.C = xp.C + a.Ptr = malloc(a.t, a.L) + } + + a.t = xp.t + a.L = xp.L + a.C = xp.C + copyArray(a, T.arrPtr()) + a.v = nil // tell the GC to release whatever a.v may hold + a.forcefix() // fix it such that a.v has a value and is not nil + return + } + a.fromSlice(x) +} + +// fix fills the a.v empty interface{} if it's not nil func (a *array) fix() { if a.v == nil { - shdr := reflect.SliceHeader{ - Data: uintptr(a.Ptr), - Len: a.L, - Cap: a.C, - } - sliceT := reflect.SliceOf(a.t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - a.v = val.Interface() + a.forcefix() } } +// forcefix fills the a.v empty interface{}. No checks are made if the thing is empty +func (a *array) forcefix() { + sliceT := reflect.SliceOf(a.t.Type) + ptr := unsafe.Pointer(&a.Header) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + a.v = val.Interface() +} + // byteSlice casts the underlying slice into a byte slice. Useful for copying and zeroing, but not much else func (a array) byteSlice() []byte { return storage.AsByteSlice(&a.Header, a.t.Type) @@ -132,6 +156,7 @@ func (a *array) sliceInto(i, j int, res *array) { res.fix() } +// slice slices an array func (a array) slice(start, end int) array { if end > a.L { panic("Index out of range") @@ -240,6 +265,13 @@ func (a *array) rtype() reflect.Type { return a.t.Type } /* MEMORY MOVEMENT STUFF */ +// malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory +func malloc(t Dtype, length int) unsafe.Pointer { + size := int(calcMemSize(t, length)) + s := make([]byte, size) + return unsafe.Pointer(&s[0]) +} + // calcMemSize calulates the memory size of an array (given its size) func calcMemSize(dt Dtype, size int) int64 { return int64(dt.Size()) * int64(size) @@ -288,6 +320,7 @@ func copyDense(dst, src DenseTensor) int { // return copyArray(dst.arr(), src.arr()) } +// copyDenseSliced copies a DenseTensor, but both are sliced func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, send int) int { if dst.Dtype() != src.Dtype() { panic("Cannot copy DenseTensors of different types") @@ -316,12 +349,14 @@ func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, return copyArraySliced(dst.arr(), dstart, dend, src.arr(), sstart, send) } +// copyDenseIter copies a DenseTensor, with iterator func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { if dst.Dtype() != src.Dtype() { panic("Cannot copy Dense arrays of different types") } - if !dst.RequiresIterator() && !src.RequiresIterator() { + // if they all don't need iterators, and have the same data order + if !dst.RequiresIterator() && !src.RequiresIterator() && dst.DataOrder().HasSameOrder(src.DataOrder()) { return copyDense(dst, src), nil } @@ -336,6 +371,7 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { siter = FlatIteratorFromDense(src) } + // if it's a masked tensor, we copy the mask as well if ms, ok := src.(MaskedTensor); ok && ms.IsMasked() { if md, ok := dst.(MaskedTensor); ok { dmask := md.Mask() @@ -388,12 +424,34 @@ func getPointer(a interface{}) unsafe.Pointer { case string: return unsafe.Pointer(&at) case uintptr: - return unsafe.Pointer(&at) + return unsafe.Pointer(at) case unsafe.Pointer: return at // POINTERS + case *bool: + return unsafe.Pointer(at) + case *int: + return unsafe.Pointer(at) + case *int8: + return unsafe.Pointer(at) + case *int16: + return unsafe.Pointer(at) + case *int32: + return unsafe.Pointer(at) + case *int64: + return unsafe.Pointer(at) + case *uint: + return unsafe.Pointer(at) + case *uint8: + return unsafe.Pointer(at) + case *uint16: + return unsafe.Pointer(at) + case *uint32: + return unsafe.Pointer(at) + case *uint64: + return unsafe.Pointer(at) case *float32: return unsafe.Pointer(at) case *float64: @@ -402,11 +460,18 @@ func getPointer(a interface{}) unsafe.Pointer { return unsafe.Pointer(at) case *complex128: return unsafe.Pointer(at) + case *string: + return unsafe.Pointer(at) + case *uintptr: + return unsafe.Pointer(*at) + case *unsafe.Pointer: + return *at } panic("Cannot get pointer") } +// scalarToHeader creates a Header from a scalar value func scalarToHeader(a interface{}) *storage.Header { hdr := borrowHeader() hdr.Ptr = getPointer(a) diff --git a/benchmark_dense_matop_test.go b/benchmark_dense_matop_test.go index 2c5b8a7..2a4ee4a 100644 --- a/benchmark_dense_matop_test.go +++ b/benchmark_dense_matop_test.go @@ -1,6 +1,9 @@ package tensor -import "testing" +import ( + "math/rand" + "testing" +) func BenchmarkDense_Transpose(b *testing.B) { T := New(WithShape(100, 100, 2), WithBacking(Range(Byte, 0, 100*100*2))) @@ -64,7 +67,7 @@ func BenchmarkGetWithIterator(b *testing.B) { f = data[next] } if _, ok := err.(NoOpError); !ok { - b.Error("Error: %v", err) + b.Errorf("Error: %v", err) } } _ = f @@ -85,8 +88,57 @@ func BenchmarkComplicatedGet(b *testing.B) { f = data[next] } if _, ok := err.(NoOpError); !ok { - b.Error("Error: %v", err) + b.Errorf("Error: %v", err) } } _ = f } + +var atCoords [10000][2]int + +func init() { + for i := range atCoords { + atCoords[i][0] = rand.Intn(100) + atCoords[i][1] = rand.Intn(100) + } +} + +var at1, at2 float64 + +// func BenchmarkAtWithNativeIterator(b *testing.B) { +// T := New(WithShape(100, 100), Of(Float64)) +// it, err := NativeMatrixF64(T) +// if err != nil { +// b.Fatalf("Error: %v", err) +// } + +// var j int +// for i := 0; i < b.N; i++ { + +// if j >= len(atCoords) { +// j = 0 +// } + +// at := atCoords[j] +// at1 = it[at[0]][at[1]] +// j++ +// } +// } + +func BenchmarkAt(b *testing.B) { + T := New(WithShape(100, 100), Of(Float64)) + var j int + for i := 0; i < b.N; i++ { + if j >= len(atCoords) { + j = 0 + } + + at := atCoords[j] + _, err := T.At(at[0], at[1]) + if err != nil { + b.Errorf("Error: %v", err) + } + + j++ + } +} diff --git a/consopt.go b/consopt.go index 9118cad..ee4b4cf 100644 --- a/consopt.go +++ b/consopt.go @@ -10,6 +10,7 @@ type ConsOpt func(Tensor) // Of is a construction option for a Tensor. func Of(a Dtype) ConsOpt { + Register(a) f := func(t Tensor) { switch tt := t.(type) { case *Dense: @@ -172,9 +173,11 @@ func WithEngine(e Engine) ConsOpt { if e != nil && !e.AllocAccessible() { tt.flag = MakeMemoryFlag(tt.flag, NativelyInaccessible) } - // if oe, ok := e.(standardEngine); ok { - // tt.oe = oe - // } + + tt.oe = nil + if oe, ok := e.(standardEngine); ok { + tt.oe = oe + } case *CS: tt.e = e if e != nil && !e.AllocAccessible() { @@ -185,14 +188,75 @@ func WithEngine(e Engine) ConsOpt { return f } -func AsFortran() ConsOpt { +// AsFortran creates a *Dense with a col-major layout. +// If the optional backing argument is passed, the backing is assumed to be C-order (row major), and +// it will be transposed before being used. +func AsFortran(backing interface{}) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - if tt.AP == nil { - // create AP + if backing != nil { + // put the data into the tensor, then make a clone tensor to transpose + tt.fromSliceOrArrayer(backing) + // create a temporary tensor, to which the transpose will be done + tmp := NewDense(tt.Dtype(), tt.shape.Clone()) + copyArray(tmp.arrPtr(), tt.arrPtr()) + tmp.T() + tmp.Transpose() + // copy the data back to the current tensor + copyArray(tt.arrPtr(), tmp.arrPtr()) + // cleanup: return the temporary tensor back to the pool + ReturnTensor(tmp) } + tt.AP.o = MakeDataOrder(tt.AP.o, ColMajor) + if tt.AP.shape != nil { + ReturnInts(tt.AP.strides) + tt.AP.strides = nil + tt.AP.strides = tt.AP.calcStrides() + } + case *CS: + panic("AsFortran is not an available option for Compressed Sparse layouts") + } + } + return f +} + +func AsDenseDiag(backing interface{}) ConsOpt { + f := func(t Tensor) { + switch tt := t.(type) { + case *Dense: + if bt, ok := backing.(Tensor); ok { + backing = bt.Data() + } + xT := reflect.TypeOf(backing) + if xT.Kind() != reflect.Slice { + panic("Expected a slice") + } + xV := reflect.ValueOf(backing) + l := xV.Len() + // elT := xT.Elem() + + sli := reflect.MakeSlice(xT, l*l, l*l) + + shape := Shape{l, l} + strides := shape.CalcStrides() + for i := 0; i < l; i++ { + idx, err := Ltoi(shape, strides, i, i) + if err != nil { + panic(err) + } + + at := sli.Index(idx) + xi := xV.Index(i) + at.Set(xi) + } + + tt.fromSliceOrArrayer(sli.Interface()) + tt.setShape(l, l) + + default: + panic("AsDenseDiag is not available as an option for CS") } } return f diff --git a/defaultengine.go b/defaultengine.go index 6dd1f45..cace41a 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -66,12 +66,7 @@ func (e StdEng) Memcpy(dst, src Memory) error { func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } -func (e StdEng) WorksWith(order DataOrder) bool { - if order.isColMajor() { - return false - } - return true -} +func (e StdEng) WorksWith(order DataOrder) bool { return true } func (e StdEng) checkAccessible(t Tensor) error { if !t.IsNativelyAccessible() { diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index 3cedd84..5632fa6 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -59,17 +59,21 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e if _, ok := err.(NoOpError); !ok && err != nil { return } else if ok { - newAP = t.Info().Clone() + t.Info().CloneTo(&newAP) } - defer ReturnAP(newAP) it := IteratorFromDense(t) - iteratorLoadAP(it, newAP) + iteratorLoadAP(it, &newAP) lastSize := it.Shape()[len(it.Shape())-1] newShape := it.Shape().Clone() newShape = newShape[:len(newShape)-1] - defer ReturnInts(newShape) + + // cleanup + defer func() { + newAP.zero() + ReturnInts(newShape) + }() if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { mask := mt.Mask() @@ -144,15 +148,19 @@ func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e } else if ok { newAP = t.Info().Clone() } - defer ReturnAP(newAP) it := IteratorFromDense(t) - iteratorLoadAP(it, newAP) + iteratorLoadAP(it, &newAP) lastSize := it.Shape()[len(it.Shape())-1] newShape := it.Shape().Clone() newShape = newShape[:len(newShape)-1] - defer ReturnInts(newShape) + + // cleanup + defer func() { + newAP.zero() + ReturnInts(newShape) + }() if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { mask := mt.Mask() diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 01d9784..3017aaa 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -16,7 +16,7 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -83,7 +83,7 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -150,7 +150,7 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -217,7 +217,7 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -284,7 +284,7 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -351,7 +351,7 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -418,7 +418,7 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -506,7 +506,7 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -594,7 +594,7 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -682,7 +682,7 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -770,7 +770,7 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -858,7 +858,7 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 98f61e1..b3651d7 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -18,7 +18,7 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -98,7 +98,7 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -178,7 +178,7 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -258,7 +258,7 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -338,7 +338,7 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -418,7 +418,7 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -498,7 +498,7 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -564,7 +564,7 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -610,7 +610,7 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -676,7 +676,7 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -722,7 +722,7 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -788,7 +788,7 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -834,7 +834,7 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -900,7 +900,7 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -942,7 +942,7 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -1008,7 +1008,7 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -1050,7 +1050,7 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -1116,7 +1116,7 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 486e7a0..45a8527 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -286,13 +286,14 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { var rd *Dense if rd, err = a.TensorMul(b, axesA, axesB); err != nil { + panic(err) return } if reuse != nil { copyDense(reuse, rd) - ReturnAP(reuse.Info()) - reuse.setAP(rd.Info().Clone()) + ap := rd.Info().Clone() + reuse.setAP(&ap) defer ReturnTensor(rd) // swap out the underlying data and metadata // reuse.data, rd.data = rd.data, reuse.data @@ -403,12 +404,35 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { n := ad.oshape()[1] tA := blas.NoTrans - if ad.oldAP() != nil { + do := a.DataOrder() + z := ad.oldAP().IsZero() + + var lda int + switch { + case do.IsRowMajor() && z: + lda = n + case do.IsRowMajor() && !z: + tA = blas.Trans + lda = n + case do.IsColMajor() && z: tA = blas.Trans + lda = m + m, n = n, m + case do.IsColMajor() && !z: + lda = m + m, n = n, m } - lda := ad.ostrides()[0] + incX, incY := 1, 1 // step size + // ASPIRATIONAL TODO: different incX and incY + // TECHNICAL DEBT. TECHDEBT. TECH DEBT + // Example use case: + // log.Printf("a %v %v", ad.Strides(), ad.ostrides()) + // log.Printf("b %v", b.Strides()) + // incX := a.Strides()[0] + // incY = b.Strides()[0] + switch A := ad.Data().(type) { case []float64: x := bd.Float64s() @@ -438,49 +462,61 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { return errors.Wrapf(err, opFail, "StdEng.MatMul") } - tA, tB := blas.NoTrans, blas.NoTrans - if ad.oldAP() != nil { - tA = blas.Trans - } - - // Special case if b is (1, N) - if bd.oldAP() != nil || bd.IsRowVec() { - tB = blas.Trans - } + ado := a.DataOrder() + bdo := b.DataOrder() + cdo := prealloc.DataOrder() + // get result shapes. k is the shared dimension + // a is (m, k) + // b is (k, n) + // c is (m, n) var m, n, k int m = ad.Shape()[0] k = ad.Shape()[1] n = bd.Shape()[1] // wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides() - lda := ad.ostrides()[0] - ldb := bd.ostrides()[0] - ldc := pd.ostrides()[0] + // lda in colmajor = number of rows; + // lda in row major = number of cols + var lda, ldb, ldc int + switch { + case ado.IsColMajor(): + lda = m + case ado.IsRowMajor(): + lda = k + } - // special case: if a is (1, N) x (N, M), then we can just use GEMV - if ad.IsRowVec() { - tB = blas.Trans - if bd.oldAP() != nil { - tB = blas.NoTrans + switch { + case bdo.IsColMajor(): + ldb = bd.Shape()[0] + case bdo.IsRowMajor(): + ldb = n + } + + switch { + case cdo.IsColMajor(): + ldc = prealloc.Shape()[0] + case cdo.IsRowMajor(): + ldc = prealloc.Shape()[1] + } + + // check for trans + tA, tB := blas.NoTrans, blas.NoTrans + if !ad.oldAP().IsZero() { + tA = blas.Trans + if ado.IsRowMajor() { + lda = m + } else { + lda = k } - m = bd.Shape()[0] - n = bd.Shape()[1] - switch A := ad.Data().(type) { - case []float64: - B := bd.Float64s() - C := pd.Float64s() - alpha, beta := float64(1), float64(0) - whichblas.Dgemv(tB, m, n, alpha, B, ldb, A, lda, beta, C, ldc) - case []float32: - B := bd.Float32s() - C := pd.Float32s() - alpha, beta := float32(1), float32(0) - whichblas.Sgemv(tB, m, n, alpha, B, ldb, A, lda, beta, C, ldc) - default: - return errors.Errorf(typeNYI, "matMul a is row vec", ad.Data()) + } + if !bd.oldAP().IsZero() { + tB = blas.Trans + if bdo.IsRowMajor() { + ldb = bd.Shape()[0] + } else { + ldb = bd.Shape()[1] } - return } switch A := ad.Data().(type) { @@ -488,12 +524,20 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { B := bd.Float64s() C := pd.Float64s() alpha, beta := float64(1), float64(0) - whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Dgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } case []float32: B := bd.Float32s() C := pd.Float32s() alpha, beta := float32(1), float32(0) - whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Sgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } default: return errors.Errorf(typeNYI, "matMul", ad.Data()) } @@ -510,12 +554,40 @@ func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { m := ad.Size() n := bd.Size() + pdo := pd.DataOrder() // the stride of a Vector is always going to be [1], // incX := t.Strides()[0] // incY := other.Strides()[0] incX, incY := 1, 1 - lda := pd.Strides()[0] + // lda := pd.Strides()[0] + var lda int + switch { + case pdo.IsColMajor(): + aShape := a.Shape().Clone() + bShape := b.Shape().Clone() + if err = a.Reshape(aShape[0], 1); err != nil { + return err + } + if err = b.Reshape(1, bShape[0]); err != nil { + return err + } + + if err = e.MatMul(a, b, prealloc); err != nil { + return err + } + + if err = b.Reshape(bShape...); err != nil { + return + } + if err = a.Reshape(aShape...); err != nil { + return + } + return nil + + case pdo.IsRowMajor(): + lda = pd.Shape()[1] + } switch x := ad.Data().(type) { case []float64: @@ -559,13 +631,13 @@ func (e StdEng) checkTwoFloatTensors(a, b Tensor) (ad, bd DenseTensor, err error func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { if err = e.checkAccessible(a); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") } if err = e.checkAccessible(b); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") } if err = e.checkAccessible(ret); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: ret is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") } if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 203a839..4964ab4 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -16,7 +16,7 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e var reuse DenseTensor var safe, _, incr bool - if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return } switch { @@ -102,18 +102,18 @@ func (e StdEng) Reduce(fn interface{}, a Tensor, axis int, defaultValue interfac // actual call out to the internal engine switch { - case (axis == 0 && at.DataOrder().isRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().isColMajor()): + case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): var size, split int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } size = a.Shape()[0] split = a.DataSize() / size storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, fn) - case (axis == lastAxis && at.DataOrder().isRowMajor()) || (axis == 0 && at.DataOrder().isColMajor()): + case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): var dimSize int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } dimSize = a.Shape()[axis] @@ -147,18 +147,18 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, // actual call out to the internal engine switch { - case (axis == 0 && at.DataOrder().isRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().isColMajor()): + case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): var size, split int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } size = a.Shape()[0] split = a.DataSize() / size storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, firstFn) - case (axis == lastAxis && at.DataOrder().isRowMajor()) || (axis == 0 && at.DataOrder().isColMajor()): + case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): var dimSize int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } dimSize = a.Shape()[axis] @@ -328,7 +328,7 @@ func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTe // FUNC PREP var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { err = errors.Wrap(err, "Unable to prep unary tensor") return } diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 23607c6..9faed77 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -1,6 +1,12 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) + +var ( + _ Diager = StdEng{} +) func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { switch tt := t.(type) { @@ -104,9 +110,18 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen all[0] = a copy(all[1:], Ts) + // TODO: OPIMIZATION + // When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor) + // just flat copy + // + + // isOuter is true when the axis is the outermost axis + // isInner is true when the axis is the inner most axis + isOuter := axis == 0 + isInner := axis == (a.Shape().Dims() - 1) + // special case var start, end int - for _, T := range all { end += T.Shape()[axis] slices := make([]Slice, axis+1) @@ -117,15 +132,124 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") } - if v.IsVector() && T.IsMatrix() && axis == 0 { + switch { + case v.IsVector() && T.IsMatrix() && axis == 0: v.reshape(v.shape[0], 1) + case T.IsRowVec() && axis == 0: + T.reshape(T.Shape()[1]) + case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv(): + copyArray(v.arrPtr(), T.arrPtr()) + if mt, ok := T.(MaskedTensor); ok { + copy(v.mask, mt.Mask()) + } + continue + default: + diff := retVal.Shape().Dims() - v.Shape().Dims() + if diff > 0 && isOuter { + newShape := make(Shape, v.Shape().Dims()+diff) + for i := 0; i < diff; i++ { + newShape[i] = 1 + } + copy(newShape[diff:], v.Shape()) + v.reshape(newShape...) + } else if diff > 0 && isInner { + newShape := v.Shape().Clone() + newStrides := v.strides + for i := 0; i < diff; i++ { + newShape = append(newShape, 1) + newStrides = append(newStrides, 1) + } + v.shape = newShape + v.strides = newStrides + } + } + + var vmask, Tmask []bool + vmask = v.mask + v.mask = nil + if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() { + Tmask = mt.Mask() + mt.SetMask(nil) + } if err = assignArray(v, T); err != nil { return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") } + // if it's a masked tensor, we copy the mask as well + if Tmask != nil { + if vmask != nil { + if cap(vmask) < len(Tmask) { + vmask2 := make([]bool, len(Tmask)) + copy(vmask2, vmask) + vmask = vmask2 + } + copy(vmask, Tmask) + v.SetMask(vmask) + } + // mt.SetMask(Tmask) + } + start = end } return retVal, nil } + +func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { + a, ok := t.(DenseTensor) + if !ok { + return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") + } + + if a.Dims() != 2 { + err = errors.Errorf(dimMismatch, 2, a.Dims()) + return + } + + if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { + return nil, errors.Wrap(err, "Diagonal") + } + + rstride := a.Strides()[0] + cstride := a.Strides()[1] + + r := a.Shape()[0] + c := a.Shape()[1] + + m := MinInt(r, c) + stride := rstride + cstride + + b := a.Clone().(DenseTensor) + b.Zero() + + switch a.rtype().Size() { + case 1: + bdata := b.hdr().Uint8s() + adata := a.hdr().Uint8s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 2: + bdata := b.hdr().Uint16s() + adata := a.hdr().Uint16s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 4: + bdata := b.hdr().Uint32s() + adata := a.hdr().Uint32s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 8: + bdata := b.hdr().Uint64s() + adata := a.hdr().Uint64s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + default: + return nil, errors.Errorf(typeNYI, "Arbitrary sized diag") + } + return b, nil +} diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index 1a43a7e..368ddb5 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -28,15 +28,13 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV info := t.Info() var newStrides []int - if info.o.isColMajor() { - newStrides = newShape.calcStridesColMajor() + if info.o.IsColMajor() { + newStrides = newShape.CalcStridesColMajor() } else { - newStrides = newShape.calcStrides() + newStrides = newShape.CalcStrides() } - ap := NewAP(newShape, newStrides) - ap.o = info.o - ap.Δ = info.Δ + ap := MakeAP(newShape, newStrides, info.o, info.Δ) allNoMat := !t.RequiresIterator() for _, ot := range others { @@ -46,8 +44,7 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV } retVal = recycledDense(t.Dtype(), ap.Shape(), WithEngine(e)) - ReturnAP(retVal.Info()) - retVal.setAP(ap) + retVal.setAP(&ap) // the "viewStack" method is the more generalized method // and will work for all Tensors, regardless of whether it's a view diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index e66c4a6..8f7c86c 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -44,7 +44,7 @@ func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { u8s := tmpArr.Uint8s() orig := a.hdr().Uint8s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u8s[j] = orig[i] @@ -59,7 +59,7 @@ func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { u16s := tmpArr.Uint16s() orig := a.hdr().Uint16s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u16s[j] = orig[i] @@ -74,7 +74,7 @@ func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { u32s := tmpArr.Uint32s() orig := a.hdr().Uint32s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u32s[j] = orig[i] @@ -89,7 +89,7 @@ func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { u64s := tmpArr.Uint64s() orig := a.hdr().Uint64s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u64s[j] = orig[i] @@ -104,7 +104,7 @@ func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { strs := tmpArr.Strings() orig := a.hdr().Strings() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { strs[j] = orig[i] @@ -122,7 +122,7 @@ func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { arbs := tmpArr.byteSlice() orig := storage.AsByteSlice(a.hdr(), rtype) - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { srcStart := i * typeSize diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index 6725aea..d8a87e4 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -51,6 +51,9 @@ func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint8s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -87,6 +90,9 @@ func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint16s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -123,6 +129,9 @@ func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint32s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -159,6 +168,9 @@ func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint64s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) if track.IsSet(i) && track.IsSet(dest) { @@ -195,6 +207,9 @@ func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { var i int data := a.hdr().Strings() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -233,6 +248,9 @@ func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { tmp := make([]byte, typeSize, typeSize) var i int data := storage.AsByteSlice(a.hdr(), rtype) + if len(data) < 4*typeSize { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) start := typeSize * i diff --git a/defaultengine_misc.go b/defaultengine_misc.go index b4bf21c..bb70e57 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -12,7 +12,7 @@ func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal T var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } diff --git a/defaultengine_prep.go b/defaultengine_prep.go index cb358a7..c203253 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -3,9 +3,10 @@ package tensor import ( "github.com/pkg/errors" "gorgonia.org/tensor/internal/storage" + // "log" ) -func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { +func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -16,7 +17,7 @@ func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) if toReuse { if reuse, err = getDenseTensor(reuseT); err != nil { returnOpOpt(fo) - err = errors.Wrapf(err, "Cannot reuse a different type of Tensor in a *Dense-Scalar operation") + err = errors.Wrapf(err, "Cannot reuse a Tensor that isn't a DenseTensor. Got %T instead", reuseT) return } @@ -40,6 +41,11 @@ func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) return } + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -101,7 +107,11 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea } // iter - useIter = a.RequiresIterator() || b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) + useIter = a.RequiresIterator() || + b.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + !a.DataOrder().HasSameOrder(b.DataOrder()) || + (reuse != nil && (!a.DataOrder().HasSameOrder(reuse.DataOrder()) || !b.DataOrder().HasSameOrder(reuse.DataOrder()))) if useIter { ait = a.Iterator() bit = b.Iterator() @@ -109,6 +119,7 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea iit = reuse.Iterator() } } + // log.Printf("Use Itrer %v ", useIter) // swap if _, ok := a.(*CS); ok { @@ -133,12 +144,14 @@ func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse if a.IsScalar() { return } - if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { + useIter = a.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + (reuse != nil && reuse.DataOrder().HasSameOrder(a.DataOrder())) + if useIter { ait = a.Iterator() if reuse != nil { iit = reuse.Iterator() } - useIter = true } return } @@ -155,12 +168,14 @@ func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse if b.IsScalar() { return } - if b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { + useIter = b.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + (reuse != nil && reuse.DataOrder().HasSameOrder(b.DataOrder())) + if useIter { bit = b.Iterator() if reuse != nil { iit = reuse.Iterator() } - useIter = true } return } diff --git a/defaultengine_unary.go b/defaultengine_unary.go index 4da968a..986e246 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -14,7 +14,7 @@ func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -82,7 +82,7 @@ func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -150,7 +150,7 @@ func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -218,7 +218,7 @@ func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -286,7 +286,7 @@ func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -354,7 +354,7 @@ func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -422,7 +422,7 @@ func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -490,7 +490,7 @@ func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -558,7 +558,7 @@ func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -626,7 +626,7 @@ func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -694,7 +694,7 @@ func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -762,7 +762,7 @@ func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -830,7 +830,7 @@ func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -898,7 +898,7 @@ func (e StdEng) Sign(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index f260479..82d48f2 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -10,7 +10,7 @@ import ( "gorgonia.org/vecf32" ) -func handleFuncOptsF32(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -30,6 +30,12 @@ func handleFuncOptsF32(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") return } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -175,7 +181,7 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), opts...); err != nil { + if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if err = e.checkThree(a, b, reuse); err != nil { diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 6fe2786..b0d9466 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -10,7 +10,7 @@ import ( "gorgonia.org/vecf64" ) -func handleFuncOptsF64(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -30,6 +30,12 @@ func handleFuncOptsF64(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") return } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -175,7 +181,7 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), opts...); err != nil { + if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if err = e.checkThree(a, b, reuse); err != nil { diff --git a/dense.go b/dense.go index 95e3f53..2824261 100644 --- a/dense.go +++ b/dense.go @@ -13,7 +13,7 @@ const ( // Dense represents a dense tensor - this is the most common form of tensors. It can be used to represent vectors, matrices.. etc type Dense struct { - *AP + AP array flag MemoryFlag @@ -21,7 +21,7 @@ type Dense struct { oe standardEngine // optimized engine // backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes - old *AP + old AP transposeWith []int // if viewOf != nil, then this *Dense is a view. @@ -54,7 +54,7 @@ func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) retVal.array.t = dt retVal.array.L = size retVal.array.C = size - retVal.AP = BorrowAP(shape.Dims()) + retVal.AP.zeroWithDims(shape.Dims()) for _, opt := range opts { opt(retVal) @@ -78,9 +78,14 @@ func (t *Dense) addMask(mask []bool) { } func (t *Dense) makeArray(size int) { - if am, ok := t.e.(arrayMaker); ok { - am.makeArray(&t.array, t.t, size) + + switch te := t.e.(type) { + case NonStdEngine: + t.flag = MakeMemoryFlag(t.flag, ManuallyManaged) + case arrayMaker: + te.makeArray(&t.array, t.t, size) return + default: } mem, err := t.e.Alloc(calcMemSize(t.t, size)) @@ -97,7 +102,7 @@ func (t *Dense) makeArray(size int) { } // Info returns the access pattern which explains how the data in the underlying array is accessed. This is mostly used for debugging. -func (t *Dense) Info() *AP { return t.AP } +func (t *Dense) Info() *AP { return &t.AP } // Dtype returns the data type of the *Dense tensor. func (t *Dense) Dtype() Dtype { return t.t } @@ -123,11 +128,11 @@ func (t *Dense) Engine() Engine { return t.e } // Reshape reshapes a *Dense. If the tensors need to be materialized (either it's a view or transpose), it will be materialized before the reshape happens func (t *Dense) Reshape(dims ...int) error { - if t.viewOf != 0 && t.o.isNotContiguous() { + if t.viewOf != 0 && t.o.IsNotContiguous() { return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") } - if t.old != nil { + if !t.old.IsZero() { t.Transpose() } @@ -159,7 +164,7 @@ func (t *Dense) IsView() bool { // IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing func (t *Dense) IsMaterializable() bool { - return t.viewOf != 0 || t.old != nil + return t.viewOf != 0 || !t.old.IsZero() } // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user) @@ -172,15 +177,16 @@ func (t *Dense) IsNativelyAccessible() bool { return t.flag.nativelyAccessible() func (t *Dense) Clone() interface{} { if t.e != nil { retVal := new(Dense) - retVal.AP = t.AP.Clone() + t.AP.CloneTo(&retVal.AP) retVal.t = t.t retVal.e = t.e retVal.oe = t.oe retVal.flag = t.flag retVal.makeArray(t.L) - if t.old != nil { + if !t.old.IsZero() { retVal.old = t.old.Clone() + t.old.CloneTo(&retVal.old) } copyDense(retVal, t) retVal.lock() @@ -246,13 +252,9 @@ func (t *Dense) setShape(s ...int) { return } -func (t *Dense) setAP(ap *AP) { t.AP = ap } +func (t *Dense) setAP(ap *AP) { t.AP = *ap } func (t *Dense) fix() { - if t.AP == nil { - return - } - if t.e == nil { t.e = StdEng{} } @@ -298,31 +300,33 @@ func (t *Dense) makeMask() { // sanity is a function that sanity checks that a tensor is correct. func (t *Dense) sanity() error { - if t.AP != nil && t.Shape() == nil && t.array.Ptr == nil { + if !t.AP.IsZero() && t.Shape() == nil && t.array.Ptr == nil { return errors.New(emptyTensor) } size := t.L expected := t.Size() if t.viewOf == 0 && size != expected && !t.IsScalar() { - return errors.Errorf(shapeMismatch, t.Shape(), size) + return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed") } // TODO: sanity check for views return nil } -func (t *Dense) isTransposed() bool { return t.old == nil } +// isTransposed returns true if the *Dense holds a transposed array. +func (t *Dense) isTransposed() bool { return t.old.IsZero() } // oshape returns the original shape func (t *Dense) oshape() Shape { - if t.old != nil { + if !t.old.IsZero() { return t.old.Shape() } return t.Shape() } +// ostrides returns the original strides func (t *Dense) ostrides() []int { - if t.old != nil { + if !t.old.IsZero() { return t.old.Strides() } return t.Strides() @@ -333,14 +337,14 @@ func (t *Dense) ShallowClone() *Dense { retVal := borrowDense() retVal.e = t.e retVal.oe = t.oe - retVal.AP = t.AP.Clone() + t.AP.CloneTo(&retVal.AP) retVal.flag = t.flag retVal.array = t.array return retVal } -func (t *Dense) oldAP() *AP { return t.old } -func (t *Dense) setOldAP(ap *AP) { t.old = ap } +func (t *Dense) oldAP() *AP { return &t.old } +func (t *Dense) setOldAP(ap *AP) { t.old = *ap } func (t *Dense) transposeAxes() []int { return t.transposeWith } func (t *Dense) parentTensor() *Dense { if t.viewOf != 0 { @@ -537,7 +541,7 @@ func (t *Dense) Memset(x interface{}) error { return errors.Errorf(inaccessibleData, t) } if t.IsMaterializable() { - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) return t.array.memsetIter(x, it) } return t.array.Memset(x) @@ -560,7 +564,7 @@ func (t *Dense) Eq(other interface{}) bool { func (t *Dense) Zero() { if t.IsMaterializable() { - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) if err := t.zeroIter(it); err != nil { panic(err) } @@ -590,7 +594,7 @@ func (t *Dense) RequiresIterator() bool { return false } // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required - if !t.o.isContiguous() || t.old != nil || t.IsMasked() { + if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { return true } return false diff --git a/dense_assign.go b/dense_assign.go index 0fdc1d4..8b2783e 100644 --- a/dense_assign.go +++ b/dense_assign.go @@ -84,11 +84,11 @@ func assignArray(dest, src DenseTensor) (err error) { return } dap := dest.Info() - sap := NewAP(tmpShape, newStrides) - sap.o = src.Info().o + sap := MakeAP(tmpShape, newStrides, src.Info().o, src.Info().Δ) - diter := NewFlatIterator(dap) - siter := NewFlatIterator(sap) + diter := newFlatIterator(dap) + siter := newFlatIterator(&sap) _, err = copyDenseIter(dest, src, diter, siter) + sap.zeroOnly() // cleanup, but not entirely because tmpShape and tmpStrides are separately cleaned up. Don't double free return } diff --git a/dense_cmp_test.go b/dense_cmp_test.go index 4c1db8e..a0bc5b6 100644 --- a/dense_cmp_test.go +++ b/dense_cmp_test.go @@ -62,7 +62,7 @@ func TestDense_Gt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -120,7 +120,7 @@ func TestDense_Gte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -178,7 +178,7 @@ func TestDense_Lt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -236,7 +236,7 @@ func TestDense_Lte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -294,7 +294,7 @@ func TestDense_ElEq(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -328,7 +328,7 @@ func TestDense_ElEq(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestDense_ElNe(t *testing.T) { @@ -363,7 +363,7 @@ func TestDense_ElNe(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestDense_Gt_assame(t *testing.T) { @@ -422,7 +422,7 @@ func TestDense_Gt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -482,7 +482,7 @@ func TestDense_Gte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -542,7 +542,7 @@ func TestDense_Lt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -602,7 +602,7 @@ func TestDense_Lte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -662,7 +662,7 @@ func TestDense_ElEq_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -699,7 +699,7 @@ func TestDense_ElEq_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestDense_ElNe_assame(t *testing.T) { @@ -737,7 +737,7 @@ func TestDense_ElNe_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestDense_GtScalar(t *testing.T) { @@ -792,7 +792,7 @@ func TestDense_GtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -848,7 +848,7 @@ func TestDense_GteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -904,7 +904,7 @@ func TestDense_LtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -960,7 +960,7 @@ func TestDense_LteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1016,7 +1016,7 @@ func TestDense_ElEqScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1048,7 +1048,7 @@ func TestDense_ElEqScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestDense_ElNeScalar(t *testing.T) { @@ -1081,7 +1081,7 @@ func TestDense_ElNeScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } func TestDense_GtScalar_assame(t *testing.T) { @@ -1138,7 +1138,7 @@ func TestDense_GtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -1196,7 +1196,7 @@ func TestDense_GteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -1254,7 +1254,7 @@ func TestDense_LtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -1312,7 +1312,7 @@ func TestDense_LteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1370,7 +1370,7 @@ func TestDense_ElEqScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1405,7 +1405,7 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestDense_ElNeScalar_assame(t *testing.T) { @@ -1441,6 +1441,6 @@ func TestDense_ElNeScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } diff --git a/dense_colmajor_linalg_test.go b/dense_colmajor_linalg_test.go new file mode 100644 index 0000000..feccfc5 --- /dev/null +++ b/dense_colmajor_linalg_test.go @@ -0,0 +1,483 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var colMajorTraceTests = []struct { + data interface{} + + correct interface{} + err bool +}{ + {[]int{0, 1, 2, 3, 4, 5}, int(4), false}, + {[]int8{0, 1, 2, 3, 4, 5}, int8(4), false}, + {[]int16{0, 1, 2, 3, 4, 5}, int16(4), false}, + {[]int32{0, 1, 2, 3, 4, 5}, int32(4), false}, + {[]int64{0, 1, 2, 3, 4, 5}, int64(4), false}, + {[]uint{0, 1, 2, 3, 4, 5}, uint(4), false}, + {[]uint8{0, 1, 2, 3, 4, 5}, uint8(4), false}, + {[]uint16{0, 1, 2, 3, 4, 5}, uint16(4), false}, + {[]uint32{0, 1, 2, 3, 4, 5}, uint32(4), false}, + {[]uint64{0, 1, 2, 3, 4, 5}, uint64(4), false}, + {[]float32{0, 1, 2, 3, 4, 5}, float32(4), false}, + {[]float64{0, 1, 2, 3, 4, 5}, float64(4), false}, + {[]complex64{0, 1, 2, 3, 4, 5}, complex64(4), false}, + {[]complex128{0, 1, 2, 3, 4, 5}, complex128(4), false}, + {[]bool{true, false, true, false, true, false}, nil, true}, +} + +func TestColMajor_Dense_Trace(t *testing.T) { + assert := assert.New(t) + for i, tts := range colMajorTraceTests { + T := New(WithShape(2, 3), AsFortran(tts.data)) + trace, err := T.Trace() + + if checkErr(t, tts.err, err, "Trace", i) { + continue + } + assert.Equal(tts.correct, trace) + + // + T = New(WithBacking(tts.data)) + _, err = T.Trace() + if err == nil { + t.Error("Expected an error when Trace() on non-matrices") + } + } +} + +var colMajorInnerTests = []struct { + a, b interface{} + shapeA, shapeB Shape + + correct interface{} + err bool +}{ + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3, 1}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3, 1}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{1, 3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{1, 3}, float64(5), false}, + + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3, 1}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3, 1}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{1, 3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{1, 3}, float32(5), false}, + + // stupids: type differences + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float32, 0, 3), Range(Byte, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, + + // differing size + {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{4}, Shape{3}, nil, true}, + + // A is not a matrix + {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{2, 2}, Shape{3}, nil, true}, +} + +func TestColMajor_Dense_Inner(t *testing.T) { + for i, its := range colMajorInnerTests { + a := New(WithShape(its.shapeA...), AsFortran(its.a)) + b := New(WithShape(its.shapeB...), AsFortran(its.b)) + + T, err := a.Inner(b) + if checkErr(t, its.err, err, "Inner", i) { + continue + } + + assert.Equal(t, its.correct, T) + } +} + +var colMajorMatVecMulTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + + // float64s with transposed matrix + {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, + Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, + + // Float32s + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + + // stupids : unpossible shapes (wrong A) + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad A shape + {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad B shape + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad reuse + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, + + //stupids: bad incr shape + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, + + // stupids: type mismatch A and B + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B (non-Float) + {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch, reuse + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, + + // stupids: type mismatch, incr + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, + + // stupids: type mismatch, incr not a Number + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, +} + +func TestColMajor_Dense_MatVecMul(t *testing.T) { + assert := assert.New(t) + for i, mvmt := range colMajorMatVecMulTests { + a := New(WithShape(mvmt.shapeA...), AsFortran(mvmt.a)) + b := New(WithShape(mvmt.shapeB...), AsFortran(mvmt.b)) + + if mvmt.transA { + if err := a.T(); err != nil { + t.Error(err) + continue + } + } + + T, err := a.MatVecMul(b) + if checkErr(t, mvmt.err, err, "Safe", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correct, T.Data()) + + // incr + incr := New(WithShape(mvmt.shapeI...), AsFortran(mvmt.incr)) + T, err = a.MatVecMul(b, WithIncr(incr)) + if checkErr(t, mvmt.errIncr, err, "WithIncr", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(mvmt.shapeR...), AsFortran(mvmt.reuse)) + T, err = a.MatVecMul(b, WithReuse(reuse)) + if checkErr(t, mvmt.errReuse, err, "WithReuse", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correct, T.Data()) + + // reuse AND incr + T, err = a.MatVecMul(b, WithIncr(incr), WithReuse(reuse)) + if checkErr(t, mvmt.err, err, "WithReuse and WithIncr", i) { + continue + } + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correctIncrReuse, T.Data()) + } +} + +var colMajorMatMulTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, + + // Float32s + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, + + // Edge cases - Row Vecs (Float64) + {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, + Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, + []float64{0, 0, 0, 1, 0, 2}, []float64{100, 103, 101, 105, 102, 107}, []float64{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, + {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, + Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, + {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, + Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, + []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, + + // Edge cases - Row Vecs (Float32) + {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, + Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, + []float32{0, 0, 0, 1, 0, 2}, []float32{100, 103, 101, 105, 102, 107}, []float32{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, + {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, + Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, + []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, + {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, + Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, + []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, + + // stupids - bad shape (not matrices): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - bad shape (incompatible shapes): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - bad shape (bad reuse shape): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, + + // stupids - bad shape (bad incr shape): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, + + // stupids - type mismatch (a,b) + {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - type mismatch (a,b) + {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids type mismatch (b not float) + {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids type mismatch (a not float) + {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids: type mismatch (incr) + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, + + // stupids: type mismatch (reuse) + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, + + // stupids: type mismatch (reuse) + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, +} + +func TestColMajorDense_MatMul(t *testing.T) { + assert := assert.New(t) + for i, mmt := range colMajorMatMulTests { + a := New(WithShape(mmt.shapeA...), AsFortran(mmt.a)) + b := New(WithShape(mmt.shapeB...), AsFortran(mmt.b)) + + T, err := a.MatMul(b) + if checkErr(t, mmt.err, err, "Safe", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mmt.correct, T.Data(), "Test %d", i) + + // incr + incr := New(WithShape(mmt.shapeI...), AsFortran(mmt.incr)) + T, err = a.MatMul(b, WithIncr(incr)) + if checkErr(t, mmt.errIncr, err, "WithIncr", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(mmt.shapeR...), AsFortran(mmt.reuse)) + T, err = a.MatMul(b, WithReuse(reuse)) + + if checkErr(t, mmt.errReuse, err, "WithReuse", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correct, T.Data()) + + // reuse AND incr + T, err = a.MatMul(b, WithIncr(incr), WithReuse(reuse)) + if checkErr(t, mmt.err, err, "WithIncr and WithReuse", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correctIncrReuse, T.Data()) + } +} + +var colMajorOuterTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, false}, + + // Float32s + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float32{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, false}, + + // stupids - a or b not vector + {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - bad incr shape + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, true, false}, + + // stupids - bad reuse shape + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, true}, + + // stupids - b not Float + {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - a not Float + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - a-b type mismatch + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids a-b type mismatch + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, +} + +func TestColMajor_Dense_Outer(t *testing.T) { + assert := assert.New(t) + for i, ot := range colMajorOuterTests { + a := New(WithShape(ot.shapeA...), AsFortran(ot.a)) + b := New(WithShape(ot.shapeB...), AsFortran(ot.b)) + + T, err := a.Outer(b) + if checkErr(t, ot.err, err, "Safe", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correct, T.Data()) + + // incr + incr := New(WithShape(ot.shapeI...), AsFortran(ot.incr)) + T, err = a.Outer(b, WithIncr(incr)) + if checkErr(t, ot.errIncr, err, "WithIncr", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(ot.shapeR...), AsFortran(ot.reuse)) + T, err = a.Outer(b, WithReuse(reuse)) + if checkErr(t, ot.errReuse, err, "WithReuse", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correct, T.Data()) + + // reuse AND incr + T, err = a.Outer(b, WithIncr(incr), WithReuse(reuse)) + if err != nil { + t.Errorf("Reuse and Incr error'd %+v", err) + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correctIncrReuse, T.Data()) + } +} diff --git a/dense_compat.go b/dense_compat.go index 151ae0a..a1b90ab 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -394,14 +394,12 @@ func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { // checks: if !t.IsNativelyAccessible() { - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") } if !t.IsMatrix() { // error - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) } fo := ParseFuncOpts(opts...) @@ -420,7 +418,7 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { case !t.IsMaterializable(): data = convToFloat64s(t) default: - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) var next int for next, err = it.Next(); err == nil; next, err = it.Next() { if err = handleNoOp(err); err != nil { diff --git a/dense_format.go b/dense_format.go index 9d31994..859477f 100644 --- a/dense_format.go +++ b/dense_format.go @@ -249,7 +249,7 @@ func (f *fmtState) writeVElision() { // Special care also needs be taken for the verb 's' - it prints a super compressed version of the tensor, only printing 4 cols and 4 rows. func (t *Dense) Format(s fmt.State, c rune) { if c == 'i' { - fmt.Fprintf(s, "INFO:\n\tAP: %v\n\tOLD: %v\n\tTRANS %v\n\t", t.AP, t.old, t.transposeWith) + fmt.Fprintf(s, "INFO:\n\tAP: %v\n\tOLD: %v\n\tTRANS %v\n\tENGINE: %T\n", t.AP, t.old, t.transposeWith, t.e) return } @@ -353,7 +353,7 @@ func (t *Dense) Format(s fmt.State, c rune) { } // standard stuff - it := NewIterator(t.AP) + it := NewIterator(&t.AP) coord := it.Coord() firstRow := true diff --git a/dense_generated.go b/dense_generated.go index 93ea20b..6349bfb 100644 --- a/dense_generated.go +++ b/dense_generated.go @@ -88,7 +88,7 @@ func I(dt Dtype, r, c, k int) *Dense { panic(err) } var nexts []int - iter := NewFlatIterator(s.AP) + iter := newFlatIterator(&s.AP) nexts, err = iter.Slice(rs{i, s.Size(), c + 1}) switch s.t.Kind() { diff --git a/dense_io.go b/dense_io.go index 84896eb..c55c66a 100644 --- a/dense_io.go +++ b/dense_io.go @@ -14,26 +14,140 @@ import ( "strconv" "strings" + flatbuffers "github.com/google/flatbuffers/go" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/serialization/fb" + "gorgonia.org/tensor/internal/serialization/pb" ) +/* GOB SERIALIZATION */ + +// GobEncode implements gob.GobEncoder +func (t *Dense) GobEncode() (p []byte, err error) { + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + + if err = encoder.Encode(t.Shape()); err != nil { + return + } + + if err = encoder.Encode(t.Strides()); err != nil { + return + } + + if err = encoder.Encode(t.AP.o); err != nil { + return + } + + if err = encoder.Encode(t.AP.Δ); err != nil { + return + } + + if err = encoder.Encode(t.mask); err != nil { + return + } + + data := t.Data() + if err = encoder.Encode(&data); err != nil { + return + } + + return buf.Bytes(), err +} + +// GobDecode implements gob.GobDecoder +func (t *Dense) GobDecode(p []byte) (err error) { + buf := bytes.NewBuffer(p) + decoder := gob.NewDecoder(buf) + + var shape Shape + if err = decoder.Decode(&shape); err != nil { + return + } + + var strides []int + if err = decoder.Decode(&strides); err != nil { + return + } + + var o DataOrder + var tr Triangle + if err = decoder.Decode(&o); err == nil { + if err = decoder.Decode(&tr); err != nil { + return + } + } + + t.AP.Init(shape, strides) + t.AP.o = o + t.AP.Δ = tr + + var mask []bool + if err = decoder.Decode(&mask); err != nil { + return + } + + var data interface{} + if err = decoder.Decode(&data); err != nil { + return + } + + t.fromSlice(data) + t.addMask(mask) + t.fix() + if t.e == nil { + t.e = StdEng{} + } + return t.sanity() +} + +/* NPY SERIALIZATION */ + +var npyDescRE = regexp.MustCompile(`'descr':\s*'([^']*)'`) +var rowOrderRE = regexp.MustCompile(`'fortran_order':\s*(False|True)`) +var shapeRE = regexp.MustCompile(`'shape':\s*\(([^\(]*)\)`) + type binaryWriter struct { io.Writer - error + err error seq int } -func (w binaryWriter) w(x interface{}) { - if w.error != nil { +func (w *binaryWriter) w(x interface{}) { + if w.err != nil { return } - binary.Write(w, binary.LittleEndian, x) + w.err = binary.Write(w, binary.LittleEndian, x) w.seq++ } -func (w binaryWriter) Error() string { - return fmt.Sprintf("Error at sequence %d : %v", w.seq, w.error.Error()) +func (w *binaryWriter) Err() error { + if w.err == nil { + return nil + } + return errors.Wrapf(w.err, "Sequence %d", w.seq) +} + +type binaryReader struct { + io.Reader + err error + seq int +} + +func (r *binaryReader) Read(data interface{}) { + if r.err != nil { + return + } + r.err = binary.Read(r.Reader, binary.LittleEndian, data) + r.seq++ +} + +func (r *binaryReader) Err() error { + if r.err == nil { + return nil + } + return errors.Wrapf(r.err, "Sequence %d", r.seq) } // WriteNpy writes the *Tensor as a numpy compatible serialized file. @@ -64,8 +178,8 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { bw.w(byte(1)) // major version bw.w(byte(0)) // minor version bw.w(uint16(len(header))) // 4 bytes to denote header length - if bw.error != nil { - return bw + if err = bw.Err(); err != nil { + return err } bw.Write([]byte(header)) @@ -86,176 +200,57 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { } } - if bw.error != nil { - return bw - } - return nil -} - -// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. -// If tensor is masked, invalid values are replaced by the default fill value. -func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { - // checks: - if !t.IsMatrix() { - // error - err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return - } - format := "%v" - if len(formats) > 0 { - format = formats[0] - } - - cw := csv.NewWriter(w) - it := IteratorFromDense(t) - coord := it.Coord() - - // rows := t.Shape()[0] - cols := t.Shape()[1] - record := make([]string, 0, cols) - var i, k, lastCol int - isMasked := t.IsMasked() - fillval := t.FillValue() - fillstr := fmt.Sprintf(format, fillval) - for i, err = it.Next(); err == nil; i, err = it.Next() { - record = append(record, fmt.Sprintf(format, t.Get(i))) - if isMasked { - if t.mask[i] { - record[k] = fillstr - } - k++ - } - if lastCol == cols-1 { - if err = cw.Write(record); err != nil { - // TODO: wrap errors - return - } - cw.Flush() - record = record[:0] - } - - // cleanup - switch { - case t.IsRowVec(): - // lastRow = coord[len(coord)-2] - lastCol = coord[len(coord)-1] - case t.IsColVec(): - // lastRow = coord[len(coord)-1] - lastCol = coord[len(coord)-2] - case t.IsVector(): - lastCol = coord[len(coord)-1] - default: - // lastRow = coord[len(coord)-2] - lastCol = coord[len(coord)-1] - } - } - return nil -} - -// GobEncode implements gob.GobEncoder -func (t *Dense) GobEncode() (p []byte, err error) { - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - - if err = encoder.Encode(t.Shape()); err != nil { - return - } - - if err = encoder.Encode(t.Strides()); err != nil { - return - } - - if err = encoder.Encode(t.AP.o); err != nil { - return - } - - if err = encoder.Encode(t.AP.Δ); err != nil { - return - } - - if err = encoder.Encode(t.mask); err != nil { - return - } - - data := t.Data() - if err = encoder.Encode(&data); err != nil { - return - } - - return buf.Bytes(), err + return bw.Err() } // ReadNpy reads NumPy formatted files into a *Dense func (t *Dense) ReadNpy(r io.Reader) (err error) { + br := binaryReader{Reader: r} var magic [6]byte - if _, err = r.Read(magic[:]); err != nil { - return - } - if string(magic[:]) != "\x93NUMPY" { - err = errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) - return + if br.Read(magic[:]); string(magic[:]) != "\x93NUMPY" { + return errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) } - var version byte - if err = binary.Read(r, binary.LittleEndian, &version); err != nil { - return - } - if version != 1 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + var version, minor byte + if br.Read(&version); version != 1 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } - var minor byte - if err = binary.Read(r, binary.LittleEndian, &minor); err != nil { - return - } - if minor != 0 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + if br.Read(&minor); minor != 0 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } var headerLen uint16 - if err = binary.Read(r, binary.LittleEndian, &headerLen); err != nil { - return - } - + br.Read(&headerLen) header := make([]byte, int(headerLen)) - if _, err = r.Read(header); err != nil { + br.Read(header) + if err = br.Err(); err != nil { return } - desc := regexp.MustCompile(`'descr':\s*'([^']*)'`) - match := desc.FindSubmatch(header) - if match == nil { - err = errors.New("No dtype information in npy file") - return + // extract stuff from header + var match [][]byte + if match = npyDescRE.FindSubmatch(header); match == nil { + return errors.New("No dtype information in npy file") } // TODO: check for endianness. For now we assume everything is little endian - var dt Dtype - if dt, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { return } - t.t = dt - rowOrder := regexp.MustCompile(`'fortran_order':\s*(False|True)`) - match = rowOrder.FindSubmatch(header) - if match == nil { - err = errors.New("No Row Order information found in the numpy file") - return + if match = rowOrderRE.FindSubmatch(header); match == nil { + return errors.New("No Row Order information found in the numpy file") } if string(match[1]) != "False" { - err = errors.New("Cannot yet read from Fortran Ordered Numpy files") - return + return errors.New("Cannot yet read from Fortran Ordered Numpy files") } - shpRe := regexp.MustCompile(`'shape':\s*\(([^\(]*)\)`) - match = shpRe.FindSubmatch(header) - if match == nil { - err = errors.New("No shape information found in npy file") - return + if match = shapeRE.FindSubmatch(header); match == nil { + return errors.New("No shape information found in npy file") } sizesStr := strings.Split(string(match[1]), ",") + var shape Shape for _, s := range sizesStr { s = strings.Trim(s, " ") @@ -273,163 +268,166 @@ func (t *Dense) ReadNpy(r io.Reader) (err error) { if t.e == nil { t.e = StdEng{} } - t.makeArray(size) switch t.t.Kind() { case reflect.Int: data := t.Ints() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int8: data := t.Int8s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int16: data := t.Int16s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int32: data := t.Int32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int64: data := t.Int64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint: data := t.Uints() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint8: data := t.Uint8s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint16: data := t.Uint16s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint32: data := t.Uint32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint64: data := t.Uint64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Float32: data := t.Float32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Float64: data := t.Float64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Complex64: data := t.Complex64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Complex128: data := t.Complex128s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } } - t.AP = BorrowAP(len(shape)) + if err = br.Err(); err != nil { + return err + } + + t.AP.zeroWithDims(len(shape)) t.setShape(shape...) t.fix() return t.sanity() } -// GobDecode implements gob.GobDecoder -func (t *Dense) GobDecode(p []byte) (err error) { - buf := bytes.NewBuffer(p) - decoder := gob.NewDecoder(buf) - - var shape Shape - if err = decoder.Decode(&shape); err != nil { - return - } +/* CSV SERIALIZATION */ - var strides []int - if err = decoder.Decode(&strides); err != nil { +// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. +// If tensor is masked, invalid values are replaced by the default fill value. +func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { + // checks: + if !t.IsMatrix() { + // error + err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) return } - - var o DataOrder - var tr Triangle - if err = decoder.Decode(&o); err == nil { - if err = decoder.Decode(&tr); err != nil { - return - } + format := "%v" + if len(formats) > 0 { + format = formats[0] } - t.AP = NewAP(shape, strides) - t.AP.o = o - t.AP.Δ = tr + cw := csv.NewWriter(w) + it := IteratorFromDense(t) + coord := it.Coord() - var mask []bool - if err = decoder.Decode(&mask); err != nil { - return - } + // rows := t.Shape()[0] + cols := t.Shape()[1] + record := make([]string, 0, cols) + var i, k, lastCol int + isMasked := t.IsMasked() + fillval := t.FillValue() + fillstr := fmt.Sprintf(format, fillval) + for i, err = it.Next(); err == nil; i, err = it.Next() { + record = append(record, fmt.Sprintf(format, t.Get(i))) + if isMasked { + if t.mask[i] { + record[k] = fillstr + } + k++ + } + if lastCol == cols-1 { + if err = cw.Write(record); err != nil { + // TODO: wrap errors + return + } + cw.Flush() + record = record[:0] + } - var data interface{} - if err = decoder.Decode(&data); err != nil { - return + // cleanup + switch { + case t.IsRowVec(): + // lastRow = coord[len(coord)-2] + lastCol = coord[len(coord)-1] + case t.IsColVec(): + // lastRow = coord[len(coord)-1] + lastCol = coord[len(coord)-2] + case t.IsVector(): + lastCol = coord[len(coord)-1] + default: + // lastRow = coord[len(coord)-2] + lastCol = coord[len(coord)-1] + } } - t.fromSlice(data) - t.addMask(mask) - t.fix() - return t.sanity() + return nil } -// convFromStrs conversts a []string to a slice of the Dtype provided -func convFromStrs(to Dtype, record []string) (interface{}, error) { +// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. +// If into is nil, then a backing slice will be created. +func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { case reflect.Int: retVal := make([]int, len(record)) + var backing []int + if into == nil { + backing = make([]int, 0, len(record)) + } else { + backing = into.([]int) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 0); err != nil { @@ -437,9 +435,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int8: retVal := make([]int8, len(record)) + var backing []int8 + if into == nil { + backing = make([]int8, 0, len(record)) + } else { + backing = into.([]int8) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 8); err != nil { @@ -447,9 +453,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int8(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int16: retVal := make([]int16, len(record)) + var backing []int16 + if into == nil { + backing = make([]int16, 0, len(record)) + } else { + backing = into.([]int16) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 16); err != nil { @@ -457,9 +471,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int16(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int32: retVal := make([]int32, len(record)) + var backing []int32 + if into == nil { + backing = make([]int32, 0, len(record)) + } else { + backing = into.([]int32) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 32); err != nil { @@ -467,9 +489,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int32(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int64: retVal := make([]int64, len(record)) + var backing []int64 + if into == nil { + backing = make([]int64, 0, len(record)) + } else { + backing = into.([]int64) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 64); err != nil { @@ -477,9 +507,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int64(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint: retVal := make([]uint, len(record)) + var backing []uint + if into == nil { + backing = make([]uint, 0, len(record)) + } else { + backing = into.([]uint) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 0); err != nil { @@ -487,9 +525,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint8: retVal := make([]uint8, len(record)) + var backing []uint8 + if into == nil { + backing = make([]uint8, 0, len(record)) + } else { + backing = into.([]uint8) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 8); err != nil { @@ -497,9 +543,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint8(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint16: retVal := make([]uint16, len(record)) + var backing []uint16 + if into == nil { + backing = make([]uint16, 0, len(record)) + } else { + backing = into.([]uint16) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 16); err != nil { @@ -507,9 +561,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint16(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint32: retVal := make([]uint32, len(record)) + var backing []uint32 + if into == nil { + backing = make([]uint32, 0, len(record)) + } else { + backing = into.([]uint32) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 32); err != nil { @@ -517,9 +579,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint32(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint64: retVal := make([]uint64, len(record)) + var backing []uint64 + if into == nil { + backing = make([]uint64, 0, len(record)) + } else { + backing = into.([]uint64) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 64); err != nil { @@ -527,9 +597,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint64(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Float32: retVal := make([]float32, len(record)) + var backing []float32 + if into == nil { + backing = make([]float32, 0, len(record)) + } else { + backing = into.([]float32) + } + for i, v := range record { var f float64 if f, err = strconv.ParseFloat(v, 32); err != nil { @@ -537,15 +615,33 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = float32(f) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Float64: retVal := make([]float64, len(record)) + var backing []float64 + if into == nil { + backing = make([]float64, 0, len(record)) + } else { + backing = into.([]float64) + } + for i, v := range record { if retVal[i], err = strconv.ParseFloat(v, 64); err != nil { return nil, err } } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil + case reflect.String: + var backing []string + if into == nil { + backing = make([]string, 0, len(record)) + } else { + backing = into.([]string) + } + backing = append(backing, record...) + return backing, nil default: return nil, errors.Errorf(methodNYI, "convFromStrs", to) } @@ -564,307 +660,223 @@ func (t *Dense) ReadCSV(r io.Reader, opts ...FuncOpt) (err error) { cr := csv.NewReader(r) var record []string - var row interface{} var rows, cols int - - switch as.Kind() { - case reflect.Int: - var backing []int - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int, record); err != nil { - return - } - backing = append(backing, row.([]int)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int8: - var backing []int8 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int8, record); err != nil { - return - } - backing = append(backing, row.([]int8)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int16: - var backing []int16 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int16, record); err != nil { - return - } - backing = append(backing, row.([]int16)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int32: - var backing []int32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int32, record); err != nil { - return - } - backing = append(backing, row.([]int32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int64: - var backing []int64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int64, record); err != nil { - return - } - backing = append(backing, row.([]int64)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint: - var backing []uint - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint, record); err != nil { - return - } - backing = append(backing, row.([]uint)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint8: - var backing []uint8 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint8, record); err != nil { - return - } - backing = append(backing, row.([]uint8)...) - cols = len(record) - rows++ + var backing interface{} + for { + record, err = cr.Read() + if err == io.EOF { + break + } else if err != nil { + return } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint16: - var backing []uint16 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint16, record); err != nil { - return - } - backing = append(backing, row.([]uint16)...) - cols = len(record) - rows++ + if backing, err = convFromStrs(as, record, backing); err != nil { + return } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint32: - var backing []uint32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } + cols = len(record) + rows++ + } + t.fromSlice(backing) + t.AP.zero() + t.AP.SetShape(rows, cols) + return nil + return errors.Errorf("not yet handled") +} - if row, err = convFromStrs(Uint32, record); err != nil { - return - } - backing = append(backing, row.([]uint32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint64: - var backing []uint64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } +/* FB SERIALIZATION */ - if err != nil { - return - } +// FBEncode encodes to a byte slice using flatbuffers. +// +// Only natively accessible data can be encided +func (t *Dense) FBEncode() ([]byte, error) { + builder := flatbuffers.NewBuilder(1024) + + fb.DenseStartShapeVector(builder, len(t.shape)) + for i := len(t.shape) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.shape[i])) + } + shape := builder.EndVector(len(t.shape)) + + fb.DenseStartStridesVector(builder, len(t.strides)) + for i := len(t.strides) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.strides[i])) + } + strides := builder.EndVector(len(t.strides)) + + var o uint32 + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + o = 0 + case t.o.IsRowMajor() && !t.o.IsContiguous(): + o = 1 + case t.o.IsColMajor() && t.o.IsContiguous(): + o = 2 + case t.o.IsColMajor() && !t.o.IsContiguous(): + o = 3 + } + + var triangle int32 + switch t.Δ { + case NotTriangle: + triangle = fb.TriangleNOT_TRIANGLE + case Upper: + triangle = fb.TriangleUPPER + case Lower: + triangle = fb.TriangleLOWER + case Symmetric: + triangle = fb.TriangleSYMMETRIC + } + + dt := builder.CreateString(t.Dtype().String()) + data := t.byteSlice() + + fb.DenseStartDataVector(builder, len(data)) + for i := len(data) - 1; i >= 0; i-- { + builder.PrependUint8(data[i]) + } + databyte := builder.EndVector(len(data)) + + fb.DenseStart(builder) + fb.DenseAddShape(builder, shape) + fb.DenseAddStrides(builder, strides) + fb.DenseAddO(builder, o) + fb.DenseAddT(builder, triangle) + fb.DenseAddType(builder, dt) + fb.DenseAddData(builder, databyte) + serialized := fb.DenseEnd(builder) + builder.Finish(serialized) + + return builder.FinishedBytes(), nil +} - if row, err = convFromStrs(Uint64, record); err != nil { - return - } - backing = append(backing, row.([]uint64)...) - cols = len(record) - rows++ +// FBDecode decodes a byteslice from a flatbuffer table into a *Dense +func (t *Dense) FBDecode(buf []byte) error { + serialized := fb.GetRootAsDense(buf, 0) + + o := serialized.O() + switch o { + case 0: + t.o = 0 + case 1: + t.o = MakeDataOrder(NonContiguous) + case 2: + t.o = MakeDataOrder(ColMajor) + case 3: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + + tri := serialized.T() + switch tri { + case fb.TriangleNOT_TRIANGLE: + t.Δ = NotTriangle + case fb.TriangleUPPER: + t.Δ = Upper + case fb.TriangleLOWER: + t.Δ = Lower + case fb.TriangleSYMMETRIC: + t.Δ = Symmetric + } + + t.shape = Shape(BorrowInts(serialized.ShapeLength())) + for i := 0; i < serialized.ShapeLength(); i++ { + t.shape[i] = int(int32(serialized.Shape(i))) + } + + t.strides = BorrowInts(serialized.StridesLength()) + for i := 0; i < serialized.ShapeLength(); i++ { + t.strides[i] = int(serialized.Strides(i)) + } + typ := string(serialized.Type()) + for _, dt := range allTypes.set { + if dt.String() == typ { + t.t = dt + break } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Float32: - var backing []float32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + } - if err != nil { - return - } + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) - if row, err = convFromStrs(Float32, record); err != nil { - return - } - backing = append(backing, row.([]float32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Float64: - var backing []float64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, serialized.DataBytes()) + t.forcefix() + return t.sanity() +} - if err != nil { - return - } +/* PB SERIALIZATION */ + +// PBEncode encodes the Dense into a protobuf byte slice. +func (t *Dense) PBEncode() ([]byte, error) { + var toSerialize pb.Dense + toSerialize.Shape = make([]int32, len(t.shape)) + for i, v := range t.shape { + toSerialize.Shape[i] = int32(v) + } + toSerialize.Strides = make([]int32, len(t.strides)) + for i, v := range t.strides { + toSerialize.Strides[i] = int32(v) + } + + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + toSerialize.O = pb.RowMajorContiguous + case t.o.IsRowMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.RowMajorNonContiguous + case t.o.IsColMajor() && t.o.IsContiguous(): + toSerialize.O = pb.ColMajorContiguous + case t.o.IsColMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.ColMajorNonContiguous + } + toSerialize.T = pb.Triangle(t.Δ) + toSerialize.Type = t.t.String() + data := t.byteSlice() + toSerialize.Data = make([]byte, len(data)) + copy(toSerialize.Data, data) + return toSerialize.Marshal() +} - if row, err = convFromStrs(Float64, record); err != nil { - return - } - backing = append(backing, row.([]float64)...) - cols = len(record) - rows++ +// PBDecode unmarshalls a protobuf byteslice into a *Dense. +func (t *Dense) PBDecode(buf []byte) error { + var toSerialize pb.Dense + if err := toSerialize.Unmarshal(buf); err != nil { + return err + } + t.shape = make(Shape, len(toSerialize.Shape)) + for i, v := range toSerialize.Shape { + t.shape[i] = int(v) + } + t.strides = make([]int, len(toSerialize.Strides)) + for i, v := range toSerialize.Strides { + t.strides[i] = int(v) + } + + switch toSerialize.O { + case pb.RowMajorContiguous: + case pb.RowMajorNonContiguous: + t.o = MakeDataOrder(NonContiguous) + case pb.ColMajorContiguous: + t.o = MakeDataOrder(ColMajor) + case pb.ColMajorNonContiguous: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + t.Δ = Triangle(toSerialize.T) + typ := string(toSerialize.Type) + for _, dt := range allTypes.set { + if dt.String() == typ { + t.t = dt + break } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.String: - var backing []string - for { - record, err = cr.Read() - if err == io.EOF { - break - } + } - if err != nil { - return - } - backing = append(backing, record...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - default: - return errors.Errorf("%v not yet handled", as) + if t.e == nil { + t.e = StdEng{} } - return errors.Errorf("not yet handled") + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, toSerialize.Data) + return t.sanity() } diff --git a/dense_io_test.go b/dense_io_test.go index 01de3f0..3c75973 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -23,7 +23,7 @@ func TestSaveLoadNumpy(t *testing.T) { script := "import numpy as np\nx = np.load('test.npy')\nprint(x)" - cmd := exec.Command("python2") + cmd := exec.Command("python") stdin, err := cmd.StdinPipe() if err != nil { t.Error(err) @@ -204,5 +204,52 @@ func TestDense_GobEncodeDecode(t *testing.T) { assert.Equal(T.mask, T3.mask) } +} + +func TestDense_FBEncodeDecode(t *testing.T) { + assert := assert.New(t) + for _, gtd := range serializationTestData { + T := New(WithShape(2, 2), WithBacking(gtd)) + buf, err := T.FBEncode() + if err != nil { + t.Errorf("UNPOSSIBLE!: %v", err) + continue + } + + T2 := new(Dense) + if err = T2.FBDecode(buf); err != nil { + t.Errorf("Error while decoding %v: %v", gtd, err) + continue + } + assert.Equal(T.Shape(), T2.Shape()) + assert.Equal(T.Strides(), T2.Strides()) + assert.Equal(T.Data(), T2.Data()) + + // TODO: MASKED ARRAY + } +} + +func TestDense_PBEncodeDecode(t *testing.T) { + assert := assert.New(t) + for _, gtd := range serializationTestData { + T := New(WithShape(2, 2), WithBacking(gtd)) + buf, err := T.PBEncode() + if err != nil { + t.Errorf("UNPOSSIBLE!: %v", err) + continue + } + + T2 := new(Dense) + if err = T2.PBDecode(buf); err != nil { + t.Errorf("Error while decoding %v: %v", gtd, err) + continue + } + + assert.Equal(T.Shape(), T2.Shape()) + assert.Equal(T.Strides(), T2.Strides()) + assert.Equal(T.Data(), T2.Data()) + + // TODO: MASKED ARRAY + } } diff --git a/dense_linalg.go b/dense_linalg.go index ca07663..6493808 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices func (t *Dense) Trace() (retVal interface{}, err error) { @@ -87,6 +89,9 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err if retVal == nil { retVal = recycledDense(t.t, expectedShape) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } e := t.e @@ -133,10 +138,12 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) if retVal == nil { retVal = recycledDense(t.t, expectedShape) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } e := t.e - if mm, ok := e.(MatMuler); ok { if err = mm.MatMul(t, other, retVal); err != nil { return @@ -170,6 +177,9 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) if retVal == nil { retVal = recycledDense(t.t, expectedShape) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } e := t.e @@ -310,7 +320,6 @@ func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err return } doOther.Transpose() - if err = doOther.Reshape(newShapeO...); err != nil { return } diff --git a/dense_linalg_test.go b/dense_linalg_test.go index bfd316c..a9a24dc 100644 --- a/dense_linalg_test.go +++ b/dense_linalg_test.go @@ -10,6 +10,7 @@ import ( type linalgTest struct { a, b interface{} shapeA, shapeB Shape + transA, transB bool reuse, incr interface{} shapeR, shapeI Shape @@ -118,89 +119,94 @@ func TestDense_Inner(t *testing.T) { var matVecMulTests = []linalgTest{ // Float64s - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + // float64s with transposed matrix + {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, + Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, + // Float32s - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, // stupids : unpossible shapes (wrong A) - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad A shape - {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, + {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad B shape - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad reuse - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, //stupids: bad incr shape - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, // stupids: type mismatch A and B - {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B (non-Float) - {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch, reuse - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, // stupids: type mismatch, incr - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, // stupids: type mismatch, incr not a Number - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, } @@ -211,12 +217,19 @@ func TestDense_MatVecMul(t *testing.T) { a := New(WithBacking(mvmt.a), WithShape(mvmt.shapeA...)) b := New(WithBacking(mvmt.b), WithShape(mvmt.shapeB...)) + if mvmt.transA { + if err := a.T(); err != nil { + t.Error(err) + continue + } + } T, err := a.MatVecMul(b) if checkErr(t, mvmt.err, err, "Safe", i) { continue } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correct, T.Data()) // incr @@ -227,6 +240,7 @@ func TestDense_MatVecMul(t *testing.T) { } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correctIncr, T.Data()) // reuse @@ -237,6 +251,7 @@ func TestDense_MatVecMul(t *testing.T) { } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correct, T.Data()) // reuse AND incr @@ -251,89 +266,89 @@ func TestDense_MatVecMul(t *testing.T) { var matMulTests = []linalgTest{ // Float64s - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, // Float32s - {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, // Edge cases - Row Vecs (Float64) - {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, + {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, []float64{0, 0, 0, 0, 1, 2}, []float64{100, 101, 102, 103, 105, 107}, []float64{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, - {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, + {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, - {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, + {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, // Edge cases - Row Vecs (Float32) - {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, + {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, []float32{0, 0, 0, 0, 1, 2}, []float32{100, 101, 102, 103, 105, 107}, []float32{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, - {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, + {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, - {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, + {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, // stupids - bad shape (not matrices): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - bad shape (incompatible shapes): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - bad shape (bad reuse shape): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, // stupids - bad shape (bad incr shape): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, // stupids - type mismatch (a,b) - {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - type mismatch (a,b) - {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids type mismatch (b not float) - {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids type mismatch (a not float) - {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids: type mismatch (incr) - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, // stupids: type mismatch (reuse) - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, // stupids: type mismatch (reuse) - {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, } @@ -382,55 +397,55 @@ func TestDense_MatMul(t *testing.T) { var outerTests = []linalgTest{ // Float64s - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, // Float32s - {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float32{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, // stupids - a or b not vector - {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, + {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - bad incr shape - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, true, false}, // stupids - bad reuse shape - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, true}, // stupids - b not Float - {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - a not Float - {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - a-b type mismatch - {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids a-b type mismatch - {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, diff --git a/dense_matop.go b/dense_matop.go index 46e8a55..1a3b815 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -5,17 +5,16 @@ import "github.com/pkg/errors" // T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides // Usually this is more than enough, as BLAS will handle the rest of the transpose func (t *Dense) T(axes ...int) (err error) { - var transform *AP + var transform AP if transform, axes, err = t.AP.T(axes...); err != nil { return handleNoOp(err) } // is there any old transposes that need to be done first? // this is important, because any old transposes for dim >=3 are merely permutations of the strides - if t.old != nil { + if !t.old.IsZero() { if t.IsVector() { // the transform that was calculated was a waste of time - return it to the pool then untranspose - ReturnAP(transform) t.UT() return } @@ -31,7 +30,6 @@ func (t *Dense) T(axes ...int) (err error) { // if it is reversed, well, we just restore the backed up one if isReversed { - ReturnAP(transform) t.UT() return } @@ -58,18 +56,17 @@ func (t *Dense) T(axes ...int) (err error) { // // Nothing will happen if there was no previous transpose func (t *Dense) UT() { - if t.old != nil { - ReturnAP(t.AP) + if !t.old.IsZero() { ReturnInts(t.transposeWith) t.AP = t.old - t.old = nil + t.old.zeroOnly() t.transposeWith = nil } } // SafeT is exactly like T(), except it returns a new *Dense. The data is also copied over, unmoved. func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { - var transform *AP + var transform AP if transform, axes, err = t.AP.T(axes...); err != nil { if err = handleNoOp(err); err != nil { return @@ -82,7 +79,7 @@ func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { retVal.e = t.e retVal.oe = t.oe retVal.AP = transform - retVal.old = t.AP.Clone() + t.AP.CloneTo(&retVal.old) retVal.transposeWith = axes return @@ -209,7 +206,7 @@ func (t *Dense) CopyTo(other *Dense) error { // // The method treats as equivalent to a colon slice. T.Slice(nil) is equivalent to T[:] in Numpy syntax func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { - var newAP *AP + var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { @@ -236,15 +233,14 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { // The underlying data is the same. // This method will override ALL the metadata in view. func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) { - var newAP *AP + var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { return } - ReturnAP(view.AP) - view.AP = nil + view.AP.zero() view.array.v = nil // reset view.t = t.t @@ -314,6 +310,7 @@ func (t *Dense) RollAxis(axis, start int, safe bool) (retVal *Dense, err error) func (t *Dense) transposeIndex(i int, transposePat, strides []int) int { oldCoord, err := Itol(i, t.oshape(), t.ostrides()) if err != nil { + err = errors.Wrapf(err, "transposeIndex ItoL failure. i %d original shape %v. original strides %v", i, t.oshape(), t.ostrides()) panic(err) } diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index 05033ef..fe05f2a 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -9,7 +9,7 @@ import "github.com/pkg/errors" // https://en.wikipedia.org/wiki/In-place_matrix_transposition func (t *Dense) Transpose() error { // if there is no oldinfo, that means the current info is the latest, and not the transpose - if t.old == nil { + if t.old.IsZero() { return nil } @@ -18,8 +18,7 @@ func (t *Dense) Transpose() error { } defer func() { - ReturnAP(t.old) - t.old = nil + t.old.zero() t.transposeWith = nil }() @@ -27,10 +26,10 @@ func (t *Dense) Transpose() error { // important! because the strides would have changed once the underlying data changed var expStrides []int - if t.AP.o.isColMajor() { - expStrides = expShape.calcStridesColMajor() + if t.AP.o.IsColMajor() { + expStrides = expShape.CalcStridesColMajor() } else { - expStrides = expShape.calcStrides() + expStrides = expShape.CalcStrides() } defer ReturnInts(expStrides) defer func() { diff --git a/dense_matop_test.go b/dense_matop_test.go index 51ee94a..ca02d3e 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -135,10 +135,10 @@ var transposeTests = []struct { correctData interface{} }{ {"c.T()", Shape{4, 1}, nil, []float64{0, 1, 2, 3}, - Shape{1, 4}, []int{1}, []int{1}, []float64{0, 1, 2, 3}}, + Shape{1, 4}, []int{1, 1}, []int{4, 1}, []float64{0, 1, 2, 3}}, {"r.T()", Shape{1, 4}, nil, []float32{0, 1, 2, 3}, - Shape{4, 1}, []int{1}, []int{1}, []float32{0, 1, 2, 3}}, + Shape{4, 1}, []int{4, 1}, []int{1, 1}, []float32{0, 1, 2, 3}}, {"v.T()", Shape{4}, nil, []int{0, 1, 2, 3}, Shape{4}, []int{1}, []int{1}, []int{0, 1, 2, 3}}, @@ -216,10 +216,10 @@ func TestDense_Transpose(t *testing.T) { } assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) - assert.Equal(tts.correctStrides, T.Strides()) + assert.Equal(tts.correctStrides, T.Strides(), "Transpose %v. Expected stride: %v. Got %v", tts.name, tts.correctStrides, T.Strides()) T.Transpose() assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) - assert.Equal(tts.correctStrides2, T.Strides(), "Transpose %v - Wrong strides", tts.name) + assert.Equal(tts.correctStrides2, T.Strides(), "Transpose2 %v - Expected stride %v. Got %v", tts.name, tts.correctStrides2, T.Strides()) assert.Equal(tts.correctData, T.Data(), "Transpose %v", tts.name) } @@ -236,7 +236,7 @@ func TestDense_Transpose(t *testing.T) { t.Errorf("Stacked .T() #1 for vector. Error: %v", err) goto matrev } - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) assert.True(T.IsColVec()) @@ -251,7 +251,7 @@ matrev: t.Errorf("Stacked .T() #2 for matrix reverse. Error: %v", err) goto matnorev } - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) assert.True(Shape{2, 3}.Eq(T.Shape())) @@ -278,12 +278,12 @@ func TestTUT(t *testing.T) { T = New(Of(Float64), WithShape(2, 3, 4)) T.T() T.UT() - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) T.T(2, 0, 1) T.UT() - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) } @@ -493,16 +493,16 @@ var denseSliceTests = []struct { // colvec {"c[0]", Range(Int64, 0, 5), Shape{5, 1}, []Slice{ss(0)}, ScalarShape(), nil, int64(0)}, - {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1}, []float32{0, 1}}, - {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2}, []float64{0, 1, 2, 3, 4}}, + {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1, 1}, []float32{0, 1}}, + {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2, 1}, []float64{0, 1, 2, 3, 4}}, // // rowvec {"r[0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{ss(0)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[0:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 5, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[:, 0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, ss(0)}, ScalarShape(), nil, float64(0)}, - {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{1}, []float64{0, 1}}, - {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{2}, []float64{1, 2, 3, 4}}, + {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{5, 1}, []float64{0, 1}}, + {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{5, 2}, []float64{1, 2, 3, 4}}, // // matrix {"A[0]", Range(Float64, 0, 6), Shape{2, 3}, []Slice{ss(0)}, Shape{1, 3}, []int{1}, Range(Float64, 0, 3)}, @@ -540,7 +540,7 @@ func TestDense_Slice(t *testing.T) { assert.True(Shape{2}.Eq(V.Shape())) assert.Equal([]int{3}, V.Strides()) assert.Equal([]float32{0, 1, 2, 3}, V.Data()) - assert.Nil(V.(*Dense).old) + assert.True(V.(*Dense).old.IsZero()) // slice a sliced V, err = V.Slice(makeRS(1, 2)) @@ -623,49 +623,61 @@ func TestDense_RollAxis(t *testing.T) { } var concatTests = []struct { - name string - dt Dtype - a interface{} - shape Shape - axis int + name string + dt Dtype + a interface{} + b interface{} + shape Shape + shapeB Shape + axis int correctShape Shape correctData interface{} }{ // Float64 - {"vector", Float64, nil, Shape{2}, 0, Shape{4}, []float64{0, 1, 0, 1}}, - {"matrix; axis 0 ", Float64, nil, Shape{2, 2}, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Float64, nil, Shape{2, 2}, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Float64, nil, nil, Shape{2}, nil, 0, Shape{4}, []float64{0, 1, 0, 1}}, + {"matrix; axis 0 ", Float64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Float64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, // Float32 - {"vector", Float32, nil, Shape{2}, 0, Shape{4}, []float32{0, 1, 0, 1}}, - {"matrix; axis 0 ", Float32, nil, Shape{2, 2}, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Float32, nil, Shape{2, 2}, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Float32, nil, nil, Shape{2}, nil, 0, Shape{4}, []float32{0, 1, 0, 1}}, + {"matrix; axis 0 ", Float32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Float32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, // Int - {"vector", Int, nil, Shape{2}, 0, Shape{4}, []int{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int, nil, Shape{2, 2}, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int, nil, Shape{2, 2}, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int, nil, nil, Shape{2}, nil, 0, Shape{4}, []int{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, // Int64 - {"vector", Int64, nil, Shape{2}, 0, Shape{4}, []int64{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int64, nil, Shape{2, 2}, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int64, nil, Shape{2, 2}, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int64, nil, nil, Shape{2}, nil, 0, Shape{4}, []int64{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, // Int32 - {"vector", Int32, nil, Shape{2}, 0, Shape{4}, []int32{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int32, nil, Shape{2, 2}, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int32, nil, Shape{2, 2}, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int32, nil, nil, Shape{2}, nil, 0, Shape{4}, []int32{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, // Byte - {"vector", Byte, nil, Shape{2}, 0, Shape{4}, []byte{0, 1, 0, 1}}, - {"matrix; axis 0 ", Byte, nil, Shape{2, 2}, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Byte, nil, Shape{2, 2}, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Byte, nil, nil, Shape{2}, nil, 0, Shape{4}, []byte{0, 1, 0, 1}}, + {"matrix; axis 0 ", Byte, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Byte, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, // Bool - {"vector", Bool, []bool{true, false}, Shape{2}, 0, Shape{4}, []bool{true, false, true, false}}, - {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, Shape{2, 2}, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, - {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, Shape{2, 2}, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, + {"vector", Bool, []bool{true, false}, nil, Shape{2}, nil, 0, Shape{4}, []bool{true, false, true, false}}, + {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, + {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, + + // gorgonia/gorgonia#218 related + {"matrix; axis 0", Float64, nil, nil, Shape{2, 2}, Shape{1, 2}, 0, Shape{3, 2}, []float64{0, 1, 2, 3, 0, 1}}, + {"matrix; axis 1", Float64, nil, nil, Shape{2, 2}, Shape{2, 1}, 1, Shape{2, 3}, []float64{0, 1, 0, 2, 3, 1}}, + {"colvec matrix, axis 0", Float64, nil, nil, Shape{2, 1}, Shape{1, 1}, 0, Shape{3, 1}, []float64{0, 1, 0}}, + {"rowvec matrix, axis 1", Float64, nil, nil, Shape{1, 2}, Shape{1, 1}, 1, Shape{1, 3}, []float64{0, 1, 0}}, + + {"3tensor; axis 0", Float64, nil, nil, Shape{2, 3, 2}, Shape{1, 3, 2}, 0, Shape{3, 3, 2}, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}}, + {"3tensor; axis 2", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 3, 1}, 2, Shape{2, 3, 3}, []float64{0, 1, 0, 2, 3, 1, 4, 5, 2, 6, 7, 3, 8, 9, 4, 10, 11, 5}}, + // {"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{222}}, } func TestDense_Concat(t *testing.T) { @@ -676,15 +688,24 @@ func TestDense_Concat(t *testing.T) { if cts.a == nil { T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) - T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) } else { T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) + } + + switch { + case cts.shapeB == nil && cts.a == nil: + T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) + case cts.shapeB == nil && cts.a != nil: T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) + case cts.shapeB != nil && cts.b == nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) + case cts.shapeB != nil && cts.b != nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) } T2, err := T0.Concat(cts.axis, T1) if err != nil { - t.Error(err) + t.Errorf("Test %v failed: %v", cts.name, err) continue } assert.True(cts.correctShape.Eq(T2.Shape())) @@ -694,24 +715,31 @@ func TestDense_Concat(t *testing.T) { //Masked case for _, cts := range concatTests { - var T0, T1 *Dense if cts.a == nil { T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) T0.MaskedEqual(castToDt(0.0, cts.dt)) - T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) - T1.MaskedEqual(castToDt(0.0, cts.dt)) } else { T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) T0.MaskedEqual(castToDt(0.0, cts.dt)) + } + + switch { + case cts.shapeB == nil && cts.a == nil: + T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) + case cts.shapeB == nil && cts.a != nil: T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) - T1.MaskedEqual(castToDt(0.0, cts.dt)) + case cts.shapeB != nil && cts.b == nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) + case cts.shapeB != nil && cts.b != nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) } + T1.MaskedEqual(castToDt(0.0, cts.dt)) T2, err := T0.Concat(cts.axis, T1) if err != nil { - t.Error(err) + t.Errorf("Test %v failed: %v", cts.name, err) continue } diff --git a/dense_norms.go b/dense_norms.go index ad75c0f..63d460a 100644 --- a/dense_norms.go +++ b/dense_norms.go @@ -94,8 +94,8 @@ func (t *Dense) Norm(ord NormOrder, axes ...int) (retVal *Dense, err error) { if len(axes) == 0 { if ord.IsUnordered() || (ord.IsFrobenius() && dims == 2) || (ord == Norm(2) && dims == 1) { backup := t.AP - ap := BorrowAP(1) - defer ReturnAP(ap) + ap := makeAP(1) + defer ap.zero() ap.unlock() ap.SetShape(t.Size()) diff --git a/dense_svd_test.go b/dense_svd_test.go index 36e4e16..89c5306 100644 --- a/dense_svd_test.go +++ b/dense_svd_test.go @@ -1,6 +1,7 @@ package tensor import ( + "fmt" "testing" "github.com/pkg/errors" @@ -103,6 +104,27 @@ func testSVD(T, T2, s, u, v *Dense, t string, i int) (err error) { return nil } +func Example_DenseSVD() { + T := New( + WithShape(4, 5), + WithBacking([]float64{1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}), + ) + _, u, _, _ := T.SVD(true, true) + uT := u.Clone().(*Dense) + uT.T() + eye, err := u.MatMul(uT) + fmt.Println(eye) + fmt.Println(err) + + // Output: + // ⎡1 0 0 0⎤ + // ⎢0 1 0 0⎥ + // ⎢0 0 1 0⎥ + // ⎣0 0 0 1⎦ + // + // +} + func TestDense_SVD(t *testing.T) { var T, T2, s, u, v *Dense var err error @@ -134,7 +156,6 @@ func TestDense_SVD(t *testing.T) { t.Errorf("Expected v = %v. Got %v instead", stts.correctVData, v.Data()) } } - // standard tests for i, stfs := range svdtestsFull { T = New(WithShape(stfs...), WithBacking(Random(Float64, stfs.TotalSize()))) @@ -143,14 +164,14 @@ func TestDense_SVD(t *testing.T) { // full if s, u, v, err = T.SVD(true, true); err != nil { t.Error(err) + fmt.Println(err) continue } - if err = testSVD(T, T2, s, u, v, "full", i); err != nil { t.Error(err) + fmt.Println(err) continue } - // thin if s, u, v, err = T.SVD(true, false); err != nil { t.Error(err) @@ -183,8 +204,8 @@ func TestDense_SVD(t *testing.T) { if !allClose(s.Data(), svd.Values(nil), closeenoughf64) { t.Errorf("Singular value mismatch between Full and None decomposition. Expected %v. Got %v instead", svd.Values(nil), s.Data()) } - } + } // this is illogical T = New(Of(Float64), WithShape(2, 2)) if _, _, _, err = T.SVD(false, true); err == nil { diff --git a/engine.go b/engine.go index af56f6b..9e3ede7 100644 --- a/engine.go +++ b/engine.go @@ -59,6 +59,11 @@ type arrayMaker interface { makeArray(arr *array, t Dtype, size int) } +// NonStdEngine are any engines that do not allocate using the default built in allocator +type NonStdEngine interface { + NonStdAlloc() // noop +} + /* Data Agnostic Execution Engine Methods */ // Transposer is any engine that can perform an unsafe transpose of a tensor. @@ -86,6 +91,11 @@ type Repeater interface { Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) } +// Diager is any engine that can return a tensor that only contains the diagonal values of the input +type Diager interface { + Diag(a Tensor) (Tensor, error) +} + /* NUMBER INTERFACES All these are expected to be unsafe on the first tensor */ @@ -369,6 +379,20 @@ type Argminer interface { Argmin(t Tensor, axis int) (Tensor, error) } +// NaNChecker checks that the tensor contains a NaN +// Errors are to be returned if the concept of NaN does not apply to the data type. +// Other errors may also occur. See specific implementations for details +type NaNChecker interface { + HasNaN(t Tensor) (bool, error) +} + +// InfChecker checks that the tensor contains a Inf. +// Errors are to be returned if the concept of Inf does not apply to the data type. +// Other errors may also occur. See specific implementations for details +type InfChecker interface { + HasInf(t Tensor) (bool, error) +} + /* Internal interfaces for faster shit */ type denseArgmaxer interface { diff --git a/example_dense_linalg_test.go b/example_dense_linalg_test.go new file mode 100644 index 0000000..d558481 --- /dev/null +++ b/example_dense_linalg_test.go @@ -0,0 +1,151 @@ +package tensor + +import ( + "fmt" +) + +func ExampleDense_MatMul() { + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(10, 15), WithBacking(Range(Float64, 0, 150))) + T1 := New(WithShape(15, 10), WithBacking(Range(Float64, 150, 0))) + T2, err := MatMul(T0, T1) + handleErr(err) + + fmt.Printf("T2:\n%v", T2) + + // Output: + // T2: + // ⎡ 5600 5495 5390 5285 ... 4970 4865 4760 4655⎤ + // ⎢ 23600 23270 22940 22610 ... 21620 21290 20960 20630⎥ + // ⎢ 41600 41045 40490 39935 ... 38270 37715 37160 36605⎥ + // ⎢ 59600 58820 58040 57260 ... 54920 54140 53360 52580⎥ + // . + // . + // . + // ⎢113600 112145 110690 109235 ... 104870 103415 101960 100505⎥ + // ⎢131600 129920 128240 126560 ... 121520 119840 118160 116480⎥ + // ⎢149600 147695 145790 143885 ... 138170 136265 134360 132455⎥ + // ⎣167600 165470 163340 161210 ... 154820 152690 150560 148430⎦ + +} + +func ExampleDense_MatVecMul() { + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(2, 3), WithBacking(Range(Float64, 1, 7))) + T1 := New(WithShape(3), WithBacking(Range(Float64, 0, 3))) + T2, err := T0.MatVecMul(T1) + handleErr(err) + + fmt.Printf("T2:\n%v\n", T2) + + // Output: + // T2: + // [ 8 17] +} + +func ExampleDense_MatVecMul_rowMajorSliced() { + // ASPIRATIONAL TODO: IncX and incY of differering values + + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(10, 12), WithBacking(Range(Float64, 1, 121))) + T1 := New(WithShape(3, 3), WithBacking(Range(Float64, 1, 10))) + T2, err := T0.Slice(makeRS(1, 3), makeRS(3, 6)) + handleErr(err) + T3, err := T1.Slice(nil, makeRS(1, 2)) + handleErr(err) + + // here the + formatting option is used because you should know that after this particular slice, the result will be a vector + fmt.Printf("T2:\n%+v", T2) + fmt.Printf("T3:\n%+v\n", T3) + + // here we print the underlying slice of T3 just to show that it's actually a much larger slice + fmt.Printf("Underlying Slice: %v\n", T3.Data()) + + T4, err := T2.(*Dense).MatVecMul(T3) + handleErr(err) + + fmt.Printf("T4:\n%v\n", T4) + + // Outputz: + // T2: + // Matrix (2, 3) [10 1] + // ⎡14 15 16⎤ + // ⎣24 25 26⎦ + // T3: + // Vector (3) [3] + // [2 5 8] + // Underlying Slice: [2 3 4 5 6 7 8] + // T4: + // [261 441] + +} + +func ExampleDense_MatMul_sliced() { + //ASPIRATIONAL TODO: incX and incY of different sizes + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(10, 15), WithBacking(Range(Float64, 0, 150))) + T1 := New(WithShape(15, 10), WithBacking(Range(Float64, 150, 0))) + T2, err := MatMul(T0, T1) + handleErr(err) + + fmt.Printf("T2:\n%v", T2) + + // Slice T0 to only take a (2, 3) on the upper quadrant + // T3 := T0[0:3, 0:2] + T3, err := T0.Slice(makeRS(0, 3), makeRS(0, 2)) + handleErr(err) + fmt.Printf("T3:\n%v", T3) + + T4, err := T1.Slice(makeRS(13, 15), makeRS(8, 10)) + handleErr(err) + fmt.Printf("T4:\n%v", T4) + + T5, err := T3.(*Dense).MatMul(T4) + handleErr(err) + fmt.Printf("T3xT4:\n%v", T5) + + // Outputz: + // T2: + // ⎡ 5600 5495 5390 5285 ... 4970 4865 4760 4655⎤ + // ⎢ 23600 23270 22940 22610 ... 21620 21290 20960 20630⎥ + // ⎢ 41600 41045 40490 39935 ... 38270 37715 37160 36605⎥ + // ⎢ 59600 58820 58040 57260 ... 54920 54140 53360 52580⎥ + // . + // . + // . + // ⎢113600 112145 110690 109235 ... 104870 103415 101960 100505⎥ + // ⎢131600 129920 128240 126560 ... 121520 119840 118160 116480⎥ + // ⎢149600 147695 145790 143885 ... 138170 136265 134360 132455⎥ + // ⎣167600 165470 163340 161210 ... 154820 152690 150560 148430⎦ + // T3: + // ⎡ 0 1⎤ + // ⎢15 16⎥ + // ⎣30 31⎦ + // T4: + // ⎡12 11⎤ + // ⎣ 2 1⎦ + // T3xT4: + // ⎡ 2 1⎤ + // ⎢212 181⎥ + // ⎣422 361⎦ +} diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 02b1dc8..97a2cb8 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -176,6 +176,8 @@ func ExampleDense_Vstack() { T3 = T1.Clone().(*Dense) if T2, err = T.Vstack(T1, T3); err == nil { fmt.Printf("T.Vstack(T1, T3):\n%v\n", T2) + } else { + fmt.Printf("====\nerr %v\n%v\n===\n", err, T3.Shape()) } // Let's look at failure conditions diff --git a/example_iterator_test.go b/example_iterator_test.go index 896a097..a6b31da 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -2,31 +2,51 @@ package tensor import "fmt" -func Example_iterator() { - fmt.Println("Row Major") - T := New(WithShape(2, 3), Of(Float64)) +// This is an example of how to use `IteratorFromDense` from a row-major Dense tensor +func Example_iteratorRowmajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) for i, err := it.Start(); err == nil; i, err = it.Next() { fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) } - /* - // FOR WHEN COL MAJOR IS SUPPORTED - fmt.Println("Col Major") - T = New(WithShape(2, 3), Of(Float64), AsFortran()) - it = IteratorFromDense(T) - - for i, err := it.Start(); err == nil; i, err = it.Next() { - fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) - } - */ // Output: - // Row Major + // T: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // // i: 0, coord: [0 1] // i: 1, coord: [0 2] // i: 2, coord: [1 0] // i: 3, coord: [1 1] // i: 4, coord: [1 2] // i: 5, coord: [0 0] + +} + +// This is an example of using `IteratorFromDense` on a col-major Dense tensor. More importantly +// this example shows the order of the iteration. +func Example_iteratorcolMajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) + it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // + // i: 0, coord: [0 1] + // i: 2, coord: [0 2] + // i: 4, coord: [1 0] + // i: 1, coord: [1 1] + // i: 3, coord: [1 2] + // i: 5, coord: [0 0] + } diff --git a/example_tensor_basics_test.go b/example_tensor_basics_test.go index d588a54..49c008b 100644 --- a/example_tensor_basics_test.go +++ b/example_tensor_basics_test.go @@ -2,16 +2,19 @@ package tensor import "fmt" +// This example showcases the very basics of the package. func Example_basics() { + // Create a (2, 2)-Matrix of integers a := New(WithShape(2, 2), WithBacking([]int{1, 2, 3, 4})) fmt.Printf("a:\n%v\n", a) + // Create a (2, 3, 4)-tensor of float32s b := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) fmt.Printf("b:\n%1.1f", b) // Accessing data x, _ := b.At(0, 1, 2) // in Numpy syntax: b[0,1,2] - fmt.Printf("x: %v\n\n", x) + fmt.Printf("x: %1.1f\n\n", x) // Setting data b.SetAt(float32(1000), 0, 1, 2) @@ -31,7 +34,7 @@ func Example_basics() { // ⎢16.0 17.0 18.0 19.0⎥ // ⎣20.0 21.0 22.0 23.0⎦ // - // x: 6 + // x: 6.0 // // b: // ⎡ 0 1 2 3⎤ @@ -42,3 +45,116 @@ func Example_basics() { // ⎢ 16 17 18 19⎥ // ⎣ 20 21 22 23⎦ } + +// This example showcases interactions between different data orders +func Example_differingDataOrders() { + T0 := New(WithShape(2, 3), WithBacking(Range(Int, 0, 6))) // Create a (2, 3)-matrix with the standard row-major backing + T1 := New(WithShape(2, 3), WithBacking(Range(Int, 0, 6)), AsFortran(nil)) // Create a (2, 3)-matrix with a col-major backing + T2, _ := Add(T0, T1) + fmt.Printf("T0:\n%vT1:\n%vT2:\n%vT2 Data Order: %v\n\n", T0, T1, T2, T2.DataOrder()) + + // the result's data order is highly dependent on the order of operation. It will take after the first operand + T0 = New(WithShape(2, 3), WithBacking(Range(Int, 1, 7)), AsFortran(nil)) // Create a (2, 3)-matrix with a col-major backing + T1 = New(WithShape(2, 3), WithBacking(Range(Int, 1, 7))) // Create a (2, 3)-matrix with the standard row-major backing + T2, _ = Add(T0, T1) + fmt.Printf("T0:\n%vT1:\n%vT2:\n%vT2 Data Order: %v\n\n", T0, T1, T2, T2.DataOrder()) + + reuse := New(WithShape(2, 3), WithBacking([]int{1000, 1000, 1000, 1000, 1000, 1000})) + fmt.Printf("reuse Data Order: %v\n", reuse.DataOrder()) + T2, _ = Add(T0, T1, WithReuse(reuse)) + fmt.Printf("T2:\n%vT2 Data Order: %v\n\n", T2, T2.DataOrder()) + + // Output: + // T0: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // T1: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // T2: + // ⎡ 0 3 6⎤ + // ⎣ 4 7 10⎦ + // T2 Data Order: Contiguous, RowMajor + // + // + // T0: + // ⎡1 3 5⎤ + // ⎣2 4 6⎦ + // T1: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // T2: + // ⎡ 2 5 8⎤ + // ⎣ 6 9 12⎦ + // T2 Data Order: Contiguous, ColMajor + // + // + // reuse Data Order: Contiguous, RowMajor + // T2: + // ⎡ 2 5 8⎤ + // ⎣ 6 9 12⎦ + // T2 Data Order: Contiguous, ColMajor + +} + +// The AsFortran construction option is a bit finnicky. +func Example_asFortran() { + // Here the data is passed in and directly used without changing the underlying data + T0 := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) + fmt.Printf("T0:\n%vData: %v\n\n", T0, T0.Data()) + + // Here the data is passed into the AsFortran construction option, and it assumes that the data is already in + // row-major form. Therefore a transpose will be performed. + T1 := New(WithShape(2, 3), AsFortran([]float64{0, 1, 2, 3, 4, 5})) + fmt.Printf("T1:\n%vData: %v\n\n", T1, T1.Data()) + + // Further example of how AsFortran works: + orig := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) + T2 := New(WithShape(2, 3), AsFortran(orig)) + fmt.Printf("Original\n%vData: %v\n", orig, orig.Data()) + fmt.Printf("T2:\n%vData: %v\n", T2, T2.Data()) + + // Output: + // T0: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // Data: [0 1 2 3 4 5] + // + // T1: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // Data: [0 3 1 4 2 5] + // + // Original + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // Data: [0 1 2 3 4 5] + // T2: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // Data: [0 3 1 4 2 5] +} + +// The AsDenseDiag construction option creates a dense diagonal matrix from the input, either a slice or a tensor. +// The resulting shape is automatically inferred from the input vector. +// +// This is like Numpy's `diag()` function, except not stupid. Numpy's `diag()` has been a cause of errors because it's somewhat isometric: +// >>> np.diag(np.diag(np.array([1,2,3]))) +// array([1,2,3]) +func Example_asDenseDiag() { + T := New(WithShape(3), WithBacking([]int{1, 2, 3})) + T1 := New(AsDenseDiag(T)) + fmt.Printf("T1:\n%v", T1) + + T2 := New(AsDenseDiag([]float64{3.14, 6.28, 11111})) + fmt.Printf("T2:\n%v", T2) + // Output: + // T1: + //⎡1 0 0⎤ + //⎢0 2 0⎥ + //⎣0 0 3⎦ + // T2: + // ⎡ 3.14 0 0⎤ + // ⎢ 0 6.28 0⎥ + // ⎣ 0 0 11111⎦ +} diff --git a/flags.go b/flags.go index dfe551e..e8a00d0 100644 --- a/flags.go +++ b/flags.go @@ -13,8 +13,13 @@ const ( // A data can either be Contiguous (0) or NonContiguous (2). // The way DataOrder was designed causes the default to be Contiguous. NonContiguous + + // Transposed indicates that the data has been transposed + Transposed ) +var dataOrderNames = []rune("NonContiguous, RowMajorᵀNonContiguous, ColMajorᵀ") + // MakeDataOrder makes a data order. Typical examples: // MakeDataOrder(DataOrder(0)) // Row Major, contiguous // MakeDataOrder(NonContiguous // Row Major, non-contiguous @@ -30,13 +35,47 @@ func MakeDataOrder(fs ...DataOrder) (retVal DataOrder) { return } -func (f DataOrder) isColMajor() bool { return (f & ColMajor) != 0 } -func (f DataOrder) isRowMajor() bool { return !f.isColMajor() } -func (f DataOrder) isContiguous() bool { return !f.isNotContiguous() } -func (f DataOrder) isNotContiguous() bool { return (f & NonContiguous) != 0 } +// IsColMajor returns true if the data order describes a col-major data +func (f DataOrder) IsColMajor() bool { return (f & ColMajor) != 0 } + +// IsRowMajor returns true if the data order describes a row-major data +func (f DataOrder) IsRowMajor() bool { return !f.IsColMajor() } + +// IsContiguous returns true if the data order describes a contiguous data. +func (f DataOrder) IsContiguous() bool { return !f.IsNotContiguous() } + +// IsNotContiguous returns true if the data order describes a noncontiguous data. +func (f DataOrder) IsNotContiguous() bool { return (f & NonContiguous) != 0 } + +// IsTransposed returns true if the data order describes whether the data has been tranposed (but not moved) +func (f DataOrder) IsTransposed() bool { return (f & Transposed) != 0 } + func (f DataOrder) toggleColMajor() DataOrder { return f ^ (ColMajor) } -func (f DataOrder) hasSameOrder(other DataOrder) bool { - return (f.isColMajor() && other.isColMajor()) || (f.isRowMajor() && other.isRowMajor()) + +func (f DataOrder) clearTransposed() DataOrder { return f &^ (Transposed) } + +func (f DataOrder) HasSameOrder(other DataOrder) bool { + return (f.IsColMajor() && other.IsColMajor()) || (f.IsRowMajor() && other.IsRowMajor()) +} + +func (f DataOrder) String() string { + var start, end int + if f.IsRowMajor() { + end = 23 + if f.IsContiguous() { + start = 3 + } + } else { + end = 47 + start = 24 + if f.IsContiguous() { + start = 27 + } + } + if f.IsTransposed() { + end++ + } + return string(dataOrderNames[start:end]) } // Triangle is a flag representing the "triangle"ness of a matrix diff --git a/flags_test.go b/flags_test.go index 98a8772..83dd3be 100644 --- a/flags_test.go +++ b/flags_test.go @@ -35,29 +35,56 @@ func TestMemoryFlag(t *testing.T) { func TestDataOrder(t *testing.T) { var defaultFlag DataOrder - if defaultFlag.isColMajor() || defaultFlag.isNotContiguous() { - t.Errorf("Expected default flag to be row major and contiguous") + if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { + t.Error("Expected default flag to be row major and contiguous and not transposed") } - if !(defaultFlag.isRowMajor() && defaultFlag.isContiguous()) { - t.Errorf("Expected default flag to be row major and contiguous") + if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { + t.Error("Expected default flag to be row major and contiguous") + } + if defaultFlag.String() != "Contiguous, RowMajor" { + t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + ncrm := MakeDataOrder(NonContiguous) + if ncrm.IsColMajor() || ncrm.IsContiguous() { + t.Error("Expected noncontiguous row major.") + } + if ncrm.String() != "NonContiguous, RowMajor" { + t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) } cm := ColMajor - if cm.isRowMajor() { - t.Errorf("colMajor cannot be rowMajor") + if cm.IsRowMajor() { + t.Error("colMajor cannot be rowMajor") + } + if cm.IsNotContiguous() { + t.Error("ColMajor by default is contiguous") } - if cm.isNotContiguous() { - t.Errorf("ColMajor by default is contiguous") + if cm.String() != "Contiguous, ColMajor" { + t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) } // check toggle rm := cm.toggleColMajor() - if rm.isColMajor() { + if rm.IsColMajor() { t.Errorf("toggled cm should be rm") } cm = rm.toggleColMajor() - if cm.isRowMajor() { + if cm.IsRowMajor() { t.Errorf("toggled rm should be cm") } + + transposed := MakeDataOrder(Transposed) + if !transposed.IsTransposed() { + t.Error("Expected transposed flag to be set") + } + if transposed.String() != "Contiguous, RowMajorᵀ" { + t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) + } + untransposed := transposed.clearTransposed() + if untransposed != defaultFlag { + t.Error("Expected default flag after untransposing") + } + } diff --git a/interfaces.go b/interfaces.go index 40be33d..c4a11c2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -76,7 +76,6 @@ type DenseTensor interface { Tensor Info() *AP - DataOrder() DataOrder IsMatrix() bool IsVector() bool IsRowVec() bool @@ -89,6 +88,7 @@ type DenseTensor interface { rtype() reflect.Type reshape(dims ...int) error + setDataOrder(o DataOrder) isTransposed() bool ostrides() []int oshape() Shape diff --git a/internal/IDLs/generated.fbs b/internal/IDLs/generated.fbs new file mode 100644 index 0000000..47ffce2 --- /dev/null +++ b/internal/IDLs/generated.fbs @@ -0,0 +1,38 @@ +// Generated from generated.proto + +namespace gorgonia.org.tensor.internal.serialization.pb; + +enum Triangle : int { + NOT_TRIANGLE = 0, + UPPER = 1, + LOWER = 2, + SYMMETRIC = 3, +} + +table AP { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; +} + +table Dense { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; + type:string; + data:[ubyte]; +} + +table MaskedDense { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; + type:string; + data:[ubyte]; + mask:[bool]; + mask_is_soft:[bool]; +} + diff --git a/internal/IDLs/generated.proto b/internal/IDLs/generated.proto new file mode 100755 index 0000000..c737106 --- /dev/null +++ b/internal/IDLs/generated.proto @@ -0,0 +1,52 @@ +syntax = "proto3"; +package gorgonia.org.tensor.internal.serialization.pb; + +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +option (gogoproto.protosizer_all) = true; +option (gogoproto.sizer_all) = false; +option go_package = "pb"; + +message AP { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; +} + +message Dense { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; + string type = 5; + bytes data = 6; +} + +message MaskedDense { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; + string type = 5; + bytes data = 6; + repeated bool mask = 7; + repeated bool mask_is_soft = 8; +} + +enum Triangle { + option (gogoproto.enumdecl) = false; + option (gogoproto.goproto_enum_prefix) = false; + option (gogoproto.goproto_enum_stringer) = false; + NOT_TRIANGLE = 0 [(gogoproto.enumvalue_customname) = "NotTriangle"]; + UPPER = 1 [(gogoproto.enumvalue_customname) = "Upper"]; + LOWER = 2 [(gogoproto.enumvalue_customname) = "Lower"]; + SYMMETRIC = 3 [(gogoproto.enumvalue_customname) = "Symmetric"]; +} + diff --git a/internal/serialization/README.md b/internal/serialization/README.md new file mode 100644 index 0000000..d3d8149 --- /dev/null +++ b/internal/serialization/README.md @@ -0,0 +1,33 @@ +# Serialization # + +This pseudopackage of sorts handles serialization. The "Canonical" serialized data structure is found in the `pb` subdirectory. + +# Protobuf generation + +Proteus needs to be installed, as does its dependencies. + + +1. `cd pb` +2. `rm generated*` +3. `proteus -f ../../IDLs -p gorgonia.org/tensor/internal/serialization/pb` +4. `cd ../../IDLs` +5. `find gorgonia.org/ -mindepth 2 -type f -exec mv -i '{}' . ';'` +6. `rm -rf gorgonia.org` + + +# FlatBuffer generation +1. generate protobuf first +2. delete the `import "github.com/gogo/protobuf/gogoproto/gogo.proto";` line from the generated protobuf file +3. `flatc --proto PATH/TO/generated.proto` +4. place the `generated.fbs` file in the IDLs directory +4. restore the import line in the `generated.proto` file +5. From this directory: `flatc --go-namespace fb -g PATH/TO/generated.fbs` + + +# Notes # + +`find gorgonia.org/ -mindepth 2 -type f -exec mv -i '{}' . ';'` is used to flatten and put all the stuff in the root IDLs directory. + +# The Serialization Story # + +To serialize, we copy/convert/coerce the data to the internal/serialization data structures, then call the `Marshall` methods from there \ No newline at end of file diff --git a/internal/serialization/doc.go b/internal/serialization/doc.go new file mode 100644 index 0000000..c4cb59b --- /dev/null +++ b/internal/serialization/doc.go @@ -0,0 +1,2 @@ +// package serialization provides the data structures for serialization +package serialization diff --git a/internal/serialization/fb/AP.go b/internal/serialization/fb/AP.go new file mode 100644 index 0000000..b3ca806 --- /dev/null +++ b/internal/serialization/fb/AP.go @@ -0,0 +1,110 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type AP struct { + _tab flatbuffers.Table +} + +func GetRootAsAP(buf []byte, offset flatbuffers.UOffsetT) *AP { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &AP{} + x.Init(buf, n+offset) + return x +} + +func (rcv *AP) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *AP) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *AP) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *AP) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *AP) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *AP) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *AP) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *AP) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *AP) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *AP) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func APStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func APAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func APStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func APAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func APStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func APAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func APAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func APEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/Dense.go b/internal/serialization/fb/Dense.go new file mode 100644 index 0000000..2a961ee --- /dev/null +++ b/internal/serialization/fb/Dense.go @@ -0,0 +1,152 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type Dense struct { + _tab flatbuffers.Table +} + +func GetRootAsDense(buf []byte, offset flatbuffers.UOffsetT) *Dense { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Dense{} + x.Init(buf, n+offset) + return x +} + +func (rcv *Dense) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Dense) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Dense) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *Dense) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *Dense) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Dense) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *Dense) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Dense) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func (rcv *Dense) Type() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Dense) Data(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Dense) DataLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) DataBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func DenseStart(builder *flatbuffers.Builder) { + builder.StartObject(6) +} +func DenseAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func DenseStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func DenseAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func DenseStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func DenseAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func DenseAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func DenseAddType(builder *flatbuffers.Builder, type_ flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(type_), 0) +} +func DenseAddData(builder *flatbuffers.Builder, data flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(data), 0) +} +func DenseStartDataVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func DenseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/MaskedDense.go b/internal/serialization/fb/MaskedDense.go new file mode 100644 index 0000000..271e77e --- /dev/null +++ b/internal/serialization/fb/MaskedDense.go @@ -0,0 +1,198 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type MaskedDense struct { + _tab flatbuffers.Table +} + +func GetRootAsMaskedDense(buf []byte, offset flatbuffers.UOffsetT) *MaskedDense { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &MaskedDense{} + x.Init(buf, n+offset) + return x +} + +func (rcv *MaskedDense) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *MaskedDense) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *MaskedDense) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *MaskedDense) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *MaskedDense) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *MaskedDense) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *MaskedDense) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *MaskedDense) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func (rcv *MaskedDense) Type() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *MaskedDense) Data(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) DataLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) DataBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *MaskedDense) Mask(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) MaskLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) MaskIsSoft(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) MaskIsSoftLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func MaskedDenseStart(builder *flatbuffers.Builder) { + builder.StartObject(8) +} +func MaskedDenseAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func MaskedDenseStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MaskedDenseAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func MaskedDenseStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MaskedDenseAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func MaskedDenseAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func MaskedDenseAddType(builder *flatbuffers.Builder, type_ flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(type_), 0) +} +func MaskedDenseAddData(builder *flatbuffers.Builder, data flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(data), 0) +} +func MaskedDenseStartDataVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseAddMask(builder *flatbuffers.Builder, mask flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(6, flatbuffers.UOffsetT(mask), 0) +} +func MaskedDenseStartMaskVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseAddMaskIsSoft(builder *flatbuffers.Builder, maskIsSoft flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(7, flatbuffers.UOffsetT(maskIsSoft), 0) +} +func MaskedDenseStartMaskIsSoftVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/Triangle.go b/internal/serialization/fb/Triangle.go new file mode 100644 index 0000000..599a06b --- /dev/null +++ b/internal/serialization/fb/Triangle.go @@ -0,0 +1,18 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +const ( + TriangleNOT_TRIANGLE = 0 + TriangleUPPER = 1 + TriangleLOWER = 2 + TriangleSYMMETRIC = 3 +) + +var EnumNamesTriangle = map[int]string{ + TriangleNOT_TRIANGLE:"NOT_TRIANGLE", + TriangleUPPER:"UPPER", + TriangleLOWER:"LOWER", + TriangleSYMMETRIC:"SYMMETRIC", +} + diff --git a/internal/serialization/pb/dense.go b/internal/serialization/pb/dense.go new file mode 100644 index 0000000..950c3ff --- /dev/null +++ b/internal/serialization/pb/dense.go @@ -0,0 +1,45 @@ +package pb + +//proteus:generate +type DataOrder byte + +// the reason for spreading the states out is because proteaus cannot handle non-iota tates +const ( + RowMajorContiguous = iota + RowMajorNonContiguous + ColMajorContiguous + ColMajorNonContiguous +) + +//proteus:generate +type Triangle byte + +const ( + NotTriangle Triangle = iota + Upper + Lower + Symmetric +) + +//proteus:generate +type AP struct { + Shape []int32 + Strides []int32 + + O DataOrder + T Triangle +} + +//proteus:generate +type Dense struct { + AP + Type string // type name + Data []byte +} + +//proteus:generate +type MaskedDense struct { + Dense + Mask []bool + MaskIsSoft []bool +} diff --git a/internal/serialization/pb/generated.pb.go b/internal/serialization/pb/generated.pb.go new file mode 100644 index 0000000..831ce90 --- /dev/null +++ b/internal/serialization/pb/generated.pb.go @@ -0,0 +1,1457 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: gorgonia.org/tensor/internal/serialization/pb/generated.proto + +/* + Package pb is a generated protocol buffer package. + + It is generated from these files: + gorgonia.org/tensor/internal/serialization/pb/generated.proto + + It has these top-level messages: + AP + Dense + MaskedDense +*/ +package pb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/gogo/protobuf/gogoproto" + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +var Triangle_name = map[int32]string{ + 0: "NOT_TRIANGLE", + 1: "UPPER", + 2: "LOWER", + 3: "SYMMETRIC", +} +var Triangle_value = map[string]int32{ + "NOT_TRIANGLE": 0, + "UPPER": 1, + "LOWER": 2, + "SYMMETRIC": 3, +} + +func (Triangle) EnumDescriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{0} } + +func (m *AP) Reset() { *m = AP{} } +func (m *AP) String() string { return proto.CompactTextString(m) } +func (*AP) ProtoMessage() {} +func (*AP) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{0} } + +func (m *Dense) Reset() { *m = Dense{} } +func (m *Dense) String() string { return proto.CompactTextString(m) } +func (*Dense) ProtoMessage() {} +func (*Dense) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{1} } + +func (m *MaskedDense) Reset() { *m = MaskedDense{} } +func (m *MaskedDense) String() string { return proto.CompactTextString(m) } +func (*MaskedDense) ProtoMessage() {} +func (*MaskedDense) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{2} } + +func init() { + proto.RegisterType((*AP)(nil), "gorgonia.org.tensor.internal.serialization.pb.AP") + proto.RegisterType((*Dense)(nil), "gorgonia.org.tensor.internal.serialization.pb.Dense") + proto.RegisterType((*MaskedDense)(nil), "gorgonia.org.tensor.internal.serialization.pb.MaskedDense") + proto.RegisterEnum("gorgonia.org.tensor.internal.serialization.pb.Triangle", Triangle_name, Triangle_value) +} +func (m *AP) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *AP) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA2 := make([]byte, len(m.Shape)*10) + var j1 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA2[j1] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j1++ + } + dAtA2[j1] = uint8(num) + j1++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j1)) + i += copy(dAtA[i:], dAtA2[:j1]) + } + if len(m.Strides) > 0 { + dAtA4 := make([]byte, len(m.Strides)*10) + var j3 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j3)) + i += copy(dAtA[i:], dAtA4[:j3]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + return i, nil +} + +func (m *Dense) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Dense) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA6 := make([]byte, len(m.Shape)*10) + var j5 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA6[j5] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j5++ + } + dAtA6[j5] = uint8(num) + j5++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j5)) + i += copy(dAtA[i:], dAtA6[:j5]) + } + if len(m.Strides) > 0 { + dAtA8 := make([]byte, len(m.Strides)*10) + var j7 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA8[j7] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j7++ + } + dAtA8[j7] = uint8(num) + j7++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j7)) + i += copy(dAtA[i:], dAtA8[:j7]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + if len(m.Type) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Type))) + i += copy(dAtA[i:], m.Type) + } + if len(m.Data) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + return i, nil +} + +func (m *MaskedDense) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *MaskedDense) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA10 := make([]byte, len(m.Shape)*10) + var j9 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA10[j9] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j9++ + } + dAtA10[j9] = uint8(num) + j9++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j9)) + i += copy(dAtA[i:], dAtA10[:j9]) + } + if len(m.Strides) > 0 { + dAtA12 := make([]byte, len(m.Strides)*10) + var j11 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA12[j11] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j11++ + } + dAtA12[j11] = uint8(num) + j11++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j11)) + i += copy(dAtA[i:], dAtA12[:j11]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + if len(m.Type) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Type))) + i += copy(dAtA[i:], m.Type) + } + if len(m.Data) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + if len(m.Mask) > 0 { + dAtA[i] = 0x3a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Mask))) + for _, b := range m.Mask { + if b { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + } + if len(m.MaskIsSoft) > 0 { + dAtA[i] = 0x42 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.MaskIsSoft))) + for _, b := range m.MaskIsSoft { + if b { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + } + return i, nil +} + +func encodeVarintGenerated(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *AP) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + return n +} + +func (m *Dense) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + l = len(m.Type) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + return n +} + +func (m *MaskedDense) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + l = len(m.Type) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + if len(m.Mask) > 0 { + n += 1 + sovGenerated(uint64(len(m.Mask))) + len(m.Mask)*1 + } + if len(m.MaskIsSoft) > 0 { + n += 1 + sovGenerated(uint64(len(m.MaskIsSoft))) + len(m.MaskIsSoft)*1 + } + return n +} + +func sovGenerated(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozGenerated(x uint64) (n int) { + return sovGenerated(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *AP) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: AP: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: AP: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Dense) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Dense: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Dense: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Type = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *MaskedDense) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: MaskedDense: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: MaskedDense: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Type = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + case 7: + if wireType == 0 { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Mask = append(m.Mask, bool(v != 0)) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Mask = append(m.Mask, bool(v != 0)) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Mask", wireType) + } + case 8: + if wireType == 0 { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.MaskIsSoft = append(m.MaskIsSoft, bool(v != 0)) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.MaskIsSoft = append(m.MaskIsSoft, bool(v != 0)) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field MaskIsSoft", wireType) + } + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipGenerated(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthGenerated + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipGenerated(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthGenerated = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowGenerated = fmt.Errorf("proto: integer overflow") +) + +func init() { + proto.RegisterFile("gorgonia.org/tensor/internal/serialization/pb/generated.proto", fileDescriptorGenerated) +} + +var fileDescriptorGenerated = []byte{ + // 482 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xd4, 0x93, 0x4f, 0x6b, 0x13, 0x41, + 0x18, 0xc6, 0x77, 0x92, 0x6e, 0x9b, 0x4c, 0x13, 0x0d, 0x43, 0x0f, 0x43, 0x94, 0xcd, 0xd8, 0xd3, + 0x22, 0x74, 0x17, 0xf4, 0x20, 0x08, 0x1e, 0x5a, 0x1b, 0x24, 0x90, 0x7f, 0x4c, 0x52, 0x44, 0x2f, + 0x61, 0xb6, 0x3b, 0xd9, 0x0e, 0x4d, 0x76, 0x96, 0x99, 0x29, 0x52, 0xef, 0x42, 0xcd, 0x27, 0xf0, + 0x12, 0xa8, 0xda, 0x83, 0x1f, 0xc3, 0xa3, 0x17, 0xc1, 0x4f, 0x20, 0x92, 0x7e, 0x01, 0xcf, 0x9e, + 0x64, 0x27, 0x44, 0xe2, 0xd1, 0x9b, 0x3d, 0xcd, 0xf3, 0xfc, 0x66, 0x9e, 0x77, 0xde, 0x97, 0x61, + 0xe0, 0x93, 0x44, 0xaa, 0x44, 0xa6, 0x82, 0x05, 0x52, 0x25, 0xa1, 0xe1, 0xa9, 0x96, 0x2a, 0x14, + 0xa9, 0xe1, 0x2a, 0x65, 0x93, 0x50, 0x73, 0x25, 0xd8, 0x44, 0xbc, 0x66, 0x46, 0xc8, 0x34, 0xcc, + 0xa2, 0x30, 0xe1, 0x29, 0x57, 0xcc, 0xf0, 0x38, 0xc8, 0x94, 0x34, 0x12, 0xed, 0xad, 0xc7, 0x83, + 0x65, 0x3c, 0x58, 0xc5, 0x83, 0xbf, 0xe2, 0x41, 0x16, 0xd5, 0xf7, 0x12, 0x61, 0x4e, 0xce, 0xa2, + 0xe0, 0x58, 0x4e, 0xc3, 0x44, 0x26, 0x32, 0xb4, 0x55, 0xa2, 0xb3, 0xb1, 0x75, 0xd6, 0x58, 0xb5, + 0xac, 0xbe, 0xfb, 0x01, 0xc0, 0xc2, 0x7e, 0x1f, 0xed, 0x40, 0x57, 0x9f, 0xb0, 0x8c, 0x63, 0x40, + 0x8a, 0xbe, 0x4b, 0x97, 0x06, 0x61, 0xb8, 0xa5, 0x8d, 0x12, 0x31, 0xd7, 0xb8, 0x60, 0xf9, 0xca, + 0xa2, 0x3b, 0x10, 0x48, 0x5c, 0x24, 0xc0, 0xaf, 0x1e, 0x54, 0x7f, 0x7d, 0x6f, 0x94, 0x0f, 0x99, + 0x61, 0x3d, 0x15, 0x73, 0x45, 0x81, 0x44, 0x4d, 0x08, 0x0c, 0xde, 0x20, 0xc0, 0xbf, 0xf5, 0xe0, + 0x51, 0xf0, 0x4f, 0xdd, 0x07, 0x43, 0x25, 0x58, 0x9a, 0x4c, 0x38, 0x05, 0xe6, 0x71, 0xe9, 0xe2, + 0xb2, 0xe1, 0xfc, 0x7c, 0xdf, 0x70, 0x76, 0xbf, 0x02, 0xe8, 0x1e, 0xf2, 0x54, 0xf3, 0xff, 0xb1, + 0x4f, 0x84, 0xe0, 0x86, 0x39, 0xcf, 0x38, 0x76, 0x09, 0xf0, 0xcb, 0xd4, 0xea, 0x9c, 0xc5, 0xcc, + 0x30, 0xbc, 0x49, 0x80, 0x5f, 0xa1, 0x56, 0xaf, 0xcd, 0xf3, 0xb6, 0x00, 0xb7, 0x3b, 0x4c, 0x9f, + 0xf2, 0xf8, 0xc6, 0x4f, 0x95, 0xb3, 0x29, 0xd3, 0xa7, 0x78, 0x8b, 0x14, 0xfd, 0x12, 0xb5, 0x1a, + 0x11, 0x58, 0xc9, 0xd7, 0x91, 0xd0, 0x23, 0x2d, 0xc7, 0x06, 0x97, 0xec, 0x1e, 0xcc, 0x59, 0x4b, + 0x0f, 0xe4, 0x78, 0xed, 0x6d, 0xef, 0xbf, 0x01, 0xb0, 0xb4, 0xba, 0x17, 0xdd, 0x83, 0x95, 0x6e, + 0x6f, 0x38, 0x1a, 0xd2, 0xd6, 0x7e, 0xf7, 0x59, 0xbb, 0x59, 0x73, 0xea, 0xb7, 0x67, 0x73, 0xb2, + 0xdd, 0x95, 0xe6, 0xcf, 0x91, 0x1d, 0xe8, 0x1e, 0xf5, 0xfb, 0x4d, 0x5a, 0x03, 0xf5, 0xf2, 0x6c, + 0x4e, 0xdc, 0xa3, 0x2c, 0xe3, 0x2a, 0xa7, 0xed, 0xde, 0xf3, 0x26, 0xad, 0x15, 0x96, 0xb4, 0x2d, + 0x5f, 0x71, 0x85, 0xee, 0xc2, 0xf2, 0xe0, 0x45, 0xa7, 0xd3, 0x1c, 0xd2, 0xd6, 0xd3, 0x5a, 0xb1, + 0x5e, 0x9d, 0xcd, 0x49, 0x79, 0x70, 0x3e, 0x9d, 0x72, 0xa3, 0xc4, 0x71, 0xbd, 0x72, 0xf1, 0xd1, + 0x73, 0x3e, 0x5d, 0x79, 0xce, 0xe7, 0x2b, 0xcf, 0x39, 0xc0, 0x5f, 0x16, 0x1e, 0xf8, 0xb6, 0xf0, + 0xc0, 0x8f, 0x85, 0xe7, 0xbc, 0xbb, 0xf6, 0x9c, 0xcb, 0x6b, 0x0f, 0xbc, 0x2c, 0x64, 0x51, 0xb4, + 0x69, 0x7f, 0xca, 0xc3, 0xdf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x0f, 0xff, 0xbb, 0x8f, 0xc8, 0x03, + 0x00, 0x00, +} diff --git a/iterator.go b/iterator.go index 70fa810..0db1158 100644 --- a/iterator.go +++ b/iterator.go @@ -1,6 +1,8 @@ package tensor -import "runtime" +import ( + "runtime" +) func requiresOrderedIterator(e Engine, t Tensor) bool { if t.IsScalar() { @@ -70,7 +72,7 @@ func NewIterator(aps ...*AP) Iterator { case 0: return nil case 1: - return NewFlatIterator(aps[0]) + return newFlatIterator(aps[0]) default: return NewMultIterator(aps...) } @@ -111,8 +113,8 @@ func iteratorLoadAP(it Iterator, ap *AP) { /* FLAT ITERATOR */ -// FlatIterator is an iterator that iterates over Tensors. It utilizes the *AP -// of a Tensor to determine what the next index is. +// FlatIterator is an iterator that iterates over Tensors according to the data's layout. +// It utilizes the *AP of a Tensor to determine what the next index is. // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course // (such as, not allowing negative indices) type FlatIterator struct { @@ -129,16 +131,20 @@ type FlatIterator struct { isScalar bool isVector bool + + outerFirst bool } -// NewFlatIterator creates a new FlatIterator. -func NewFlatIterator(ap *AP) *FlatIterator { +// newFlatIterator creates a new FlatIterator. +func newFlatIterator(ap *AP) *FlatIterator { var strides0 int - if ap.IsVector() { - strides0 = ap.strides[0] - } else if ap.o.isColMajor() { + + if len(ap.strides) == 1 { strides0 = ap.strides[0] } + // else if ap.o.isColMajor() { + // strides0 = ap.strides[len(ap.strides)-1] + // } return &FlatIterator{ AP: ap, @@ -147,13 +153,13 @@ func NewFlatIterator(ap *AP) *FlatIterator { strides0: strides0, isScalar: ap.IsScalar(), - isVector: ap.IsVector(), + isVector: len(ap.strides) == 1, } } // FlatIteratorFromDense creates a new FlatIterator from a dense tensor func FlatIteratorFromDense(tt DenseTensor) *FlatIterator { - return NewFlatIterator(tt.Info()) + return newFlatIterator(tt.Info()) } // SetReverse initializes iterator to run backwards @@ -200,6 +206,9 @@ func (it *FlatIterator) Next() (int, error) { if it.reverse { return it.ndPrevious() } + if it.outerFirst { + return it.colMajorNDNext() + } return it.ndNext() } } @@ -233,6 +242,11 @@ func (it *FlatIterator) NextValid() (int, int, error) { a, err := it.ndPrevious() return a, -1, err } + + if it.outerFirst { + a, err := it.colMajorNDNext() + return a, 1, err + } a, err := it.ndNext() return a, 1, err } @@ -293,7 +307,6 @@ func (it *FlatIterator) singlePrevious() (int, error) { if tracked < 0 { it.done = true } - return it.lastIndex, nil } @@ -332,7 +345,39 @@ func (it *FlatIterator) ndNext() (int, error) { } func (it *FlatIterator) colMajorNDNext() (int, error) { - return 0, nil + // the reason for this weird looking bits of code is because the SSA compiler doesn't + // know how to optimize for this bit of code, not keeping things in registers correctly + // @stuartcarnie optimized this iout to great effect + + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + // the following 3 lines causes the compiler to perform bounds check here, + // instead of being done in the loop + coord := it.shape[:v+1] + track := it.track[:v+1] + strides := it.strides[:v+1] + for i := 0; i <= v; i++ { + track[i]++ + shapeI := coord[i] + strideI := strides[i] + + if track[i] == shapeI { + if i == v { + it.done = true + } + track[i] = 0 + + nextIndex -= (shapeI - 1) * strideI + continue + } + nextIndex += strideI + break + } + it.nextIndex = nextIndex + return it.lastIndex, nil + } func (it *FlatIterator) ndPrevious() (int, error) { @@ -353,6 +398,7 @@ func (it *FlatIterator) ndPrevious() (int, error) { return it.lastIndex, nil } +// TODO v0.9.0 func (it *FlatIterator) colMajorNDPrevious() (int, error) { return 0, nil } @@ -424,10 +470,12 @@ func (it *FlatIterator) Reset() { switch { case it.IsScalar(): it.nextIndex = 0 - case it.IsRowVec(): - it.nextIndex = (it.shape[1] - 1) * it.strides[0] - case it.IsColVec(), it.IsVector(): + case it.isVector: it.nextIndex = (it.shape[0] - 1) * it.strides[0] + // case it.IsRowVec(): + // it.nextIndex = (it.shape[1] - 1) * it.strides[1] + // case it.IsColVec(): + // it.nextIndex = (it.shape[0] - 1) * it.strides[0] default: it.nextIndex = 0 for i := range it.track { diff --git a/iterator_mult.go b/iterator_mult.go index a0af458..74f9a4b 100644 --- a/iterator_mult.go +++ b/iterator_mult.go @@ -97,16 +97,19 @@ func NewMultIterator(aps ...*AP) *MultIterator { ReturnInts(apStrides) // Borrowed in BroadcastStrides but returned here - dangerous pattern? nBlocks++ } - ap2 := NewAP(it.shape[:maxDims], it.strides[offset:offset+maxDims]) - ap2.o = ap.o - ap2.Δ = ap.Δ - + ap2 := MakeAP(it.shape[:maxDims], it.strides[offset:offset+maxDims], ap.o, ap.Δ) it.whichBlock[i] = f - it.fitArr[nBlocks-1] = NewFlatIterator(ap2) + it.fitArr[nBlocks-1] = newFlatIterator(&ap2) } it.fitArr = it.fitArr[:nBlocks] it.strides = it.strides[:nBlocks*maxDims] + // fill 0s with 1s + for i := range it.strides { + if it.strides[i] == 0 { + it.strides[i] = 1 + } + } it.fit0 = it.fitArr[0] for _, f := range it.fitArr { @@ -120,7 +123,7 @@ func NewMultIterator(aps ...*AP) *MultIterator { // MultIteratorFromDense creates a new MultIterator from a list of dense tensors func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { - aps := BorrowAPList(len(tts)) + aps := make([]*AP, len(tts)) hasMask := BorrowBools(len(tts)) defer ReturnBools(hasMask) @@ -155,7 +158,6 @@ func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { } } it.numMasked = numMasked - ReturnAPList(aps) return it } @@ -221,7 +223,9 @@ func (it *MultIterator) Next() (int, error) { } it.done = false for _, f := range it.fitArr { - f.Next() + if _, err := f.Next(); err != nil { + return -1, err + } it.done = it.done || f.done } for i, j := range it.whichBlock { diff --git a/iterator_test.go b/iterator_test.go index 1d7f170..d0ca6de 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -6,6 +6,12 @@ import ( "github.com/stretchr/testify/assert" ) +// newAP is a helper function now +func newAP(shape Shape, strides []int) *AP { + ap := MakeAP(shape, strides, 0, 0) + return &ap +} + var flatIterTests1 = []struct { shape Shape strides []int @@ -14,8 +20,8 @@ var flatIterTests1 = []struct { }{ {ScalarShape(), []int{}, []int{0}}, // scalar {Shape{5}, []int{1}, []int{0, 1, 2, 3, 4}}, // vector - {Shape{5, 1}, []int{1}, []int{0, 1, 2, 3, 4}}, // colvec - {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec + {Shape{5, 1}, []int{1, 1}, []int{0, 1, 2, 3, 4}}, // colvec + {Shape{1, 5}, []int{5, 1}, []int{0, 1, 2, 3, 4}}, // rowvec {Shape{2, 3}, []int{3, 1}, []int{0, 1, 2, 3, 4, 5}}, // basic mat {Shape{3, 2}, []int{1, 3}, []int{0, 3, 1, 4, 2, 5}}, // basic mat, transposed {Shape{2}, []int{2}, []int{0, 2}}, // basic 2x2 mat, sliced: Mat[:, 1] @@ -27,6 +33,11 @@ var flatIterTests1 = []struct { {Shape{4, 2, 3}, []int{1, 12, 4}, []int{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}}, // basic 3-Tensor (under (2, 0, 1) transpose) {Shape{3, 2, 4}, []int{4, 12, 1}, []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}}, // basic 3-Tensor (under (1, 0, 2) transpose) {Shape{4, 3, 2}, []int{1, 4, 12}, []int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, // basic 3-Tensor (under (2, 1, 0) transpose) + + // ARTIFICIAL CASES - TODO + // These cases should be impossible to reach in normal operation + // You would have to specially construct these + // {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec - NEARLY IMPOSSIBLE CASE- TODO } var flatIterSlices = []struct { @@ -49,8 +60,8 @@ func TestFlatIterator(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) for next, err := it.Next(); err == nil; next, err = it.Next() { nexts = append(nexts, next) } @@ -73,8 +84,8 @@ func TestFlatIteratorReverse(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) it.SetReverse() for next, err := it.Next(); err == nil; next, err = it.Next() { nexts = append(nexts, next) @@ -108,7 +119,7 @@ func TestMultIterator(t *testing.T) { for i, fit := range flatIterTests1 { nexts[0] = nexts[0][:0] err = nil - ap[0] = NewAP(fit.shape, fit.strides) + ap[0] = newAP(fit.shape, fit.strides) it = NewMultIterator(ap[0]) if reverse { it.SetReverse() @@ -124,43 +135,45 @@ func TestMultIterator(t *testing.T) { nexts[0][i], nexts[0][j] = nexts[0][j], nexts[0][i] } } - assert.Equal(fit.correct, nexts[0], "Repeating flat test %d", i) + assert.Equal(fit.correct, nexts[0], "Repeating flat test %d. Reverse? %v", i, reverse) } // Test multiple iterators simultaneously - var choices = []int{0, 0, 9, 9, 0, 9} - for j := 0; j < 6; j++ { - fit := flatIterTests1[choices[j]] - nexts[j] = nexts[j][:0] - err = nil - ap[j] = NewAP(fit.shape, fit.strides) - } - it = NewMultIterator(ap...) - if reverse { - it.SetReverse() - } - for _, err := it.Next(); err == nil; _, err = it.Next() { + /* + var choices = []int{0, 0, 9, 9, 0, 9} for j := 0; j < 6; j++ { - nexts[j] = append(nexts[j], it.LastIndex(j)) + fit := flatIterTests1[choices[j]] + nexts[j] = nexts[j][:0] + err = nil + ap[j] = newAP(fit.shape, fit.strides) } - - if _, ok := err.(NoOpError); err != nil && !ok { - t.Error(err) + it = NewMultIterator(ap...) + if reverse { + it.SetReverse() } - } + for _, err := it.Next(); err == nil; _, err = it.Next() { + for j := 0; j < 6; j++ { + nexts[j] = append(nexts[j], it.LastIndex(j)) + } - for j := 0; j < 6; j++ { - fit := flatIterTests1[choices[j]] - if reverse { - for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { - nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] + if _, ok := err.(NoOpError); err != nil && !ok { + t.Error(err) } } - if ap[j].IsScalar() { - assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) - } else { - assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) + + for j := 0; j < 6; j++ { + fit := flatIterTests1[choices[j]] + if reverse { + for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { + nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] + } + } + if ap[j].IsScalar() { + assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) + } else { + assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) + } } - } + */ } } @@ -177,7 +190,7 @@ func TestIteratorInterface(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) + ap = newAP(fit.shape, fit.strides) it = NewIterator(ap) for next, err := it.Start(); err == nil; next, err = it.Next() { nexts = append(nexts, next) @@ -223,8 +236,8 @@ func TestFlatIterator_Chan(t *testing.T) { // basic stuff for i, fit := range flatIterTests1 { nexts = nexts[:0] - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) ch := it.Chan() for next := range ch { nexts = append(nexts, next) @@ -242,8 +255,8 @@ func TestFlatIterator_Slice(t *testing.T) { var nexts []int for i, fit := range flatIterTests1 { - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) nexts, err = it.Slice(nil) if _, ok := err.(NoOpError); err != nil && !ok { t.Error(err) @@ -276,8 +289,8 @@ func TestFlatIterator_Coord(t *testing.T) { // var nexts []int var donecount int - ap = NewAP(Shape{2, 3, 4}, []int{12, 4, 1}) - it = NewFlatIterator(ap) + ap = newAP(Shape{2, 3, 4}, []int{12, 4, 1}) + it = newFlatIterator(ap) var correct = [][]int{ {0, 0, 1}, @@ -315,8 +328,8 @@ func TestFlatIterator_Coord(t *testing.T) { // really this is just for completeness sake func TestFlatIterator_Reset(t *testing.T) { assert := assert.New(t) - ap := NewAP(Shape{2, 3, 4}, []int{12, 4, 1}) - it := NewFlatIterator(ap) + ap := newAP(Shape{2, 3, 4}, []int{12, 4, 1}) + it := newFlatIterator(ap) it.Next() it.Next() @@ -349,7 +362,7 @@ type oldFlatIterator struct { done bool } -// NewFlatIterator creates a new FlatIterator +// newFlatIterator creates a new FlatIterator func newOldFlatIterator(ap *AP) *oldFlatIterator { return &oldFlatIterator{ AP: ap, @@ -406,7 +419,7 @@ func BenchmarkOldFlatIterator(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := newOldFlatIterator(ap) for n := 0; n < b.N; n++ { @@ -426,8 +439,8 @@ func BenchmarkFlatIterator(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) - it := NewFlatIterator(ap) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + it := newFlatIterator(ap) for n := 0; n < b.N; n++ { for _, err := it.Next(); err == nil; _, err = it.Next() { @@ -450,8 +463,8 @@ func BenchmarkFlatIteratorParallel6(b *testing.B) { it := make([]*FlatIterator, 6) for j := 0; j < 6; j++ { - ap[j] = NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) - it[j] = NewFlatIterator(ap[j]) + ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + it[j] = newFlatIterator(ap[j]) } for n := 0; n < b.N; n++ { @@ -476,7 +489,7 @@ func BenchmarkFlatIteratorMulti1(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := NewMultIterator(ap) @@ -496,7 +509,7 @@ func BenchmarkFlatIteratorGeneric1(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := NewIterator(ap) @@ -519,7 +532,7 @@ func BenchmarkFlatIteratorMulti6(b *testing.B) { ap := make([]*AP, 6) for j := 0; j < 6; j++ { - ap[j] = NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) } it := NewMultIterator(ap...) diff --git a/native/example_test.go b/native/example_test.go index 94b324a..740d103 100644 --- a/native/example_test.go +++ b/native/example_test.go @@ -6,8 +6,9 @@ import ( . "gorgonia.org/tensor" ) -// There are times where it is more effective to use native Go slice semantics to do work (for example, when performing batch work over kernels) -// NativeIterators are useful for this purpose. +// There are times where it is more effective to use native Go slice semantics to do work (for example, when performing batch work over kernels). +// Iterators are useful for this purpose. This package provides iterators for the standard types +// However, custom types are also available. See Vector, Matrix and Tensor3 examples. func Example_iterator() { var T *Dense T = New(WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) @@ -26,7 +27,7 @@ func Example_iterator() { } // The NativeSelect function squashes the dimensions, and returns an iterator in native Go slice semantics. -func Exampleselect() { +func Example_select() { // Selection is a bit of an interesting use case. Sometimes you don't want to iterate through the layers. // // For example, in a number of use cases where you have a 4-Tensor, you'd typically reshape it to some diff --git a/native/generic.go b/native/generic.go new file mode 100644 index 0000000..79d8dc3 --- /dev/null +++ b/native/generic.go @@ -0,0 +1,72 @@ +package native + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +func Vector(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 1, t.Dtype()); err != nil { + return nil, err + } + return t.Data(), nil +} + +func Matrix(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 2, t.Dtype()); err != nil { + return nil, err + } + + shape := t.Shape() + strides := t.Strides() + typ := t.Dtype().Type + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + + retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) + ptr := t.Uintptr() + for i := 0; i < rows; i++ { + e := retVal.Index(i) + sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) + sh.Data = uintptr(i*rowStride)*typ.Size() + ptr + sh.Len = cols + sh.Cap = cols + } + return retVal.Interface(), nil +} + +func Tensor3(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 3, t.Dtype()); err != nil { + return nil, err + } + shape := t.Shape() + strides := t.Strides() + typ := t.Dtype().Type + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(reflect.SliceOf(typ))), layers, layers) + ptr := t.Uintptr() + for i := 0; i < layers; i++ { + el := retVal.Index(i) + inner := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) + for j := 0; j < rows; j++ { + e := inner.Index(j) + sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) + sh.Data = uintptr(i*layerStride+j*rowStride)*typ.Size() + ptr + sh.Len = cols + sh.Cap = cols + } + sh := (*reflect.SliceHeader)(unsafe.Pointer(el.Addr().Pointer())) + sh.Data = inner.Index(0).Addr().Pointer() + sh.Len = rows + sh.Cap = rows + } + return retVal.Interface(), nil +} diff --git a/native/generic_test.go b/native/generic_test.go new file mode 100644 index 0000000..cf09802 --- /dev/null +++ b/native/generic_test.go @@ -0,0 +1,67 @@ +package native_test + +import ( + "fmt" + + "gorgonia.org/tensor" + . "gorgonia.org/tensor/native" +) + +type MyType int + +func Example_vector() { + backing := []MyType{ + 0, 1, 2, 3, + } + T := tensor.New(tensor.WithShape(4), tensor.WithBacking(backing)) + val, err := Vector(T) + if err != nil { + fmt.Printf("error: %v", err) + } + it := val.([]MyType) + fmt.Println(it) + + // Output: + // [0 1 2 3] +} + +func Example_matrix() { + backing := []MyType{ + 0, 1, + 2, 3, + 4, 5, + } + T := tensor.New(tensor.WithShape(3, 2), tensor.WithBacking(backing)) + val, err := Matrix(T) + if err != nil { + fmt.Printf("error: %v", err) + } + + it := val.([][]MyType) + fmt.Println(it) + + // Output: + // [[0 1] [2 3] [4 5]] +} + +func Example_tensor3() { + backing := []MyType{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + } + T := tensor.New(tensor.WithShape(2, 3, 4), tensor.WithBacking(backing)) + val, err := Tensor3(T) + if err != nil { + fmt.Printf("error: %v", err) + } + it := val.([][][]MyType) + fmt.Println(it) + + //Output: + // [[[0 1 2 3] [4 5 6 7] [8 9 10 11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23]]] +} diff --git a/perf.go b/perf.go index 573d8be..2d20df2 100644 --- a/perf.go +++ b/perf.go @@ -83,19 +83,13 @@ func ReturnTensor(t Tensor) { } switch tt := t.(type) { case *Dense: - if tt.old != nil { - ReturnAP(tt.old) - tt.old = nil - } + tt.AP.zero() if tt.transposeWith != nil { ReturnInts(tt.transposeWith) tt.transposeWith = nil } - // return AP - ReturnAP(tt.AP) - // array reset tt.t = Dtype{} tt.array.Ptr = nil @@ -109,7 +103,7 @@ func ReturnTensor(t Tensor) { tt.flag = 0 // other reset - tt.old = nil + tt.old.zero() tt.viewOf = 0 tt.transposeWith = nil @@ -124,63 +118,14 @@ func ReturnTensor(t Tensor) { } } -/* AP POOL */ - -var apPool = make(chan *AP, PoolSize) - -func borrowAP() *AP { - select { - case ap := <-apPool: - return ap - default: - return new(AP) - } - // return apPool.Get().(*AP) -} - -// BorrowAP gets an AP from the pool. USE WITH CAUTION. -func BorrowAP(dims int) *AP { - ap := borrowAP() - ap.shape = BorrowInts(dims) - ap.strides = BorrowInts(dims) - ap.shape = ap.shape[:cap(ap.shape)] - ap.strides = ap.strides[:cap(ap.strides)] - return ap -} - -// ReturnAP returns the AP to the pool. USE WITH CAUTION. -func ReturnAP(ap *AP) { - ReturnInts([]int(ap.shape)) - ReturnInts(ap.strides) - ap.fin = false - - ap.o = 0 - ap.Δ = 0 - - if len(apPool) < cap(apPool) { - apPool <- ap - } - // apPool.Put(ap) -} - /* ---------------------------------------------------------------- ------------------ Create Pools ------------------------------------------------------------------*/ /* APLIST POOL */ -var apListPool [maxAPDims]sync.Pool - // Init function func init() { - for i := range apListPool { - size := i - apListPool[i].New = func() interface{} { return make([]*AP, size) } - } - - // for i := 0; i < PoolSize; i++ { - // intsPool <- make([]int, 8, 8) - // } for i := range intsPool { size := i @@ -222,11 +167,13 @@ func BorrowInts(size int) []int { if retVal == nil { return make([]int, size) } + // log.Printf("Borrowing %p. Called by %v", retVal, string(debug.Stack())) return retVal.([]int)[:size] } // ReturnInts returns a slice from the pool. USE WITH CAUTION. func ReturnInts(is []int) { + // log.Printf("Returning %p. Called by %v", is, string(debug.Stack())) if is == nil { return } @@ -293,36 +240,6 @@ func ReturnBools(is []bool) { // boolsPool[size].Put(is) } -// BorrowAPList gets an APList from the pool. USE WITH CAUTION. -func BorrowAPList(size int) []*AP { - if size >= 8 { - return make([]*AP, size) - } - - retVal := apListPool[size].Get() - if retVal == nil { - return make([]*AP, size) - } - return retVal.([]*AP) -} - -// ReturnAPList returns the APList to the pool. USE WITH CAUTION. -func ReturnAPList(aps []*AP) { - if aps == nil { - return - } - size := cap(aps) - if size >= 8 { - return - } - aps = aps[:cap(aps)] - for i := range aps { - aps[i] = nil - } - - apListPool[size].Put(aps) -} - // var optPool = make(chan *OpOpt, PoolSize) // var optPool = newRingbuffer(PoolSize) var optPool = &sync.Pool{ diff --git a/shape.go b/shape.go index ba0b18f..cecb41d 100644 --- a/shape.go +++ b/shape.go @@ -24,17 +24,18 @@ func (s Shape) TotalSize() int { return ProdInts([]int(s)) } -func (s Shape) calcStrides() []int { +// CalcStrides calculates the default strides for a shape +func (s Shape) CalcStrides() []int { if s.IsScalar() { return nil } retVal := BorrowInts(len(s)) - if s.IsVector() { - retVal[0] = 1 - retVal = retVal[:1] - return retVal - } + // if s.IsVector() { + // retVal[0] = 1 + // retVal = retVal[:1] + // return retVal + // } acc := 1 for i := len(s) - 1; i >= 0; i-- { @@ -48,9 +49,9 @@ func (s Shape) calcStrides() []int { return retVal } -// calcStridesWithMask is similar to calcStrides, except that it has an argument, masks. It is used to mask out given dimensions +// CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions // during calculation of stride -func (s Shape) calcStridesWithMask(mask []bool) []int { +func (s Shape) CalcStridesWithMask(mask []bool) []int { if s.IsScalar() { return nil } @@ -84,7 +85,8 @@ func (s Shape) calcStridesWithMask(mask []bool) []int { return retVal } -func (s Shape) calcStridesColMajor() []int { +// CalcStridesColMajor is like CalcStrides, but assumes a col major layout +func (s Shape) CalcStridesColMajor() []int { if s.IsScalar() { return nil } @@ -152,7 +154,23 @@ func (s Shape) Clone() Shape { } // IsScalar returns true if the access pattern indicates it's a scalar value -func (s Shape) IsScalar() bool { return len(s) == 0 || (len(s) == 1 && s[0] == 1) } +func (s Shape) IsScalar() bool { + return len(s) == 0 || (len(s) == 1 && s[0] == 1) +} + +// IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value +func (s Shape) IsScalarEquiv() bool { + if len(s) == 0 { + return true + } + isEquiv := true + for i := range s { + if s[i] != 1 { + return false + } + } + return isEquiv +} // IsVector returns whether the access pattern falls into one of three possible definitions of vectors: // vanilla vector (not a row or a col) @@ -172,6 +190,9 @@ func (s Shape) IsMatrix() bool { return len(s) == 2 } // Dims returns the number of dimensions in the shape func (s Shape) Dims() int { return len(s) } +// DimSize returns the size of the dimension wanted. +// +// This method implemnents the DimSizer interface in Gorgonia. func (s Shape) DimSize(d int) (size int, err error) { if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) { err = errors.Errorf(dimMismatch, len(s), d) @@ -221,12 +242,14 @@ func (s Shape) S(slices ...Slice) (retVal Shape, err error) { } // drop any dimension with size 1, except the last dimension + offset := 0 dims := s.Dims() for d := 0; d < dims; d++ { - if retVal[d] == 1 /*&& d != t.dims-1 && dims > 2*/ { + if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { retVal = append(retVal[:d], retVal[d+1:]...) d-- dims-- + offset++ } } @@ -326,7 +349,7 @@ func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) { } else { // validate that the rest of the dimensions match up if newShape[d] != shp[d] { - err = errors.Errorf(dimMismatch, newShape[d], shp[d]) + err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d) return } } diff --git a/shape_test.go b/shape_test.go index 51fe64a..aa9e1de 100644 --- a/shape_test.go +++ b/shape_test.go @@ -90,36 +90,36 @@ func TestShapeCalcStride(t *testing.T) { // scalar shape s = Shape{} - assert.Nil(s.calcStrides()) + assert.Nil(s.CalcStrides()) s = Shape{1} - assert.Nil(s.calcStrides()) + assert.Nil(s.CalcStrides()) // vector shape s = Shape{2, 1} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{1, 1}, s.CalcStrides()) s = Shape{1, 2} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{2, 1}, s.CalcStrides()) s = Shape{2} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{1}, s.CalcStrides()) // matrix strides s = Shape{2, 2} - assert.Equal([]int{2, 1}, s.calcStrides()) + assert.Equal([]int{2, 1}, s.CalcStrides()) s = Shape{5, 2} - assert.Equal([]int{2, 1}, s.calcStrides()) + assert.Equal([]int{2, 1}, s.CalcStrides()) // 3D strides s = Shape{2, 3, 4} - assert.Equal([]int{12, 4, 1}, s.calcStrides()) + assert.Equal([]int{12, 4, 1}, s.CalcStrides()) // stupid shape s = Shape{-2, 1, 2} fail := func() { - s.calcStrides() + s.CalcStrides() } assert.Panics(fail) } @@ -191,6 +191,10 @@ var shapeSliceTests = []struct { {"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true}, {"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true}, {"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false}, + {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, Shape{2, 2}, false}, + {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, Shape{1, 2}, false}, + {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, Shape{1, 2, 2}, false}, + {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, Shape{1, 2, 2}, false}, } func TestShape_Slice(t *testing.T) { diff --git a/sparse.go b/sparse.go index abb36c1..5de67d4 100644 --- a/sparse.go +++ b/sparse.go @@ -31,7 +31,7 @@ type coo struct { func (c *coo) Len() int { return c.data.L } func (c *coo) Less(i, j int) bool { - if c.o.isColMajor() { + if c.o.IsColMajor() { return c.colMajorLess(i, j) } return c.rowMajorLess(i, j) @@ -182,13 +182,14 @@ func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { return t } -func (t *CS) Shape() Shape { return t.s } -func (t *CS) Strides() []int { return nil } -func (t *CS) Dtype() Dtype { return t.t } -func (t *CS) Dims() int { return 2 } -func (t *CS) Size() int { return t.s.TotalSize() } -func (t *CS) DataSize() int { return t.L } -func (t *CS) Engine() Engine { return t.e } +func (t *CS) Shape() Shape { return t.s } +func (t *CS) Strides() []int { return nil } +func (t *CS) Dtype() Dtype { return t.t } +func (t *CS) Dims() int { return 2 } +func (t *CS) Size() int { return t.s.TotalSize() } +func (t *CS) DataSize() int { return t.L } +func (t *CS) Engine() Engine { return t.e } +func (t *CS) DataOrder() DataOrder { return t.o } func (t *CS) Slice(...Slice) (View, error) { return nil, errors.Errorf("Slice for sparse tensors not implemented yet") @@ -232,11 +233,12 @@ func (t *CS) T(axes ...int) error { } UnsafePermute(axes, []int(t.s)) t.o = t.o.toggleColMajor() + t.o = MakeDataOrder(t.o, Transposed) return errors.Errorf(methodNYI, "T") } // UT untransposes the CS -func (t *CS) UT() { t.T() } +func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } // Transpose is a no-op. The data does not move func (t *CS) Transpose() error { return nil } @@ -307,7 +309,7 @@ func (t *CS) Iterator() Iterator { return NewFlatSparseIterator(t) } func (t *CS) at(coord ...int) (int, bool) { var r, c int - if t.o.isColMajor() { + if t.o.IsColMajor() { r = coord[1] c = coord[0] } else { @@ -330,7 +332,7 @@ func (t *CS) Dense() *Dense { } d := recycledDense(t.t, t.Shape().Clone()) - if t.o.isColMajor() { + if t.o.IsColMajor() { for i := 0; i < len(t.indptr)-1; i++ { for j := t.indptr[i]; j < t.indptr[i+1]; j++ { d.SetAt(t.Get(j), t.indices[j], i) @@ -361,14 +363,14 @@ func (t *CS) Indices() []int { } func (t *CS) AsCSR() { - if t.o.isRowMajor() { + if t.o.IsRowMajor() { return } t.o.toggleColMajor() } func (t *CS) AsCSC() { - if t.o.isColMajor() { + if t.o.IsColMajor() { return } t.o.toggleColMajor() diff --git a/tensor.go b/tensor.go index d1b348a..a04e425 100644 --- a/tensor.go +++ b/tensor.go @@ -36,6 +36,7 @@ type Tensor interface { // Data access related RequiresIterator() bool Iterator() Iterator + DataOrder() DataOrder // ops Slicer @@ -86,7 +87,6 @@ type Tensor interface { // New creates a new Dense Tensor. For sparse arrays use their relevant construction function func New(opts ...ConsOpt) *Dense { d := borrowDense() - d.AP = new(AP) for _, opt := range opts { opt(d) } diff --git a/testutils_test.go b/testutils_test.go index e219ab1..71a43a4 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -240,6 +240,7 @@ func allClose(a, b interface{}, approxFn ...interface{}) bool { return reflect.DeepEqual(a, b) } } + func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) { switch { case expected: diff --git a/types.go b/types.go index fd8e189..69740cf 100644 --- a/types.go +++ b/types.go @@ -299,6 +299,7 @@ func RegisterFloat(a Dtype) { RegisterOrd(a) } +// RegisterOrd registers a dtype as a type that can be typed func RegisterOrd(a Dtype) { for _, dt := range ordTypes.set { if dt == a { @@ -306,8 +307,10 @@ func RegisterOrd(a Dtype) { } } ordTypes.set = append(ordTypes.set, a) + RegisterEq(a) } +// RegisterEq registers a dtype as a type that can be compared for equality func RegisterEq(a Dtype) { for _, dt := range eqTypes.set { if dt == a { @@ -315,6 +318,26 @@ func RegisterEq(a Dtype) { } } eqTypes.set = append(eqTypes.set, a) + Register(a) +} + +// Register registers a new Dtype +func Register(a Dtype) { + for _, dt := range allTypes.set { + if a == dt { + return + } + } + allTypes.set = append(allTypes.set, a) +} + +func dtypeID(a Dtype) int { + for i, v := range allTypes.set { + if a == v { + return i + } + } + return -1 } // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte. diff --git a/utils.go b/utils.go index 9dcd936..8e62448 100644 --- a/utils.go +++ b/utils.go @@ -213,7 +213,6 @@ func UnsafePermute(pattern []int, xs ...[]int) (err error) { return nil } - // CheckSlice checks a slice to see if it's sane func CheckSlice(s Slice, size int) error { start := s.Start() @@ -282,9 +281,8 @@ func reuseCheckShape(reuse DenseTensor, s Shape) (err error) { } // clean up any funny things that may be in the reuse - if oldAP := reuse.oldAP(); oldAP != nil { - ReturnAP(oldAP) - reuse.setOldAP(nil) + if oldAP := reuse.oldAP(); !oldAP.IsZero() { + oldAP.zero() } if axes := reuse.transposeAxes(); axes != nil { @@ -309,7 +307,6 @@ func memsetBools(a []bool, v bool) { } } - /* FOR ILLUSTRATIVE PURPOSES */ // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) @@ -385,4 +382,4 @@ func Permute(pattern []int, xs ...[]int) (retVal [][]int, err error) { } return } -*/ \ No newline at end of file +*/ From aff2ae51d6a744979511bf83fd7201d3dbcb4f54 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Thu, 30 May 2019 09:06:50 +0200 Subject: [PATCH 014/154] optimization for the array structure (#42) * chore: change the name of the receiver for coherency * fix: wrong formating for the error message * chore: fix error formating and unreachable code * chore: error formatting * chore: error formatting * fix: expected output of the readnumpy test * fix: replace the string equal by a regexp the test will not fail in case of a change of the formatting directive * feat: doing lazy initialization of the underlying value of an array The value is populated on a call to Data() --- array.go | 42 ++++++++++++++++++++++--------------- defaultengine.go | 2 +- defaultengine_linalg.go | 3 +-- defaultengine_matop_misc.go | 2 +- dense.go | 14 +++++++++++++ dense_io_test.go | 7 +++++-- errors.go | 2 +- sparse.go | 4 ++-- 8 files changed, 50 insertions(+), 26 deletions(-) diff --git a/array.go b/array.go index 321c9cf..7d07dff 100644 --- a/array.go +++ b/array.go @@ -33,20 +33,10 @@ func makeArray(t Dtype, length int) array { // makeArrayFromHeader makes an array given a header func makeArrayFromHeader(hdr storage.Header, t Dtype) array { - // build a type of []T - shdr := reflect.SliceHeader{ - Data: uintptr(hdr.Ptr), - Len: hdr.L, - Cap: hdr.C, - } - sliceT := reflect.SliceOf(t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - return array{ Header: hdr, t: t, - v: val.Interface(), + v: nil, } } @@ -224,16 +214,31 @@ func (a *array) swap(i, j int) { /* *Array is a Memory */ // Uintptr returns the pointer of the first value of the slab -func (t *array) Uintptr() uintptr { return uintptr(t.Ptr) } +func (a *array) Uintptr() uintptr { return uintptr(a.Ptr) } // MemSize returns how big the slice is in bytes -func (t *array) MemSize() uintptr { return uintptr(t.L) * t.t.Size() } +func (a *array) MemSize() uintptr { return uintptr(a.L) * a.t.Size() } // Pointer returns the pointer of the first value of the slab, as an unsafe.Pointer -func (t *array) Pointer() unsafe.Pointer { return t.Ptr } +func (a *array) Pointer() unsafe.Pointer { return a.Ptr } // Data returns the representation of a slice. -func (a array) Data() interface{} { return a.v } +func (a array) Data() interface{} { + if a.v == nil { + // build a type of []T + shdr := reflect.SliceHeader{ + Data: uintptr(a.Header.Ptr), + Len: a.Header.L, + Cap: a.Header.C, + } + sliceT := reflect.SliceOf(a.t.Type) + ptr := unsafe.Pointer(&shdr) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + a.v = val.Interface() + + } + return a.v +} // Zero zeroes out the underlying array of the *Dense tensor. func (a array) Zero() { @@ -360,8 +365,11 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { return copyDense(dst, src), nil } - if !dst.IsNativelyAccessible() || !src.IsNativelyAccessible() { - return 0, errors.Errorf(inaccessibleData, "copy") + if !dst.IsNativelyAccessible() { + return 0, errors.Errorf(inaccessibleData, dst) + } + if !src.IsNativelyAccessible() { + return 0, errors.Errorf(inaccessibleData, src) } if diter == nil { diff --git a/defaultengine.go b/defaultengine.go index cace41a..bc92e8c 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -30,7 +30,7 @@ func (e StdEng) Memset(mem Memory, val interface{}) error { if ms, ok := mem.(MemSetter); ok { return ms.Memset(val) } - return errors.Errorf("Cannot memset %v with StdEng") + return errors.Errorf("Cannot memset %v with StdEng", mem) } func (e StdEng) Memclr(mem Memory) { diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 45a8527..c75f4a7 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -287,7 +287,6 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { var rd *Dense if rd, err = a.TensorMul(b, axesA, axesB); err != nil { panic(err) - return } if reuse != nil { @@ -313,7 +312,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { var t *Dense var ok bool if err = e.checkAccessible(a); err != nil { - return nil, nil, nil, errors.Wrapf(err, "opFail", "SVD") + return nil, nil, nil, errors.Wrapf(err, "opFail %v", "SVD") } if t, ok = a.(*Dense); !ok { return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a) diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 9faed77..f71de3c 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -249,7 +249,7 @@ func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { bdata[i] = adata[i*stride] } default: - return nil, errors.Errorf(typeNYI, "Arbitrary sized diag") + return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t) } return b, nil } diff --git a/dense.go b/dense.go index 2824261..69f4d70 100644 --- a/dense.go +++ b/dense.go @@ -2,6 +2,7 @@ package tensor import ( "fmt" + "reflect" "unsafe" "github.com/pkg/errors" @@ -112,6 +113,19 @@ func (t *Dense) Data() interface{} { if t.IsScalar() { return t.Get(0) } + if t.v == nil { + // build a type of []T + shdr := reflect.SliceHeader{ + Data: uintptr(t.Header.Ptr), + Len: t.Header.L, + Cap: t.Header.C, + } + sliceT := reflect.SliceOf(t.t.Type) + ptr := unsafe.Pointer(&shdr) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + t.v = val.Interface() + + } return t.v } diff --git a/dense_io_test.go b/dense_io_test.go index 3c75973..482b170 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -5,6 +5,7 @@ import ( "encoding/gob" "os" "os/exec" + "regexp" "testing" "github.com/stretchr/testify/assert" @@ -45,10 +46,12 @@ func TestSaveLoadNumpy(t *testing.T) { t.Error(err) } - expected := "[[ 1. 5.]\n [ 10. -1.]]\n" + expected := `\[\[\s*1\.\s*5\.\]\n \[\s*10\.\s*-1\.\]\]\n` + if ok, _ := regexp.Match(expected, buf.Bytes()); !ok { + t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected) + } if buf.String() != expected { - t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected) } // cleanup diff --git a/errors.go b/errors.go index d688c09..fd3b9db 100644 --- a/errors.go +++ b/errors.go @@ -59,7 +59,7 @@ const ( unknownState = "Unknown state reached: Safe %t, Incr %t, Reuse %t" unsupportedDtype = "Array of %v is unsupported for %v" maskRequired = "Masked array type required for %v" - inaccessibleData = "Data in %p inaccessble" + inaccessibleData = "Data in %p inaccessble" methodNYI = "%q not yet implemented for %v" typeNYI = "%q not yet implemented for interactions with %T" diff --git a/sparse.go b/sparse.go index 5de67d4..1e46586 100644 --- a/sparse.go +++ b/sparse.go @@ -234,7 +234,7 @@ func (t *CS) T(axes ...int) error { UnsafePermute(axes, []int(t.s)) t.o = t.o.toggleColMajor() t.o = MakeDataOrder(t.o, Transposed) - return errors.Errorf(methodNYI, "T") + return errors.Errorf(methodNYI, "T", t) } // UT untransposes the CS @@ -244,7 +244,7 @@ func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } func (t *CS) Transpose() error { return nil } func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { - return nil, errors.Errorf(methodNYI, "Apply") + return nil, errors.Errorf(methodNYI, "Apply", t) } func (t *CS) Eq(other interface{}) bool { From 777a631d96a403f5162aae9dba59161cbb89b86d Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sun, 2 Jun 2019 07:02:03 +1000 Subject: [PATCH 015/154] Fixed Travis (#44) * Fixed Travis * Fixed reflect type issue per https://github.com/golang/go/issues/32303 --- .travis.yml | 4 ++-- array.go | 2 +- array_getset.go | 10 +++++----- example_extension_matop_test.go | 1 + genlib2/array_getset.go | 10 +++++----- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9a3402d..1bec54c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,9 +5,9 @@ branches: only: - master go: - - 1.8.x - - 1.9.x - 1.10.x + - 1.11.x + - 1.12.x - tip env: diff --git a/array.go b/array.go index 7d07dff..d96b56f 100644 --- a/array.go +++ b/array.go @@ -259,7 +259,7 @@ func (a array) Zero() { ptr := uintptr(a.Ptr) for i := 0; i < a.L; i++ { want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(reflect.Zero(a.t)) } diff --git a/array_getset.go b/array_getset.go index fe65438..8896a69 100644 --- a/array_getset.go +++ b/array_getset.go @@ -70,7 +70,7 @@ func (a *array) Set(i int, x interface{}) { xv := reflect.ValueOf(x) ptr := uintptr(a.Ptr) want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) } @@ -117,7 +117,7 @@ func (a *array) Get(i int) interface{} { return a.GetUnsafePointer(i) default: at := uintptr(a.Ptr) + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(at)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(at)) val = reflect.Indirect(val) return val.Interface() } @@ -294,7 +294,7 @@ func (a *array) Memset(x interface{}) error { ptr := uintptr(a.Ptr) for i := 0; i < a.L; i++ { want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) } @@ -489,7 +489,7 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next() { want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) } @@ -754,7 +754,7 @@ func (t *array) zeroIter(it Iterator) (err error) { ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next() { want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(reflect.Zero(t.t)) } diff --git a/example_extension_matop_test.go b/example_extension_matop_test.go index 0855b77..372c9a0 100644 --- a/example_extension_matop_test.go +++ b/example_extension_matop_test.go @@ -39,6 +39,7 @@ func Example_TransposeExtension() { LongStruct{3, 3, 3, 3, 3}, }), ) + fmt.Printf("Before:\n%v\n", T) retVal, _ := tensor.Transpose(T) // an alternative would be to use T.T(); T.Transpose() fmt.Printf("After:\n%v\n", retVal) diff --git a/genlib2/array_getset.go b/genlib2/array_getset.go index 8fa0702..f24b19c 100644 --- a/genlib2/array_getset.go +++ b/genlib2/array_getset.go @@ -27,7 +27,7 @@ func (a *array) Get(i int) interface{} { {{end -}} default: at := uintptr(a.Ptr) + uintptr(i) * a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(at)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(at)) val = reflect.Indirect(val) return val.Interface() } @@ -49,7 +49,7 @@ func (a *array) Set(i int, x interface{}) { xv := reflect.ValueOf(x) ptr := uintptr(a.Ptr) want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) } @@ -80,7 +80,7 @@ func (a *array) Memset(x interface{}) error { ptr := uintptr(a.Ptr) for i := 0; i < a.L; i++ { want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) } @@ -202,7 +202,7 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next(){ want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) } @@ -235,7 +235,7 @@ const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next(){ want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t, unsafe.Pointer(want)) + val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(reflect.Zero(t.t)) } From 356c45cfe892e0bfda4eed8f25234a0d43c3df78 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Sun, 9 Jun 2019 15:06:01 +0200 Subject: [PATCH 016/154] Optimization on the Repeat operation (#43) * feat: enhance the performances of the Repeat operator If possible, avoid the use of the copyDenseSlice function to be able to reuse the source slice. * feat: add some bench for the Repeat operation * fix: use named subbenchmark for clarity on the output * Moved "fastCopy" of repeat to a new method. Added tests for slow path (by creating masked copies) * Moved benchmarks out to a separate file --- benchmark_dense_repeat_test.go | 14 +++++++++ defaultengine_matop_misc.go | 55 ++++++++++++++++++++++++++++++++++ dense_matop_test.go | 49 ++++++++++++++++++++++++++++-- 3 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 benchmark_dense_repeat_test.go diff --git a/benchmark_dense_repeat_test.go b/benchmark_dense_repeat_test.go new file mode 100644 index 0000000..f22138a --- /dev/null +++ b/benchmark_dense_repeat_test.go @@ -0,0 +1,14 @@ +package tensor + +import "testing" + +func BenchmarkDenseRepeat(b *testing.B) { + for _, tst := range repeatTests { + tst := tst + b.Run(tst.name, func(b *testing.B) { + for n := 0; n < b.N; n++ { + tst.tensor.Repeat(tst.axis, tst.repeats...) + } + }) + } +} diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index f71de3c..6ecdfaf 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -8,6 +8,11 @@ var ( _ Diager = StdEng{} ) +type fastcopier interface { + fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error +} + +// Repeat ... func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { switch tt := t.(type) { case DenseTensor: @@ -54,6 +59,29 @@ func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseT } var destStart, srcStart int + // fastCopy is not bypassing the copyDenseSliced method to populate the output tensor + var fastCopy bool + var fce fastcopier + // we need an engine for fastCopying... + e := t.Engine() + // e can never be nil. Error would have occured elsewhere + var ok bool + if fce, ok = e.(fastcopier); ok { + fastCopy = true + } + + // In this case, let's not implement the fast copy to keep the code readable + if ms, ok := t.(MaskedTensor); ok && ms.IsMasked() { + fastCopy = false + } + + if fastCopy { + if err := fce.fastCopyDenseRepeat(t, d, outers, size, stride, newStride, repeats); err != nil { + return nil, err + } + return d, nil + } + for i := 0; i < outers; i++ { for j := 0; j < size; j++ { var tmp int @@ -72,6 +100,32 @@ func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseT return d, nil } +func (StdEng) fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error { + var destStart, srcStart int + for i := 0; i < outers; i++ { + for j := 0; j < size; j++ { + var tmp int + tmp = repeats[j] + var tSlice array + tSlice = t.arr().slice(srcStart, t.len()) + + for k := 0; k < tmp; k++ { + if srcStart >= t.len() || destStart+stride > d.len() { + break + } + dSlice := d.arr().slice(destStart, d.len()) + if err := t.Engine().Memcpy(&dSlice, &tSlice); err != nil { + return err + } + destStart += newStride + } + srcStart += stride + } + } + return nil +} + +// Concat tensors func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { switch tt := t.(type) { case DenseTensor: @@ -196,6 +250,7 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen return retVal, nil } +// Diag ... func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { a, ok := t.(DenseTensor) if !ok { diff --git a/dense_matop_test.go b/dense_matop_test.go index ca02d3e..8052b2b 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -287,7 +287,7 @@ func TestTUT(t *testing.T) { assert.Nil(T.transposeWith) } -var repeatTests = []struct { +type repeatTest struct { name string tensor *Dense ne bool // should assert tensor not equal @@ -297,7 +297,9 @@ var repeatTests = []struct { correct interface{} shape Shape err bool -}{ +} + +var repeatTests = []repeatTest{ {"Scalar Repeat on axis 0", New(FromScalar(true)), true, 0, []int{3}, []bool{true, true, true}, @@ -435,6 +437,49 @@ func TestDense_Repeat(t *testing.T) { } } +func TestDense_Repeat_Slow(t *testing.T) { + rt2 := make([]repeatTest, len(repeatTests)) + for i, rt := range repeatTests { + rt2[i] = repeatTest{ + name: rt.name, + ne: rt.ne, + axis: rt.axis, + repeats: rt.repeats, + correct: rt.correct, + shape: rt.shape, + err: rt.err, + tensor: rt.tensor.Clone().(*Dense), + } + } + for i := range rt2 { + maskLen := rt2[i].tensor.len() + mask := make([]bool, maskLen) + rt2[i].tensor.mask = mask + } + + assert := assert.New(t) + + for i, test := range rt2 { + T, err := test.tensor.Repeat(test.axis, test.repeats...) + if checkErr(t, test.err, err, "Repeat", i) { + continue + } + + var D DenseTensor + if D, err = getDenseTensor(T); err != nil { + t.Errorf("Expected Repeat to return a *Dense. got %v of %T instead", T, T) + continue + } + + if test.ne { + assert.NotEqual(test.tensor, D, test.name) + } + + assert.Equal(test.correct, D.Data(), test.name) + assert.Equal(test.shape, D.Shape(), test.name) + } +} + func TestDense_CopyTo(t *testing.T) { assert := assert.New(t) var T, T2 *Dense From 2e19d2e09b59c6ae597fe7175b5aaf9c00723f1e Mon Sep 17 00:00:00 2001 From: Chewxy Date: Thu, 5 Sep 2019 10:29:17 +1000 Subject: [PATCH 017/154] Go1.13 ci (#45) * Updated gomod and deleted Gopkg * Updated gomod and go sum --- Gopkg.lock | 75 ------------------------------------------------------ Gopkg.toml | 54 --------------------------------------- go.mod | 17 +++++++++++++ go.sum | 56 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 129 deletions(-) delete mode 100644 Gopkg.lock delete mode 100644 Gopkg.toml create mode 100644 go.mod create mode 100644 go.sum diff --git a/Gopkg.lock b/Gopkg.lock deleted file mode 100644 index cb5367f..0000000 --- a/Gopkg.lock +++ /dev/null @@ -1,75 +0,0 @@ -# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. - - -[[projects]] - name = "github.com/chewxy/hm" - packages = ["."] - revision = "61efb3290a086d1335e8954b3734c102126818ba" - version = "v1.0.0" - -[[projects]] - name = "github.com/chewxy/math32" - packages = ["."] - revision = "d1e7b22839c693f54edf7811dd9487623abf2cd2" - version = "v1.0.0" - -[[projects]] - branch = "master" - name = "gorgonia.org/vecf32" - packages = ["."] - revision = "1f59516136c1a7f1c19871d3dc5f0d9928ffbd7c" - -[[projects]] - branch = "master" - name = "gorgonia.org/vecf64" - packages = ["."] - revision = "a97a4d31b6c9343b1860ef8ce583069671265b81" - -[[projects]] - name = "github.com/davecgh/go-spew" - packages = ["spew"] - revision = "346938d642f2ec3594ed81d874461961cd0faa76" - version = "v1.1.0" - -[[projects]] - name = "github.com/pkg/errors" - packages = ["."] - revision = "645ef00459ed84a119197bfb8d8205042c6df63d" - version = "v0.8.0" - -[[projects]] - name = "github.com/pmezard/go-difflib" - packages = ["difflib"] - revision = "792786c7400a136282c1664665ae0a8db921c6c2" - version = "v1.0.0" - -[[projects]] - name = "github.com/stretchr/testify" - packages = ["assert"] - revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" - version = "v1.1.4" - -[[projects]] - branch = "master" - name = "github.com/xtgo/set" - packages = ["."] - revision = "4431f6b51265b1e0b76af4dafc09d6f12c2bdcd0" - -[[projects]] - branch = "master" - name = "golang.org/x/tools" - packages = ["go/ast/astutil","go/buildutil","go/loader"] - revision = "e531a2a1c15f94033f6fa87666caeb19a688175f" - -[[projects]] - branch = "master" - name = "gonum.org/v1/gonum" - packages = ["blas","blas/blas64","blas/gonum","floats","internal/asm/c128","internal/asm/f32","internal/asm/f64","internal/math32","lapack","lapack/gonum","lapack/lapack64","mat"] - revision = "f818f8f7a9e59de54e475b747a3dc9c86ed141f1" - -[solve-meta] - analyzer-name = "dep" - analyzer-version = 1 - inputs-digest = "eee289039fc6a17513fe715d3457e0997e6bffa8a27d57987d15cb1a17407705" - solver-name = "gps-cdcl" - solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml deleted file mode 100644 index 1a52f27..0000000 --- a/Gopkg.toml +++ /dev/null @@ -1,54 +0,0 @@ - -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - -[[constraint]] - name = "github.com/chewxy/hm" - version = "~1.0.0" - -[[constraint]] - name = "github.com/chewxy/math32" - version = "~1.0.0" - -[[constraint]] - branch = "master" - name = "gorgonia.org/vecf32" - -[[constraint]] - branch = "master" - name = "gorgonia.org/vecf64" - -[[constraint]] - name = "github.com/pkg/errors" - version = "~0.8.0" - -[[constraint]] - name = "github.com/stretchr/testify" - version = "1.1.4" - -[[constraint]] - branch = "master" - name = "golang.org/x/tools" - -[[constraint]] - branch = "master" - name = "gonum.org/v1/gonum" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..70baaff --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module gorgonia.org/tensor + +go 1.13 + +require ( + github.com/chewxy/hm v1.0.0 + github.com/chewxy/math32 v1.0.4 + github.com/gogo/protobuf v1.3.0 + github.com/golang/protobuf v1.3.2 + github.com/google/flatbuffers v1.11.0 + github.com/pkg/errors v0.8.1 + github.com/stretchr/testify v1.4.0 + github.com/xtgo/set v1.0.0 // indirect + gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee + gorgonia.org/vecf32 v0.9.0 + gorgonia.org/vecf64 v0.9.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..850df1d --- /dev/null +++ b/go.sum @@ -0,0 +1,56 @@ +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= +github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= +github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= +github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo= +github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE= +github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A= +github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= +github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2 h1:y102fOLFqhV41b+4GPiJoa0k/x+pJcEi2/HB1Y5T6fU= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee h1:4pVWuAEGpaPZ7dPfd6aA8LyDNzMA2RKCxAS/XNCLZUM= +gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0= +gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8= +gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= +gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= +gorgonia.org/vecf64 v0.7.0 h1:ZphOGJfnWlFfY7x8WAJAfO64IAtYqPPq9TEGem+ItZE= +gorgonia.org/vecf64 v0.7.0/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q= +gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= +gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= From 0ea1dd1395ddb99c05212b88ab325df7a657eb27 Mon Sep 17 00:00:00 2001 From: Tiago Rodrigo Lampert Date: Sun, 13 Oct 2019 20:37:03 -0300 Subject: [PATCH 018/154] Fix typo (#49) * Fix typo stucture -> structure obiviating -> obviating * Fix typo (initalizes -> initializes) initalizes -> initializes * Fix typo (inaccessble -> inaccessible) inaccessble -> inaccessible * Fix typo (occured -> occurred) occured -> occurred --- README.md | 2 +- ap.go | 2 +- defaultengine_matop_misc.go | 2 +- errors.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 086cdc4..b93aa9a 100644 --- a/README.md +++ b/README.md @@ -156,7 +156,7 @@ And here's a visual representation of the `*Dense`. ![dense](https://github.com/gorgonia/tensor/blob/master/media/dense.png?raw=true) -`*Dense` draws its inspiration from Go's slice. Underlying it all is a flat array, and access to elements are controlled by `*AP`. Where a Go is able to store its metadata in a 3-word stucture (obiviating the need to allocate memory), a `*Dense` unfortunately needs to allocate some memory. The majority of the data is stored in the `*AP` structure, which contains metadata such as shape, stride, and methods for accessing the array. +`*Dense` draws its inspiration from Go's slice. Underlying it all is a flat array, and access to elements are controlled by `*AP`. Where a Go is able to store its metadata in a 3-word structure (obviating the need to allocate memory), a `*Dense` unfortunately needs to allocate some memory. The majority of the data is stored in the `*AP` structure, which contains metadata such as shape, stride, and methods for accessing the array. `*Dense` embeds an `array` (not to be confused with Go's array), which is an abstracted data structure that looks like this: diff --git a/ap.go b/ap.go index 83df9c5..d9dac59 100644 --- a/ap.go +++ b/ap.go @@ -44,7 +44,7 @@ func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP { } } -// Init initalizes an already created AP with a shape and stries. +// Init initializes an already created AP with a shape and stries. // It will panic if AP is nil. func (ap *AP) Init(shape Shape, strides []int) { ap.shape = shape diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 6ecdfaf..06f53ae 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -64,7 +64,7 @@ func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseT var fce fastcopier // we need an engine for fastCopying... e := t.Engine() - // e can never be nil. Error would have occured elsewhere + // e can never be nil. Error would have occurred elsewhere var ok bool if fce, ok = e.(fastcopier); ok { fastCopy = true diff --git a/errors.go b/errors.go index fd3b9db..461baf4 100644 --- a/errors.go +++ b/errors.go @@ -59,7 +59,7 @@ const ( unknownState = "Unknown state reached: Safe %t, Incr %t, Reuse %t" unsupportedDtype = "Array of %v is unsupported for %v" maskRequired = "Masked array type required for %v" - inaccessibleData = "Data in %p inaccessble" + inaccessibleData = "Data in %p inaccessible" methodNYI = "%q not yet implemented for %v" typeNYI = "%q not yet implemented for interactions with %T" From 8e1c77328ea1cdfe9e599541d000df0e86aa5a26 Mon Sep 17 00:00:00 2001 From: Ben Leitner <7515022+bdleitner@users.noreply.github.com> Date: Sun, 13 Oct 2019 17:31:24 -0700 Subject: [PATCH 019/154] Fix a bug where axes in the along array were only ever decremented by 1 to account for prior reductions, no matter how many reductions had occurred. (#48) * Factor out common code for the Max, Min, and Sum reductions. * Include tests for 4D tensors for complete as well as partial reduction across all 3 reduction methods. Closes #47 --- defaultengine_mapreduce.go | 122 +++++++------------------------------ dense_reduction_test.go | 112 ++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 100 deletions(-) diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 4964ab4..03b5c0e 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -2,8 +2,10 @@ package tensor import ( "reflect" + "sort" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" ) @@ -176,99 +178,23 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, } func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { - switch at := a.(type) { - case *Dense: - hdr := at.hdr() - typ := at.t.Type - monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value - if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { - var ret interface{} - if ret, err = execution.MonotonicSum(typ, hdr); err != nil { - return - } - return New(FromScalar(ret)), nil - } - var firstFn, lastFn, defaultFn interface{} - if firstFn, lastFn, defaultFn, err = execution.SumMethods(typ); err != nil { - return - } - defaultVal := reflect.Zero(typ).Interface() - - retVal = a - prev := -1 - dims := len(retVal.Shape()) - - for _, axis := range along { - if prev == -1 { - prev = axis - } - if axis > prev { - axis-- - } - - if axis >= dims { - err = errors.Errorf(dimMismatch, retVal.Dims(), axis) - return - } - if retVal, err = e.OptimizedReduce(retVal, axis, firstFn, lastFn, defaultFn, defaultVal); err != nil { - return - } - } - return - - default: - return nil, errors.Errorf("Cannot perform Sum on %T", a) - } + return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a, along...) } func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { - switch at := a.(type) { - case *Dense: - hdr := at.hdr() - typ := at.t.Type - monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value - if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { - var ret interface{} - if ret, err = execution.MonotonicMin(typ, hdr); err != nil { - return - } - return New(FromScalar(ret)), nil - } - var firstFn, lastFn, defaultFn interface{} - if firstFn, lastFn, defaultFn, err = execution.MinMethods(typ); err != nil { - return - } - defaultVal := reflect.Zero(typ).Interface() - - retVal = a - prev := -1 - dims := len(retVal.Shape()) - - for _, axis := range along { - if prev == -1 { - prev = axis - } - if axis > prev { - axis-- - } - - if axis >= dims { - err = errors.Errorf(dimMismatch, retVal.Dims(), axis) - return - } - - if retVal, err = e.OptimizedReduce(retVal, axis, firstFn, lastFn, defaultFn, defaultVal); err != nil { - return - } - } - return - - default: - return nil, errors.Errorf("Cannot perform Min on %T", a) - } + return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a, along...) } func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { + return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a, along...) +} + +func (e StdEng) reduce( + op string, + monotonicMethod func(t reflect.Type, a *storage.Header) (interface{}, error), + methods func(t reflect.Type) (interface{}, interface{}, interface{}, error), + a Tensor, + along ...int) (retVal Tensor, err error) { switch at := a.(type) { case *Dense: hdr := at.hdr() @@ -276,30 +202,25 @@ func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { var ret interface{} - if ret, err = execution.MonotonicMax(typ, hdr); err != nil { + if ret, err = monotonicMethod(typ, hdr); err != nil { return } return New(FromScalar(ret)), nil } var firstFn, lastFn, defaultFn interface{} - if firstFn, lastFn, defaultFn, err = execution.MaxMethods(typ); err != nil { + if firstFn, lastFn, defaultFn, err = methods(typ); err != nil { return } defaultVal := reflect.Zero(typ).Interface() retVal = a - prev := -1 - dims := len(retVal.Shape()) + dimsReduced := 0 + sort.Slice(along, func(i, j int) bool { return along[i] < along[j] }) for _, axis := range along { - if prev == -1 { - prev = axis - } - if axis > prev { - axis-- - } - - if axis >= dims { + axis -= dimsReduced + dimsReduced++ + if axis >= retVal.Dims() { err = errors.Errorf(dimMismatch, retVal.Dims(), axis) return } @@ -311,8 +232,9 @@ func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { return default: - return nil, errors.Errorf("Cannot perform Max on %T", a) + return nil, errors.Errorf("Cannot perform %s on %T", op, a) } + } func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTensor, dataA, dataReuse *storage.Header, err error) { diff --git a/dense_reduction_test.go b/dense_reduction_test.go index ffe673d..f4bda25 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/tensor/internal/execution" ) @@ -129,6 +130,9 @@ var sumTests = []struct { {"A.Sum(0,1) for int", Int, Shape{2, 3}, []int{0, 1}, ScalarShape(), int(15)}, {"A.Sum(1,0) for int", Int, Shape{2, 3}, []int{1, 0}, ScalarShape(), int(15)}, {"3T.Sum(1,2) for int", Int, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int{66, 210}}, + {"4T.Sum() for int", Int, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int(120)}, + {"4T.Sum(1,3) for int", Int, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int", Int, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int{44, 76}}, {"common case: T.Sum() for int8", Int8, Shape{2, 3}, []int{}, ScalarShape(), int8(15)}, {"A.Sum(0) for int8", Int8, Shape{2, 3}, []int{0}, Shape{3}, []int8{3, 5, 7}}, {"A.Sum(1) for int8", Int8, Shape{2, 3}, []int{1}, Shape{2}, []int8{3, 12}}, @@ -141,72 +145,108 @@ var sumTests = []struct { {"A.Sum(0,1) for int16", Int16, Shape{2, 3}, []int{0, 1}, ScalarShape(), int16(15)}, {"A.Sum(1,0) for int16", Int16, Shape{2, 3}, []int{1, 0}, ScalarShape(), int16(15)}, {"3T.Sum(1,2) for int16", Int16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int16{66, 210}}, + {"4T.Sum() for int16", Int16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int16(120)}, + {"4T.Sum(1,3) for int16", Int16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int16{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int16", Int16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int16{44, 76}}, {"common case: T.Sum() for int32", Int32, Shape{2, 3}, []int{}, ScalarShape(), int32(15)}, {"A.Sum(0) for int32", Int32, Shape{2, 3}, []int{0}, Shape{3}, []int32{3, 5, 7}}, {"A.Sum(1) for int32", Int32, Shape{2, 3}, []int{1}, Shape{2}, []int32{3, 12}}, {"A.Sum(0,1) for int32", Int32, Shape{2, 3}, []int{0, 1}, ScalarShape(), int32(15)}, {"A.Sum(1,0) for int32", Int32, Shape{2, 3}, []int{1, 0}, ScalarShape(), int32(15)}, {"3T.Sum(1,2) for int32", Int32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int32{66, 210}}, + {"4T.Sum() for int32", Int32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int32(120)}, + {"4T.Sum(1,3) for int32", Int32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int32{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int32", Int32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int32{44, 76}}, {"common case: T.Sum() for int64", Int64, Shape{2, 3}, []int{}, ScalarShape(), int64(15)}, {"A.Sum(0) for int64", Int64, Shape{2, 3}, []int{0}, Shape{3}, []int64{3, 5, 7}}, {"A.Sum(1) for int64", Int64, Shape{2, 3}, []int{1}, Shape{2}, []int64{3, 12}}, {"A.Sum(0,1) for int64", Int64, Shape{2, 3}, []int{0, 1}, ScalarShape(), int64(15)}, {"A.Sum(1,0) for int64", Int64, Shape{2, 3}, []int{1, 0}, ScalarShape(), int64(15)}, {"3T.Sum(1,2) for int64", Int64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int64{66, 210}}, + {"4T.Sum() for int64", Int64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int64(120)}, + {"4T.Sum(1,3) for int64", Int64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int64", Int64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int64{44, 76}}, {"common case: T.Sum() for uint", Uint, Shape{2, 3}, []int{}, ScalarShape(), uint(15)}, {"A.Sum(0) for uint", Uint, Shape{2, 3}, []int{0}, Shape{3}, []uint{3, 5, 7}}, {"A.Sum(1) for uint", Uint, Shape{2, 3}, []int{1}, Shape{2}, []uint{3, 12}}, {"A.Sum(0,1) for uint", Uint, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint(15)}, {"A.Sum(1,0) for uint", Uint, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint(15)}, {"3T.Sum(1,2) for uint", Uint, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint{66, 210}}, + {"4T.Sum() for uint", Uint, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint(120)}, + {"4T.Sum(1,3) for uint", Uint, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint", Uint, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint{44, 76}}, {"common case: T.Sum() for uint8", Uint8, Shape{2, 3}, []int{}, ScalarShape(), uint8(15)}, {"A.Sum(0) for uint8", Uint8, Shape{2, 3}, []int{0}, Shape{3}, []uint8{3, 5, 7}}, {"A.Sum(1) for uint8", Uint8, Shape{2, 3}, []int{1}, Shape{2}, []uint8{3, 12}}, {"A.Sum(0,1) for uint8", Uint8, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint8(15)}, {"A.Sum(1,0) for uint8", Uint8, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint8(15)}, {"3T.Sum(1,2) for uint8", Uint8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint8{66, 210}}, + {"4T.Sum() for uint8", Uint8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint8(120)}, + {"4T.Sum(1,3) for uint8", Uint8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint8{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint8", Uint8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint8{44, 76}}, {"common case: T.Sum() for uint16", Uint16, Shape{2, 3}, []int{}, ScalarShape(), uint16(15)}, {"A.Sum(0) for uint16", Uint16, Shape{2, 3}, []int{0}, Shape{3}, []uint16{3, 5, 7}}, {"A.Sum(1) for uint16", Uint16, Shape{2, 3}, []int{1}, Shape{2}, []uint16{3, 12}}, {"A.Sum(0,1) for uint16", Uint16, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint16(15)}, {"A.Sum(1,0) for uint16", Uint16, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint16(15)}, {"3T.Sum(1,2) for uint16", Uint16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint16{66, 210}}, + {"4T.Sum() for uint16", Uint16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint16(120)}, + {"4T.Sum(1,3) for uint16", Uint16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint16{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint16", Uint16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint16{44, 76}}, {"common case: T.Sum() for uint32", Uint32, Shape{2, 3}, []int{}, ScalarShape(), uint32(15)}, {"A.Sum(0) for uint32", Uint32, Shape{2, 3}, []int{0}, Shape{3}, []uint32{3, 5, 7}}, {"A.Sum(1) for uint32", Uint32, Shape{2, 3}, []int{1}, Shape{2}, []uint32{3, 12}}, {"A.Sum(0,1) for uint32", Uint32, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint32(15)}, {"A.Sum(1,0) for uint32", Uint32, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint32(15)}, {"3T.Sum(1,2) for uint32", Uint32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint32{66, 210}}, + {"4T.Sum() for uint32", Uint32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint32(120)}, + {"4T.Sum(1,3) for uint32", Uint32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint32{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint32", Uint32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint32{44, 76}}, {"common case: T.Sum() for uint64", Uint64, Shape{2, 3}, []int{}, ScalarShape(), uint64(15)}, {"A.Sum(0) for uint64", Uint64, Shape{2, 3}, []int{0}, Shape{3}, []uint64{3, 5, 7}}, {"A.Sum(1) for uint64", Uint64, Shape{2, 3}, []int{1}, Shape{2}, []uint64{3, 12}}, {"A.Sum(0,1) for uint64", Uint64, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint64(15)}, {"A.Sum(1,0) for uint64", Uint64, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint64(15)}, {"3T.Sum(1,2) for uint64", Uint64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint64{66, 210}}, + {"4T.Sum() for uint64", Uint64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint64(120)}, + {"4T.Sum(1,3) for uint64", Uint64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for uint64", Uint64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint64{44, 76}}, {"common case: T.Sum() for float32", Float32, Shape{2, 3}, []int{}, ScalarShape(), float32(15)}, {"A.Sum(0) for float32", Float32, Shape{2, 3}, []int{0}, Shape{3}, []float32{3, 5, 7}}, {"A.Sum(1) for float32", Float32, Shape{2, 3}, []int{1}, Shape{2}, []float32{3, 12}}, {"A.Sum(0,1) for float32", Float32, Shape{2, 3}, []int{0, 1}, ScalarShape(), float32(15)}, {"A.Sum(1,0) for float32", Float32, Shape{2, 3}, []int{1, 0}, ScalarShape(), float32(15)}, {"3T.Sum(1,2) for float32", Float32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float32{66, 210}}, + {"4T.Sum() for float32", Float32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float32(120)}, + {"4T.Sum(1,3) for float32", Float32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float32{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for float32", Float32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float32{44, 76}}, {"common case: T.Sum() for float64", Float64, Shape{2, 3}, []int{}, ScalarShape(), float64(15)}, {"A.Sum(0) for float64", Float64, Shape{2, 3}, []int{0}, Shape{3}, []float64{3, 5, 7}}, {"A.Sum(1) for float64", Float64, Shape{2, 3}, []int{1}, Shape{2}, []float64{3, 12}}, {"A.Sum(0,1) for float64", Float64, Shape{2, 3}, []int{0, 1}, ScalarShape(), float64(15)}, {"A.Sum(1,0) for float64", Float64, Shape{2, 3}, []int{1, 0}, ScalarShape(), float64(15)}, {"3T.Sum(1,2) for float64", Float64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float64{66, 210}}, + {"4T.Sum() for float64", Float64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float64(120)}, + {"4T.Sum(1,3) for float64", Float64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for float64", Float64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float64{44, 76}}, {"common case: T.Sum() for complex64", Complex64, Shape{2, 3}, []int{}, ScalarShape(), complex64(15)}, {"A.Sum(0) for complex64", Complex64, Shape{2, 3}, []int{0}, Shape{3}, []complex64{3, 5, 7}}, {"A.Sum(1) for complex64", Complex64, Shape{2, 3}, []int{1}, Shape{2}, []complex64{3, 12}}, {"A.Sum(0,1) for complex64", Complex64, Shape{2, 3}, []int{0, 1}, ScalarShape(), complex64(15)}, {"A.Sum(1,0) for complex64", Complex64, Shape{2, 3}, []int{1, 0}, ScalarShape(), complex64(15)}, {"3T.Sum(1,2) for complex64", Complex64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []complex64{66, 210}}, + {"4T.Sum() for complex64", Complex64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), complex64(120)}, + {"4T.Sum(1,3) for complex64", Complex64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []complex64{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for complex64", Complex64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []complex64{44, 76}}, {"common case: T.Sum() for complex128", Complex128, Shape{2, 3}, []int{}, ScalarShape(), complex128(15)}, {"A.Sum(0) for complex128", Complex128, Shape{2, 3}, []int{0}, Shape{3}, []complex128{3, 5, 7}}, {"A.Sum(1) for complex128", Complex128, Shape{2, 3}, []int{1}, Shape{2}, []complex128{3, 12}}, {"A.Sum(0,1) for complex128", Complex128, Shape{2, 3}, []int{0, 1}, ScalarShape(), complex128(15)}, {"A.Sum(1,0) for complex128", Complex128, Shape{2, 3}, []int{1, 0}, ScalarShape(), complex128(15)}, {"3T.Sum(1,2) for complex128", Complex128, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []complex128{66, 210}}, + {"4T.Sum() for complex128", Complex128, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), complex128(120)}, + {"4T.Sum(1,3) for complex128", Complex128, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []complex128{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for complex128", Complex128, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []complex128{44, 76}}, } func TestDense_Sum(t *testing.T) { @@ -244,72 +284,108 @@ var maxTests = []struct { {"A.Max(0,1)", Int, Shape{2, 3}, []int{0, 1}, ScalarShape(), int(5)}, {"A.Max(1,0)", Int, Shape{2, 3}, []int{1, 0}, ScalarShape(), int(5)}, {"3T.Max(1,2)", Int, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int{11, 23}}, + {"4T.Max()", Int, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int(15)}, + {"4T.Max(1,3)", Int, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int{11, 15}}, {"common case: T.Max() for int8", Int8, Shape{2, 3}, []int{}, ScalarShape(), int8(5)}, {"A.Max(0)", Int8, Shape{2, 3}, []int{0}, Shape{3}, []int8{3, 4, 5}}, {"A.Max(1)", Int8, Shape{2, 3}, []int{1}, Shape{2}, []int8{2, 5}}, {"A.Max(0,1)", Int8, Shape{2, 3}, []int{0, 1}, ScalarShape(), int8(5)}, {"A.Max(1,0)", Int8, Shape{2, 3}, []int{1, 0}, ScalarShape(), int8(5)}, {"3T.Max(1,2)", Int8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int8{11, 23}}, + {"4T.Max()", Int8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int8(15)}, + {"4T.Max(1,3)", Int8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int8{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int8{11, 15}}, {"common case: T.Max() for int16", Int16, Shape{2, 3}, []int{}, ScalarShape(), int16(5)}, {"A.Max(0)", Int16, Shape{2, 3}, []int{0}, Shape{3}, []int16{3, 4, 5}}, {"A.Max(1)", Int16, Shape{2, 3}, []int{1}, Shape{2}, []int16{2, 5}}, {"A.Max(0,1)", Int16, Shape{2, 3}, []int{0, 1}, ScalarShape(), int16(5)}, {"A.Max(1,0)", Int16, Shape{2, 3}, []int{1, 0}, ScalarShape(), int16(5)}, {"3T.Max(1,2)", Int16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int16{11, 23}}, + {"4T.Max()", Int16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int16(15)}, + {"4T.Max(1,3)", Int16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int16{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int16{11, 15}}, {"common case: T.Max() for int32", Int32, Shape{2, 3}, []int{}, ScalarShape(), int32(5)}, {"A.Max(0)", Int32, Shape{2, 3}, []int{0}, Shape{3}, []int32{3, 4, 5}}, {"A.Max(1)", Int32, Shape{2, 3}, []int{1}, Shape{2}, []int32{2, 5}}, {"A.Max(0,1)", Int32, Shape{2, 3}, []int{0, 1}, ScalarShape(), int32(5)}, {"A.Max(1,0)", Int32, Shape{2, 3}, []int{1, 0}, ScalarShape(), int32(5)}, {"3T.Max(1,2)", Int32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int32{11, 23}}, + {"4T.Max()", Int32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int32(15)}, + {"4T.Max(1,3)", Int32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int32{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int32{11, 15}}, {"common case: T.Max() for int64", Int64, Shape{2, 3}, []int{}, ScalarShape(), int64(5)}, {"A.Max(0)", Int64, Shape{2, 3}, []int{0}, Shape{3}, []int64{3, 4, 5}}, {"A.Max(1)", Int64, Shape{2, 3}, []int{1}, Shape{2}, []int64{2, 5}}, {"A.Max(0,1)", Int64, Shape{2, 3}, []int{0, 1}, ScalarShape(), int64(5)}, {"A.Max(1,0)", Int64, Shape{2, 3}, []int{1, 0}, ScalarShape(), int64(5)}, {"3T.Max(1,2)", Int64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int64{11, 23}}, + {"4T.Max()", Int64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int64(15)}, + {"4T.Max(1,3)", Int64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int64{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Int64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int64{11, 15}}, {"common case: T.Max() for uint", Uint, Shape{2, 3}, []int{}, ScalarShape(), uint(5)}, {"A.Max(0)", Uint, Shape{2, 3}, []int{0}, Shape{3}, []uint{3, 4, 5}}, {"A.Max(1)", Uint, Shape{2, 3}, []int{1}, Shape{2}, []uint{2, 5}}, {"A.Max(0,1)", Uint, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint(5)}, {"A.Max(1,0)", Uint, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint(5)}, {"3T.Max(1,2)", Uint, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint{11, 23}}, + {"4T.Max()", Uint, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint(15)}, + {"4T.Max(1,3)", Uint, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint{11, 15}}, {"common case: T.Max() for uint8", Uint8, Shape{2, 3}, []int{}, ScalarShape(), uint8(5)}, {"A.Max(0)", Uint8, Shape{2, 3}, []int{0}, Shape{3}, []uint8{3, 4, 5}}, {"A.Max(1)", Uint8, Shape{2, 3}, []int{1}, Shape{2}, []uint8{2, 5}}, {"A.Max(0,1)", Uint8, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint8(5)}, {"A.Max(1,0)", Uint8, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint8(5)}, {"3T.Max(1,2)", Uint8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint8{11, 23}}, + {"4T.Max()", Uint8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint8(15)}, + {"4T.Max(1,3)", Uint8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint8{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint8{11, 15}}, {"common case: T.Max() for uint16", Uint16, Shape{2, 3}, []int{}, ScalarShape(), uint16(5)}, {"A.Max(0)", Uint16, Shape{2, 3}, []int{0}, Shape{3}, []uint16{3, 4, 5}}, {"A.Max(1)", Uint16, Shape{2, 3}, []int{1}, Shape{2}, []uint16{2, 5}}, {"A.Max(0,1)", Uint16, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint16(5)}, {"A.Max(1,0)", Uint16, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint16(5)}, {"3T.Max(1,2)", Uint16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint16{11, 23}}, + {"4T.Max()", Uint16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint16(15)}, + {"4T.Max(1,3)", Uint16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint16{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint16{11, 15}}, {"common case: T.Max() for uint32", Uint32, Shape{2, 3}, []int{}, ScalarShape(), uint32(5)}, {"A.Max(0)", Uint32, Shape{2, 3}, []int{0}, Shape{3}, []uint32{3, 4, 5}}, {"A.Max(1)", Uint32, Shape{2, 3}, []int{1}, Shape{2}, []uint32{2, 5}}, {"A.Max(0,1)", Uint32, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint32(5)}, {"A.Max(1,0)", Uint32, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint32(5)}, {"3T.Max(1,2)", Uint32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint32{11, 23}}, + {"4T.Max()", Uint32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint32(15)}, + {"4T.Max(1,3)", Uint32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint32{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint32{11, 15}}, {"common case: T.Max() for uint64", Uint64, Shape{2, 3}, []int{}, ScalarShape(), uint64(5)}, {"A.Max(0)", Uint64, Shape{2, 3}, []int{0}, Shape{3}, []uint64{3, 4, 5}}, {"A.Max(1)", Uint64, Shape{2, 3}, []int{1}, Shape{2}, []uint64{2, 5}}, {"A.Max(0,1)", Uint64, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint64(5)}, {"A.Max(1,0)", Uint64, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint64(5)}, {"3T.Max(1,2)", Uint64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint64{11, 23}}, + {"4T.Max()", Uint64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint64(15)}, + {"4T.Max(1,3)", Uint64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint64{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Uint64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint64{11, 15}}, {"common case: T.Max() for float32", Float32, Shape{2, 3}, []int{}, ScalarShape(), float32(5)}, {"A.Max(0)", Float32, Shape{2, 3}, []int{0}, Shape{3}, []float32{3, 4, 5}}, {"A.Max(1)", Float32, Shape{2, 3}, []int{1}, Shape{2}, []float32{2, 5}}, {"A.Max(0,1)", Float32, Shape{2, 3}, []int{0, 1}, ScalarShape(), float32(5)}, {"A.Max(1,0)", Float32, Shape{2, 3}, []int{1, 0}, ScalarShape(), float32(5)}, {"3T.Max(1,2)", Float32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float32{11, 23}}, + {"4T.Max()", Float32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float32(15)}, + {"4T.Max(1,3)", Float32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float32{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Float32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float32{11, 15}}, {"common case: T.Max() for float64", Float64, Shape{2, 3}, []int{}, ScalarShape(), float64(5)}, {"A.Max(0)", Float64, Shape{2, 3}, []int{0}, Shape{3}, []float64{3, 4, 5}}, {"A.Max(1)", Float64, Shape{2, 3}, []int{1}, Shape{2}, []float64{2, 5}}, {"A.Max(0,1)", Float64, Shape{2, 3}, []int{0, 1}, ScalarShape(), float64(5)}, {"A.Max(1,0)", Float64, Shape{2, 3}, []int{1, 0}, ScalarShape(), float64(5)}, {"3T.Max(1,2)", Float64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float64{11, 23}}, + {"4T.Max()", Float64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float64(15)}, + {"4T.Max(1,3)", Float64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float64{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", Float64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float64{11, 15}}, } func TestDense_Max(t *testing.T) { @@ -346,72 +422,108 @@ var minTests = []struct { {"A.Min(0,1)", Int, Shape{2, 3}, []int{0, 1}, ScalarShape(), int(0)}, {"A.Min(1,0)", Int, Shape{2, 3}, []int{1, 0}, ScalarShape(), int(0)}, {"3T.Min(1,2)", Int, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int{0, 12}}, + {"4T.Min()", Int, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int(0)}, + {"4T.Min(1,3)", Int, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int{0, 4}}, {"common case: T.Min() for int8", Int8, Shape{2, 3}, []int{}, ScalarShape(), int8(0)}, {"A.Min(0)", Int8, Shape{2, 3}, []int{0}, Shape{3}, []int8{0, 1, 2}}, {"A.Min(1)", Int8, Shape{2, 3}, []int{1}, Shape{2}, []int8{0, 3}}, {"A.Min(0,1)", Int8, Shape{2, 3}, []int{0, 1}, ScalarShape(), int8(0)}, {"A.Min(1,0)", Int8, Shape{2, 3}, []int{1, 0}, ScalarShape(), int8(0)}, {"3T.Min(1,2)", Int8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int8{0, 12}}, + {"4T.Min()", Int8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int8(0)}, + {"4T.Min(1,3)", Int8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int8{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int8{0, 4}}, {"common case: T.Min() for int16", Int16, Shape{2, 3}, []int{}, ScalarShape(), int16(0)}, {"A.Min(0)", Int16, Shape{2, 3}, []int{0}, Shape{3}, []int16{0, 1, 2}}, {"A.Min(1)", Int16, Shape{2, 3}, []int{1}, Shape{2}, []int16{0, 3}}, {"A.Min(0,1)", Int16, Shape{2, 3}, []int{0, 1}, ScalarShape(), int16(0)}, {"A.Min(1,0)", Int16, Shape{2, 3}, []int{1, 0}, ScalarShape(), int16(0)}, {"3T.Min(1,2)", Int16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int16{0, 12}}, + {"4T.Min()", Int16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int16(0)}, + {"4T.Min(1,3)", Int16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int16{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int16{0, 4}}, {"common case: T.Min() for int32", Int32, Shape{2, 3}, []int{}, ScalarShape(), int32(0)}, {"A.Min(0)", Int32, Shape{2, 3}, []int{0}, Shape{3}, []int32{0, 1, 2}}, {"A.Min(1)", Int32, Shape{2, 3}, []int{1}, Shape{2}, []int32{0, 3}}, {"A.Min(0,1)", Int32, Shape{2, 3}, []int{0, 1}, ScalarShape(), int32(0)}, {"A.Min(1,0)", Int32, Shape{2, 3}, []int{1, 0}, ScalarShape(), int32(0)}, {"3T.Min(1,2)", Int32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int32{0, 12}}, + {"4T.Min()", Int32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int32(0)}, + {"4T.Min(1,3)", Int32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int32{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int32{0, 4}}, {"common case: T.Min() for int64", Int64, Shape{2, 3}, []int{}, ScalarShape(), int64(0)}, {"A.Min(0)", Int64, Shape{2, 3}, []int{0}, Shape{3}, []int64{0, 1, 2}}, {"A.Min(1)", Int64, Shape{2, 3}, []int{1}, Shape{2}, []int64{0, 3}}, {"A.Min(0,1)", Int64, Shape{2, 3}, []int{0, 1}, ScalarShape(), int64(0)}, {"A.Min(1,0)", Int64, Shape{2, 3}, []int{1, 0}, ScalarShape(), int64(0)}, {"3T.Min(1,2)", Int64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int64{0, 12}}, + {"4T.Min()", Int64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int64(0)}, + {"4T.Min(1,3)", Int64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int64{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Int64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int64{0, 4}}, {"common case: T.Min() for uint", Uint, Shape{2, 3}, []int{}, ScalarShape(), uint(0)}, {"A.Min(0)", Uint, Shape{2, 3}, []int{0}, Shape{3}, []uint{0, 1, 2}}, {"A.Min(1)", Uint, Shape{2, 3}, []int{1}, Shape{2}, []uint{0, 3}}, {"A.Min(0,1)", Uint, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint(0)}, {"A.Min(1,0)", Uint, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint(0)}, {"3T.Min(1,2)", Uint, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint{0, 12}}, + {"4T.Min()", Uint, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint(0)}, + {"4T.Min(1,3)", Uint, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint{0, 4}}, {"common case: T.Min() for uint8", Uint8, Shape{2, 3}, []int{}, ScalarShape(), uint8(0)}, {"A.Min(0)", Uint8, Shape{2, 3}, []int{0}, Shape{3}, []uint8{0, 1, 2}}, {"A.Min(1)", Uint8, Shape{2, 3}, []int{1}, Shape{2}, []uint8{0, 3}}, {"A.Min(0,1)", Uint8, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint8(0)}, {"A.Min(1,0)", Uint8, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint8(0)}, {"3T.Min(1,2)", Uint8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint8{0, 12}}, + {"4T.Min()", Uint8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint8(0)}, + {"4T.Min(1,3)", Uint8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint8{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint8{0, 4}}, {"common case: T.Min() for uint16", Uint16, Shape{2, 3}, []int{}, ScalarShape(), uint16(0)}, {"A.Min(0)", Uint16, Shape{2, 3}, []int{0}, Shape{3}, []uint16{0, 1, 2}}, {"A.Min(1)", Uint16, Shape{2, 3}, []int{1}, Shape{2}, []uint16{0, 3}}, {"A.Min(0,1)", Uint16, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint16(0)}, {"A.Min(1,0)", Uint16, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint16(0)}, {"3T.Min(1,2)", Uint16, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint16{0, 12}}, + {"4T.Min()", Uint16, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint16(0)}, + {"4T.Min(1,3)", Uint16, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint16{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint16, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint16{0, 4}}, {"common case: T.Min() for uint32", Uint32, Shape{2, 3}, []int{}, ScalarShape(), uint32(0)}, {"A.Min(0)", Uint32, Shape{2, 3}, []int{0}, Shape{3}, []uint32{0, 1, 2}}, {"A.Min(1)", Uint32, Shape{2, 3}, []int{1}, Shape{2}, []uint32{0, 3}}, {"A.Min(0,1)", Uint32, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint32(0)}, {"A.Min(1,0)", Uint32, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint32(0)}, {"3T.Min(1,2)", Uint32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint32{0, 12}}, + {"4T.Min()", Uint32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint32(0)}, + {"4T.Min(1,3)", Uint32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint32{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint32{0, 4}}, {"common case: T.Min() for uint64", Uint64, Shape{2, 3}, []int{}, ScalarShape(), uint64(0)}, {"A.Min(0)", Uint64, Shape{2, 3}, []int{0}, Shape{3}, []uint64{0, 1, 2}}, {"A.Min(1)", Uint64, Shape{2, 3}, []int{1}, Shape{2}, []uint64{0, 3}}, {"A.Min(0,1)", Uint64, Shape{2, 3}, []int{0, 1}, ScalarShape(), uint64(0)}, {"A.Min(1,0)", Uint64, Shape{2, 3}, []int{1, 0}, ScalarShape(), uint64(0)}, {"3T.Min(1,2)", Uint64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []uint64{0, 12}}, + {"4T.Min()", Uint64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), uint64(0)}, + {"4T.Min(1,3)", Uint64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []uint64{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Uint64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []uint64{0, 4}}, {"common case: T.Min() for float32", Float32, Shape{2, 3}, []int{}, ScalarShape(), float32(0)}, {"A.Min(0)", Float32, Shape{2, 3}, []int{0}, Shape{3}, []float32{0, 1, 2}}, {"A.Min(1)", Float32, Shape{2, 3}, []int{1}, Shape{2}, []float32{0, 3}}, {"A.Min(0,1)", Float32, Shape{2, 3}, []int{0, 1}, ScalarShape(), float32(0)}, {"A.Min(1,0)", Float32, Shape{2, 3}, []int{1, 0}, ScalarShape(), float32(0)}, {"3T.Min(1,2)", Float32, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float32{0, 12}}, + {"4T.Min()", Float32, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float32(0)}, + {"4T.Min(1,3)", Float32, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float32{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Float32, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float32{0, 4}}, {"common case: T.Min() for float64", Float64, Shape{2, 3}, []int{}, ScalarShape(), float64(0)}, {"A.Min(0)", Float64, Shape{2, 3}, []int{0}, Shape{3}, []float64{0, 1, 2}}, {"A.Min(1)", Float64, Shape{2, 3}, []int{1}, Shape{2}, []float64{0, 3}}, {"A.Min(0,1)", Float64, Shape{2, 3}, []int{0, 1}, ScalarShape(), float64(0)}, {"A.Min(1,0)", Float64, Shape{2, 3}, []int{1, 0}, ScalarShape(), float64(0)}, {"3T.Min(1,2)", Float64, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []float64{0, 12}}, + {"4T.Min()", Float64, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), float64(0)}, + {"4T.Min(1,3)", Float64, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []float64{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", Float64, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []float64{0, 4}}, } func TestDense_Min(t *testing.T) { From aaf076a18e0fd2fd64687c399ce0ae1cc1cc5d17 Mon Sep 17 00:00:00 2001 From: Tiago Rodrigo Lampert Date: Mon, 21 Oct 2019 03:35:28 -0300 Subject: [PATCH 020/154] Fixed some typos (#50) * Fix typo abstrations -> abstractions * Fix typo Simliarly -> Similarly --- README.md | 2 +- example_dense_cmp_test.go | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b93aa9a..459048d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ It comes to reason then there should be a data structure that handles these thin ### Basic Idea: Tensor ### A tensor is a multidimensional array. It's like a slice, but works in multiple dimensions. -With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstrations used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). +With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstractions used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). Tensors come with their own set of usage patterns and abstractions. Most of these have analogues in slices, enumerated below (do note that certain slice operation will have more than one tensor analogue - this is due to the number of options available): diff --git a/example_dense_cmp_test.go b/example_dense_cmp_test.go index b006c19..6d72c4d 100644 --- a/example_dense_cmp_test.go +++ b/example_dense_cmp_test.go @@ -25,7 +25,7 @@ func ExampleDense_Gt_basic() { T3, _ = V.Gt(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) - // Simliarly for tensors that return the same type + // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) V = sliced.(*Dense) @@ -197,7 +197,7 @@ func ExampleDense_Gte_basic() { T3, _ = V.Gte(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) - // Simliarly for tensors that return the same type + // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) V = sliced.(*Dense) @@ -369,7 +369,7 @@ func ExampleDense_Lt_basic() { T3, _ = V.Lt(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) - // Simliarly for tensors that return the same type + // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) V = sliced.(*Dense) @@ -540,7 +540,7 @@ func ExampleDense_Lte_basic() { T3, _ = V.Lte(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) - // Simliarly for tensors that return the same type + // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) V = sliced.(*Dense) @@ -712,7 +712,7 @@ func ExampleDense_ElEq_basic() { T3, _ = V.ElEq(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) - // Simliarly for tensors that return the same type + // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) V = sliced.(*Dense) @@ -885,7 +885,7 @@ func ExampleDense_ElNe_basic() { T3, _ = V.ElNe(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) - // Simliarly for tensors that return the same type + // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) V = sliced.(*Dense) From 5338885595edda56aba4097a47badd7c9def1a2d Mon Sep 17 00:00:00 2001 From: bezineb5 Date: Sun, 8 Dec 2019 05:15:28 -0500 Subject: [PATCH 021/154] Fix scalar operations (fix for #52) (#54) * Fixed scalar operations * Fixed comparison operations with tensors which are actually scalars. Fixed copy-paste in comments. --- api_arith.go | 264 +++++++++++++++++++++++++++++--------- api_arith_test.go | 296 +++++++++++++++++++++++++++++++++++++++++++ api_cmp.go | 128 ++++++++++++++++--- api_cmp_test.go | 239 ++++++++++++++++++++++++++++++++++ genlib2/agg2_body.go | 8 +- 5 files changed, 848 insertions(+), 87 deletions(-) create mode 100644 api_cmp_test.go diff --git a/api_arith.go b/api_arith.go index b70dd9b..4e86ffa 100644 --- a/api_arith.go +++ b/api_arith.go @@ -26,19 +26,46 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { oe = at.standardEngine() switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Add(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Add(at, bt, opts...) - } - if adder, ok = at.Engine().(Adder); ok { - return adder.Add(at, bt, opts...) - } - if adder, ok = bt.Engine().(Adder); ok { - return adder.Add(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition + if oe != nil { + return oe.Add(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.Add(at, bt, opts...) + } + if adder, ok = at.Engine().(Adder); ok { + return adder.Add(at, bt, opts...) + } + if adder, ok = bt.Engine().(Adder); ok { + return adder.Add(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Add") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.AddScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.AddScalar(at, bt, leftTensor, opts...) + } + if adder, ok = at.Engine().(Adder); ok { + return adder.AddScalar(at, bt, leftTensor, opts...) + } + if adder, ok = bt.Engine().(Adder); ok { + return adder.AddScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Add") } - return nil, errors.New("Neither engines of either operand support Add") default: if oe != nil { @@ -80,19 +107,46 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { oe = at.standardEngine() switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Sub(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Sub(at, bt, opts...) - } - if suber, ok = at.Engine().(Suber); ok { - return suber.Sub(at, bt, opts...) - } - if suber, ok = bt.Engine().(Suber); ok { - return suber.Sub(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor substraction + if oe != nil { + return oe.Sub(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.Sub(at, bt, opts...) + } + if suber, ok = at.Engine().(Suber); ok { + return suber.Sub(at, bt, opts...) + } + if suber, ok = bt.Engine().(Suber); ok { + return suber.Sub(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Sub") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.SubScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.SubScalar(at, bt, leftTensor, opts...) + } + if suber, ok = at.Engine().(Suber); ok { + return suber.SubScalar(at, bt, leftTensor, opts...) + } + if suber, ok = bt.Engine().(Suber); ok { + return suber.SubScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Sub") } - return nil, errors.New("Neither engines of either operand support Sub") default: if oe != nil { @@ -149,10 +203,13 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { } return nil, errors.New("Neither engines of either operand support Mul") - } else { // one of the operands is a scalar + } else { // at least one of the operands is a scalar var leftTensor bool - if at.Shape().IsScalar() { + if !bt.Shape().IsScalar() { leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp } else { leftTensor = true // a Tensor * b Scalar-Tensor } @@ -214,19 +271,46 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { oe = at.standardEngine() switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Div(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Div(at, bt, opts...) - } - if diver, ok = at.Engine().(Diver); ok { - return diver.Div(at, bt, opts...) - } - if diver, ok = bt.Engine().(Diver); ok { - return diver.Div(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor division + if oe != nil { + return oe.Div(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.Div(at, bt, opts...) + } + if diver, ok = at.Engine().(Diver); ok { + return diver.Div(at, bt, opts...) + } + if diver, ok = bt.Engine().(Diver); ok { + return diver.Div(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Div") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.DivScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.DivScalar(at, bt, leftTensor, opts...) + } + if diver, ok = at.Engine().(Diver); ok { + return diver.DivScalar(at, bt, leftTensor, opts...) + } + if diver, ok = bt.Engine().(Diver); ok { + return diver.DivScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Div") } - return nil, errors.New("Neither engines of either operand support Div") default: if oe != nil { @@ -268,19 +352,46 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { oe = at.standardEngine() switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Pow(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Pow(at, bt, opts...) - } - if power, ok = at.Engine().(Power); ok { - return power.Pow(at, bt, opts...) - } - if power, ok = bt.Engine().(Power); ok { - return power.Pow(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor exponentiation + if oe != nil { + return oe.Pow(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.Pow(at, bt, opts...) + } + if power, ok = at.Engine().(Power); ok { + return power.Pow(at, bt, opts...) + } + if power, ok = bt.Engine().(Power); ok { + return power.Pow(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Pow") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.PowScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.PowScalar(at, bt, leftTensor, opts...) + } + if power, ok = at.Engine().(Power); ok { + return power.PowScalar(at, bt, leftTensor, opts...) + } + if power, ok = bt.Engine().(Power); ok { + return power.PowScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Pow") } - return nil, errors.New("Neither engines of either operand support Pow") default: if oe != nil { @@ -308,7 +419,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { panic("Unreachable") } -// Mod performs elementwise exponentiation on the Tensor(s). These operations are supported: +// Mod performs elementwise modulo on the Tensor(s). These operations are supported: // Mod(*Dense, scalar) // Mod(scalar, *Dense) // Mod(*Dense, *Dense) @@ -322,19 +433,46 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { oe = at.standardEngine() switch bt := b.(type) { case Tensor: - if oe != nil { - return oe.Mod(at, bt, opts...) - } - if oe = bt.standardEngine(); oe != nil { - return oe.Mod(at, bt, opts...) - } - if moder, ok = at.Engine().(Moder); ok { - return moder.Mod(at, bt, opts...) - } - if moder, ok = bt.Engine().(Moder); ok { - return moder.Mod(at, bt, opts...) + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor modulo + if oe != nil { + return oe.Mod(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.Mod(at, bt, opts...) + } + if moder, ok = at.Engine().(Moder); ok { + return moder.Mod(at, bt, opts...) + } + if moder, ok = bt.Engine().(Moder); ok { + return moder.Mod(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support Mod") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.ModScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.ModScalar(at, bt, leftTensor, opts...) + } + if moder, ok = at.Engine().(Moder); ok { + return moder.ModScalar(at, bt, leftTensor, opts...) + } + if moder, ok = bt.Engine().(Moder); ok { + return moder.ModScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support Mod") } - return nil, errors.New("Neither engines of either operand support Mod") default: if oe != nil { diff --git a/api_arith_test.go b/api_arith_test.go index 687e4b7..00bf271 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -152,3 +152,299 @@ func TestFMA(t *testing.T) { assert.Equal(t, f.Data(), f2.Data()) } + +func TestMulScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{2})) + b := New(WithBacking([]float64{3})) + var correct interface{} = 6.0 + + res, err := Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{3, 2})) + b = New(WithBacking([]float64{2})) + correct = []float64{6, 4} + + res, err = Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 5})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{21, 10} + + res, err = Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestDivScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 3.0 + + res, err := Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{3, 2} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{2, 3} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{3, 5} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestAddScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{2})) + b := New(WithBacking([]float64{3})) + var correct interface{} = 5.0 + + res, err := Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{3, 2})) + b = New(WithBacking([]float64{2})) + correct = []float64{5, 4} + + res, err = Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 5})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{10, 7} + + res, err = Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestSubScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 4.0 + + res, err := Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{4, 2} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{3, 4} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{14, 8} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestModScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{5})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 1.0 + + res, err := Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{5, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{1, 0} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{5})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{2, 1} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{22, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{1, 0} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestPowScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 36.0 + + res, err := Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{36, 16} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{216, 36} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{2187, 100} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} diff --git a/api_cmp.go b/api_cmp.go index e79398e..ffb602d 100644 --- a/api_cmp.go +++ b/api_cmp.go @@ -17,12 +17,30 @@ func Lt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { lter, ok = at.Engine().(Lter) switch bt := b.(type) { case Tensor: - if !ok { - if lter, ok = bt.Engine().(Lter); !ok { - return nil, errors.Errorf("Neither operands have engines that support Lt") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if lter, ok = bt.Engine().(Lter); !ok { + return nil, errors.Errorf("Neither operands have engines that support Lt") + } + } + + return lter.Lt(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support Lt") + } + return lter.LtScalar(at, bt, leftTensor, opts...) } - return lter.Lt(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support Lt") @@ -55,12 +73,29 @@ func Gt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { gter, ok = at.Engine().(Gter) switch bt := b.(type) { case Tensor: - if !ok { - if gter, ok = bt.Engine().(Gter); !ok { - return nil, errors.Errorf("Neither operands have engines that support Gt") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if gter, ok = bt.Engine().(Gter); !ok { + return nil, errors.Errorf("Neither operands have engines that support Gt") + } + } + return gter.Gt(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support Gt") + } + return gter.GtScalar(at, bt, leftTensor, opts...) } - return gter.Gt(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support Gt") @@ -93,12 +128,30 @@ func Lte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { lteer, ok = at.Engine().(Lteer) switch bt := b.(type) { case Tensor: - if !ok { - if lteer, ok = bt.Engine().(Lteer); !ok { - return nil, errors.Errorf("Neither operands have engines that support Lte") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if lteer, ok = bt.Engine().(Lteer); !ok { + return nil, errors.Errorf("Neither operands have engines that support Lte") + } + } + return lteer.Lte(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if !ok { + return nil, errors.Errorf("Engine does not support Lte") } + return lteer.LteScalar(at, bt, leftTensor, opts...) } - return lteer.Lte(at, bt, opts...) + default: if !ok { return nil, errors.Errorf("Engine does not support Lte") @@ -131,12 +184,29 @@ func Gte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { gteer, ok = at.Engine().(Gteer) switch bt := b.(type) { case Tensor: - if !ok { - if gteer, ok = bt.Engine().(Gteer); !ok { - return nil, errors.Errorf("Neither operands have engines that support Gte") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if gteer, ok = bt.Engine().(Gteer); !ok { + return nil, errors.Errorf("Neither operands have engines that support Gte") + } + } + return gteer.Gte(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support Gte") + } + return gteer.GteScalar(at, bt, leftTensor, opts...) } - return gteer.Gte(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support Gte") @@ -169,12 +239,30 @@ func ElEq(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { eleqer, ok = at.Engine().(ElEqer) switch bt := b.(type) { case Tensor: - if !ok { - if eleqer, ok = bt.Engine().(ElEqer); !ok { - return nil, errors.Errorf("Neither operands have engines that support ElEq") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if eleqer, ok = bt.Engine().(ElEqer); !ok { + return nil, errors.Errorf("Neither operands have engines that support ElEq") + } + } + return eleqer.ElEq(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor } + + if !ok { + return nil, errors.Errorf("Engine does not support ElEq") + } + return eleqer.EqScalar(at, bt, leftTensor, opts...) } - return eleqer.ElEq(at, bt, opts...) + default: if !ok { return nil, errors.Errorf("Engine does not support ElEq") diff --git a/api_cmp_test.go b/api_cmp_test.go new file mode 100644 index 0000000..9e785d7 --- /dev/null +++ b/api_cmp_test.go @@ -0,0 +1,239 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// This file contains the tests for API functions that aren't generated by genlib + +func TestLtScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = false + + res, err := Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{true, false} + + res, err = Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 2})) + correct = []bool{true, false} + + res, err = Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 2})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{false, true} + + res, err = Lt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestGtScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = true + + res, err := Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{false, true} + + res, err = Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 2})) + correct = []bool{false, true} + + res, err = Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 2})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{true, false} + + res, err = Gt(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestLteScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = false + + res, err := Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 2, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{true, true, false} + + res, err = Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 2})) + correct = []bool{true, false} + + res, err = Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 2})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{false, true} + + res, err = Lte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestGteScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = true + + res, err := Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 2, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{false, true, true} + + res, err = Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 3, 2})) + correct = []bool{false, true, true} + + res, err = Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 31, 2})) + b = New(WithBacking([]float64{7, 31, 10})) + correct = []bool{true, true, false} + + res, err = Gte(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestElEqScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = false + + res, err := ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{1, 2, 4})) + b = New(WithBacking([]float64{2})) + correct = []bool{false, true, false} + + res, err = ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{3})) + b = New(WithBacking([]float64{6, 3, 2})) + correct = []bool{false, true, false} + + res, err = ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 10})) + correct = []bool{false, true} + + res, err = ElEq(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 6c1716c..cef73e2 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -114,7 +114,7 @@ const agg2BodyRaw = `if useIter { retVal = a default: {{if .VV -}} - if swap{ + if swap { retVal = b.Clone().(Tensor) }else{ retVal = a.Clone().(Tensor) @@ -124,7 +124,7 @@ const agg2BodyRaw = `if useIter { retVal = a.Clone().(Tensor) if leftTensor { err = e.E.{{.Name}}Iter(typ, retVal.hdr(), dataB, ait, bit) - }else { + } else { err = e.E.{{.Name}}Iter(typ, dataA, retVal.hdr(), ait, bit) } {{end -}} @@ -166,8 +166,8 @@ const agg2BodyRaw = `if useIter { retVal = a.Clone().(Tensor) if leftTensor { err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) - }else { - err = e.E.{{.Name}}(typ, dataA, retVal.hdr()) + } else { + err = e.E.{{.Name}}(typ, dataA, retVal.hdr()) } {{end -}} } From 0025a335fcc1a52d5311d2503c0be14b8e776afc Mon Sep 17 00:00:00 2001 From: bezineb5 Date: Sun, 8 Dec 2019 05:25:36 -0500 Subject: [PATCH 022/154] Synchronized genlib2 with changes in the generated files. (#53) Added some documentation for the development dependencies. --- README.md | 4 +++- dense_io_test.go | 9 ++++++++- dense_reduction_test.go | 4 +++- genlib2/dense_io.go | 16 ++++++++-------- genlib2/dense_reduction_methods_tests.go | 9 +++++++++ 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 459048d..66d6372 100644 --- a/README.md +++ b/README.md @@ -218,8 +218,10 @@ One could argue that this sidesteps the compiler's type checking system, deferri Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. # How This Package is Developed # -Much of the code in this package is generated. The code to generate them is in the directory `genlib2`. +Much of the code in this package is generated. The code to generate them is in the directory `genlib2`. `genlib2` requires [`goimports`](https://godoc.org/golang.org/x/tools/cmd/goimports) binary to be available in the $PATH. +## Tests ## +Tests require python with numpy installed. You can select which python intepreter is being used by setting the environment variable `PYTHON_COMMAND` accordingly. The default value is `python`. ## Things Knowingly Untested For ## - `complex64` and `complex128` are excluded from quick check generation process [Issue #11](https://github.com/gorgonia/tensor/issues/11) diff --git a/dense_io_test.go b/dense_io_test.go index 482b170..cdbe610 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -24,7 +24,13 @@ func TestSaveLoadNumpy(t *testing.T) { script := "import numpy as np\nx = np.load('test.npy')\nprint(x)" - cmd := exec.Command("python") + // Configurable python command, in order to be able to use python or python3 + pythonCommand := os.Getenv("PYTHON_COMMAND") + if pythonCommand == "" { + pythonCommand = "python" + } + + cmd := exec.Command(pythonCommand) stdin, err := cmd.StdinPipe() if err != nil { t.Error(err) @@ -40,6 +46,7 @@ func TestSaveLoadNumpy(t *testing.T) { if err = cmd.Start(); err != nil { t.Error(err) + t.Logf("Do you have a python with numpy installed? You can change the python interpreter by setting the environment variable PYTHON_COMMAND. Current value: PYTHON_COMMAND=%s", pythonCommand) } if err := cmd.Wait(); err != nil { diff --git a/dense_reduction_test.go b/dense_reduction_test.go index f4bda25..b10e3ac 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "gorgonia.org/tensor/internal/execution" ) @@ -139,6 +138,9 @@ var sumTests = []struct { {"A.Sum(0,1) for int8", Int8, Shape{2, 3}, []int{0, 1}, ScalarShape(), int8(15)}, {"A.Sum(1,0) for int8", Int8, Shape{2, 3}, []int{1, 0}, ScalarShape(), int8(15)}, {"3T.Sum(1,2) for int8", Int8, Shape{2, 3, 4}, []int{1, 2}, Shape{2}, []int8{66, -46}}, + {"4T.Sum() for int8", Int8, Shape{2, 2, 2, 2}, []int{}, ScalarShape(), int8(120)}, + {"4T.Sum(1,3) for int8", Int8, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []int8{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for int8", Int8, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []int8{44, 76}}, {"common case: T.Sum() for int16", Int16, Shape{2, 3}, []int{}, ScalarShape(), int16(15)}, {"A.Sum(0) for int16", Int16, Shape{2, 3}, []int{0}, Shape{3}, []int16{3, 5, 7}}, {"A.Sum(1) for int16", Int16, Shape{2, 3}, []int{1}, Shape{2}, []int16{3, 12}}, diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 2cad2d2..7f975ce 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -456,13 +456,13 @@ func (t *Dense) FBEncode() ([]byte, error) { var o uint32 switch { - case t.o.isRowMajor() && t.o.isContiguous(): + case t.o.IsRowMajor() && t.o.IsContiguous(): o = 0 - case t.o.isRowMajor() && !t.o.isContiguous(): + case t.o.IsRowMajor() && !t.o.IsContiguous(): o = 1 - case t.o.isColMajor() && t.o.isContiguous(): + case t.o.IsColMajor() && t.o.IsContiguous(): o = 2 - case t.o.isColMajor() && !t.o.isContiguous(): + case t.o.IsColMajor() && !t.o.IsContiguous(): o = 3 } @@ -571,13 +571,13 @@ func (t *Dense) PBEncode() ([]byte, error) { } switch { - case t.o.isRowMajor() && t.o.isContiguous(): + case t.o.IsRowMajor() && t.o.IsContiguous(): toSerialize.O = pb.RowMajorContiguous - case t.o.isRowMajor() && !t.o.isContiguous(): + case t.o.IsRowMajor() && !t.o.IsContiguous(): toSerialize.O = pb.RowMajorNonContiguous - case t.o.isColMajor() && t.o.isContiguous(): + case t.o.IsColMajor() && t.o.IsContiguous(): toSerialize.O = pb.ColMajorContiguous - case t.o.isColMajor() && !t.o.isContiguous(): + case t.o.IsColMajor() && !t.o.IsContiguous(): toSerialize.O = pb.ColMajorNonContiguous } toSerialize.T = pb.Triangle(t.Δ) diff --git a/genlib2/dense_reduction_methods_tests.go b/genlib2/dense_reduction_methods_tests.go index 62d042f..30defc9 100644 --- a/genlib2/dense_reduction_methods_tests.go +++ b/genlib2/dense_reduction_methods_tests.go @@ -23,6 +23,9 @@ const testDenseSumRaw = `var sumTests = []struct { {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, + {"4T.Sum() for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)}, + {"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}}, {{end -}} {{end -}} } @@ -65,6 +68,9 @@ const testDenseMaxRaw = `var maxTests = []struct { {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, + {"4T.Max()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)}, + {"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}}, {{end -}} {{end -}} {{end -}} @@ -108,6 +114,9 @@ const testDenseMinRaw = `var minTests = []struct { {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, + {"4T.Min()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)}, + {"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}}, {{end -}} {{end -}} {{end -}} From 7abc187f50e3bb1c986fc996fac7564b98597de3 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sun, 8 Dec 2019 21:39:24 +1100 Subject: [PATCH 023/154] Cleaned up pointer semantics (#55) * Cleaned up pointer semantics so that all pointer arithmetics are correctly done now * Fixed travis yml --- .travis.yml | 1 + array.go | 12 +++++------- array_getset.go | 7 +++---- consopt.go | 3 +-- genlib2/array_getset.go | 13 ++++++------- genlib2/dense_io.go | 20 ++++++++++---------- testutils_test.go | 10 ++++++++++ 7 files changed, 36 insertions(+), 30 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1bec54c..ee216a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,7 @@ go: - 1.10.x - 1.11.x - 1.12.x + - 1.13.x - tip env: diff --git a/array.go b/array.go index d96b56f..7e7116c 100644 --- a/array.go +++ b/array.go @@ -49,8 +49,7 @@ func arrayFromSlice(x interface{}) array { elT := xT.Elem() xV := reflect.ValueOf(x) - ptr := xV.Pointer() - uptr := unsafe.Pointer(ptr) + uptr := unsafe.Pointer(xV.Pointer()) return array{ Header: storage.Header{ @@ -71,8 +70,8 @@ func (a *array) fromSlice(x interface{}) { } elT := xT.Elem() xV := reflect.ValueOf(x) - ptr := xV.Pointer() - uptr := unsafe.Pointer(ptr) + uptr := unsafe.Pointer(xV.Pointer()) + a.Ptr = uptr a.L = xV.Len() a.C = xV.Cap() @@ -127,7 +126,6 @@ func (a array) byteSlice() []byte { // sliceInto creates a slice. Instead of returning an array, which would cause a lot of reallocations, sliceInto expects a array to // already have been created. This allows repetitive actions to be done without having to have many pointless allocation func (a *array) sliceInto(i, j int, res *array) { - base := uintptr(a.Ptr) c := a.C if i < 0 || j < i || j > c { @@ -138,10 +136,10 @@ func (a *array) sliceInto(i, j int, res *array) { res.C = c - i if c-1 > 0 { - res.Ptr = storage.ElementAt(i, unsafe.Pointer(base), a.t.Size()) + res.Ptr = storage.ElementAt(i, a.Ptr, a.t.Size()) } else { // don't advance pointer - res.Ptr = unsafe.Pointer(base) + res.Ptr = a.Ptr } res.fix() } diff --git a/array_getset.go b/array_getset.go index 8896a69..e016823 100644 --- a/array_getset.go +++ b/array_getset.go @@ -68,8 +68,7 @@ func (a *array) Set(i int, x interface{}) { a.SetUnsafePointer(i, xv) default: xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - want := ptr + uintptr(i)*a.t.Size() + want := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i)*a.t.Size()) val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) @@ -116,8 +115,8 @@ func (a *array) Get(i int) interface{} { case reflect.UnsafePointer: return a.GetUnsafePointer(i) default: - at := uintptr(a.Ptr) + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t.Type, unsafe.Pointer(at)) + at := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i)*a.t.Size()) + val := reflect.NewAt(a.t.Type, at) val = reflect.Indirect(val) return val.Interface() } diff --git a/consopt.go b/consopt.go index ee4b4cf..332e0c1 100644 --- a/consopt.go +++ b/consopt.go @@ -110,8 +110,7 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { xv := reflect.New(xt) xvi := reflect.Indirect(xv) xvi.Set(reflect.ValueOf(x)) - ptr := xv.Pointer() - uptr := unsafe.Pointer(ptr) + uptr := unsafe.Pointer(xv.Pointer()) tt.array.Ptr = uptr tt.array.L = 1 diff --git a/genlib2/array_getset.go b/genlib2/array_getset.go index f24b19c..c21c767 100644 --- a/genlib2/array_getset.go +++ b/genlib2/array_getset.go @@ -26,15 +26,15 @@ func (a *array) Get(i int) interface{} { {{end -}} {{end -}} default: - at := uintptr(a.Ptr) + uintptr(i) * a.t.Size() - val := reflect.NewAt(a.t.Type, unsafe.Pointer(at)) + at := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i) * a.t.Size()) + val := reflect.NewAt(a.t.Type, at) val = reflect.Indirect(val) return val.Interface() } } ` -const setRaw = `// Set sets the value of the underlying array at the index i. +const setRaw = `// Set sets the value of the underlying array at the index i. func (a *array) Set(i int, x interface{}) { switch a.t.Kind() { {{range .Kinds -}} @@ -47,8 +47,7 @@ func (a *array) Set(i int, x interface{}) { {{end -}} default: xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - want := ptr + uintptr(i)*a.t.Size() + want := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i)*a.t.Size()) val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) val = reflect.Indirect(val) val.Set(xv) @@ -75,7 +74,7 @@ func (a *array) Memset(x interface{}) error { {{end -}} {{end -}} } - + xv := reflect.ValueOf(x) ptr := uintptr(a.Ptr) for i := 0; i < a.L; i++ { @@ -192,7 +191,7 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { } data := t.{{sliceOf .}} for i, err = it.Next(); err == nil; i, err = it.Next(){ - data[i] = xv + data[i] = xv } err = handleNoOp(err) {{end -}} diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 7f975ce..73754df 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -56,7 +56,7 @@ func (r *binaryReader) Err() error { // http://docs.scipy.org/doc/numpy/neps/npy-format.html // // Gorgonia specifically uses Version 1.0, as 65535 bytes should be more than enough for the headers. -// The values are written in little endian order, because let's face it - +// The values are written in little endian order, because let's face it - // 90% of the world's computers are running on x86+ processors. // // This method does not close the writer. Closing (if needed) is deferred to the caller @@ -64,7 +64,7 @@ func (r *binaryReader) Err() error { func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string if npdt, err = t.t.numpyDtype(); err != nil{ - return + return } header := "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" @@ -86,7 +86,7 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { bw.seq = 0 if t.IsMasked(){ fillval:=t.FillValue() - it := FlatMaskedIteratorFromDense(t) + it := FlatMaskedIteratorFromDense(t) for i, err := it.Next(); err == nil; i, err = it.Next() { if t.mask[i] { bw.w(fillval) @@ -119,7 +119,7 @@ func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { } cw := csv.NewWriter(w) - it := IteratorFromDense(t) + it := IteratorFromDense(t) coord := it.Coord() // rows := t.Shape()[0] @@ -195,7 +195,7 @@ func (t *Dense) GobEncode() (p []byte, err error){ if err = encoder.Encode(&data); err != nil { return } - + return buf.Bytes(), err } ` @@ -205,7 +205,7 @@ func (t *Dense) GobDecode(p []byte) (err error){ buf := bytes.NewBuffer(p) decoder := gob.NewDecoder(buf) - + var shape Shape if err = decoder.Decode(&shape); err != nil { return @@ -232,12 +232,12 @@ func (t *Dense) GobDecode(p []byte) (err error){ if err = decoder.Decode(&mask); err != nil { return } - + var data interface{} if err = decoder.Decode(&data); err != nil { return } - + t.fromSlice(data) t.addMask(mask) t.fix() @@ -298,7 +298,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ return errors.New("No shape information found in npy file") } sizesStr := strings.Split(string(match[1]), ",") - + var shape Shape for _, s := range sizesStr { @@ -339,7 +339,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ } ` -const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. +const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. // If into is nil, then a backing slice will be created. func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { var err error diff --git a/testutils_test.go b/testutils_test.go index 71a43a4..20cdda9 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -1,6 +1,7 @@ package tensor import ( + "bytes" "errors" "math" "math/cmplx" @@ -562,3 +563,12 @@ func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{} } return true } + +// DummyState is a dummy fmt.State, used to debug things +type DummyState struct { + *bytes.Buffer +} + +func (d *DummyState) Width() (int, bool) { return 0, false } +func (d *DummyState) Precision() (int, bool) { return 0, false } +func (d *DummyState) Flag(c int) bool { return false } From a6e7a3d4d545be09210a26da01ae7883aad377a2 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sun, 5 Jan 2020 00:21:01 +1100 Subject: [PATCH 024/154] Fixes a bug in `ShallowClone` (#56) --- dense.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dense.go b/dense.go index 69f4d70..09847ef 100644 --- a/dense.go +++ b/dense.go @@ -354,6 +354,12 @@ func (t *Dense) ShallowClone() *Dense { t.AP.CloneTo(&retVal.AP) retVal.flag = t.flag retVal.array = t.array + + retVal.old = t.old + retVal.transposeWith = t.transposeWith + retVal.viewOf = t.viewOf + retVal.mask = t.mask + retVal.maskIsSoft = t.maskIsSoft return retVal } From b702564ac20091d8513d68ed3bdb10d357a99935 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Fri, 10 Jan 2020 09:43:18 +1100 Subject: [PATCH 025/154] Added example for Apply and mutation of data (#58) --- example_apply_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 example_apply_test.go diff --git a/example_apply_test.go b/example_apply_test.go new file mode 100644 index 0000000..1e11641 --- /dev/null +++ b/example_apply_test.go @@ -0,0 +1,37 @@ +package tensor_test + +import ( + "fmt" + + "gorgonia.org/tensor" +) + +func ExampleDense_Apply() { + a := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) + cube := func(a float64) float64 { return a * a * a } + + b, err := a.Apply(cube) + if err != nil { + fmt.Printf("b is an error %v", err) + } + fmt.Printf("a and b are the same object - %t\n", a.Eq(b)) + fmt.Printf("a is unmutated\n%v\n", a) + + c, err := a.Apply(cube, tensor.WithReuse(a)) + if err != nil { + fmt.Printf("c is an error %v\n", err) + } + fmt.Printf("a and c are the same object - %t\n", a.Eq(c)) + + fmt.Printf("a is now mutated\n%v\n", a) + // Output: + // a and b are the same object - false + // a is unmutated + // ⎡1 2⎤ + // ⎣3 4⎦ + // + // a and c are the same object - true + // a is now mutated + // ⎡ 1 8⎤ + // ⎣27 64⎦ +} From 37ecf88a8f1c5944876f5ba285e9ee9d3549e55c Mon Sep 17 00:00:00 2001 From: Darrell <577768+cfgt@users.noreply.github.com> Date: Fri, 10 Jan 2020 21:41:38 +1100 Subject: [PATCH 026/154] Change license from MIT to Apache license (#59) --- LICENSE | 202 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c7a1c7b --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Gorgonia Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. From 63215a9b99d87d3c2650d9aa0acd883f7b05993d Mon Sep 17 00:00:00 2001 From: Chewxy Date: Mon, 2 Mar 2020 10:07:10 +1100 Subject: [PATCH 027/154] Fixed Engine stuff. Also Go 1.14 fixed (#61) * Fixed Engine stuff. Also Go 1.14 fixed * Fixed Travis (Gonum only supports the two most recent versions of Go) --- .travis.yml | 4 +--- consopt.go | 1 + defaultengine_linalg.go | 2 +- defaultengine_matop_misc.go | 4 ++-- dense_linalg.go | 6 +++--- dense_matop.go | 2 +- dense_test.go | 19 ------------------- dense_views.go | 2 +- known_race_test.go | 33 +++++++++++++++++++++++++++++++++ sparse.go | 2 +- 10 files changed, 44 insertions(+), 31 deletions(-) create mode 100644 known_race_test.go diff --git a/.travis.yml b/.travis.yml index ee216a5..2e9a505 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,8 @@ branches: only: - master go: - - 1.10.x - - 1.11.x - - 1.12.x - 1.13.x + - 1.14.x - tip env: diff --git a/consopt.go b/consopt.go index 332e0c1..19d47ad 100644 --- a/consopt.go +++ b/consopt.go @@ -140,6 +140,7 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { // Memory must be manually managed by the caller. // Tensors called with this construction option will not be returned to any pool - rather, all references to the pointers will be null'd. // Use with caution. +//go:nocheckptr func FromMemory(ptr uintptr, memsize uintptr) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index c75f4a7..5e0ecd3 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -355,7 +355,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // extract values var um, vm mat.Dense - s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}) + s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}, WithEngine(e)) svd.Values(s.Data().([]float64)) if uv { svd.UTo(&um) diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 06f53ae..15c5f2d 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -33,7 +33,7 @@ func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseT axis = 0 } - d := recycledDense(t.Dtype(), newShape) + d := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) var outers int if t.IsScalar() { @@ -155,7 +155,7 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") } - retVal := recycledDense(a.Dtype(), newShape) + retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) if isMasked { retVal.makeMask() } diff --git a/dense_linalg.go b/dense_linalg.go index 6493808..c5362c5 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -88,7 +88,7 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err } if retVal == nil { - retVal = recycledDense(t.t, expectedShape) + retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) if t.o.IsColMajor() { AsFortran(nil)(retVal) } @@ -137,7 +137,7 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) } if retVal == nil { - retVal = recycledDense(t.t, expectedShape) + retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) if t.o.IsColMajor() { AsFortran(nil)(retVal) } @@ -176,7 +176,7 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) } if retVal == nil { - retVal = recycledDense(t.t, expectedShape) + retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) if t.o.IsColMajor() { AsFortran(nil)(retVal) } diff --git a/dense_matop.go b/dense_matop.go index 1a3b815..43e1967 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -73,7 +73,7 @@ func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { } } - retVal = recycledDense(t.t, Shape{t.len()}) + retVal = recycledDense(t.t, Shape{t.len()}, WithEngine(t.e)) copyDense(retVal, t) retVal.e = t.e diff --git a/dense_test.go b/dense_test.go index bcb9ba3..d3f43a3 100644 --- a/dense_test.go +++ b/dense_test.go @@ -5,7 +5,6 @@ import ( "testing" "testing/quick" "time" - "unsafe" "github.com/stretchr/testify/assert" ) @@ -90,24 +89,6 @@ func TestFromScalar(t *testing.T) { assert.Equal(t, []float64{3.14}, data) } -func TestFromMemory(t *testing.T) { - // dummy memory - this could be an externally malloc'd memory, or a mmap'ed file. - // but here we're just gonna let Go manage memory. - s := make([]float64, 100) - ptr := uintptr(unsafe.Pointer(&s[0])) - size := uintptr(100 * 8) - - T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size)) - if len(T.Float32s()) != 200 { - t.Error("expected 200 Float32s") - } - assert.Equal(t, make([]float32, 200), T.Data()) - assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1)) - - fail := func() { New(FromMemory(ptr, size), Of(Float32)) } - assert.Panics(t, fail, "Expected bad New() call to panic") -} - func Test_recycledDense(t *testing.T) { T := recycledDense(Float64, ScalarShape()) assert.Equal(t, float64(0), T.Data()) diff --git a/dense_views.go b/dense_views.go index d56ee8c..201ff20 100644 --- a/dense_views.go +++ b/dense_views.go @@ -9,7 +9,7 @@ func (t *Dense) Materialize() Tensor { return t } - retVal := recycledDense(t.t, t.shape.Clone()) + retVal := recycledDense(t.t, t.shape.Clone(), WithEngine(t.e)) copyDenseIter(retVal, t, nil, nil) retVal.e = t.e retVal.oe = t.oe diff --git a/known_race_test.go b/known_race_test.go new file mode 100644 index 0000000..cb9e265 --- /dev/null +++ b/known_race_test.go @@ -0,0 +1,33 @@ +// +build !race + +package tensor + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +// This test will fail the `go test -race`. +// +// This is because FromMemory() will use uintptr in a way that is incorrect according to the checkptr directive of Go 1.14+ +// +// Though it's incorrect, it's the only way to use heterogenous, readable memory (i.e. CUDA). +func TestFromMemory(t *testing.T) { + // dummy memory - this could be an externally malloc'd memory, or a mmap'ed file. + // but here we're just gonna let Go manage memory. + s := make([]float64, 100) + ptr := uintptr(unsafe.Pointer(&s[0])) + size := uintptr(100 * 8) + + T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size)) + if len(T.Float32s()) != 200 { + t.Error("expected 200 Float32s") + } + assert.Equal(t, make([]float32, 200), T.Data()) + assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1)) + + fail := func() { New(FromMemory(ptr, size), Of(Float32)) } + assert.Panics(t, fail, "Expected bad New() call to panic") +} diff --git a/sparse.go b/sparse.go index 1e46586..3126843 100644 --- a/sparse.go +++ b/sparse.go @@ -331,7 +331,7 @@ func (t *CS) Dense() *Dense { // use } - d := recycledDense(t.t, t.Shape().Clone()) + d := recycledDense(t.t, t.Shape().Clone(), WithEngine(t.e)) if t.o.IsColMajor() { for i := 0; i < len(t.indptr)-1; i++ { for j := t.indptr[i]; j < t.indptr[i+1]; j++ { From 31c42912a8712a5d3a7d05efa0f648434b088e46 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 10 Mar 2020 11:33:59 +1100 Subject: [PATCH 028/154] Some sketch in simplifying `tensor.Tensor` --- tensor.go | 9 ++++----- utils.go | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tensor.go b/tensor.go index a04e425..9f3c619 100644 --- a/tensor.go +++ b/tensor.go @@ -5,7 +5,6 @@ package tensor // import "gorgonia.org/tensor" import ( "encoding/gob" "fmt" - "io" "unsafe" "github.com/pkg/errors" @@ -74,10 +73,10 @@ type Tensor interface { fmt.Stringer // all Tensors are serializable to these formats - WriteNpy(io.Writer) error - ReadNpy(io.Reader) error - gob.GobEncoder - gob.GobDecoder + //WriteNpy(io.Writer) error + //ReadNpy(io.Reader) error + //gob.GobEncoder + //gob.GobDecoder standardEngine() standardEngine headerer diff --git a/utils.go b/utils.go index 8e62448..98f7546 100644 --- a/utils.go +++ b/utils.go @@ -55,11 +55,11 @@ func ProdInts(a []int) (retVal int) { // if len(a) != len(b) { // return false // } - +// // if (a == nil) != (b == nil) { // return false // } - +// // b = b[:len(a)] // for i, v := range a { // if v != b[i] { From 956f24a67618b177b6fffaa743ef68046e1802c6 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 16 Mar 2020 13:09:22 +1100 Subject: [PATCH 029/154] Added Scalar type to prepare for v0.10.0 --- scalar.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 scalar.go diff --git a/scalar.go b/scalar.go new file mode 100644 index 0000000..05cc219 --- /dev/null +++ b/scalar.go @@ -0,0 +1,81 @@ +// +build ignore + +package tensor + +import ( + "fmt" + "io" + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +var _ Tensor = Scalar{} + +// Scalar is a representation of a scalar value on the CPU. +type Scalar struct{ v interface{} } + +func MakeScalar(v interface{}) Scalar { + if s, ok := v.(Scalar); ok { + return s + } + if s, ok := v.(*Scalar); ok { + return Scalar{s.v} + } + return Scalar{v} +} + +func (s Scalar) Shape() Shape { return ScalarShape() } +func (s Scalar) Strides() []int { return nil } +func (s Scalar) Dtype() Dtype { return Dtype{reflect.TypeOf(s.v)} } +func (s Scalar) Dims() int { return 0 } +func (s Scalar) Size() int { return 0 } // TODO +func (s Scalar) DataSize() int { return 0 } +func (s Scalar) RequiresIterator() bool { return false } +func (s Scalar) Iterator() Iterator { return nil } +func (s Scalar) DataOrder() DataOrder { return 0 } // TODO + +func (s Scalar) Slice(...Slice) (View, error) { return nil, errors.New("Cannot slice a scalar") } +func (s Scalar) At(at ...int) (interface{}, error) { return nil, errors.New("Get a value of a scalar") } +func (s Scalar) SetAt(_ interface{}, _ ...int) error { return errors.New("Cannot set value of scalar") } +func (s Scalar) Reshape(_ ...int) error { return errors.New("Cannot reshape a scalar") } +func (s Scalar) T(_ ...int) error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) UT() {} +func (s Scalar) Transpose() error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { + // TODO + return nil, errors.New("Cannot apply ") +} + +func (s Scalar) Zero() {} //TODO +func (s Scalar) Memset(interface{}) error { return errors.New("Cannot Memset") } +func (s Scalar) Data() interface{} { return s.v } +func (s Scalar) Eq(other interface{}) bool { return s == other } +func (s Scalar) Clone() interface{} { return s } + +func (s Scalar) IsScalar() bool { return true } +func (s Scalar) ScalarValue() interface{} { return s.v } + +func (s Scalar) Engine() Engine { return nil } +func (s Scalar) MemSize() uintptr { return 0 } +func (s Scalar) Uintptr() uintptr { return 0 } +func (s Scalar) Pointer() unsafe.Pointer { return nil } +func (s Scalar) IsNativelyAccessible() bool { return true } +func (s Scalar) IsManuallyManaged() bool { return false } + +func (s Scalar) Format(t fmt.State, c rune) {} // TODO +func (s Scalar) String() string { return fmt.Sprintf("%v", s) } + +func (s Scalar) WriteNpy(io.Writer) error { return errors.Errorf(methodNYI, "WriteNpy", "Scalar") } +func (s Scalar) ReadNpy(io.Reader) error { return errors.Errorf(methodNYI, "ReadNypy", "Scalar") } +func (s Scalar) GobEncode() ([]byte, error) { + return nil, errors.Errorf(methodNYI, "GobEncode", "Scalar") +} +func (s Scalar) GobDecode([]byte) error { return errors.Errorf(methodNYI, "GobDecode", "Scalar") } + +func (s Scalar) standardEngine() standardEngine { return StdEng{} } +func (s Scalar) hdr() *storage.Header { return nil } +func (s Scalar) arr() array { return array{} } +func (s Scalar) arrPtr() *array { return nil } From 097d44d562783026fedccaaab745114a7e6238c4 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 18 Mar 2020 10:35:08 +1100 Subject: [PATCH 030/154] Added RepeatReuse API. (#62) * Added RepeatReuse API. Right now any engine implementing Repeater is expected to also implement RepeatReuse. This imposition is strictly unnecessary, but also makes the number of possible combinations easier to handle for now. The reason for adding `RepeatReuse` is due to benchmarks done by @mattn who has found that the `Repeat` function in Gorgonia causes a lot of additional allocation. Given that gorgonia.org/gorgonia can actually determine ahead of time how much space to use, a RepeatReuse API was designed rapidly to allow taking advantage of that. We cannot just tack on a `FuncOpt` to `Repeat` as the variadic parameters are reserved for the number of repeats to be made. Thus a new function `RepeatReuse` is required. Originally the name was `RepeatWithReuse` but that is a rather long name. * Added more examples to Repeat, if only as a method to document to myself --- api_matop.go | 8 +++ defaultengine_matop_misc.go | 47 +++++++++++++++--- dense_svd_test.go | 2 +- engine.go | 1 + example_dense_matop_test.go | 86 ++++++++++++++++++++++++++++++++- example_extension_matop_test.go | 2 +- example_mapreduce_test.go | 6 +-- tensor.go | 3 ++ 8 files changed, 141 insertions(+), 14 deletions(-) diff --git a/api_matop.go b/api_matop.go index 8c687b2..2bc616c 100644 --- a/api_matop.go +++ b/api_matop.go @@ -13,6 +13,14 @@ func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { return nil, errors.New("Engine does not support Repeat") } +// RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given // ???? , but the results will still be valid. +func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) { + if r, ok := t.Engine().(Repeater); ok { + return r.RepeatReuse(t, reuse, axis, repeats...) + } + return nil, errors.New("Engine does not support Repeat") +} + // T safely transposes a Tensor. It returns a tensor that is not a view of the input tensor - rather, the data is all copied. func T(t Tensor, axes ...int) (retVal Tensor, err error) { switch tt := t.(type) { diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 15c5f2d..0c508f1 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -16,25 +16,56 @@ type fastcopier interface { func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { switch tt := t.(type) { case DenseTensor: - return e.denseRepeat(tt, axis, repeats) + newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) + if err != nil { + return nil, err + } + rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) + return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) default: return nil, errors.Errorf("NYI") } } -func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseTensor, err error) { - var newShape Shape - var size int - if newShape, repeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { - return nil, errors.Wrap(err, "Unable to get repeated shape") +// RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. +func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { + switch tt := t.(type) { + case DenseTensor: + newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) + if err != nil { + return nil, err + } + + rr, ok := reuse.(DenseTensor) + if !ok { + return nil, errors.Errorf("t is a DenseTensor but reuse is of %T", reuse) + } + if !reuse.Shape().Eq(newShape) { + return nil, errors.Errorf("Reuse shape is %v. Expected shape is %v", reuse.Shape(), newShape) + } + return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) + default: + return nil, errors.Errorf("NYI") } +} +func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { + if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { + return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") + } + newAxis = axis if axis == AllAxes { - axis = 0 + newAxis = 0 } - d := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) + return +} +func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, repeats []int) (retVal DenseTensor, err error) { + d, err := assertDense(reuse) + if err != nil { + return nil, errors.Wrapf(err, "Repeat reuse is not a *Dense") + } var outers int if t.IsScalar() { outers = 1 diff --git a/dense_svd_test.go b/dense_svd_test.go index 89c5306..282868b 100644 --- a/dense_svd_test.go +++ b/dense_svd_test.go @@ -104,7 +104,7 @@ func testSVD(T, T2, s, u, v *Dense, t string, i int) (err error) { return nil } -func Example_DenseSVD() { +func ExampleDense_SVD() { T := New( WithShape(4, 5), WithBacking([]float64{1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}), diff --git a/engine.go b/engine.go index 9e3ede7..ae508a8 100644 --- a/engine.go +++ b/engine.go @@ -89,6 +89,7 @@ type DenseStacker interface { // Repeater is any engine that can repeat values along the given axis. type Repeater interface { Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) + RepeatReuse(t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error) } // Diager is any engine that can return a tensor that only contains the diagonal values of the input diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 97a2cb8..08ab492 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -1,6 +1,8 @@ package tensor -import "fmt" +import ( + "fmt" +) func ExampleDense_Slice() { var T Tensor @@ -205,3 +207,85 @@ func ExampleDense_Vstack() { // Vstacking (4) with (1, 2): Tensor has to be at least 2 dimensions // Vstacking (1, 2) with (4): Tensor has to be at least 2 dimensions } + +func ExampleRepeatReuse() { + var T, T1 *Dense + T = New(WithBacking([]float64{1, 2, 3, 4}), WithShape(1, 4)) + T1 = New(Of(Float64), WithShape(3, 4)) + + var T2 Tensor + var err error + if T2, err = RepeatReuse(T, T1, 0, 3); err != nil { + fmt.Printf("Err %v", err) + } + fmt.Printf("RepeatReuse(T, T1):\n%v", T2) + fmt.Printf("T1 == T2: %t\n", T1 == T2) + + // But if your reuse is wrongly shaped, an error occurs + T1 = New(Of(Float64), WithShape(1, 4)) // too small + if _, err = RepeatReuse(T, T1, 0, 3); err != nil { + fmt.Printf("Expected Error: %v\n", err) + } + + // Output: + // RepeatReuse(T, T1): + // ⎡1 2 3 4⎤ + // ⎢1 2 3 4⎥ + // ⎣1 2 3 4⎦ + // T1 == T2: true + // Expected Error: Reuse shape is (1, 4). Expected shape is (3, 4) +} + +func ExampleRepeat_uncommonUses() { + T := New(WithBacking([]int{1, 2, 3, 4, 5, 6}), WithShape(2, 3)) + fmt.Printf("T:\n%v", T) + + fmt.Println("Axis 0 has 2 elements. So we will need to write the number of times each element is to be repeated") + fmt.Println("Here, Repeat(T, 0, 3, 2) results in this:") + T1, err := Repeat(T, 0, 3, 2) + if err != nil { + fmt.Printf("Err %v", err) + } + fmt.Printf("%v", T1) + fmt.Println("Observe the 0th element ([1 2 3]) has been repeated 3 times, and the 1st element ([4 5 6]) has been repeated twice") + fmt.Println("") + + fmt.Println("We can also repeat on Axis 1. Now along Axis 1 there are 3 elements: ([1 4], [2 5], [3 6])") + fmt.Println("So we have to specify how many times to repeat each element.") + fmt.Println("Repeat(T, 1, 2, 3, 2) yields the following result:") + T1, err = Repeat(T, 1, 2, 3, 2) + if err != nil { + fmt.Printf("Err %v", err) + } + fmt.Printf("%v", T1) + fmt.Println("Once again, observe that the 1st element ([2 5]) has been repeated 3 times, while the rest have been repeated twice") + /* + // TODO break this out to another example + T1, err = Repeat(T, AllAxes, 2, 3, 2, 2, 2, 2) + if err != nil { + fmt.Printf("Err %v", err) + } + fmt.Printf("%#v", T1) + */ + + // Output: + // T: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // Axis 0 has 2 elements. So we will need to write the number of times each element is to be repeated + // Here, Repeat(T, 0, 3, 2) results in this: + // ⎡1 2 3⎤ + // ⎢1 2 3⎥ + // ⎢1 2 3⎥ + // ⎢4 5 6⎥ + // ⎣4 5 6⎦ + // Observe the 0th element ([1 2 3]) has been repeated 3 times, and the 1st element ([4 5 6]) has been repeated twice + // + // We can also repeat on Axis 1. Now along Axis 1 there are 3 elements: ([1 4], [2 5], [3 6]) + // So we have to specify how many times to repeat each element. + // Repeat(T, 1, 2, 3, 2) yields the following result: + // ⎡1 1 2 2 2 3 3⎤ + // ⎣4 4 5 5 5 6 6⎦ + // Once again, observe that the 1st element ([2 5]) has been repeated 3 times, while the rest have been repeated twice + +} diff --git a/example_extension_matop_test.go b/example_extension_matop_test.go index 372c9a0..09f63a4 100644 --- a/example_extension_matop_test.go +++ b/example_extension_matop_test.go @@ -24,7 +24,7 @@ func (ss s) Start() int { return int(ss) } func (ss s) End() int { return int(ss) + 1 } func (ss s) Step() int { return 1 } -func Example_TransposeExtension() { +func ExampleTranspose_extension() { // For documentation if you're reading this on godoc: // // type LongStruct struct { diff --git a/example_mapreduce_test.go b/example_mapreduce_test.go index 27f4a6b..4f42a72 100644 --- a/example_mapreduce_test.go +++ b/example_mapreduce_test.go @@ -2,7 +2,7 @@ package tensor import "fmt" -func Example_Sum() { +func ExampleSum() { T := New(WithBacking([]float64{0, 1, 2, 3}), WithShape(2, 2)) fmt.Printf("T:\n%v\n", T) @@ -31,7 +31,7 @@ func Example_Sum() { // Summed along (1, 0): 6 } -func Example_Argmax() { +func ExampleArgmax() { T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2)) fmt.Printf("T:\n%v\n", T) @@ -49,7 +49,7 @@ func Example_Argmax() { // Argmax is *tensor.Dense of int } -func Example_Argmin() { +func ExampleArgmin() { T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2)) fmt.Printf("T:\n%v\n", T) diff --git a/tensor.go b/tensor.go index a04e425..b06066a 100644 --- a/tensor.go +++ b/tensor.go @@ -105,6 +105,9 @@ func assertDense(t Tensor) (*Dense, error) { if retVal, ok := t.(*Dense); ok { return retVal, nil } + if retVal, ok := t.(Densor); ok { + return retVal.Dense(), nil + } return nil, errors.Errorf("%T is not *Dense", t) } From c9019034597a18ab2650df18ff822b8fb2863945 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sun, 22 Mar 2020 10:29:04 +1100 Subject: [PATCH 031/154] delete LICENCE (#63) * delete LICENCE * Rename LICENSE to LICENCE --- LICENCE | 391 +++++++++++++++++++++++++++++--------------------------- LICENSE | 202 ----------------------------- 2 files changed, 202 insertions(+), 391 deletions(-) delete mode 100644 LICENSE diff --git a/LICENCE b/LICENCE index 7bec963..c7a1c7b 100644 --- a/LICENCE +++ b/LICENCE @@ -1,189 +1,202 @@ -The Gorgonia Licence - -Copyright (c) 2016 Xuanyi Chew - -Licensed under the Gorgonia License, Version 1.0 (the "License"); -you may not use this file except in compliance with the License. - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Significant Contribution" shall mean any Contribution that indicates a deep - understanding of the Work and/or its Derivatives thereof. - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. You are not permitted - to directly commercially profit from this Work unless You are also a - Significant Contributor, which is listed under the Contributors list. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. -END OF TERMS AND CONDITIONS + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Gorgonia Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE b/LICENSE deleted file mode 100644 index c7a1c7b..0000000 --- a/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2019 Gorgonia Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. From f10d2f09e8b90b9c5a7f9e9723c172888b6a3805 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Fri, 10 Apr 2020 12:33:58 +1000 Subject: [PATCH 032/154] Repeat reuse with optimizations (#65) * Added RepeatReuse API. Right now any engine implementing Repeater is expected to also implement RepeatReuse. This imposition is strictly unnecessary, but also makes the number of possible combinations easier to handle for now. The reason for adding `RepeatReuse` is due to benchmarks done by @mattn who has found that the `Repeat` function in Gorgonia causes a lot of additional allocation. Given that gorgonia.org/gorgonia can actually determine ahead of time how much space to use, a RepeatReuse API was designed rapidly to allow taking advantage of that. We cannot just tack on a `FuncOpt` to `Repeat` as the variadic parameters are reserved for the number of repeats to be made. Thus a new function `RepeatReuse` is required. Originally the name was `RepeatWithReuse` but that is a rather long name. * Added more examples to Repeat, if only as a method to document to myself * Added some performance optimization to array and Repeat * Ugh, a bit of a conflict cleanups (line end types) --- array.go | 70 +++- defaultengine_matop_misc.go | 691 ++++++++++++++++++------------------ dense.go | 2 +- 3 files changed, 414 insertions(+), 349 deletions(-) diff --git a/array.go b/array.go index 7e7116c..40995ec 100644 --- a/array.go +++ b/array.go @@ -9,6 +9,26 @@ import ( "gorgonia.org/tensor/internal/storage" ) +//go:notinheap +type rawdata []byte + +// array2 is a type that will not be allocated on the heap. This is useful for operational stuff - no unnecessary allocations required. + +//go:notinheap +type array2 struct { + storage.Header + t Dtype + v interface{} +} + +func (a array2) toarray() array { + return array{ + Header: a.Header, + t: a.t, + v: a.v, + } +} + // array is the underlying generic array. type array struct { storage.Header // the header - the Go representation (a slice) @@ -145,7 +165,7 @@ func (a *array) sliceInto(i, j int, res *array) { } // slice slices an array -func (a array) slice(start, end int) array { +func (a array) slice(start, end int) array2 { if end > a.L { panic("Index out of range") } @@ -169,7 +189,11 @@ func (a array) slice(start, end int) array { C: C, } - return makeArrayFromHeader(hdr, a.t) + return array2{ + Header: hdr, + t: a.t, + v: nil, + } } // swap swaps the elements i and j in the array @@ -271,7 +295,7 @@ func (a *array) rtype() reflect.Type { return a.t.Type } // malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory func malloc(t Dtype, length int) unsafe.Pointer { size := int(calcMemSize(t, length)) - s := make([]byte, size) + s := make(rawdata, size) return unsafe.Pointer(&s[0]) } @@ -342,11 +366,43 @@ func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, } } if e := src.Engine(); e != nil { - d := dst.arr().slice(dstart, dend) - s := src.arr().slice(sstart, send) - if err := e.Memcpy(&d, &s); err != nil { - panic(err) + darr := dst.arr() + sarr := src.arr() + d := darr.slice(dstart, dend) + s := sarr.slice(sstart, send) + + switch e.(type) { + case NonStdEngine: + da := d.toarray() + sa := s.toarray() + if err := e.Memcpy(&da, &sa); err != nil { + panic(err) + } + default: + // THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED. + // + // THE PURPOSE of this optimization is to make this perform better under + // default circumstances. + // + // The original code simply uses t.Engine().Memcpy(&dSlice, &tSlice). + // A variant can still be seen in the NonStdEngine case above. + // + // The `array.slice()` method has been optimized to return `array2`, which is a + // non-heap allocated type. + // a value of `array2` cannot have its address taken - e.g. + // var a array2 + // doSomething(&a) // ← this cannot be done + // + // We *could* make `array2` implement Memory. But then a lot of runtime.convT2I and + // runtime.convI2T would be called. Which defeats the purpose of making things fast. + // + // So instead, we check to see if the Engine uses standard allocation methods. + // Typically this means `StdEng`. + // + // If so, we directly use storage.Copy instead of using the engine + storage.Copy(d.t.Type, &d.Header, &s.Header) } + return d.Len() } return copyArraySliced(dst.arr(), dstart, dend, src.arr(), sstart, send) diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 0c508f1..b0fc6c1 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -1,341 +1,350 @@ -package tensor - -import ( - "github.com/pkg/errors" -) - -var ( - _ Diager = StdEng{} -) - -type fastcopier interface { - fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error -} - -// Repeat ... -func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { - switch tt := t.(type) { - case DenseTensor: - newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) - if err != nil { - return nil, err - } - rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) - return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) - default: - return nil, errors.Errorf("NYI") - } -} - -// RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. -func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { - switch tt := t.(type) { - case DenseTensor: - newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) - if err != nil { - return nil, err - } - - rr, ok := reuse.(DenseTensor) - if !ok { - return nil, errors.Errorf("t is a DenseTensor but reuse is of %T", reuse) - } - if !reuse.Shape().Eq(newShape) { - return nil, errors.Errorf("Reuse shape is %v. Expected shape is %v", reuse.Shape(), newShape) - } - return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) - default: - return nil, errors.Errorf("NYI") - } -} - -func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { - if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { - return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") - } - newAxis = axis - if axis == AllAxes { - newAxis = 0 - } - - return -} - -func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, repeats []int) (retVal DenseTensor, err error) { - d, err := assertDense(reuse) - if err != nil { - return nil, errors.Wrapf(err, "Repeat reuse is not a *Dense") - } - var outers int - if t.IsScalar() { - outers = 1 - } else { - outers = ProdInts(t.Shape()[0:axis]) - if outers == 0 { - outers = 1 - } - } - - var stride, newStride int - if newShape.IsVector() || t.IsVector() { - stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector - } else { - stride = t.ostrides()[axis] - } - - if newShape.IsVector() { - newStride = 1 - } else { - newStride = d.ostrides()[axis] - } - - var destStart, srcStart int - // fastCopy is not bypassing the copyDenseSliced method to populate the output tensor - var fastCopy bool - var fce fastcopier - // we need an engine for fastCopying... - e := t.Engine() - // e can never be nil. Error would have occurred elsewhere - var ok bool - if fce, ok = e.(fastcopier); ok { - fastCopy = true - } - - // In this case, let's not implement the fast copy to keep the code readable - if ms, ok := t.(MaskedTensor); ok && ms.IsMasked() { - fastCopy = false - } - - if fastCopy { - if err := fce.fastCopyDenseRepeat(t, d, outers, size, stride, newStride, repeats); err != nil { - return nil, err - } - return d, nil - } - - for i := 0; i < outers; i++ { - for j := 0; j < size; j++ { - var tmp int - tmp = repeats[j] - - for k := 0; k < tmp; k++ { - if srcStart >= t.len() || destStart+stride > d.len() { - break - } - copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len()) - destStart += newStride - } - srcStart += stride - } - } - return d, nil -} - -func (StdEng) fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error { - var destStart, srcStart int - for i := 0; i < outers; i++ { - for j := 0; j < size; j++ { - var tmp int - tmp = repeats[j] - var tSlice array - tSlice = t.arr().slice(srcStart, t.len()) - - for k := 0; k < tmp; k++ { - if srcStart >= t.len() || destStart+stride > d.len() { - break - } - dSlice := d.arr().slice(destStart, d.len()) - if err := t.Engine().Memcpy(&dSlice, &tSlice); err != nil { - return err - } - destStart += newStride - } - srcStart += stride - } - } - return nil -} - -// Concat tensors -func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { - switch tt := t.(type) { - case DenseTensor: - var denses []DenseTensor - if denses, err = tensorsToDenseTensors(others); err != nil { - return nil, errors.Wrap(err, "Concat failed") - } - return e.denseConcat(tt, axis, denses) - default: - return nil, errors.Errorf("NYI") - } -} - -func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) { - ss := make([]Shape, len(Ts)) - var err error - var isMasked bool - for i, T := range Ts { - ss[i] = T.Shape() - if mt, ok := T.(MaskedTensor); ok { - isMasked = isMasked || mt.IsMasked() - } - } - - var newShape Shape - if newShape, err = a.Shape().Concat(axis, ss...); err != nil { - return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") - } - - retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) - if isMasked { - retVal.makeMask() - } - - all := make([]DenseTensor, len(Ts)+1) - all[0] = a - copy(all[1:], Ts) - - // TODO: OPIMIZATION - // When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor) - // just flat copy - // - - // isOuter is true when the axis is the outermost axis - // isInner is true when the axis is the inner most axis - isOuter := axis == 0 - isInner := axis == (a.Shape().Dims() - 1) - - // special case - var start, end int - for _, T := range all { - end += T.Shape()[axis] - slices := make([]Slice, axis+1) - slices[axis] = makeRS(start, end) - - var v *Dense - if v, err = sliceDense(retVal, slices...); err != nil { - return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") - } - - switch { - case v.IsVector() && T.IsMatrix() && axis == 0: - v.reshape(v.shape[0], 1) - case T.IsRowVec() && axis == 0: - T.reshape(T.Shape()[1]) - case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv(): - copyArray(v.arrPtr(), T.arrPtr()) - if mt, ok := T.(MaskedTensor); ok { - copy(v.mask, mt.Mask()) - } - continue - default: - diff := retVal.Shape().Dims() - v.Shape().Dims() - if diff > 0 && isOuter { - newShape := make(Shape, v.Shape().Dims()+diff) - for i := 0; i < diff; i++ { - newShape[i] = 1 - } - copy(newShape[diff:], v.Shape()) - v.reshape(newShape...) - } else if diff > 0 && isInner { - newShape := v.Shape().Clone() - newStrides := v.strides - for i := 0; i < diff; i++ { - newShape = append(newShape, 1) - newStrides = append(newStrides, 1) - } - v.shape = newShape - v.strides = newStrides - } - } - - var vmask, Tmask []bool - vmask = v.mask - v.mask = nil - if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() { - Tmask = mt.Mask() - mt.SetMask(nil) - - } - - if err = assignArray(v, T); err != nil { - return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") - } - // if it's a masked tensor, we copy the mask as well - if Tmask != nil { - if vmask != nil { - if cap(vmask) < len(Tmask) { - vmask2 := make([]bool, len(Tmask)) - copy(vmask2, vmask) - vmask = vmask2 - } - copy(vmask, Tmask) - v.SetMask(vmask) - } - // mt.SetMask(Tmask) - } - - start = end - } - - return retVal, nil -} - -// Diag ... -func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { - a, ok := t.(DenseTensor) - if !ok { - return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") - } - - if a.Dims() != 2 { - err = errors.Errorf(dimMismatch, 2, a.Dims()) - return - } - - if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { - return nil, errors.Wrap(err, "Diagonal") - } - - rstride := a.Strides()[0] - cstride := a.Strides()[1] - - r := a.Shape()[0] - c := a.Shape()[1] - - m := MinInt(r, c) - stride := rstride + cstride - - b := a.Clone().(DenseTensor) - b.Zero() - - switch a.rtype().Size() { - case 1: - bdata := b.hdr().Uint8s() - adata := a.hdr().Uint8s() - for i := 0; i < m; i++ { - bdata[i] = adata[i*stride] - } - case 2: - bdata := b.hdr().Uint16s() - adata := a.hdr().Uint16s() - for i := 0; i < m; i++ { - bdata[i] = adata[i*stride] - } - case 4: - bdata := b.hdr().Uint32s() - adata := a.hdr().Uint32s() - for i := 0; i < m; i++ { - bdata[i] = adata[i*stride] - } - case 8: - bdata := b.hdr().Uint64s() - adata := a.hdr().Uint64s() - for i := 0; i < m; i++ { - bdata[i] = adata[i*stride] - } - default: - return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t) - } - return b, nil -} +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +var ( + _ Diager = StdEng{} +) + +type fastcopier interface { + fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error +} + +// Repeat ... +func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { + switch tt := t.(type) { + case DenseTensor: + newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) + if err != nil { + return nil, err + } + rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) + return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) + default: + return nil, errors.Errorf("NYI") + } +} + +// RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. +func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { + switch tt := t.(type) { + case DenseTensor: + newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) + if err != nil { + return nil, err + } + + rr, ok := reuse.(DenseTensor) + if !ok { + return nil, errors.Errorf("t is a DenseTensor but reuse is of %T", reuse) + } + if !reuse.Shape().Eq(newShape) { + return nil, errors.Errorf("Reuse shape is %v. Expected shape is %v", reuse.Shape(), newShape) + } + return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) + default: + return nil, errors.Errorf("NYI") + } +} + +func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { + if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { + return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") + } + newAxis = axis + if axis == AllAxes { + newAxis = 0 + } + + return +} + +func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, repeats []int) (retVal DenseTensor, err error) { + d, err := assertDense(reuse) + if err != nil { + return nil, errors.Wrapf(err, "Repeat reuse is not a *Dense") + } + var outers int + if t.IsScalar() { + outers = 1 + } else { + outers = ProdInts(t.Shape()[0:axis]) + if outers == 0 { + outers = 1 + } + } + + var stride, newStride int + if newShape.IsVector() || t.IsVector() { + stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector + } else { + stride = t.ostrides()[axis] + } + + if newShape.IsVector() { + newStride = 1 + } else { + newStride = d.ostrides()[axis] + } + + var destStart, srcStart int + // fastCopy is not bypassing the copyDenseSliced method to populate the output tensor + var fastCopy bool + var fce fastcopier + // we need an engine for fastCopying... + e := t.Engine() + // e can never be nil. Error would have occurred elsewhere + var ok bool + if fce, ok = e.(fastcopier); ok { + fastCopy = true + } + + // In this case, let's not implement the fast copy to keep the code readable + if ms, ok := t.(MaskedTensor); ok && ms.IsMasked() { + fastCopy = false + } + + // if d is not a fastcopier, then we also cannot use fast copy + if _, ok := d.Engine().(fastcopier); !ok { + fastCopy = false + } + + if fastCopy { + if err := fce.fastCopyDenseRepeat(t, d, outers, size, stride, newStride, repeats); err != nil { + return nil, err + } + return d, nil + } + + for i := 0; i < outers; i++ { + for j := 0; j < size; j++ { + var tmp int + tmp = repeats[j] + + for k := 0; k < tmp; k++ { + if srcStart >= t.len() || destStart+stride > d.len() { + break + } + copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len()) + destStart += newStride + } + srcStart += stride + } + } + return d, nil +} + +func (e StdEng) fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error { + var destStart, srcStart int + for i := 0; i < outers; i++ { + for j := 0; j < size; j++ { + var tmp int + tmp = repeats[j] + var tSlice array2 + tarr := t.arr() + tSlice = tarr.slice(srcStart, t.len()) + + for k := 0; k < tmp; k++ { + if srcStart >= t.len() || destStart+stride > d.len() { + break + } + arr := d.arr() + dSlice := arr.slice(destStart, d.len()) + + // THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED. + storage.Copy(dSlice.t.Type, &dSlice.Header, &tSlice.Header) + + destStart += newStride + } + srcStart += stride + } + } + return nil +} + +// Concat tensors +func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { + switch tt := t.(type) { + case DenseTensor: + var denses []DenseTensor + if denses, err = tensorsToDenseTensors(others); err != nil { + return nil, errors.Wrap(err, "Concat failed") + } + return e.denseConcat(tt, axis, denses) + default: + return nil, errors.Errorf("NYI") + } +} + +func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) { + ss := make([]Shape, len(Ts)) + var err error + var isMasked bool + for i, T := range Ts { + ss[i] = T.Shape() + if mt, ok := T.(MaskedTensor); ok { + isMasked = isMasked || mt.IsMasked() + } + } + + var newShape Shape + if newShape, err = a.Shape().Concat(axis, ss...); err != nil { + return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") + } + + retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) + if isMasked { + retVal.makeMask() + } + + all := make([]DenseTensor, len(Ts)+1) + all[0] = a + copy(all[1:], Ts) + + // TODO: OPIMIZATION + // When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor) + // just flat copy + // + + // isOuter is true when the axis is the outermost axis + // isInner is true when the axis is the inner most axis + isOuter := axis == 0 + isInner := axis == (a.Shape().Dims() - 1) + + // special case + var start, end int + for _, T := range all { + end += T.Shape()[axis] + slices := make([]Slice, axis+1) + slices[axis] = makeRS(start, end) + + var v *Dense + if v, err = sliceDense(retVal, slices...); err != nil { + return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") + } + + switch { + case v.IsVector() && T.IsMatrix() && axis == 0: + v.reshape(v.shape[0], 1) + case T.IsRowVec() && axis == 0: + T.reshape(T.Shape()[1]) + case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv(): + copyArray(v.arrPtr(), T.arrPtr()) + if mt, ok := T.(MaskedTensor); ok { + copy(v.mask, mt.Mask()) + } + continue + default: + diff := retVal.Shape().Dims() - v.Shape().Dims() + if diff > 0 && isOuter { + newShape := make(Shape, v.Shape().Dims()+diff) + for i := 0; i < diff; i++ { + newShape[i] = 1 + } + copy(newShape[diff:], v.Shape()) + v.reshape(newShape...) + } else if diff > 0 && isInner { + newShape := v.Shape().Clone() + newStrides := v.strides + for i := 0; i < diff; i++ { + newShape = append(newShape, 1) + newStrides = append(newStrides, 1) + } + v.shape = newShape + v.strides = newStrides + } + } + + var vmask, Tmask []bool + vmask = v.mask + v.mask = nil + if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() { + Tmask = mt.Mask() + mt.SetMask(nil) + + } + + if err = assignArray(v, T); err != nil { + return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") + } + // if it's a masked tensor, we copy the mask as well + if Tmask != nil { + if vmask != nil { + if cap(vmask) < len(Tmask) { + vmask2 := make([]bool, len(Tmask)) + copy(vmask2, vmask) + vmask = vmask2 + } + copy(vmask, Tmask) + v.SetMask(vmask) + } + // mt.SetMask(Tmask) + } + + start = end + } + + return retVal, nil +} + +// Diag ... +func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { + a, ok := t.(DenseTensor) + if !ok { + return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") + } + + if a.Dims() != 2 { + err = errors.Errorf(dimMismatch, 2, a.Dims()) + return + } + + if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { + return nil, errors.Wrap(err, "Diagonal") + } + + rstride := a.Strides()[0] + cstride := a.Strides()[1] + + r := a.Shape()[0] + c := a.Shape()[1] + + m := MinInt(r, c) + stride := rstride + cstride + + b := a.Clone().(DenseTensor) + b.Zero() + + switch a.rtype().Size() { + case 1: + bdata := b.hdr().Uint8s() + adata := a.hdr().Uint8s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 2: + bdata := b.hdr().Uint16s() + adata := a.hdr().Uint16s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 4: + bdata := b.hdr().Uint32s() + adata := a.hdr().Uint32s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 8: + bdata := b.hdr().Uint64s() + adata := a.hdr().Uint64s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + default: + return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t) + } + return b, nil +} diff --git a/dense.go b/dense.go index 09847ef..e80b2f5 100644 --- a/dense.go +++ b/dense.go @@ -605,7 +605,7 @@ func (t *Dense) SetMask(mask []bool) { } func (t *Dense) slice(start, end int) { - t.array = t.array.slice(start, end) + t.array = t.array.slice(start, end).toarray() } // RequiresIterator indicates if an iterator is required to read the data in *Dense in the correct fashion From ca49742127debcb8d70439e6bc422cf784e0dc18 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Fri, 10 Apr 2020 12:51:47 +1000 Subject: [PATCH 033/154] Gomod (#66) --- go.mod | 2 +- go.sum | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 70baaff..b93eb28 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.4.0 github.com/xtgo/set v1.0.0 // indirect - gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee + gonum.org/v1/gonum v0.7.0 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) diff --git a/go.sum b/go.sum index 850df1d..cd19944 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee h1:4pVWuAEGpaPZ7dPfd6aA8LyDNzMA2RKCxAS/XNCLZUM= -gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= +gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw= +gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= @@ -45,12 +45,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0= -gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8= gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= -gorgonia.org/vecf64 v0.7.0 h1:ZphOGJfnWlFfY7x8WAJAfO64IAtYqPPq9TEGem+ItZE= -gorgonia.org/vecf64 v0.7.0/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q= gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= From 7e357ab9632b35c8c8eb12a58d51214f4615144d Mon Sep 17 00:00:00 2001 From: Johannes Lauinger Date: Mon, 1 Jun 2020 20:23:27 +0200 Subject: [PATCH 034/154] fix possible memory confusion in unsafe slice cast (#68) --- internal/storage/header.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/storage/header.go b/internal/storage/header.go index 23eeb22..e34e254 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -87,12 +87,12 @@ func CopyIter(t reflect.Type, dst, src *Header, diter, siter Iterator) int { func AsByteSlice(a *Header, t reflect.Type) []byte { size := a.L * int(t.Size()) - hdr := reflect.SliceHeader{ - Data: uintptr(a.Ptr), - Len: size, - Cap: size, - } - return *(*[]byte)(unsafe.Pointer(&hdr)) + b := make([]byte, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + hdr.Data = uintptr(a.Ptr) + hdr.Cap = size + hdr.Len = size + return b } // Element gets the pointer of ith element From 49e78e0c4c16cc776f7a1645c23f5ff5151024f5 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 2 Jun 2020 04:54:20 +1000 Subject: [PATCH 035/154] V0.9.6 performance (#67) * Added a bit of optimizations for Repeat * There was a bug in prepDataSV and prepDataVS that led to unnecessary use of iterators. * Changed the notion of an array's equality - it doesn't need to check for whether the caps are the same. * Added a check in reuse to Reshape if not correct, but the storage is correct * Updated the file that generates array_getset.go * Incorporated #68 which fixes some possibly unsafe slice issue --- array_getset.go | 10 ++--- defaultengine_matop_misc.go | 77 ++++++++++++++++++++++++++++++++++--- defaultengine_prep.go | 14 +++++-- dense_format.go | 4 +- flags.go | 1 + genlib2/array_getset.go | 3 +- 6 files changed, 92 insertions(+), 17 deletions(-) diff --git a/array_getset.go b/array_getset.go index e016823..0054bc3 100644 --- a/array_getset.go +++ b/array_getset.go @@ -507,11 +507,11 @@ func (a array) Eq(other interface{}) bool { if oa.L != a.L { return false } - - if oa.C != a.C { - return false - } - + /* + if oa.C != a.C { + return false + } + */ // same exact thing if uintptr(oa.Ptr) == uintptr(a.Ptr) { return true diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index b0fc6c1..217a85e 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -137,22 +137,87 @@ func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, return d, nil } -func (e StdEng) fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error { +func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, stride, newStride int, repeats []int) error { + sarr := src.arr() + darr := dest.arr() + var destStart, srcStart int for i := 0; i < outers; i++ { + // faster shortcut for common case. + // + // Consider a case where: + // a := ⎡ 1 ⎤ + // ⎢ 2 ⎥ + // ⎢ 3 ⎥ + // ⎣ 4 ⎦ + // a has a shape of (4, 1). it is a *Dense. + // + // Now assume we want to repeat it on axis 1, 3 times. We want to repeat it into `b`, + // which is already allocated and zeroed, as shown below + // + // b := ⎡ 0 0 0 ⎤ + // ⎢ 0 0 0 ⎥ + // ⎢ 0 0 0 ⎥ + // ⎣ 0 0 0 ⎦ + // + // Now, both `a` and `b` have a stride of 1. + // + // The desired result is: + // b := ⎡ 1 1 1 ⎤ + // ⎢ 2 2 2 ⎥ + // ⎢ 3 3 3 ⎥ + // ⎣ 4 4 4 ⎦ + /// + // Observe that this is simply broadcasting (copying) a[0] (a scalar value) to the row b[0], and so on and so forth. + // This can be done without knowing the full type - we simply copy the bytes over. + if stride == 1 && newStride == 1 { + for sz := 0; sz < size; sz++ { + tmp := repeats[sz] + + // first we get the bounds of the src and the dest + // the srcStart and destStart are the indices assuming a flat array of []T + // we need to get the byte slice equivalent. + bSrcStart := srcStart * int(sarr.t.Size()) + bSrcEnd := (srcStart + stride) * int(sarr.t.Size()) + bDestStart := destStart * int(darr.t.Size()) + bDestEnd := (destStart + tmp) * int(darr.t.Size()) + + // then we get the data as a slice of raw bytes + sBS := storage.AsByteSlice(&sarr.Header, sarr.t.Type) + dBS := storage.AsByteSlice(&darr.Header, darr.t.Type) + + // recall that len(src) < len(dest) + // it's easier to understand if we define the ranges. + // Less prone to errors. + sRange := sBS[bSrcStart:bSrcEnd] + dRange := dBS[bDestStart:bDestEnd] + + // finally we copy things. + for i := 0; i < len(dRange); i += len(sRange) { + copy(dRange[i:], sRange) + } + srcStart += stride + destStart += tmp + } + + // we can straightaway broadcast + + continue + } + for j := 0; j < size; j++ { var tmp int tmp = repeats[j] var tSlice array2 - tarr := t.arr() - tSlice = tarr.slice(srcStart, t.len()) + + tSlice = sarr.slice(srcStart, src.len()) for k := 0; k < tmp; k++ { - if srcStart >= t.len() || destStart+stride > d.len() { + if srcStart >= src.len() || destStart+stride > dest.len() { break } - arr := d.arr() - dSlice := arr.slice(destStart, d.len()) + + dSlice := darr.slice(destStart, destStart+newStride) // THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED. storage.Copy(dSlice.t.Type, &dSlice.Header, &tSlice.Header) diff --git a/defaultengine_prep.go b/defaultengine_prep.go index c203253..fca9848 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -40,6 +40,14 @@ func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opt err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.len(), expShape.TotalSize()) return } + if !reuse.Shape().Eq(expShape) { + cloned := expShape.Clone() + if err = reuse.Reshape(cloned...); err != nil { + return + + } + ReturnInts([]int(cloned)) + } if !incr && reuse != nil { reuse.setDataOrder(o) @@ -119,7 +127,6 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea iit = reuse.Iterator() } } - // log.Printf("Use Itrer %v ", useIter) // swap if _, ok := a.(*CS); ok { @@ -146,7 +153,7 @@ func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse } useIter = a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) || - (reuse != nil && reuse.DataOrder().HasSameOrder(a.DataOrder())) + (reuse != nil && !reuse.DataOrder().HasSameOrder(a.DataOrder())) if useIter { ait = a.Iterator() if reuse != nil { @@ -170,7 +177,8 @@ func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse } useIter = b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) || - (reuse != nil && reuse.DataOrder().HasSameOrder(b.DataOrder())) + (reuse != nil && !reuse.DataOrder().HasSameOrder(b.DataOrder())) + if useIter { bit = b.Iterator() if reuse != nil { diff --git a/dense_format.go b/dense_format.go index 859477f..3f4f5e7 100644 --- a/dense_format.go +++ b/dense_format.go @@ -45,8 +45,8 @@ type fmtState struct { meta bool flat bool - ext bool - comp bool + ext bool // extended (i.e no elision) + comp bool // compact c rune // c is here mainly for struct packing reasons w, p int // width and precision diff --git a/flags.go b/flags.go index e8a00d0..547136e 100644 --- a/flags.go +++ b/flags.go @@ -54,6 +54,7 @@ func (f DataOrder) toggleColMajor() DataOrder { return f ^ (ColMajor) } func (f DataOrder) clearTransposed() DataOrder { return f &^ (Transposed) } +// HasSameOrder returns true if both data orders are the same (either both are ColMajor or both are RowMajor) func (f DataOrder) HasSameOrder(other DataOrder) bool { return (f.IsColMajor() && other.IsColMajor()) || (f.IsRowMajor() && other.IsRowMajor()) } diff --git a/genlib2/array_getset.go b/genlib2/array_getset.go index c21c767..73a686b 100644 --- a/genlib2/array_getset.go +++ b/genlib2/array_getset.go @@ -97,10 +97,11 @@ func (a array) Eq(other interface{}) bool { if oa.L != a.L { return false } - + /* if oa.C != a.C { return false } + */ // same exact thing if uintptr(oa.Ptr) == uintptr(a.Ptr){ From a2c1134aff153b795056cf4f5ac6d10c9e82f8b4 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 23 Jun 2020 13:17:56 +1000 Subject: [PATCH 036/154] Exported StandardEngine --- api_arith.go | 73 +++++++++++++++++++---- array.go | 4 ++ consopt.go | 2 +- defaultengine.go | 152 +++++++++++++++++++++++------------------------ dense.go | 8 ++- dense_linalg.go | 2 +- engine.go | 2 +- scalar.go | 14 ++++- sparse.go | 2 +- tensor.go | 9 ++- 10 files changed, 168 insertions(+), 100 deletions(-) diff --git a/api_arith.go b/api_arith.go index 4e86ffa..8ef78db 100644 --- a/api_arith.go +++ b/api_arith.go @@ -19,7 +19,7 @@ import ( // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var adder Adder - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: @@ -100,7 +100,7 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var suber Suber - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: @@ -181,7 +181,7 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var muler Muler - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: @@ -264,7 +264,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var diver Diver - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: @@ -345,7 +345,7 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var power Power - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: @@ -426,7 +426,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var moder Moder - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: @@ -570,13 +570,64 @@ func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) return } + ad, aok := a.(*Dense) + _, bok := b.(*Dense) + if aok && bok { + // fast path + return ad.MatMul(b, opts...) + } - switch at := a.(type) { - case *Dense: - bt := b.(*Dense) - return at.MatMul(bt, opts...) + // check that both are matrices + if !a.Shape().IsMatrix() || !b.Shape().IsMatrix() { + err = errors.Errorf("MatMul requires both operands to be matrices. Got t's shape: %v, other's shape: %v", a.Shape(), b.Shape()) + return } - panic("Unreachable") + + // checks that t is mxk matrix + var m, n, k int + m = a.Shape()[0] + k = a.Shape()[1] + n = b.Shape()[1] + + // check shape + if k != b.Shape()[0] { + err = errors.Errorf(shapeMismatch, a.Shape(), b.Shape()) + return + } + + // check whether retVal has the same size as the resulting matrix would be: mxn + expectedShape := Shape{m, n} + + // find an engine + aEng, aok := a.Engine().(MatMuler) + bEng, bok := b.Engine().(MatMuler) + mm := aEng + var eng Engine = a.Engine() + if !aok { + mm = bEng + eng = b.Engine() + if !bok { + return nil, errors.Errorf("Neither a or b have an engine that is a MatMuler. a: %T, b: %T", a.Engine(), b.Engine()) + } + } + + // parse function options, and get a preallocated value + var reuse *Dense + fo := ParseFuncOpts(opts...) + defer returnOpOpt(fo) + if reuse, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + err = errors.Wrapf(err, opFail, "MatMul") + return + } + + if reuse == nil { + reuse = recycledDense(a.Dtype(), expectedShape, WithEngine(eng)) + } + retVal = reuse + if err = mm.MatMul(a, b, retVal); err != nil { + return + } + return handleIncr(retVal.(*Dense), fo.Reuse(), fo.Incr(), expectedShape) } // MatVecMul performs matrix-vector multiplication between two Tensors. `a` is expected to be a matrix, and `b` is expected to be a vector diff --git a/array.go b/array.go index 40995ec..3960a09 100644 --- a/array.go +++ b/array.go @@ -12,6 +12,10 @@ import ( //go:notinheap type rawdata []byte +func (d rawdata) Uintptr() uintptr { return uintptr(unsafe.Pointer(&d[0])) } +func (d rawdata) MemSize() uintptr { return uintptr(len(d)) } +func (d rawdata) Pointer() unsafe.Pointer { return unsafe.Pointer(&d[0]) } + // array2 is a type that will not be allocated on the heap. This is useful for operational stuff - no unnecessary allocations required. //go:notinheap diff --git a/consopt.go b/consopt.go index 19d47ad..14c9695 100644 --- a/consopt.go +++ b/consopt.go @@ -175,7 +175,7 @@ func WithEngine(e Engine) ConsOpt { } tt.oe = nil - if oe, ok := e.(standardEngine); ok { + if oe, ok := e.(StandardEngine); ok { tt.oe = oe } case *CS: diff --git a/defaultengine.go b/defaultengine.go index bc92e8c..67449aa 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -1,76 +1,76 @@ -package tensor - -import ( - "unsafe" - - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/execution" -) - -// StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. -type StdEng struct { - execution.E -} - -// makeArray allocates a slice for the array -func (e StdEng) makeArray(arr *array, t Dtype, size int) { - memsize := calcMemSize(t, size) - s := make([]byte, memsize) - arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() -} - -func (e StdEng) AllocAccessible() bool { return true } -func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } -func (e StdEng) Free(mem Memory, size int64) error { return nil } -func (e StdEng) Memset(mem Memory, val interface{}) error { - if ms, ok := mem.(MemSetter); ok { - return ms.Memset(val) - } - return errors.Errorf("Cannot memset %v with StdEng", mem) -} - -func (e StdEng) Memclr(mem Memory) { - if z, ok := mem.(Zeroer); ok { - z.Zero() - } - return -} - -func (e StdEng) Memcpy(dst, src Memory) error { - switch dt := dst.(type) { - case *array: - switch st := src.(type) { - case *array: - copyArray(dt, st) - return nil - case arrayer: - copyArray(dt, st.arrPtr()) - return nil - } - case arrayer: - switch st := src.(type) { - case *array: - copyArray(dt.arrPtr(), st) - return nil - case arrayer: - copyArray(dt.arrPtr(), st.arrPtr()) - return nil - } - } - return errors.Errorf("Failed to copy %T %T", dst, src) -} - -func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } - -func (e StdEng) WorksWith(order DataOrder) bool { return true } - -func (e StdEng) checkAccessible(t Tensor) error { - if !t.IsNativelyAccessible() { - return errors.Errorf(inaccessibleData, t) - } - return nil -} +package tensor + +import ( + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/execution" +) + +// StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. +type StdEng struct { + execution.E +} + +// makeArray allocates a slice for the array +func (e StdEng) makeArray(arr *array, t Dtype, size int) { + memsize := calcMemSize(t, size) + s := make([]byte, memsize) + arr.t = t + arr.L = size + arr.C = size + arr.Ptr = unsafe.Pointer(&s[0]) + arr.fix() +} + +func (e StdEng) AllocAccessible() bool { return true } +func (e StdEng) Alloc(size int64) (Memory, error) { return make(rawdata, size), nil } +func (e StdEng) Free(mem Memory, size int64) error { return nil } +func (e StdEng) Memset(mem Memory, val interface{}) error { + if ms, ok := mem.(MemSetter); ok { + return ms.Memset(val) + } + return errors.Errorf("Cannot memset %v with StdEng", mem) +} + +func (e StdEng) Memclr(mem Memory) { + if z, ok := mem.(Zeroer); ok { + z.Zero() + } + return +} + +func (e StdEng) Memcpy(dst, src Memory) error { + switch dt := dst.(type) { + case *array: + switch st := src.(type) { + case *array: + copyArray(dt, st) + return nil + case arrayer: + copyArray(dt, st.arrPtr()) + return nil + } + case arrayer: + switch st := src.(type) { + case *array: + copyArray(dt.arrPtr(), st) + return nil + case arrayer: + copyArray(dt.arrPtr(), st.arrPtr()) + return nil + } + } + return errors.Errorf("Failed to copy %T %T", dst, src) +} + +func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } + +func (e StdEng) WorksWith(order DataOrder) bool { return true } + +func (e StdEng) checkAccessible(t Tensor) error { + if !t.IsNativelyAccessible() { + return errors.Errorf(inaccessibleData, t) + } + return nil +} diff --git a/dense.go b/dense.go index e80b2f5..09d72e3 100644 --- a/dense.go +++ b/dense.go @@ -19,7 +19,7 @@ type Dense struct { flag MemoryFlag e Engine // execution engine for the *Dense - oe standardEngine // optimized engine + oe StandardEngine // optimized engine // backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes old AP @@ -86,7 +86,9 @@ func (t *Dense) makeArray(size int) { case arrayMaker: te.makeArray(&t.array, t.t, size) return + case StandardEngine: default: + } mem, err := t.e.Alloc(calcMemSize(t.t, size)) @@ -273,7 +275,7 @@ func (t *Dense) fix() { t.e = StdEng{} } - if oe, ok := t.e.(standardEngine); ok { + if oe, ok := t.e.(StandardEngine); ok { t.oe = oe } @@ -622,4 +624,4 @@ func (t *Dense) RequiresIterator() bool { func (t *Dense) Iterator() Iterator { return IteratorFromDense(t) } -func (t *Dense) standardEngine() standardEngine { return t.oe } +func (t *Dense) standardEngine() StandardEngine { return t.oe } diff --git a/dense_linalg.go b/dense_linalg.go index c5362c5..0edf270 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -393,7 +393,7 @@ func handleReuse(reuse Tensor, expectedShape Shape) (retVal *Dense, err error) { } return } - return + return nil, nil } // handleIncr is the cleanup step for when there is an Tensor to increment. If the result tensor is the same as the reuse Tensor, the result tensor gets returned to the pool diff --git a/engine.go b/engine.go index ae508a8..f42e067 100644 --- a/engine.go +++ b/engine.go @@ -29,7 +29,7 @@ type Engine interface { WorksWith(order DataOrder) bool // WorksWith returns true if the data order can be directly worked with } -type standardEngine interface { +type StandardEngine interface { Engine Adder diff --git a/scalar.go b/scalar.go index 05cc219..3721cbc 100644 --- a/scalar.go +++ b/scalar.go @@ -13,6 +13,17 @@ import ( ) var _ Tensor = Scalar{} +var _ ScalarRep = Scalar{} +var _ ScalarRep = ScalarDense{} + +// ScalarDense wraps a *Dense to provide a typesafe alternative for a scalar to be represented in a *Dense. +type ScalarDense struct { + *Dense +} + +func (s ScalarDense) IsScalar() bool { return true } + +func (s ScalarDense) ScalarValue() interface{} { return s.Dense.Data() } // Scalar is a representation of a scalar value on the CPU. type Scalar struct{ v interface{} } @@ -71,9 +82,10 @@ func (s Scalar) String() string { return fmt.Sprintf("%v", s) } func (s Scalar) WriteNpy(io.Writer) error { return errors.Errorf(methodNYI, "WriteNpy", "Scalar") } func (s Scalar) ReadNpy(io.Reader) error { return errors.Errorf(methodNYI, "ReadNypy", "Scalar") } func (s Scalar) GobEncode() ([]byte, error) { + // TODO return nil, errors.Errorf(methodNYI, "GobEncode", "Scalar") } -func (s Scalar) GobDecode([]byte) error { return errors.Errorf(methodNYI, "GobDecode", "Scalar") } +func (s Scalar) GobDecode([]byte) error { return errors.Errorf(methodNYI, "GobDecode", "Scalar") } // TODO func (s Scalar) standardEngine() standardEngine { return StdEng{} } func (s Scalar) hdr() *storage.Header { return nil } diff --git a/sparse.go b/sparse.go index 3126843..dcddb2e 100644 --- a/sparse.go +++ b/sparse.go @@ -381,4 +381,4 @@ func (t *CS) IsManuallyManaged() bool { return t.f.manuallyManaged() } func (t *CS) arr() array { return t.array } func (t *CS) arrPtr() *array { return &t.array } -func (t *CS) standardEngine() standardEngine { return nil } +func (t *CS) standardEngine() StandardEngine { return nil } diff --git a/tensor.go b/tensor.go index 094adad..3a182ff 100644 --- a/tensor.go +++ b/tensor.go @@ -54,10 +54,6 @@ type Tensor interface { Eq Cloner - // type overloading methods - IsScalar() bool - ScalarValue() interface{} - // engine/memory related stuff // all Tensors should be able to be expressed of as a slab of memory // Note: the size of each element can be acquired by T.Dtype().Size() @@ -78,9 +74,12 @@ type Tensor interface { //gob.GobEncoder //gob.GobDecoder - standardEngine() standardEngine + standardEngine() StandardEngine headerer arrayer + + // TO BE DEPRECATED + ScalarRep } // New creates a new Dense Tensor. For sparse arrays use their relevant construction function From ed5851a3786394e0d0fce6637a1d76745ff392c8 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 30 Jun 2020 13:26:02 +1000 Subject: [PATCH 037/154] Bezineb5 fix interface scalar (#73) * Fixed an issue with the leftTensor parameter leading to a bug for scalars. * OK this works * Fixes #70 and #72. Though this patch is quite at the surface. I haven't really got the time to dig in why the behaviour is as such, given that I'm feeling quite ill atm. I will come back and fix it if need be in the future Co-authored-by: Benjamin <> Co-authored-by: wzzhu <> --- api_arith_test.go | 124 ++++++++++++++++++++++++++++++++ array_getset.go | 1 + defaultengine_arith.go | 78 +++++++++++++------- genlib2/agg2_body.go | 25 ++++--- internal/storage/header.go | 17 +++++ internal/storage/header_test.go | 43 +++++++++++ known_issues_test.go | 74 +++++++++++++++++++ 7 files changed, 330 insertions(+), 32 deletions(-) create mode 100644 internal/storage/header_test.go create mode 100644 known_issues_test.go diff --git a/api_arith_test.go b/api_arith_test.go index 00bf271..ca45f8f 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -207,6 +207,24 @@ func TestMulScalarScalar(t *testing.T) { t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) + + // Interface - tensor + ai := 2.0 + b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) + correct = []float64{6.0} + + res, err = Mul(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Commutativity + res, err = Mul(b, ai) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) } func TestDivScalarScalar(t *testing.T) { @@ -253,6 +271,28 @@ func TestDivScalarScalar(t *testing.T) { t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 3.0 + + res, err = Div(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 3.0 + + res, err = Div(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) } func TestAddScalarScalar(t *testing.T) { @@ -309,6 +349,24 @@ func TestAddScalarScalar(t *testing.T) { t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 2.0 + b = New(WithBacking([]float64{3})) + correct = 5.0 + + res, err = Add(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, ai) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) } func TestSubScalarScalar(t *testing.T) { @@ -355,6 +413,28 @@ func TestSubScalarScalar(t *testing.T) { t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 4.0 + + res, err = Sub(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 4.0 + + res, err = Sub(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) } func TestModScalarScalar(t *testing.T) { @@ -401,6 +481,28 @@ func TestModScalarScalar(t *testing.T) { t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 5.0 + b = New(WithBacking([]float64{2})) + correct = 1.0 + + res, err = Mod(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{5})) + bi := 2.0 + correct = 1.0 + + res, err = Mod(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) } func TestPowScalarScalar(t *testing.T) { @@ -447,4 +549,26 @@ func TestPowScalarScalar(t *testing.T) { t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 36.0 + + res, err = Pow(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 36.0 + + res, err = Pow(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) } diff --git a/array_getset.go b/array_getset.go index 0054bc3..1f71afd 100644 --- a/array_getset.go +++ b/array_getset.go @@ -512,6 +512,7 @@ func (a array) Eq(other interface{}) bool { return false } */ + // same exact thing if uintptr(oa.Ptr) == uintptr(a.Ptr) { return true diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 3017aaa..5779897 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -481,17 +481,22 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Add(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Add(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Add(typ, retVal.hdr(), dataB) - } else { - err = e.E.Add(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Add(typ, retVal.hdr(), dataB) } returnHeader(scalarHeader) return @@ -569,17 +574,22 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Sub(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Sub(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Sub(typ, retVal.hdr(), dataB) - } else { - err = e.E.Sub(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Sub(typ, retVal.hdr(), dataB) } returnHeader(scalarHeader) return @@ -657,17 +667,22 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Mul(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Mul(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Mul(typ, retVal.hdr(), dataB) - } else { - err = e.E.Mul(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Mul(typ, retVal.hdr(), dataB) } returnHeader(scalarHeader) return @@ -745,17 +760,22 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Div(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Div(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Div(typ, retVal.hdr(), dataB) - } else { - err = e.E.Div(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Div(typ, retVal.hdr(), dataB) } returnHeader(scalarHeader) return @@ -833,17 +853,22 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Pow(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Pow(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Pow(typ, retVal.hdr(), dataB) - } else { - err = e.E.Pow(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Pow(typ, retVal.hdr(), dataB) } returnHeader(scalarHeader) return @@ -921,17 +946,22 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.Mod(typ, dataA, dataReuse) + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } retVal = reuse case !safe: err = e.E.Mod(typ, dataA, dataB) + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } retVal = a default: retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.Mod(typ, retVal.hdr(), dataB) - } else { - err = e.E.Mod(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.Mod(typ, retVal.hdr(), dataB) } returnHeader(scalarHeader) return diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index cef73e2..29c73ee 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -16,7 +16,7 @@ const cmpPrepRaw = `var safe, same bool const arithPrepRaw = `var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") - } + } ` const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); err != nil { @@ -88,7 +88,7 @@ const agg2BodyRaw = `if useIter { case incr: err = e.E.{{.Name}}IterIncr(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse - {{if .VV -}} + {{if .VV -}} case toReuse: storage.CopyIter(typ,dataReuse, dataA, iit, ait) ait.Reset() @@ -149,10 +149,20 @@ const agg2BodyRaw = `if useIter { case toReuse && !leftTensor: storage.Copy(typ, dataReuse, dataB) err = e.E.{{.Name}}(typ, dataA, dataReuse) + {{if not .VV -}} + if t.Shape().IsScalarEquiv() { + storage.Copy(typ, dataReuse, dataA) + } + {{end -}} retVal = reuse {{end -}} case !safe: err = e.E.{{.Name}}(typ, dataA, dataB) + {{if not .VV -}} + if t.Shape().IsScalarEquiv() && !leftTensor { + storage.Copy(typ, dataB, dataA) + } + {{end -}} retVal = a default: {{if .VV -}} @@ -164,11 +174,10 @@ const agg2BodyRaw = `if useIter { err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) {{else -}} retVal = a.Clone().(Tensor) - if leftTensor { - err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) - } else { - err = e.E.{{.Name}}(typ, dataA, retVal.hdr()) + if !leftTensor { + storage.Fill(typ, retVal.hdr(), dataA) } + err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) {{end -}} } {{if not .VV -}}returnHeader(scalarHeader){{end}} @@ -195,7 +204,7 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created reuse = NewDense(Bool, a.Shape().Clone(), WithEngine(e)) dataReuse = reuse.hdr() if useIter{ - iit = IteratorFromDense(reuse) + iit = IteratorFromDense(reuse) } } @@ -247,7 +256,7 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created err = e.E.{{.Inv}}Same(typ, dataReuse, dataA) retVal = reuse return - } + } } {{end -}} diff --git a/internal/storage/header.go b/internal/storage/header.go index e34e254..249f2fc 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -56,6 +56,23 @@ func CopySliced(t reflect.Type, dst *Header, dstart, dend int, src *Header, ssta return copied / size } +func Fill(t reflect.Type, dst, src *Header) int { + dstBA := AsByteSlice(dst, t) + srcBA := AsByteSlice(src, t) + size := int(t.Size()) + lenSrc := len(srcBA) + + dstart := 0 + for { + copied := copy(dstBA[dstart:], srcBA) + dstart += copied + if copied < lenSrc { + break + } + } + return dstart / size +} + func CopyIter(t reflect.Type, dst, src *Header, diter, siter Iterator) int { dstBA := AsByteSlice(dst, t) srcBA := AsByteSlice(src, t) diff --git a/internal/storage/header_test.go b/internal/storage/header_test.go new file mode 100644 index 0000000..c59fe58 --- /dev/null +++ b/internal/storage/header_test.go @@ -0,0 +1,43 @@ +package storage + +import ( + "reflect" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestFill(t *testing.T) { + // A longer than B + a := headerFromSlice([]int{0, 1, 2, 3, 4}) + b := headerFromSlice([]int{10, 11}) + copied := Fill(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, copied, 5) + assert.Equal(t, a.Ints(), []int{10, 11, 10, 11, 10}) + + // B longer than A + a = headerFromSlice([]int{10, 11}) + b = headerFromSlice([]int{0, 1, 2, 3, 4}) + copied = Fill(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, copied, 2) + assert.Equal(t, a.Ints(), []int{0, 1}) +} + +func headerFromSlice(x interface{}) Header { + xT := reflect.TypeOf(x) + if xT.Kind() != reflect.Slice { + panic("Expected a slice") + } + + xV := reflect.ValueOf(x) + uptr := unsafe.Pointer(xV.Pointer()) + + return Header{ + Ptr: uptr, + L: xV.Len(), + C: xV.Cap(), + } +} diff --git a/known_issues_test.go b/known_issues_test.go new file mode 100644 index 0000000..20d8717 --- /dev/null +++ b/known_issues_test.go @@ -0,0 +1,74 @@ +package tensor + +import ( + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" +) + +func TestIssue70(t *testing.T) { + a := 2.0 + b := NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) + var correct interface{} = []float64{6.0} + + res, err := Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + t.Logf("a %v b %v, res %v", a, b, res) +} + +func TestIssue72(t *testing.T) { + a := New(FromScalar(3.14)) + b := 0.0 + + bsa, err := Sub(b, a) + if err != nil { + t.Fatal(err) + } + t.Logf("%v", bsa) + ret, err := Sub(b, bsa, UseUnsafe()) + if err != nil { + t.Fatal(err) + } + t.Logf("%v %v", ret, bsa) + + invReuseScalar := func(q *Dense) bool { + a := q.Clone().(*Dense) + //if !a.Shape().IsScalarEquiv() { + // return true + //} + b := identityVal(0, q.t) + reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, numberTypes, unsignedTypes) + _, ok := q.Engine().(Suber) + we = we || !ok + //log.Printf("b-a(r) | b:%v, a %v, r %v", b, a, reuse) + + ret, err := Sub(b, a, WithReuse(reuse)) + if err, retEarly := qcErrCheck(t, "SubSV", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + //log.Printf("b-a(r) | b:%v, a %v, r %v, ret %v", b, a, reuse, ret) + ret, err = Sub(b, ret, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(invReuseScalar, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) + } +} From 6b9a52fce98234c6fa2a30b5b522ddfb1b12c104 Mon Sep 17 00:00:00 2001 From: Johannes Lauinger Date: Tue, 30 Jun 2020 05:37:54 +0200 Subject: [PATCH 038/154] fix 48 more possible memory confusion bugs (#72) --- native/iterator_native.go | 352 +++++++++++++++++-------------------- native/iterator_native2.go | 192 ++++++++++---------- 2 files changed, 256 insertions(+), 288 deletions(-) diff --git a/native/iterator_native.go b/native/iterator_native.go index a360aeb..6cf492c 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -59,12 +59,11 @@ func MatrixB(t *Dense) (retVal [][]bool, err error) { retVal = make([][]bool, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]bool)(unsafe.Pointer(hdr)) + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -90,12 +89,11 @@ func Tensor3B(t *Dense) (retVal [][][]bool, err error) { retVal[i] = make([][]bool, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]bool)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -129,12 +127,11 @@ func MatrixI(t *Dense) (retVal [][]int, err error) { retVal = make([][]int, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]int)(unsafe.Pointer(hdr)) + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -160,12 +157,11 @@ func Tensor3I(t *Dense) (retVal [][][]int, err error) { retVal[i] = make([][]int, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]int)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -199,12 +195,11 @@ func MatrixI8(t *Dense) (retVal [][]int8, err error) { retVal = make([][]int8, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]int8)(unsafe.Pointer(hdr)) + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -230,12 +225,11 @@ func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { retVal[i] = make([][]int8, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]int8)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -269,12 +263,11 @@ func MatrixI16(t *Dense) (retVal [][]int16, err error) { retVal = make([][]int16, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]int16)(unsafe.Pointer(hdr)) + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -300,12 +293,11 @@ func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { retVal[i] = make([][]int16, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]int16)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -339,12 +331,11 @@ func MatrixI32(t *Dense) (retVal [][]int32, err error) { retVal = make([][]int32, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]int32)(unsafe.Pointer(hdr)) + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -370,12 +361,11 @@ func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { retVal[i] = make([][]int32, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]int32)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -409,12 +399,11 @@ func MatrixI64(t *Dense) (retVal [][]int64, err error) { retVal = make([][]int64, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]int64)(unsafe.Pointer(hdr)) + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -440,12 +429,11 @@ func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { retVal[i] = make([][]int64, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]int64)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -479,12 +467,11 @@ func MatrixU(t *Dense) (retVal [][]uint, err error) { retVal = make([][]uint, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]uint)(unsafe.Pointer(hdr)) + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -510,12 +497,11 @@ func Tensor3U(t *Dense) (retVal [][][]uint, err error) { retVal[i] = make([][]uint, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]uint)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -549,12 +535,11 @@ func MatrixU8(t *Dense) (retVal [][]uint8, err error) { retVal = make([][]uint8, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]uint8)(unsafe.Pointer(hdr)) + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -580,12 +565,11 @@ func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { retVal[i] = make([][]uint8, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]uint8)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -619,12 +603,11 @@ func MatrixU16(t *Dense) (retVal [][]uint16, err error) { retVal = make([][]uint16, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]uint16)(unsafe.Pointer(hdr)) + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -650,12 +633,11 @@ func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { retVal[i] = make([][]uint16, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]uint16)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -689,12 +671,11 @@ func MatrixU32(t *Dense) (retVal [][]uint32, err error) { retVal = make([][]uint32, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]uint32)(unsafe.Pointer(hdr)) + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -720,12 +701,11 @@ func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { retVal[i] = make([][]uint32, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]uint32)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -759,12 +739,11 @@ func MatrixU64(t *Dense) (retVal [][]uint64, err error) { retVal = make([][]uint64, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]uint64)(unsafe.Pointer(hdr)) + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -790,12 +769,11 @@ func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { retVal[i] = make([][]uint64, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]uint64)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -829,12 +807,11 @@ func MatrixF32(t *Dense) (retVal [][]float32, err error) { retVal = make([][]float32, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]float32)(unsafe.Pointer(hdr)) + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -860,12 +837,11 @@ func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { retVal[i] = make([][]float32, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]float32)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -899,12 +875,11 @@ func MatrixF64(t *Dense) (retVal [][]float64, err error) { retVal = make([][]float64, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]float64)(unsafe.Pointer(hdr)) + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -930,12 +905,11 @@ func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { retVal[i] = make([][]float64, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]float64)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -969,12 +943,11 @@ func MatrixC64(t *Dense) (retVal [][]complex64, err error) { retVal = make([][]complex64, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]complex64)(unsafe.Pointer(hdr)) + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -1000,12 +973,11 @@ func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { retVal[i] = make([][]complex64, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]complex64)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -1039,12 +1011,11 @@ func MatrixC128(t *Dense) (retVal [][]complex128, err error) { retVal = make([][]complex128, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]complex128)(unsafe.Pointer(hdr)) + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -1070,12 +1041,11 @@ func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { retVal[i] = make([][]complex128, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]complex128)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return @@ -1109,12 +1079,11 @@ func MatrixStr(t *Dense) (retVal [][]string, err error) { retVal = make([][]string, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]string)(unsafe.Pointer(hdr)) + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } @@ -1140,12 +1109,11 @@ func Tensor3Str(t *Dense) (retVal [][][]string, err error) { retVal[i] = make([][]string, rows) for j := range retVal[i] { start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]string)(unsafe.Pointer(hdr)) + retVal[i][j] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return diff --git a/native/iterator_native2.go b/native/iterator_native2.go index 85045ce..9a0ae34 100644 --- a/native/iterator_native2.go +++ b/native/iterator_native2.go @@ -50,12 +50,12 @@ func SelectB(t *Dense, axis int) (retVal [][]bool, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]bool, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]bool)(unsafe.Pointer(hdr))) + s := make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -88,12 +88,12 @@ func SelectI(t *Dense, axis int) (retVal [][]int, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]int, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]int)(unsafe.Pointer(hdr))) + s := make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -126,12 +126,12 @@ func SelectI8(t *Dense, axis int) (retVal [][]int8, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]int8, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]int8)(unsafe.Pointer(hdr))) + s := make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -164,12 +164,12 @@ func SelectI16(t *Dense, axis int) (retVal [][]int16, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]int16, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]int16)(unsafe.Pointer(hdr))) + s := make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -202,12 +202,12 @@ func SelectI32(t *Dense, axis int) (retVal [][]int32, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]int32, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]int32)(unsafe.Pointer(hdr))) + s := make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -240,12 +240,12 @@ func SelectI64(t *Dense, axis int) (retVal [][]int64, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]int64, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]int64)(unsafe.Pointer(hdr))) + s := make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -278,12 +278,12 @@ func SelectU(t *Dense, axis int) (retVal [][]uint, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]uint, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]uint)(unsafe.Pointer(hdr))) + s := make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -316,12 +316,12 @@ func SelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]uint8, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]uint8)(unsafe.Pointer(hdr))) + s := make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -354,12 +354,12 @@ func SelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]uint16, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]uint16)(unsafe.Pointer(hdr))) + s := make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -392,12 +392,12 @@ func SelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]uint32, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]uint32)(unsafe.Pointer(hdr))) + s := make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -430,12 +430,12 @@ func SelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]uint64, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]uint64)(unsafe.Pointer(hdr))) + s := make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -468,12 +468,12 @@ func SelectF32(t *Dense, axis int) (retVal [][]float32, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]float32, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]float32)(unsafe.Pointer(hdr))) + s := make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -506,12 +506,12 @@ func SelectF64(t *Dense, axis int) (retVal [][]float64, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]float64, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]float64)(unsafe.Pointer(hdr))) + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -544,12 +544,12 @@ func SelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]complex64, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]complex64)(unsafe.Pointer(hdr))) + s := make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -582,12 +582,12 @@ func SelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]complex128, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]complex128)(unsafe.Pointer(hdr))) + s := make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil @@ -620,12 +620,12 @@ func SelectStr(t *Dense, axis int) (retVal [][]string, err error) { upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]string, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]string)(unsafe.Pointer(hdr))) + s := make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Cap = stride + hdr.Len = stride + retVal = append(retVal, s) r++ } return retVal, nil From ef379c5f3c40982d1d236a997b26a22e039b36e8 Mon Sep 17 00:00:00 2001 From: cpllbstr <44461532+cpllbstr@users.noreply.github.com> Date: Wed, 8 Jul 2020 11:12:11 +0300 Subject: [PATCH 039/154] slice with step fix + tests (#74) --- ap.go | 5 ++++- dense_matop_test.go | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/ap.go b/ap.go index d9dac59..22c7d9e 100644 --- a/ap.go +++ b/ap.go @@ -246,8 +246,11 @@ func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err er // a slice where start == end is [] ndStart = ndStart + start*stride ndEnd = ndEnd - (size-end)*stride + if step > 0 { - newShape[i] = (end - start) / step + if newShape[i] = (end - start) / step; (end-start)%step > 0 && i > 0 { + newShape[i]++ + } newStrides[i] = stride * step //fix diff --git a/dense_matop_test.go b/dense_matop_test.go index 8052b2b..5e5486d 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -558,6 +558,25 @@ var denseSliceTests = []struct { {"A[:, 0]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, ss(0)}, Shape{4, 1}, []int{5}, Range(Float64, 0, 16)}, {"A[:, 1:5]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5)}, Shape{4, 4}, []int{5, 1}, Range(Float64, 1, 20)}, {"A[:, 1:5:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{4, 2}, []int{5, 2}, Range(Float64, 1, 20)}, + + // 3tensor with leading and trailing 1s + + {"3T1[0]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{ss(0)}, Shape{9, 1}, []int{1, 1}, Range(Float64, 0, 9)}, + {"3T1[nil, 0:2]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2, 1}, []int{9, 1, 1}, Range(Float64, 0, 2)}, + {"3T1[nil, 0:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 0, 5)}, + {"3T1[nil, 1:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 1, 5)}, + {"3T1[nil, 1:9:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 9, 3)}, Shape{1, 3, 1}, []int{9, 3, 1}, Range(Float64, 1, 9)}, + + // 3tensor + {"3T[0]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(0)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 0, 18)}, + {"3T[1]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 18, 36)}, + {"3T[1, 2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), ss(2)}, Shape{2}, []int{1}, Range(Float64, 22, 24)}, + {"3T[1, 2:4]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 4)}, Shape{2, 2}, []int{2, 1}, Range(Float64, 22, 26)}, + {"3T[1, 2:8:2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 2)}, Shape{3, 2}, []int{4, 1}, Range(Float64, 22, 34)}, + {"3T[1, 2:8:3]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 3)}, Shape{2, 2}, []int{6, 1}, Range(Float64, 22, 34)}, + {"3T[1, 2:9:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2)}, Shape{4, 7}, []int{14, 1}, Range(Float64, 77, 126)}, + {"3T[1, 2:9:2, 1]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), ss(1)}, Shape{4}, []int{14}, Range(Float64, 78, 121)}, // should this be a colvec? + {"3T[1, 2:9:2, 1:4:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), makeRS(1, 4, 2)}, Shape{4, 2}, []int{14, 2}, Range(Float64, 78, 123)}, } func TestDense_Slice(t *testing.T) { From cd80d2fbd4419015361ec697529d653317a22bcc Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 20 Jul 2020 15:41:03 -0400 Subject: [PATCH 040/154] add conversion from arrow to tensor --- dense_arrow.go | 35 +++++++++++++++++++++++++++++++++++ example_apply_test.go | 25 +++++++++++++++++++++++++ go.mod | 1 + go.sum | 6 ++++++ 4 files changed, 67 insertions(+) create mode 100644 dense_arrow.go diff --git a/dense_arrow.go b/dense_arrow.go new file mode 100644 index 0000000..98f5abf --- /dev/null +++ b/dense_arrow.go @@ -0,0 +1,35 @@ +// Code generated by genlib2. DO NOT EDIT. + +package tensor + +import ( + "fmt" + + "github.com/apache/arrow/go/arrow" + arrowArray "github.com/apache/arrow/go/arrow/array" +) + +// FromArrowArray converts an "arrow/array".Interface into a Tensor. +func FromArrowArray(a arrowArray.Interface) *Dense { + a.Retain() + + r := a.Len() + + // TODO(poopoothegorilla): instead of creating bool ValidMask maybe + // bitmapBytes can be used from arrow API + mask := make([]bool, r) + for i := 0; i < r; i++ { + mask[i] = a.IsNull(i) + } + + switch a.DataType() { + case arrow.PrimitiveTypes.Float64: + backing := a.(*arrowArray.Float64).Float64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} diff --git a/example_apply_test.go b/example_apply_test.go index 1e11641..c84f4c7 100644 --- a/example_apply_test.go +++ b/example_apply_test.go @@ -3,6 +3,8 @@ package tensor_test import ( "fmt" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" "gorgonia.org/tensor" ) @@ -35,3 +37,26 @@ func ExampleDense_Apply() { // ⎡ 1 8⎤ // ⎣27 64⎦ } + +func ExampleDense_Arrow() { + pool := memory.NewGoAllocator() + + b := array.NewFloat64Builder(pool) + defer b.Release() + + b.AppendValues( + []float64{1, 2, 3, -1, 4, 5}, + []bool{true, true, true, false, true, true}, + ) + + arr := b.NewFloat64Array() + defer arr.Release() + fmt.Printf("arrow array = %v\n", arr) + + a := tensor.FromArrowArray(arr) + fmt.Printf("tensor = %v\n", a) + + // Output: + // arrow array = [1 2 3 (null) 4 5] + // tensor = C[ 1 2 3 -- 4 5] +} diff --git a/go.mod b/go.mod index b93eb28..1bd9f13 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module gorgonia.org/tensor go 1.13 require ( + github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.4 github.com/gogo/protobuf v1.3.0 diff --git a/go.sum b/go.sum index cd19944..412a6f2 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,7 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/apache/arrow v0.0.0-20200720164908-23b19f65e1eb h1:/guPTo4KRiOQnB4UX0Sn9kk5k7kCC00eSKsoykKc0tU= +github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb h1:vBEPOeLNZ2RUgG/e+G2tOIucgCojRKRPorB3STXC+xw= +github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= @@ -23,6 +26,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= @@ -35,6 +39,8 @@ golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86h golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw= gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM= From 0490bc4d46a93f8d000a32601f036e9b0b6d44f9 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 20 Jul 2020 15:49:06 -0400 Subject: [PATCH 041/154] add defer release for arrow array --- dense_arrow.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dense_arrow.go b/dense_arrow.go index 98f5abf..fc2a131 100644 --- a/dense_arrow.go +++ b/dense_arrow.go @@ -12,6 +12,7 @@ import ( // FromArrowArray converts an "arrow/array".Interface into a Tensor. func FromArrowArray(a arrowArray.Interface) *Dense { a.Retain() + defer a.Release() r := a.Len() From de0920c273967a060e2db3d990c553601995fbe6 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 20 Jul 2020 21:24:38 -0400 Subject: [PATCH 042/154] move arrow code to generators --- example_dense_arrow_test.go | 32 ++++++++++++++++++++++++ genlib2/declarations.go | 22 ++++++++++++++++ dense_arrow.go => genlib2/dense_arrow.go | 24 +++++++++--------- genlib2/dense_compat.go | 12 +++++++-- 4 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 example_dense_arrow_test.go rename dense_arrow.go => genlib2/dense_arrow.go (57%) diff --git a/example_dense_arrow_test.go b/example_dense_arrow_test.go new file mode 100644 index 0000000..201a3d6 --- /dev/null +++ b/example_dense_arrow_test.go @@ -0,0 +1,32 @@ +package tensor_test + +import ( + "fmt" + + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" + "gorgonia.org/tensor" +) + +func ExampleDense_Arrow() { + pool := memory.NewGoAllocator() + + b := array.NewFloat64Builder(pool) + defer b.Release() + + b.AppendValues( + []float64{1, 2, 3, -1, 4, 5}, + []bool{true, true, true, false, true, true}, + ) + + arr := b.NewFloat64Array() + defer arr.Release() + fmt.Printf("arrow array = %v\n", arr) + + a := tensor.FromArrowArray(arr) + fmt.Printf("tensor = %v\n", a) + + // Output: + // arrow array = [1 2 3 (null) 4 5] + // tensor = C[ 1 2 3 -- 4 5] +} diff --git a/genlib2/declarations.go b/genlib2/declarations.go index 970f4ae..7bcd6bc 100644 --- a/genlib2/declarations.go +++ b/genlib2/declarations.go @@ -120,6 +120,27 @@ var stdTypes = [...]string{ "UnsafePointer", } +var arrowBinaryTypes = []string{ + "String", +} + +var arrowFixedWidthTypes = []string{ + "Boolean", +} + +var arrowPrimitiveTypes = []string{ + "Int8", + "Int16", + "Int32", + "Int64", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Float32", + "Float64", +} + var parameterizedKinds = [...]reflect.Kind{ reflect.Array, reflect.Chan, @@ -130,6 +151,7 @@ var parameterizedKinds = [...]reflect.Kind{ reflect.Slice, reflect.Struct, } + var number = [...]reflect.Kind{ reflect.Int, reflect.Int8, diff --git a/dense_arrow.go b/genlib2/dense_arrow.go similarity index 57% rename from dense_arrow.go rename to genlib2/dense_arrow.go index fc2a131..b3e63ba 100644 --- a/dense_arrow.go +++ b/genlib2/dense_arrow.go @@ -1,15 +1,12 @@ -// Code generated by genlib2. DO NOT EDIT. +package main -package tensor - -import ( - "fmt" - - "github.com/apache/arrow/go/arrow" - arrowArray "github.com/apache/arrow/go/arrow/array" -) +type ArrowData struct { + BinaryTypes []string + FixedWidthTypes []string + PrimitiveTypes []string +} -// FromArrowArray converts an "arrow/array".Interface into a Tensor. +const compatArrowRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. func FromArrowArray(a arrowArray.Interface) *Dense { a.Retain() defer a.Release() @@ -24,13 +21,16 @@ func FromArrowArray(a arrowArray.Interface) *Dense { } switch a.DataType() { - case arrow.PrimitiveTypes.Float64: - backing := a.(*arrowArray.Float64).Float64Values() + {{range .ArrowData.PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{.}}: + backing := a.(*arrowArray.{{.}}).{{.}}Values() retVal := New(WithBacking(backing, mask), WithShape(r, 1)) return retVal + {{end -}} default: panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) } panic("Unreachable") } +` diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index fb353d0..018e3db 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -238,16 +238,24 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { ` var ( - conversions *template.Template - compats *template.Template + conversions *template.Template + compats *template.Template + compatsArrow *template.Template ) func init() { conversions = template.Must(template.New("conversions").Funcs(funcs).Parse(conversionsRaw)) compats = template.Must(template.New("compat").Funcs(funcs).Parse(compatRaw)) + compatsArrow = template.Must(template.New("compat_arrow").Funcs(funcs).Parse(compatArrowRaw)) } func generateDenseCompat(f io.Writer, generic Kinds) { conversions.Execute(f, generic) compats.Execute(f, generic) + arrowData := ArrowData{ + BinaryTypes: arrowBinaryTypes, + FixedWidthTypes: arrowFixedWidthTypes, + PrimitiveTypes: arrowPrimitiveTypes, + } + compatsArrow.Execute(f, arrowData) } From 040ea29e30a875dcebbb4694e888a0e3b2b4015e Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 20 Jul 2020 21:50:52 -0400 Subject: [PATCH 043/154] add binary types and fixed width types --- genlib2/dense_arrow.go | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/genlib2/dense_arrow.go b/genlib2/dense_arrow.go index b3e63ba..2d225d0 100644 --- a/genlib2/dense_arrow.go +++ b/genlib2/dense_arrow.go @@ -21,7 +21,33 @@ func FromArrowArray(a arrowArray.Interface) *Dense { } switch a.DataType() { - {{range .ArrowData.PrimitiveTypes -}} + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{.}}: + {{if eq . "String" -}} + backing := make([]string, a.Len()) + for i := 0; i < len(backing); i++ { + backing[i] = a.Value(i) + } + {{else -}} + backing := a.(*arrowArray.{{.}}).{{.}}Values() + {{end -}} + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{.}}: + {{if eq . "Boolean" -}} + backing := make([]bool, a.Len()) + for i := 0; i < len(backing); i++ { + backing[i] = a.Value(i) + } + {{else -}} + backing := a.(*arrowArray.{{.}}).{{.}}Values() + {{end -}} + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + {{range .PrimitiveTypes -}} case arrow.PrimitiveTypes.{{.}}: backing := a.(*arrowArray.{{.}}).{{.}}Values() retVal := New(WithBacking(backing, mask), WithShape(r, 1)) From ad41a322cc9200490f3046f461c4202b6edecbad Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 20 Jul 2020 22:38:13 -0400 Subject: [PATCH 044/154] update native iterator --- genlib2/native_iterator.go | 33 ++++---- genlib2/native_select.go | 12 +-- native/iterator_native.go | 160 ++++++++++++++++++++++--------------- native/iterator_native2.go | 32 ++++---- 4 files changed, 134 insertions(+), 103 deletions(-) diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go index 5b586e9..7a0c720 100644 --- a/genlib2/native_iterator.go +++ b/genlib2/native_iterator.go @@ -29,7 +29,8 @@ const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dty ` const nativeIterRaw = `// Vector{{short .}} converts a *Dense into a []{{asType .}} -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) { if err = checkNativeIterable(t, 1, {{reflectKind .}}); err != nil { return nil, err @@ -38,7 +39,8 @@ func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) { } // Matrix{{short .}} converts a *Dense into a [][]{{asType .}} -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { if err = checkNativeIterable(t, 2, {{reflectKind .}}); err != nil { return nil, err @@ -54,17 +56,16 @@ func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { retVal = make([][]{{asType .}}, rows) for i := range retVal { start := i * rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i] = *(*[]{{asType .}})(unsafe.Pointer(hdr)) + retVal[i] = make([]{{asType .}}, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } return } -// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil { @@ -84,18 +85,16 @@ func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { for i := range retVal { retVal[i] = make([][]{{asType .}}, rows) for j := range retVal[i] { + retVal[i][j] = make([]{{asType .}}, 0) start := i*layerStride + j*rowStride - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[start])), - Len: cols, - Cap: cols, - } - retVal[i][j] = *(*[]{{asType .}})(unsafe.Pointer(hdr)) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols } } return -} -` +}` const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { assert := assert.New(t) diff --git a/genlib2/native_select.go b/genlib2/native_select.go index a386eaa..6b1e277 100644 --- a/genlib2/native_select.go +++ b/genlib2/native_select.go @@ -44,12 +44,12 @@ func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) upper := ProdInts(t.Shape()[:axis+1]) retVal = make([][]{{asType .}}, 0, upper) for i, r := 0, 0; r < upper; i += stride { - hdr := &reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(&data[i])), - Len: stride, - Cap: stride, - } - retVal = append(retVal, *(*[]{{asType .}})(unsafe.Pointer(hdr))) + s := make([]{{asType .}}, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) r++ } return retVal, nil diff --git a/native/iterator_native.go b/native/iterator_native.go index 6cf492c..958e160 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -34,7 +34,8 @@ func checkNativeIterable(t *Dense, dims int, dt Dtype) error { /* Native Iterables for bool */ // VectorB converts a *Dense into a []bool -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorB(t *Dense) (retVal []bool, err error) { if err = checkNativeIterable(t, 1, Bool); err != nil { return nil, err @@ -43,7 +44,8 @@ func VectorB(t *Dense) (retVal []bool, err error) { } // MatrixB converts a *Dense into a [][]bool -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixB(t *Dense) (retVal [][]bool, err error) { if err = checkNativeIterable(t, 2, Bool); err != nil { return nil, err @@ -68,7 +70,7 @@ func MatrixB(t *Dense) (retVal [][]bool, err error) { return } -// Tensor3B converts a *Dense into a [][][]bool. +// Tensor3B converts a *Dense into a [][][]bool. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3B(t *Dense) (retVal [][][]bool, err error) { if err = checkNativeIterable(t, 3, Bool); err != nil { @@ -88,8 +90,8 @@ func Tensor3B(t *Dense) (retVal [][][]bool, err error) { for i := range retVal { retVal[i] = make([][]bool, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -102,7 +104,8 @@ func Tensor3B(t *Dense) (retVal [][][]bool, err error) { /* Native Iterables for int */ // VectorI converts a *Dense into a []int -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorI(t *Dense) (retVal []int, err error) { if err = checkNativeIterable(t, 1, Int); err != nil { return nil, err @@ -111,7 +114,8 @@ func VectorI(t *Dense) (retVal []int, err error) { } // MatrixI converts a *Dense into a [][]int -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixI(t *Dense) (retVal [][]int, err error) { if err = checkNativeIterable(t, 2, Int); err != nil { return nil, err @@ -136,7 +140,7 @@ func MatrixI(t *Dense) (retVal [][]int, err error) { return } -// Tensor3I converts a *Dense into a [][][]int. +// Tensor3I converts a *Dense into a [][][]int. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I(t *Dense) (retVal [][][]int, err error) { if err = checkNativeIterable(t, 3, Int); err != nil { @@ -156,8 +160,8 @@ func Tensor3I(t *Dense) (retVal [][][]int, err error) { for i := range retVal { retVal[i] = make([][]int, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -170,7 +174,8 @@ func Tensor3I(t *Dense) (retVal [][][]int, err error) { /* Native Iterables for int8 */ // VectorI8 converts a *Dense into a []int8 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorI8(t *Dense) (retVal []int8, err error) { if err = checkNativeIterable(t, 1, Int8); err != nil { return nil, err @@ -179,7 +184,8 @@ func VectorI8(t *Dense) (retVal []int8, err error) { } // MatrixI8 converts a *Dense into a [][]int8 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixI8(t *Dense) (retVal [][]int8, err error) { if err = checkNativeIterable(t, 2, Int8); err != nil { return nil, err @@ -204,7 +210,7 @@ func MatrixI8(t *Dense) (retVal [][]int8, err error) { return } -// Tensor3I8 converts a *Dense into a [][][]int8. +// Tensor3I8 converts a *Dense into a [][][]int8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { if err = checkNativeIterable(t, 3, Int8); err != nil { @@ -224,8 +230,8 @@ func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { for i := range retVal { retVal[i] = make([][]int8, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -238,7 +244,8 @@ func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { /* Native Iterables for int16 */ // VectorI16 converts a *Dense into a []int16 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorI16(t *Dense) (retVal []int16, err error) { if err = checkNativeIterable(t, 1, Int16); err != nil { return nil, err @@ -247,7 +254,8 @@ func VectorI16(t *Dense) (retVal []int16, err error) { } // MatrixI16 converts a *Dense into a [][]int16 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixI16(t *Dense) (retVal [][]int16, err error) { if err = checkNativeIterable(t, 2, Int16); err != nil { return nil, err @@ -272,7 +280,7 @@ func MatrixI16(t *Dense) (retVal [][]int16, err error) { return } -// Tensor3I16 converts a *Dense into a [][][]int16. +// Tensor3I16 converts a *Dense into a [][][]int16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { if err = checkNativeIterable(t, 3, Int16); err != nil { @@ -292,8 +300,8 @@ func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { for i := range retVal { retVal[i] = make([][]int16, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -306,7 +314,8 @@ func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { /* Native Iterables for int32 */ // VectorI32 converts a *Dense into a []int32 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorI32(t *Dense) (retVal []int32, err error) { if err = checkNativeIterable(t, 1, Int32); err != nil { return nil, err @@ -315,7 +324,8 @@ func VectorI32(t *Dense) (retVal []int32, err error) { } // MatrixI32 converts a *Dense into a [][]int32 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixI32(t *Dense) (retVal [][]int32, err error) { if err = checkNativeIterable(t, 2, Int32); err != nil { return nil, err @@ -340,7 +350,7 @@ func MatrixI32(t *Dense) (retVal [][]int32, err error) { return } -// Tensor3I32 converts a *Dense into a [][][]int32. +// Tensor3I32 converts a *Dense into a [][][]int32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { if err = checkNativeIterable(t, 3, Int32); err != nil { @@ -360,8 +370,8 @@ func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { for i := range retVal { retVal[i] = make([][]int32, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -374,7 +384,8 @@ func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { /* Native Iterables for int64 */ // VectorI64 converts a *Dense into a []int64 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorI64(t *Dense) (retVal []int64, err error) { if err = checkNativeIterable(t, 1, Int64); err != nil { return nil, err @@ -383,7 +394,8 @@ func VectorI64(t *Dense) (retVal []int64, err error) { } // MatrixI64 converts a *Dense into a [][]int64 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixI64(t *Dense) (retVal [][]int64, err error) { if err = checkNativeIterable(t, 2, Int64); err != nil { return nil, err @@ -408,7 +420,7 @@ func MatrixI64(t *Dense) (retVal [][]int64, err error) { return } -// Tensor3I64 converts a *Dense into a [][][]int64. +// Tensor3I64 converts a *Dense into a [][][]int64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { if err = checkNativeIterable(t, 3, Int64); err != nil { @@ -428,8 +440,8 @@ func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { for i := range retVal { retVal[i] = make([][]int64, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -442,7 +454,8 @@ func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { /* Native Iterables for uint */ // VectorU converts a *Dense into a []uint -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorU(t *Dense) (retVal []uint, err error) { if err = checkNativeIterable(t, 1, Uint); err != nil { return nil, err @@ -451,7 +464,8 @@ func VectorU(t *Dense) (retVal []uint, err error) { } // MatrixU converts a *Dense into a [][]uint -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixU(t *Dense) (retVal [][]uint, err error) { if err = checkNativeIterable(t, 2, Uint); err != nil { return nil, err @@ -476,7 +490,7 @@ func MatrixU(t *Dense) (retVal [][]uint, err error) { return } -// Tensor3U converts a *Dense into a [][][]uint. +// Tensor3U converts a *Dense into a [][][]uint. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U(t *Dense) (retVal [][][]uint, err error) { if err = checkNativeIterable(t, 3, Uint); err != nil { @@ -496,8 +510,8 @@ func Tensor3U(t *Dense) (retVal [][][]uint, err error) { for i := range retVal { retVal[i] = make([][]uint, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -510,7 +524,8 @@ func Tensor3U(t *Dense) (retVal [][][]uint, err error) { /* Native Iterables for uint8 */ // VectorU8 converts a *Dense into a []uint8 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorU8(t *Dense) (retVal []uint8, err error) { if err = checkNativeIterable(t, 1, Uint8); err != nil { return nil, err @@ -519,7 +534,8 @@ func VectorU8(t *Dense) (retVal []uint8, err error) { } // MatrixU8 converts a *Dense into a [][]uint8 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixU8(t *Dense) (retVal [][]uint8, err error) { if err = checkNativeIterable(t, 2, Uint8); err != nil { return nil, err @@ -544,7 +560,7 @@ func MatrixU8(t *Dense) (retVal [][]uint8, err error) { return } -// Tensor3U8 converts a *Dense into a [][][]uint8. +// Tensor3U8 converts a *Dense into a [][][]uint8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { if err = checkNativeIterable(t, 3, Uint8); err != nil { @@ -564,8 +580,8 @@ func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { for i := range retVal { retVal[i] = make([][]uint8, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -578,7 +594,8 @@ func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { /* Native Iterables for uint16 */ // VectorU16 converts a *Dense into a []uint16 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorU16(t *Dense) (retVal []uint16, err error) { if err = checkNativeIterable(t, 1, Uint16); err != nil { return nil, err @@ -587,7 +604,8 @@ func VectorU16(t *Dense) (retVal []uint16, err error) { } // MatrixU16 converts a *Dense into a [][]uint16 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixU16(t *Dense) (retVal [][]uint16, err error) { if err = checkNativeIterable(t, 2, Uint16); err != nil { return nil, err @@ -612,7 +630,7 @@ func MatrixU16(t *Dense) (retVal [][]uint16, err error) { return } -// Tensor3U16 converts a *Dense into a [][][]uint16. +// Tensor3U16 converts a *Dense into a [][][]uint16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { if err = checkNativeIterable(t, 3, Uint16); err != nil { @@ -632,8 +650,8 @@ func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { for i := range retVal { retVal[i] = make([][]uint16, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -646,7 +664,8 @@ func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { /* Native Iterables for uint32 */ // VectorU32 converts a *Dense into a []uint32 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorU32(t *Dense) (retVal []uint32, err error) { if err = checkNativeIterable(t, 1, Uint32); err != nil { return nil, err @@ -655,7 +674,8 @@ func VectorU32(t *Dense) (retVal []uint32, err error) { } // MatrixU32 converts a *Dense into a [][]uint32 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixU32(t *Dense) (retVal [][]uint32, err error) { if err = checkNativeIterable(t, 2, Uint32); err != nil { return nil, err @@ -680,7 +700,7 @@ func MatrixU32(t *Dense) (retVal [][]uint32, err error) { return } -// Tensor3U32 converts a *Dense into a [][][]uint32. +// Tensor3U32 converts a *Dense into a [][][]uint32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { if err = checkNativeIterable(t, 3, Uint32); err != nil { @@ -700,8 +720,8 @@ func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { for i := range retVal { retVal[i] = make([][]uint32, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -714,7 +734,8 @@ func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { /* Native Iterables for uint64 */ // VectorU64 converts a *Dense into a []uint64 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorU64(t *Dense) (retVal []uint64, err error) { if err = checkNativeIterable(t, 1, Uint64); err != nil { return nil, err @@ -723,7 +744,8 @@ func VectorU64(t *Dense) (retVal []uint64, err error) { } // MatrixU64 converts a *Dense into a [][]uint64 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixU64(t *Dense) (retVal [][]uint64, err error) { if err = checkNativeIterable(t, 2, Uint64); err != nil { return nil, err @@ -748,7 +770,7 @@ func MatrixU64(t *Dense) (retVal [][]uint64, err error) { return } -// Tensor3U64 converts a *Dense into a [][][]uint64. +// Tensor3U64 converts a *Dense into a [][][]uint64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { if err = checkNativeIterable(t, 3, Uint64); err != nil { @@ -768,8 +790,8 @@ func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { for i := range retVal { retVal[i] = make([][]uint64, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -782,7 +804,8 @@ func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { /* Native Iterables for float32 */ // VectorF32 converts a *Dense into a []float32 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorF32(t *Dense) (retVal []float32, err error) { if err = checkNativeIterable(t, 1, Float32); err != nil { return nil, err @@ -791,7 +814,8 @@ func VectorF32(t *Dense) (retVal []float32, err error) { } // MatrixF32 converts a *Dense into a [][]float32 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixF32(t *Dense) (retVal [][]float32, err error) { if err = checkNativeIterable(t, 2, Float32); err != nil { return nil, err @@ -816,7 +840,7 @@ func MatrixF32(t *Dense) (retVal [][]float32, err error) { return } -// Tensor3F32 converts a *Dense into a [][][]float32. +// Tensor3F32 converts a *Dense into a [][][]float32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { if err = checkNativeIterable(t, 3, Float32); err != nil { @@ -836,8 +860,8 @@ func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { for i := range retVal { retVal[i] = make([][]float32, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -850,7 +874,8 @@ func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { /* Native Iterables for float64 */ // VectorF64 converts a *Dense into a []float64 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorF64(t *Dense) (retVal []float64, err error) { if err = checkNativeIterable(t, 1, Float64); err != nil { return nil, err @@ -859,7 +884,8 @@ func VectorF64(t *Dense) (retVal []float64, err error) { } // MatrixF64 converts a *Dense into a [][]float64 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixF64(t *Dense) (retVal [][]float64, err error) { if err = checkNativeIterable(t, 2, Float64); err != nil { return nil, err @@ -884,7 +910,7 @@ func MatrixF64(t *Dense) (retVal [][]float64, err error) { return } -// Tensor3F64 converts a *Dense into a [][][]float64. +// Tensor3F64 converts a *Dense into a [][][]float64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { if err = checkNativeIterable(t, 3, Float64); err != nil { @@ -904,8 +930,8 @@ func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { for i := range retVal { retVal[i] = make([][]float64, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -918,7 +944,8 @@ func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { /* Native Iterables for complex64 */ // VectorC64 converts a *Dense into a []complex64 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorC64(t *Dense) (retVal []complex64, err error) { if err = checkNativeIterable(t, 1, Complex64); err != nil { return nil, err @@ -927,7 +954,8 @@ func VectorC64(t *Dense) (retVal []complex64, err error) { } // MatrixC64 converts a *Dense into a [][]complex64 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixC64(t *Dense) (retVal [][]complex64, err error) { if err = checkNativeIterable(t, 2, Complex64); err != nil { return nil, err @@ -952,7 +980,7 @@ func MatrixC64(t *Dense) (retVal [][]complex64, err error) { return } -// Tensor3C64 converts a *Dense into a [][][]complex64. +// Tensor3C64 converts a *Dense into a [][][]complex64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { if err = checkNativeIterable(t, 3, Complex64); err != nil { @@ -972,8 +1000,8 @@ func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { for i := range retVal { retVal[i] = make([][]complex64, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -986,7 +1014,8 @@ func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { /* Native Iterables for complex128 */ // VectorC128 converts a *Dense into a []complex128 -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorC128(t *Dense) (retVal []complex128, err error) { if err = checkNativeIterable(t, 1, Complex128); err != nil { return nil, err @@ -995,7 +1024,8 @@ func VectorC128(t *Dense) (retVal []complex128, err error) { } // MatrixC128 converts a *Dense into a [][]complex128 -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixC128(t *Dense) (retVal [][]complex128, err error) { if err = checkNativeIterable(t, 2, Complex128); err != nil { return nil, err @@ -1020,7 +1050,7 @@ func MatrixC128(t *Dense) (retVal [][]complex128, err error) { return } -// Tensor3C128 converts a *Dense into a [][][]complex128. +// Tensor3C128 converts a *Dense into a [][][]complex128. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { if err = checkNativeIterable(t, 3, Complex128); err != nil { @@ -1040,8 +1070,8 @@ func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { for i := range retVal { retVal[i] = make([][]complex128, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -1054,7 +1084,8 @@ func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { /* Native Iterables for string */ // VectorStr converts a *Dense into a []string -// If the *Dense does not represent a vector of the wanted type, it will return an error. +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. func VectorStr(t *Dense) (retVal []string, err error) { if err = checkNativeIterable(t, 1, String); err != nil { return nil, err @@ -1063,7 +1094,8 @@ func VectorStr(t *Dense) (retVal []string, err error) { } // MatrixStr converts a *Dense into a [][]string -// If the *Dense does not represent a matrix of the wanted type, it will return an error. +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. func MatrixStr(t *Dense) (retVal [][]string, err error) { if err = checkNativeIterable(t, 2, String); err != nil { return nil, err @@ -1088,7 +1120,7 @@ func MatrixStr(t *Dense) (retVal [][]string, err error) { return } -// Tensor3Str converts a *Dense into a [][][]string. +// Tensor3Str converts a *Dense into a [][][]string. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3Str(t *Dense) (retVal [][][]string, err error) { if err = checkNativeIterable(t, 3, String); err != nil { @@ -1108,8 +1140,8 @@ func Tensor3Str(t *Dense) (retVal [][][]string, err error) { for i := range retVal { retVal[i] = make([][]string, rows) for j := range retVal[i] { - start := i*layerStride + j*rowStride retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols diff --git a/native/iterator_native2.go b/native/iterator_native2.go index 9a0ae34..934863d 100644 --- a/native/iterator_native2.go +++ b/native/iterator_native2.go @@ -53,8 +53,8 @@ func SelectB(t *Dense, axis int) (retVal [][]bool, err error) { s := make([]bool, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -91,8 +91,8 @@ func SelectI(t *Dense, axis int) (retVal [][]int, err error) { s := make([]int, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -129,8 +129,8 @@ func SelectI8(t *Dense, axis int) (retVal [][]int8, err error) { s := make([]int8, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -167,8 +167,8 @@ func SelectI16(t *Dense, axis int) (retVal [][]int16, err error) { s := make([]int16, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -205,8 +205,8 @@ func SelectI32(t *Dense, axis int) (retVal [][]int32, err error) { s := make([]int32, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -243,8 +243,8 @@ func SelectI64(t *Dense, axis int) (retVal [][]int64, err error) { s := make([]int64, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -281,8 +281,8 @@ func SelectU(t *Dense, axis int) (retVal [][]uint, err error) { s := make([]uint, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -319,8 +319,8 @@ func SelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { s := make([]uint8, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -357,8 +357,8 @@ func SelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { s := make([]uint16, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -395,8 +395,8 @@ func SelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { s := make([]uint32, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -433,8 +433,8 @@ func SelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { s := make([]uint64, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -471,8 +471,8 @@ func SelectF32(t *Dense, axis int) (retVal [][]float32, err error) { s := make([]float32, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -509,8 +509,8 @@ func SelectF64(t *Dense, axis int) (retVal [][]float64, err error) { s := make([]float64, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -547,8 +547,8 @@ func SelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { s := make([]complex64, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -585,8 +585,8 @@ func SelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { s := make([]complex128, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } @@ -623,8 +623,8 @@ func SelectStr(t *Dense, axis int) (retVal [][]string, err error) { s := make([]string, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Cap = stride hdr.Len = stride + hdr.Cap = stride retVal = append(retVal, s) r++ } From 40d2b9c03e003087f635318ffd847da4554146dd Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 20 Jul 2020 22:39:58 -0400 Subject: [PATCH 045/154] remove arrow specific files and fix arrow imports --- dense_compat.go | 78 +++++++++++++++++++ example_apply_test.go | 25 ------ ...ow_test.go => example_dense_compat_test.go | 0 genlib2/dense_arrow.go | 62 --------------- genlib2/dense_compat.go | 72 +++++++++++++++++ go.mod | 2 +- go.sum | 2 + 7 files changed, 153 insertions(+), 88 deletions(-) rename example_dense_arrow_test.go => example_dense_compat_test.go (100%) delete mode 100644 genlib2/dense_arrow.go diff --git a/dense_compat.go b/dense_compat.go index a1b90ab..1e94358 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -8,6 +8,8 @@ import ( "math/cmplx" "reflect" + arrow "github.com/apache/arrow/go/arrow" + arrowArray "github.com/apache/arrow/go/arrow/array" "github.com/chewxy/math32" "github.com/pkg/errors" "gonum.org/v1/gonum/mat" @@ -433,3 +435,79 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { retVal = mat.NewDense(r, c, data) return } + +// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. +func FromArrowArray(a arrowArray.Interface) *Dense { + a.Retain() + defer a.Release() + + r := a.Len() + + // TODO(poopoothegorilla): instead of creating bool ValidMask maybe + // bitmapBytes can be used from arrow API + mask := make([]bool, r) + for i := 0; i < r; i++ { + mask[i] = a.IsNull(i) + } + + switch a.DataType() { + case arrow.BinaryTypes.String: + backing := make([]string, a.Len()) + for i := 0; i < len(backing); i++ { + backing[i] = a.(*arrowArray.String).Value(i) + } + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.FixedWidthTypes.Boolean: + backing := make([]bool, a.Len()) + for i := 0; i < len(backing); i++ { + backing[i] = a.(*arrowArray.Boolean).Value(i) + } + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int8: + backing := a.(*arrowArray.Int8).Int8Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int16: + backing := a.(*arrowArray.Int16).Int16Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int32: + backing := a.(*arrowArray.Int32).Int32Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Int64: + backing := a.(*arrowArray.Int64).Int64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint8: + backing := a.(*arrowArray.Uint8).Uint8Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint16: + backing := a.(*arrowArray.Uint16).Uint16Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint32: + backing := a.(*arrowArray.Uint32).Uint32Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Uint64: + backing := a.(*arrowArray.Uint64).Uint64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Float32: + backing := a.(*arrowArray.Float32).Float32Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + case arrow.PrimitiveTypes.Float64: + backing := a.(*arrowArray.Float64).Float64Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} diff --git a/example_apply_test.go b/example_apply_test.go index c84f4c7..1e11641 100644 --- a/example_apply_test.go +++ b/example_apply_test.go @@ -3,8 +3,6 @@ package tensor_test import ( "fmt" - "github.com/apache/arrow/go/arrow/array" - "github.com/apache/arrow/go/arrow/memory" "gorgonia.org/tensor" ) @@ -37,26 +35,3 @@ func ExampleDense_Apply() { // ⎡ 1 8⎤ // ⎣27 64⎦ } - -func ExampleDense_Arrow() { - pool := memory.NewGoAllocator() - - b := array.NewFloat64Builder(pool) - defer b.Release() - - b.AppendValues( - []float64{1, 2, 3, -1, 4, 5}, - []bool{true, true, true, false, true, true}, - ) - - arr := b.NewFloat64Array() - defer arr.Release() - fmt.Printf("arrow array = %v\n", arr) - - a := tensor.FromArrowArray(arr) - fmt.Printf("tensor = %v\n", a) - - // Output: - // arrow array = [1 2 3 (null) 4 5] - // tensor = C[ 1 2 3 -- 4 5] -} diff --git a/example_dense_arrow_test.go b/example_dense_compat_test.go similarity index 100% rename from example_dense_arrow_test.go rename to example_dense_compat_test.go diff --git a/genlib2/dense_arrow.go b/genlib2/dense_arrow.go deleted file mode 100644 index 2d225d0..0000000 --- a/genlib2/dense_arrow.go +++ /dev/null @@ -1,62 +0,0 @@ -package main - -type ArrowData struct { - BinaryTypes []string - FixedWidthTypes []string - PrimitiveTypes []string -} - -const compatArrowRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. -func FromArrowArray(a arrowArray.Interface) *Dense { - a.Retain() - defer a.Release() - - r := a.Len() - - // TODO(poopoothegorilla): instead of creating bool ValidMask maybe - // bitmapBytes can be used from arrow API - mask := make([]bool, r) - for i := 0; i < r; i++ { - mask[i] = a.IsNull(i) - } - - switch a.DataType() { - {{range .BinaryTypes -}} - case arrow.BinaryTypes.{{.}}: - {{if eq . "String" -}} - backing := make([]string, a.Len()) - for i := 0; i < len(backing); i++ { - backing[i] = a.Value(i) - } - {{else -}} - backing := a.(*arrowArray.{{.}}).{{.}}Values() - {{end -}} - retVal := New(WithBacking(backing, mask), WithShape(r, 1)) - return retVal - {{end -}} - {{range .FixedWidthTypes -}} - case arrow.FixedWidthTypes.{{.}}: - {{if eq . "Boolean" -}} - backing := make([]bool, a.Len()) - for i := 0; i < len(backing); i++ { - backing[i] = a.Value(i) - } - {{else -}} - backing := a.(*arrowArray.{{.}}).{{.}}Values() - {{end -}} - retVal := New(WithBacking(backing, mask), WithShape(r, 1)) - return retVal - {{end -}} - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{.}}: - backing := a.(*arrowArray.{{.}}).{{.}}Values() - retVal := New(WithBacking(backing, mask), WithShape(r, 1)) - return retVal - {{end -}} - default: - panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) - } - - panic("Unreachable") -} -` diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 018e3db..6e19b83 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -5,6 +5,12 @@ import ( "text/template" ) +const importsArrowRaw = `import ( + arrowArray "github.com/apache/arrow/go/arrow/array" + arrow "github.com/apache/arrow/go/arrow" +) +` + const conversionsRaw = `func convFromFloat64s(to Dtype, data []float64) interface{} { switch to { {{range .Kinds -}} @@ -237,19 +243,85 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { ` +type ArrowData struct { + BinaryTypes []string + FixedWidthTypes []string + PrimitiveTypes []string +} + +const compatArrowRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. +func FromArrowArray(a arrowArray.Interface) *Dense { + a.Retain() + defer a.Release() + + r := a.Len() + + // TODO(poopoothegorilla): instead of creating bool ValidMask maybe + // bitmapBytes can be used from arrow API + mask := make([]bool, r) + for i := 0; i < r; i++ { + mask[i] = a.IsNull(i) + } + + switch a.DataType() { + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{.}}: + {{if eq . "String" -}} + backing := make([]string, a.Len()) + for i := 0; i < len(backing); i++ { + backing[i] = a.(*arrowArray.{{.}}).Value(i) + } + {{else -}} + backing := a.(*arrowArray.{{.}}).{{.}}Values() + {{end -}} + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{.}}: + {{if eq . "Boolean" -}} + backing := make([]bool, a.Len()) + for i := 0; i < len(backing); i++ { + backing[i] = a.(*arrowArray.{{.}}).Value(i) + } + {{else -}} + backing := a.(*arrowArray.{{.}}).{{.}}Values() + {{end -}} + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{.}}: + backing := a.(*arrowArray.{{.}}).{{.}}Values() + retVal := New(WithBacking(backing, mask), WithShape(r, 1)) + return retVal + {{end -}} + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} +` + var ( + importsArrow *template.Template conversions *template.Template compats *template.Template compatsArrow *template.Template ) func init() { + importsArrow = template.Must(template.New("imports_arrow").Funcs(funcs).Parse(importsArrowRaw)) conversions = template.Must(template.New("conversions").Funcs(funcs).Parse(conversionsRaw)) compats = template.Must(template.New("compat").Funcs(funcs).Parse(compatRaw)) compatsArrow = template.Must(template.New("compat_arrow").Funcs(funcs).Parse(compatArrowRaw)) } func generateDenseCompat(f io.Writer, generic Kinds) { + // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming + // colisions + importsArrow.Execute(f, generic) conversions.Execute(f, generic) compats.Execute(f, generic) arrowData := ArrowData{ diff --git a/go.mod b/go.mod index 1bd9f13..a2b2ad0 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module gorgonia.org/tensor go 1.13 require ( - github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb + github.com/apache/arrow/go/arrow v0.0.0-20200720215425-c09a82a388e7 github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.4 github.com/gogo/protobuf v1.3.0 diff --git a/go.sum b/go.sum index 412a6f2..1d7bc0c 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3 github.com/apache/arrow v0.0.0-20200720164908-23b19f65e1eb h1:/guPTo4KRiOQnB4UX0Sn9kk5k7kCC00eSKsoykKc0tU= github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb h1:vBEPOeLNZ2RUgG/e+G2tOIucgCojRKRPorB3STXC+xw= github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= +github.com/apache/arrow/go/arrow v0.0.0-20200720215425-c09a82a388e7 h1:vStV77omxrTbs1UsDogumkL0+TU28S6Ebp0LmKCM7KE= +github.com/apache/arrow/go/arrow v0.0.0-20200720215425-c09a82a388e7/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= From 2c590b5f91449971cb887ed2a409099e15ff13e4 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Sun, 26 Jul 2020 11:45:53 -0400 Subject: [PATCH 046/154] add tests for arrow to tensor function --- dense_compat_test.go | 202 ++++++++++++++++++++++++++++++++++ genlib2/dense_compat_tests.go | 65 ++++++++++- 2 files changed, 266 insertions(+), 1 deletion(-) diff --git a/dense_compat_test.go b/dense_compat_test.go index 494fc25..4b3bfc5 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -5,6 +5,9 @@ package tensor import ( "testing" + arrow "github.com/apache/arrow/go/arrow" + arrowArray "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" "github.com/stretchr/testify/assert" "gonum.org/v1/gonum/mat" ) @@ -105,3 +108,202 @@ func TestFromMat64(t *testing.T) { } } } + +var toArrowArrayTests = []struct { + data interface{} + dt arrow.DataType + shape Shape +}{ + { + data: Range(Int8, 0, 6), + dt: arrow.PrimitiveTypes.Int8, + shape: Shape{6, 1}, + }, + { + data: Range(Int16, 0, 6), + dt: arrow.PrimitiveTypes.Int16, + shape: Shape{6, 1}, + }, + { + data: Range(Int32, 0, 6), + dt: arrow.PrimitiveTypes.Int32, + shape: Shape{6, 1}, + }, + { + data: Range(Int64, 0, 6), + dt: arrow.PrimitiveTypes.Int64, + shape: Shape{6, 1}, + }, + { + data: Range(Uint8, 0, 6), + dt: arrow.PrimitiveTypes.Uint8, + shape: Shape{6, 1}, + }, + { + data: Range(Uint16, 0, 6), + dt: arrow.PrimitiveTypes.Uint16, + shape: Shape{6, 1}, + }, + { + data: Range(Uint32, 0, 6), + dt: arrow.PrimitiveTypes.Uint32, + shape: Shape{6, 1}, + }, + { + data: Range(Uint64, 0, 6), + dt: arrow.PrimitiveTypes.Uint64, + shape: Shape{6, 1}, + }, + { + data: Range(Float32, 0, 6), + dt: arrow.PrimitiveTypes.Float32, + shape: Shape{6, 1}, + }, + { + data: Range(Float64, 0, 6), + dt: arrow.PrimitiveTypes.Float64, + shape: Shape{6, 1}, + }, +} + +func TestFromArrowArray(t *testing.T) { + assert := assert.New(t) + var T *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowArrayTests { + var m arrowArray.Interface + + switch taat.dt { + case arrow.PrimitiveTypes.Int8: + b := arrowArray.NewInt8Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int8, 0, 6).([]int8), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int16: + b := arrowArray.NewInt16Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int16, 0, 6).([]int16), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int32: + b := arrowArray.NewInt32Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int32, 0, 6).([]int32), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Int64: + b := arrowArray.NewInt64Builder(pool) + defer b.Release() + b.AppendValues( + Range(Int64, 0, 6).([]int64), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint8: + b := arrowArray.NewUint8Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint8, 0, 6).([]uint8), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint16: + b := arrowArray.NewUint16Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint16, 0, 6).([]uint16), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint32: + b := arrowArray.NewUint32Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint32, 0, 6).([]uint32), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Uint64: + b := arrowArray.NewUint64Builder(pool) + defer b.Release() + b.AppendValues( + Range(Uint64, 0, 6).([]uint64), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Float32: + b := arrowArray.NewFloat32Builder(pool) + defer b.Release() + b.AppendValues( + Range(Float32, 0, 6).([]float32), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + case arrow.PrimitiveTypes.Float64: + b := arrowArray.NewFloat64Builder(pool) + defer b.Release() + b.AppendValues( + Range(Float64, 0, 6).([]float64), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + T = FromArrowArray(m) + switch taat.dt { + case arrow.PrimitiveTypes.Int8: + conv := taat.data.([]int8) + assert.Equal(conv, T.Int8s(), "test %d: []int8 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Int16: + conv := taat.data.([]int16) + assert.Equal(conv, T.Int16s(), "test %d: []int16 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Int32: + conv := taat.data.([]int32) + assert.Equal(conv, T.Int32s(), "test %d: []int32 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Int64: + conv := taat.data.([]int64) + assert.Equal(conv, T.Int64s(), "test %d: []int64 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint8: + conv := taat.data.([]uint8) + assert.Equal(conv, T.Uint8s(), "test %d: []uint8 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint16: + conv := taat.data.([]uint16) + assert.Equal(conv, T.Uint16s(), "test %d: []uint16 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint32: + conv := taat.data.([]uint32) + assert.Equal(conv, T.Uint32s(), "test %d: []uint32 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Uint64: + conv := taat.data.([]uint64) + assert.Equal(conv, T.Uint64s(), "test %d: []uint64 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Float32: + conv := taat.data.([]float32) + assert.Equal(conv, T.Float32s(), "test %d: []float32 from %v", i, taat.dt) + case arrow.PrimitiveTypes.Float64: + conv := taat.data.([]float64) + assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, taat.dt) + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + assert.True(T.Shape().Eq(taat.shape)) + } +} diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index d2fd049..a539da3 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -94,14 +94,77 @@ func TestFromMat64(t *testing.T){ } ` +const compatArrowTestsRaw = `var toArrowArrayTests = []struct{ + data interface{} + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + data: Range({{.}}, 0, 6), + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{6,1}, + }, + {{end -}} +} +func TestFromArrowArray(t *testing.T){ + assert := assert.New(t) + var T *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowArrayTests { + var m arrowArray.Interface + + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + Range({{ . }}, 0, 6).([]{{lower . }}), + nil, // TODO(poopoothegorilla): add valid bitmask + ) + m = b.NewArray() + defer m.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + T = FromArrowArray(m) + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + conv := taat.data.([]{{lower . }}) + assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + assert.True(T.Shape().Eq(taat.shape)) + } +} +` + var ( - compatTests *template.Template + compatTests *template.Template + compatArrowTests *template.Template ) func init() { compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) + compatArrowTests = template.Must(template.New("testArrowCompat").Funcs(funcs).Parse(compatArrowTestsRaw)) } func generateDenseCompatTests(f io.Writer, generic Kinds) { + // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming + // colisions + importsArrow.Execute(f, generic) compatTests.Execute(f, generic) + arrowData := ArrowData{ + BinaryTypes: arrowBinaryTypes, + FixedWidthTypes: arrowFixedWidthTypes, + PrimitiveTypes: arrowPrimitiveTypes, + } + compatArrowTests.Execute(f, arrowData) } From 061a685ce673171829144b43a4a7452df1f39cd0 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Sun, 26 Jul 2020 13:49:52 -0400 Subject: [PATCH 047/154] add checks for valid mask in arrow to tensor tests --- dense_compat_test.go | 34 ++++++++++++++++++++++++---------- genlib2/dense_compat_tests.go | 7 ++++++- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/dense_compat_test.go b/dense_compat_test.go index 4b3bfc5..51c1e39 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -111,56 +111,67 @@ func TestFromMat64(t *testing.T) { var toArrowArrayTests = []struct { data interface{} + valid []bool dt arrow.DataType shape Shape }{ { data: Range(Int8, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Int8, shape: Shape{6, 1}, }, { data: Range(Int16, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Int16, shape: Shape{6, 1}, }, { data: Range(Int32, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Int32, shape: Shape{6, 1}, }, { data: Range(Int64, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Int64, shape: Shape{6, 1}, }, { data: Range(Uint8, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Uint8, shape: Shape{6, 1}, }, { data: Range(Uint16, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Uint16, shape: Shape{6, 1}, }, { data: Range(Uint32, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Uint32, shape: Shape{6, 1}, }, { data: Range(Uint64, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Uint64, shape: Shape{6, 1}, }, { data: Range(Float32, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Float32, shape: Shape{6, 1}, }, { data: Range(Float64, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.Float64, shape: Shape{6, 1}, }, @@ -180,7 +191,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Int8, 0, 6).([]int8), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -189,7 +200,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Int16, 0, 6).([]int16), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -198,7 +209,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Int32, 0, 6).([]int32), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -207,7 +218,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Int64, 0, 6).([]int64), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -216,7 +227,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Uint8, 0, 6).([]uint8), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -225,7 +236,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Uint16, 0, 6).([]uint16), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -234,7 +245,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Uint32, 0, 6).([]uint32), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -243,7 +254,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Uint64, 0, 6).([]uint64), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -252,7 +263,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Float32, 0, 6).([]float32), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -261,7 +272,7 @@ func TestFromArrowArray(t *testing.T) { defer b.Release() b.AppendValues( Range(Float64, 0, 6).([]float64), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -304,6 +315,9 @@ func TestFromArrowArray(t *testing.T) { default: t.Errorf("DataType not supported in tests: %v", taat.dt) } + for i, invalid := range T.Mask() { + assert.Equal(taat.valid[i], !invalid) + } assert.True(T.Shape().Eq(taat.shape)) } } diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index a539da3..488a1d5 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -96,12 +96,14 @@ func TestFromMat64(t *testing.T){ const compatArrowTestsRaw = `var toArrowArrayTests = []struct{ data interface{} + valid []bool dt arrow.DataType shape Shape }{ {{range .PrimitiveTypes -}} { data: Range({{.}}, 0, 6), + valid: []bool{true, true, true, false, true, true}, dt: arrow.PrimitiveTypes.{{ . }}, shape: Shape{6,1}, }, @@ -122,7 +124,7 @@ func TestFromArrowArray(t *testing.T){ defer b.Release() b.AppendValues( Range({{ . }}, 0, 6).([]{{lower . }}), - nil, // TODO(poopoothegorilla): add valid bitmask + taat.valid, ) m = b.NewArray() defer m.Release() @@ -141,6 +143,9 @@ func TestFromArrowArray(t *testing.T){ default: t.Errorf("DataType not supported in tests: %v", taat.dt) } + for i, invalid := range T.Mask() { + assert.Equal(taat.valid[i], !invalid) + } assert.True(T.Shape().Eq(taat.shape)) } } From 9b3b7c725ec745bf7895cc3232bad3055de0bec0 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Sun, 26 Jul 2020 14:03:38 -0400 Subject: [PATCH 048/154] small changes to address inconsistencies --- dense_compat.go | 8 ++++---- genlib2/dense_compat.go | 10 +++++----- genlib2/dense_compat_tests.go | 2 +- genlib2/native_iterator.go | 5 +++-- native/iterator_native.go | 32 ++++++++++++++++---------------- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/dense_compat.go b/dense_compat.go index 1e94358..554466d 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -452,15 +452,15 @@ func FromArrowArray(a arrowArray.Interface) *Dense { switch a.DataType() { case arrow.BinaryTypes.String: - backing := make([]string, a.Len()) - for i := 0; i < len(backing); i++ { + backing := make([]string, r) + for i := 0; i < r; i++ { backing[i] = a.(*arrowArray.String).Value(i) } retVal := New(WithBacking(backing, mask), WithShape(r, 1)) return retVal case arrow.FixedWidthTypes.Boolean: - backing := make([]bool, a.Len()) - for i := 0; i < len(backing); i++ { + backing := make([]bool, r) + for i := 0; i < r; i++ { backing[i] = a.(*arrowArray.Boolean).Value(i) } retVal := New(WithBacking(backing, mask), WithShape(r, 1)) diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 6e19b83..13da09d 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -267,8 +267,8 @@ func FromArrowArray(a arrowArray.Interface) *Dense { {{range .BinaryTypes -}} case arrow.BinaryTypes.{{.}}: {{if eq . "String" -}} - backing := make([]string, a.Len()) - for i := 0; i < len(backing); i++ { + backing := make([]string, r) + for i := 0; i < r; i++ { backing[i] = a.(*arrowArray.{{.}}).Value(i) } {{else -}} @@ -280,8 +280,8 @@ func FromArrowArray(a arrowArray.Interface) *Dense { {{range .FixedWidthTypes -}} case arrow.FixedWidthTypes.{{.}}: {{if eq . "Boolean" -}} - backing := make([]bool, a.Len()) - for i := 0; i < len(backing); i++ { + backing := make([]bool, r) + for i := 0; i < r; i++ { backing[i] = a.(*arrowArray.{{.}}).Value(i) } {{else -}} @@ -320,7 +320,7 @@ func init() { func generateDenseCompat(f io.Writer, generic Kinds) { // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming - // colisions + // collisions importsArrow.Execute(f, generic) conversions.Execute(f, generic) compats.Execute(f, generic) diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index 488a1d5..036d774 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -163,7 +163,7 @@ func init() { func generateDenseCompatTests(f io.Writer, generic Kinds) { // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming - // colisions + // collisions importsArrow.Execute(f, generic) compatTests.Execute(f, generic) arrowData := ArrowData{ diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go index 7a0c720..565d9e9 100644 --- a/genlib2/native_iterator.go +++ b/genlib2/native_iterator.go @@ -65,7 +65,7 @@ func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { return } -// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil { @@ -94,7 +94,8 @@ func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { } } return -}` +} +` const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { assert := assert.New(t) diff --git a/native/iterator_native.go b/native/iterator_native.go index 958e160..d9727fe 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -70,7 +70,7 @@ func MatrixB(t *Dense) (retVal [][]bool, err error) { return } -// Tensor3B converts a *Dense into a [][][]bool. +// Tensor3B converts a *Dense into a [][][]bool. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3B(t *Dense) (retVal [][][]bool, err error) { if err = checkNativeIterable(t, 3, Bool); err != nil { @@ -140,7 +140,7 @@ func MatrixI(t *Dense) (retVal [][]int, err error) { return } -// Tensor3I converts a *Dense into a [][][]int. +// Tensor3I converts a *Dense into a [][][]int. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I(t *Dense) (retVal [][][]int, err error) { if err = checkNativeIterable(t, 3, Int); err != nil { @@ -210,7 +210,7 @@ func MatrixI8(t *Dense) (retVal [][]int8, err error) { return } -// Tensor3I8 converts a *Dense into a [][][]int8. +// Tensor3I8 converts a *Dense into a [][][]int8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { if err = checkNativeIterable(t, 3, Int8); err != nil { @@ -280,7 +280,7 @@ func MatrixI16(t *Dense) (retVal [][]int16, err error) { return } -// Tensor3I16 converts a *Dense into a [][][]int16. +// Tensor3I16 converts a *Dense into a [][][]int16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { if err = checkNativeIterable(t, 3, Int16); err != nil { @@ -350,7 +350,7 @@ func MatrixI32(t *Dense) (retVal [][]int32, err error) { return } -// Tensor3I32 converts a *Dense into a [][][]int32. +// Tensor3I32 converts a *Dense into a [][][]int32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { if err = checkNativeIterable(t, 3, Int32); err != nil { @@ -420,7 +420,7 @@ func MatrixI64(t *Dense) (retVal [][]int64, err error) { return } -// Tensor3I64 converts a *Dense into a [][][]int64. +// Tensor3I64 converts a *Dense into a [][][]int64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { if err = checkNativeIterable(t, 3, Int64); err != nil { @@ -490,7 +490,7 @@ func MatrixU(t *Dense) (retVal [][]uint, err error) { return } -// Tensor3U converts a *Dense into a [][][]uint. +// Tensor3U converts a *Dense into a [][][]uint. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U(t *Dense) (retVal [][][]uint, err error) { if err = checkNativeIterable(t, 3, Uint); err != nil { @@ -560,7 +560,7 @@ func MatrixU8(t *Dense) (retVal [][]uint8, err error) { return } -// Tensor3U8 converts a *Dense into a [][][]uint8. +// Tensor3U8 converts a *Dense into a [][][]uint8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { if err = checkNativeIterable(t, 3, Uint8); err != nil { @@ -630,7 +630,7 @@ func MatrixU16(t *Dense) (retVal [][]uint16, err error) { return } -// Tensor3U16 converts a *Dense into a [][][]uint16. +// Tensor3U16 converts a *Dense into a [][][]uint16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { if err = checkNativeIterable(t, 3, Uint16); err != nil { @@ -700,7 +700,7 @@ func MatrixU32(t *Dense) (retVal [][]uint32, err error) { return } -// Tensor3U32 converts a *Dense into a [][][]uint32. +// Tensor3U32 converts a *Dense into a [][][]uint32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { if err = checkNativeIterable(t, 3, Uint32); err != nil { @@ -770,7 +770,7 @@ func MatrixU64(t *Dense) (retVal [][]uint64, err error) { return } -// Tensor3U64 converts a *Dense into a [][][]uint64. +// Tensor3U64 converts a *Dense into a [][][]uint64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { if err = checkNativeIterable(t, 3, Uint64); err != nil { @@ -840,7 +840,7 @@ func MatrixF32(t *Dense) (retVal [][]float32, err error) { return } -// Tensor3F32 converts a *Dense into a [][][]float32. +// Tensor3F32 converts a *Dense into a [][][]float32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { if err = checkNativeIterable(t, 3, Float32); err != nil { @@ -910,7 +910,7 @@ func MatrixF64(t *Dense) (retVal [][]float64, err error) { return } -// Tensor3F64 converts a *Dense into a [][][]float64. +// Tensor3F64 converts a *Dense into a [][][]float64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { if err = checkNativeIterable(t, 3, Float64); err != nil { @@ -980,7 +980,7 @@ func MatrixC64(t *Dense) (retVal [][]complex64, err error) { return } -// Tensor3C64 converts a *Dense into a [][][]complex64. +// Tensor3C64 converts a *Dense into a [][][]complex64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { if err = checkNativeIterable(t, 3, Complex64); err != nil { @@ -1050,7 +1050,7 @@ func MatrixC128(t *Dense) (retVal [][]complex128, err error) { return } -// Tensor3C128 converts a *Dense into a [][][]complex128. +// Tensor3C128 converts a *Dense into a [][][]complex128. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { if err = checkNativeIterable(t, 3, Complex128); err != nil { @@ -1120,7 +1120,7 @@ func MatrixStr(t *Dense) (retVal [][]string, err error) { return } -// Tensor3Str converts a *Dense into a [][][]string. +// Tensor3Str converts a *Dense into a [][][]string. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. func Tensor3Str(t *Dense) (retVal [][][]string, err error) { if err = checkNativeIterable(t, 3, String); err != nil { From e3dc770596f02c4e79e04f11c5dd0cb4fd6e75b3 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 27 Jul 2020 19:34:32 -0400 Subject: [PATCH 049/154] add fixedwidth and binary types to arrow tensor tests --- dense_compat_test.go | 18 ++++++++++++++++++ genlib2/dense_compat_tests.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/dense_compat_test.go b/dense_compat_test.go index 51c1e39..0d98680 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -186,6 +186,24 @@ func TestFromArrowArray(t *testing.T) { var m arrowArray.Interface switch taat.dt { + case arrow.BinaryTypes.String: + b := arrowArray.NewStringBuilder(pool) + defer b.Release() + b.AppendValues( + []string{"0", "1", "2", "3", "4", "5"}, + taat.valid, + ) + m = b.NewArray() + defer m.Release() + case arrow.FixedWidthTypes.Boolean: + b := arrowArray.NewBooleanBuilder(pool) + defer b.Release() + b.AppendValues( + []bool{true, false, true, false, true, false}, + taat.valid, + ) + m = b.NewArray() + defer m.Release() case arrow.PrimitiveTypes.Int8: b := arrowArray.NewInt8Builder(pool) defer b.Release() diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index 036d774..8a45e94 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -118,6 +118,36 @@ func TestFromArrowArray(t *testing.T){ var m arrowArray.Interface switch taat.dt { + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "String" -}} + []string{"0", "1", "2", "3", "4", "5"}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "Boolean" -}} + []bool{true, false, true, false, true, false}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} {{range .PrimitiveTypes -}} case arrow.PrimitiveTypes.{{ . }}: b := arrowArray.New{{ . }}Builder(pool) From 942e6a56649bb138709ddcb15101ea0f31792595 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 3 Aug 2020 00:31:31 -0400 Subject: [PATCH 050/154] add Arrow Tensor to Gorgonia Tensor conversion --- dense_compat.go | 119 +++++++++++++ dense_compat_test.go | 307 ++++++++++++++++++++++++++++++++++ genlib2/dense_compat.go | 49 +++++- genlib2/dense_compat_tests.go | 80 ++++++++- 4 files changed, 543 insertions(+), 12 deletions(-) diff --git a/dense_compat.go b/dense_compat.go index 554466d..126c1b4 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -10,6 +10,7 @@ import ( arrow "github.com/apache/arrow/go/arrow" arrowArray "github.com/apache/arrow/go/arrow/array" + arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/chewxy/math32" "github.com/pkg/errors" "gonum.org/v1/gonum/mat" @@ -511,3 +512,121 @@ func FromArrowArray(a arrowArray.Interface) *Dense { panic("Unreachable") } + +// FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType. +func FromArrowTensor(a arrowTensor.Interface) *Dense { + a.Retain() + defer a.Release() + + var shape []int + for _, val := range a.Shape() { + shape = append(shape, int(val)) + } + + switch a.DataType() { + case arrow.PrimitiveTypes.Int8: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Int8).Int8Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Int16: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Int16).Int16Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Int32: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Int32).Int32Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Int64: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Int64).Int64Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Uint8: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Uint8).Uint8Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Uint16: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Uint16).Uint16Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Uint32: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Uint32).Uint32Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Uint64: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Uint64).Uint64Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Float32: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Float32).Float32Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + case arrow.PrimitiveTypes.Float64: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.Float64).Float64Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} diff --git a/dense_compat_test.go b/dense_compat_test.go index 0d98680..38e395c 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -8,6 +8,7 @@ import ( arrow "github.com/apache/arrow/go/arrow" arrowArray "github.com/apache/arrow/go/arrow/array" "github.com/apache/arrow/go/arrow/memory" + arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/stretchr/testify/assert" "gonum.org/v1/gonum/mat" ) @@ -339,3 +340,309 @@ func TestFromArrowArray(t *testing.T) { assert.True(T.Shape().Eq(taat.shape)) } } + +var toArrowTensorTests = []struct { + rowMajorData interface{} + colMajorData interface{} + dt arrow.DataType + shape Shape +}{ + { + rowMajorData: []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Int8, + shape: Shape{2, 5}, + }, + { + rowMajorData: []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Int16, + shape: Shape{2, 5}, + }, + { + rowMajorData: []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Int32, + shape: Shape{2, 5}, + }, + { + rowMajorData: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Int64, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Uint8, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Uint16, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Uint32, + shape: Shape{2, 5}, + }, + { + rowMajorData: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Uint64, + shape: Shape{2, 5}, + }, + { + rowMajorData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []float32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Float32, + shape: Shape{2, 5}, + }, + { + rowMajorData: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []float64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.Float64, + shape: Shape{2, 5}, + }, +} + +func TestFromArrowTensor(t *testing.T) { + assert := assert.New(t) + var rowMajorT *Dense + var colMajorT *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowTensorTests { + var rowMajorArr arrowArray.Interface + var colMajorArr arrowArray.Interface + var rowMajor arrowTensor.Interface + var colMajor arrowTensor.Interface + + switch taat.dt { + case arrow.PrimitiveTypes.Int8: + b := arrowArray.NewInt8Builder(pool) + defer b.Release() + b.AppendValues( + []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt8(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt8(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int8SizeBytes), int64(arrow.Int8SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Int16: + b := arrowArray.NewInt16Builder(pool) + defer b.Release() + b.AppendValues( + []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt16(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt16(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int16SizeBytes), int64(arrow.Int16SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Int32: + b := arrowArray.NewInt32Builder(pool) + defer b.Release() + b.AppendValues( + []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt32(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt32(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int32SizeBytes), int64(arrow.Int32SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Int64: + b := arrowArray.NewInt64Builder(pool) + defer b.Release() + b.AppendValues( + []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewInt64(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewInt64(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Int64SizeBytes), int64(arrow.Int64SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint8: + b := arrowArray.NewUint8Builder(pool) + defer b.Release() + b.AppendValues( + []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint8(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint8(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint8SizeBytes), int64(arrow.Uint8SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint16: + b := arrowArray.NewUint16Builder(pool) + defer b.Release() + b.AppendValues( + []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint16(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint16(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint16SizeBytes), int64(arrow.Uint16SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint32: + b := arrowArray.NewUint32Builder(pool) + defer b.Release() + b.AppendValues( + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint32(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint32(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint32SizeBytes), int64(arrow.Uint32SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Uint64: + b := arrowArray.NewUint64Builder(pool) + defer b.Release() + b.AppendValues( + []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewUint64(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewUint64(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Uint64SizeBytes), int64(arrow.Uint64SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Float32: + b := arrowArray.NewFloat32Builder(pool) + defer b.Release() + b.AppendValues( + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewFloat32(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewFloat32(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Float32SizeBytes), int64(arrow.Float32SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + case arrow.PrimitiveTypes.Float64: + b := arrowArray.NewFloat64Builder(pool) + defer b.Release() + b.AppendValues( + []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.NewFloat64(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.NewFloat64(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.Float64SizeBytes), int64(arrow.Float64SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + rowMajorT = FromArrowTensor(rowMajor) + colMajorT = FromArrowTensor(colMajor) + + assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.True(colMajorT.Shape().Eq(taat.shape)) + + assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.True(rowMajorT.Shape().Eq(taat.shape)) + } +} diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 13da09d..acecb21 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -7,6 +7,7 @@ import ( const importsArrowRaw = `import ( arrowArray "github.com/apache/arrow/go/arrow/array" + arrowTensor "github.com/apache/arrow/go/arrow/tensor" arrow "github.com/apache/arrow/go/arrow" ) ` @@ -249,7 +250,7 @@ type ArrowData struct { PrimitiveTypes []string } -const compatArrowRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. +const compatArrowArrayRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. func FromArrowArray(a arrowArray.Interface) *Dense { a.Retain() defer a.Release() @@ -304,18 +305,51 @@ func FromArrowArray(a arrowArray.Interface) *Dense { } ` +const compatArrowTensorRaw = `// FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType. +func FromArrowTensor(a arrowTensor.Interface) *Dense { + a.Retain() + defer a.Release() + + var shape []int + for _, val := range a.Shape() { + shape = append(shape, int(val)) + } + + switch a.DataType() { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{.}}: + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + backing := a.(*arrowTensor.{{.}}).{{.}}Values() + if a.IsColMajor() { + return New(WithShape(shape...), AsFortran(backing)) + } + + return New(WithShape(shape...), WithBacking(backing)) + {{end -}} + default: + panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) + } + + panic("Unreachable") +} +` + var ( - importsArrow *template.Template - conversions *template.Template - compats *template.Template - compatsArrow *template.Template + importsArrow *template.Template + conversions *template.Template + compats *template.Template + compatsArrowArray *template.Template + compatsArrowTensor *template.Template ) func init() { importsArrow = template.Must(template.New("imports_arrow").Funcs(funcs).Parse(importsArrowRaw)) conversions = template.Must(template.New("conversions").Funcs(funcs).Parse(conversionsRaw)) compats = template.Must(template.New("compat").Funcs(funcs).Parse(compatRaw)) - compatsArrow = template.Must(template.New("compat_arrow").Funcs(funcs).Parse(compatArrowRaw)) + compatsArrowArray = template.Must(template.New("compat_arrow_array").Funcs(funcs).Parse(compatArrowArrayRaw)) + compatsArrowTensor = template.Must(template.New("compat_arrow_tensor").Funcs(funcs).Parse(compatArrowTensorRaw)) } func generateDenseCompat(f io.Writer, generic Kinds) { @@ -329,5 +363,6 @@ func generateDenseCompat(f io.Writer, generic Kinds) { FixedWidthTypes: arrowFixedWidthTypes, PrimitiveTypes: arrowPrimitiveTypes, } - compatsArrow.Execute(f, arrowData) + compatsArrowArray.Execute(f, arrowData) + compatsArrowTensor.Execute(f, arrowData) } diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index 8a45e94..dd4b442 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -94,7 +94,7 @@ func TestFromMat64(t *testing.T){ } ` -const compatArrowTestsRaw = `var toArrowArrayTests = []struct{ +const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ data interface{} valid []bool dt arrow.DataType @@ -181,14 +181,83 @@ func TestFromArrowArray(t *testing.T){ } ` +const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ + rowMajorData interface{} + colMajorData interface{} + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{2,5}, + }, + {{end -}} +} +func TestFromArrowTensor(t *testing.T){ + assert := assert.New(t) + var rowMajorT *Dense + var colMajorT *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowTensorTests { + var rowMajorArr arrowArray.Interface + var colMajorArr arrowArray.Interface + var rowMajor arrowTensor.Interface + var colMajor arrowTensor.Interface + + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + nil, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + rowMajorT = FromArrowTensor(rowMajor) + colMajorT = FromArrowTensor(colMajor) + + assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.True(colMajorT.Shape().Eq(taat.shape)) + + assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.True(rowMajorT.Shape().Eq(taat.shape)) + } +} +` + var ( - compatTests *template.Template - compatArrowTests *template.Template + compatTests *template.Template + compatArrowArrayTests *template.Template + compatArrowTensorTests *template.Template ) func init() { compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) - compatArrowTests = template.Must(template.New("testArrowCompat").Funcs(funcs).Parse(compatArrowTestsRaw)) + compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) + compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) } func generateDenseCompatTests(f io.Writer, generic Kinds) { @@ -201,5 +270,6 @@ func generateDenseCompatTests(f io.Writer, generic Kinds) { FixedWidthTypes: arrowFixedWidthTypes, PrimitiveTypes: arrowPrimitiveTypes, } - compatArrowTests.Execute(f, arrowData) + compatArrowArrayTests.Execute(f, arrowData) + compatArrowTensorTests.Execute(f, arrowData) } From e48784454c69dea86773b17b6c44f9d25030bea9 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 3 Aug 2020 09:09:41 -0400 Subject: [PATCH 051/154] take contiguous check out of every type in switch --- dense_compat.go | 34 ++++------------------------------ genlib2/dense_compat.go | 7 ++++--- 2 files changed, 8 insertions(+), 33 deletions(-) diff --git a/dense_compat.go b/dense_compat.go index 126c1b4..a8e394e 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -523,11 +523,12 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { shape = append(shape, int(val)) } + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + switch a.DataType() { case arrow.PrimitiveTypes.Int8: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Int8).Int8Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -535,9 +536,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Int16: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Int16).Int16Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -545,9 +543,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Int32: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Int32).Int32Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -555,9 +550,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Int64: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Int64).Int64Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -565,9 +557,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Uint8: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Uint8).Uint8Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -575,9 +564,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Uint16: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Uint16).Uint16Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -585,9 +571,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Uint32: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Uint32).Uint32Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -595,9 +578,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Uint64: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Uint64).Uint64Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -605,9 +585,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Float32: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Float32).Float32Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) @@ -615,9 +592,6 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { return New(WithShape(shape...), WithBacking(backing)) case arrow.PrimitiveTypes.Float64: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.Float64).Float64Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index acecb21..38593c3 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -315,12 +315,13 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { shape = append(shape, int(val)) } + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + switch a.DataType() { {{range .PrimitiveTypes -}} case arrow.PrimitiveTypes.{{.}}: - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") - } backing := a.(*arrowTensor.{{.}}).{{.}}Values() if a.IsColMajor() { return New(WithShape(shape...), AsFortran(backing)) From d23e193031f65224f8d7508f0c535b531dfdab4b Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Mon, 3 Aug 2020 18:11:00 -0400 Subject: [PATCH 052/154] add masks to the FromArrowTensor conversion --- dense_compat.go | 53 +++++++----- dense_compat_test.go | 158 ++++++++++++++++++++-------------- genlib2/dense_compat.go | 17 +++- genlib2/dense_compat_tests.go | 16 +++- 4 files changed, 152 insertions(+), 92 deletions(-) diff --git a/dense_compat.go b/dense_compat.go index a8e394e..2e4a1d5 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -10,6 +10,7 @@ import ( arrow "github.com/apache/arrow/go/arrow" arrowArray "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/bitutil" arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/chewxy/math32" "github.com/pkg/errors" @@ -518,86 +519,94 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { a.Retain() defer a.Release() + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + var shape []int for _, val := range a.Shape() { shape = append(shape, int(val)) } - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") + l := a.Len() + validMask := a.Data().Buffers()[0].Bytes() + dataOffset := a.Data().Offset() + mask := make([]bool, l) + for i := 0; i < l; i++ { + mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i) } switch a.DataType() { case arrow.PrimitiveTypes.Int8: backing := a.(*arrowTensor.Int8).Int8Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Int16: backing := a.(*arrowTensor.Int16).Int16Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Int32: backing := a.(*arrowTensor.Int32).Int32Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Int64: backing := a.(*arrowTensor.Int64).Int64Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint8: backing := a.(*arrowTensor.Uint8).Uint8Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint16: backing := a.(*arrowTensor.Uint16).Uint16Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint32: backing := a.(*arrowTensor.Uint32).Uint32Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint64: backing := a.(*arrowTensor.Uint64).Uint64Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Float32: backing := a.(*arrowTensor.Float32).Float32Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Float64: backing := a.(*arrowTensor.Float64).Float64Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) default: panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) } diff --git a/dense_compat_test.go b/dense_compat_test.go index 38e395c..f872329 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -342,70 +342,92 @@ func TestFromArrowArray(t *testing.T) { } var toArrowTensorTests = []struct { - rowMajorData interface{} - colMajorData interface{} - dt arrow.DataType - shape Shape + rowMajorData interface{} + colMajorData interface{} + rowMajorValid []bool + colMajorValid []bool + dt arrow.DataType + shape Shape }{ { - rowMajorData: []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []int8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Int8, - shape: Shape{2, 5}, + rowMajorData: []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int8, + shape: Shape{2, 5}, }, { - rowMajorData: []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []int16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Int16, - shape: Shape{2, 5}, + rowMajorData: []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int16, + shape: Shape{2, 5}, }, { - rowMajorData: []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []int32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Int32, - shape: Shape{2, 5}, + rowMajorData: []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int32, + shape: Shape{2, 5}, }, { - rowMajorData: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []int64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Int64, - shape: Shape{2, 5}, + rowMajorData: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []int64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Int64, + shape: Shape{2, 5}, }, { - rowMajorData: []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []uint8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Uint8, - shape: Shape{2, 5}, + rowMajorData: []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint8{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint8, + shape: Shape{2, 5}, }, { - rowMajorData: []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []uint16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Uint16, - shape: Shape{2, 5}, + rowMajorData: []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint16{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint16, + shape: Shape{2, 5}, }, { - rowMajorData: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []uint32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Uint32, - shape: Shape{2, 5}, + rowMajorData: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint32, + shape: Shape{2, 5}, }, { - rowMajorData: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []uint64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Uint64, - shape: Shape{2, 5}, + rowMajorData: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []uint64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Uint64, + shape: Shape{2, 5}, }, { - rowMajorData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []float32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Float32, - shape: Shape{2, 5}, + rowMajorData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []float32{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Float32, + shape: Shape{2, 5}, }, { - rowMajorData: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []float64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - dt: arrow.PrimitiveTypes.Float64, - shape: Shape{2, 5}, + rowMajorData: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []float64{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.Float64, + shape: Shape{2, 5}, }, } @@ -427,14 +449,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []int8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -448,14 +470,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []int16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -469,14 +491,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -490,14 +512,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -511,14 +533,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -532,14 +554,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []uint16{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -553,14 +575,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -574,14 +596,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -595,14 +617,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -616,14 +638,14 @@ func TestFromArrowTensor(t *testing.T) { defer b.Release() b.AppendValues( []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -640,9 +662,17 @@ func TestFromArrowTensor(t *testing.T) { colMajorT = FromArrowTensor(colMajor) assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: column major %v", i, taat.dt) + for i, invalid := range rowMajorT.Mask() { + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + } assert.True(colMajorT.Shape().Eq(taat.shape)) assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v", i, taat.dt) + for i, invalid := range colMajorT.Mask() { + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + } assert.True(rowMajorT.Shape().Eq(taat.shape)) } } diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 38593c3..45dfa23 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -7,6 +7,7 @@ import ( const importsArrowRaw = `import ( arrowArray "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/bitutil" arrowTensor "github.com/apache/arrow/go/arrow/tensor" arrow "github.com/apache/arrow/go/arrow" ) @@ -310,13 +311,21 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { a.Retain() defer a.Release() + if !a.IsContiguous() { + panic("Non-contiguous data is Unsupported") + } + var shape []int for _, val := range a.Shape() { shape = append(shape, int(val)) } - if !a.IsContiguous() { - panic("Non-contiguous data is Unsupported") + l := a.Len() + validMask := a.Data().Buffers()[0].Bytes() + dataOffset := a.Data().Offset() + mask := make([]bool, l) + for i := 0; i < l; i++ { + mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i) } switch a.DataType() { @@ -324,10 +333,10 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { case arrow.PrimitiveTypes.{{.}}: backing := a.(*arrowTensor.{{.}}).{{.}}Values() if a.IsColMajor() { - return New(WithShape(shape...), AsFortran(backing)) + return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) } - return New(WithShape(shape...), WithBacking(backing)) + return New(WithShape(shape...), WithBacking(backing, mask)) {{end -}} default: panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index dd4b442..384314f 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -184,6 +184,8 @@ func TestFromArrowArray(t *testing.T){ const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ rowMajorData interface{} colMajorData interface{} + rowMajorValid []bool + colMajorValid []bool dt arrow.DataType shape Shape }{ @@ -191,6 +193,8 @@ const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ { rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, dt: arrow.PrimitiveTypes.{{ . }}, shape: Shape{2,5}, }, @@ -215,14 +219,14 @@ func TestFromArrowTensor(t *testing.T){ defer b.Release() b.AppendValues( []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) rowMajorArr = b.NewArray() defer rowMajorArr.Release() b.AppendValues( []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - nil, + taat.rowMajorValid, ) colMajorArr = b.NewArray() defer colMajorArr.Release() @@ -240,9 +244,17 @@ func TestFromArrowTensor(t *testing.T){ colMajorT = FromArrowTensor(colMajor) assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: column major %v", i, taat.dt) + for i, invalid := range rowMajorT.Mask() { + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + } assert.True(colMajorT.Shape().Eq(taat.shape)) assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v", i, taat.dt) + for i, invalid := range colMajorT.Mask() { + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + } assert.True(rowMajorT.Shape().Eq(taat.shape)) } } From 0ac34d108ae04f164298fe15cb272fea06de4171 Mon Sep 17 00:00:00 2001 From: poopoothegorilla Date: Fri, 7 Aug 2020 11:51:30 -0400 Subject: [PATCH 053/154] add transpose function for masks and add mask to fortran function --- consopt.go | 10 ++++-- defaultengine_matop_transpose.go | 19 +++++++++++ defaultengine_matop_transpose_inplace.go | 43 ++++++++++++++++++++++++ dense_compat.go | 20 +++++------ dense_compat_test.go | 8 ++--- genlib2/dense_compat.go | 2 +- genlib2/dense_compat_tests.go | 8 ++--- 7 files changed, 89 insertions(+), 21 deletions(-) diff --git a/consopt.go b/consopt.go index 19d47ad..c0c57c7 100644 --- a/consopt.go +++ b/consopt.go @@ -54,7 +54,7 @@ func WithBacking(x interface{}, argMask ...[]bool) ConsOpt { // WithMask is a construction option for a Tensor // Use it as such: // mask := []bool{true,true,false,false} -// t := New(WithBacking(backing)) +// t := New(WithBacking(backing), WithMask(mask)) // It can be used with other construction options like WithShape // The supplied mask can be any type. If non-boolean, then tensor mask is set to true // wherever non-zero value is obtained @@ -191,7 +191,11 @@ func WithEngine(e Engine) ConsOpt { // AsFortran creates a *Dense with a col-major layout. // If the optional backing argument is passed, the backing is assumed to be C-order (row major), and // it will be transposed before being used. -func AsFortran(backing interface{}) ConsOpt { +func AsFortran(backing interface{}, argMask ...[]bool) ConsOpt { + var mask []bool + if len(argMask) > 0 { + mask = argMask[0] + } f := func(t Tensor) { switch tt := t.(type) { case *Dense: @@ -201,10 +205,12 @@ func AsFortran(backing interface{}) ConsOpt { // create a temporary tensor, to which the transpose will be done tmp := NewDense(tt.Dtype(), tt.shape.Clone()) copyArray(tmp.arrPtr(), tt.arrPtr()) + tmp.SetMask(mask) tmp.T() tmp.Transpose() // copy the data back to the current tensor copyArray(tt.arrPtr(), tmp.arrPtr()) + tt.SetMask(tmp.Mask()) // cleanup: return the temporary tensor back to the pool ReturnTensor(tmp) } diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index 8f7c86c..76a2f0a 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -24,6 +24,8 @@ func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { return } + e.transposeMask(a) + switch a.rtype().Size() { case 1: e.denseTranspose1(a, expStrides) @@ -38,6 +40,23 @@ func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { } } +func (e StdEng) transposeMask(a DenseTensor) { + if !a.(*Dense).IsMasked() { + return + } + + orig := a.(*Dense).Mask() + tmp := make([]bool, len(orig)) + + it := newFlatIterator(a.Info()) + var j int + for i, err := it.Next(); err == nil; i, err = it.Next() { + tmp[j] = orig[i] + j++ + } + copy(orig, tmp) +} + func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { var tmpArr array e.makeArray(&tmpArr, a.Dtype(), a.Size()) diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index d8a87e4..612e1cc 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -24,6 +24,8 @@ func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { return } + e.transposeMask(a) + switch a.rtype().Size() { case 1: e.denseTranspose1(a, expStrides) @@ -38,6 +40,47 @@ func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { } } +func (e StdEng) transposeMask(a DenseTensor) { + if !a.(*Dense).IsMasked() { + return + } + + shape := a.Shape() + if len(shape) != 2 { + // TODO(poopoothegorilla): currently only two dimensions are implemented + return + } + n, m := shape[0], shape[1] + mask := a.(*Dense).Mask() + size := len(mask) + + track := NewBitMap(size) + track.Set(0) + track.Set(size - 1) + + for i := 0; i < size; i++ { + srci := i + if track.IsSet(srci) { + continue + } + srcv := mask[srci] + for { + oc := srci % n + or := (srci - oc) / n + desti := oc*m + or + + if track.IsSet(desti) { + break + } + track.Set(desti) + destv := mask[desti] + mask[desti] = srcv + srci = desti + srcv = destv + } + } +} + func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { axes := a.transposeAxes() size := a.len() diff --git a/dense_compat.go b/dense_compat.go index 2e4a1d5..dcbefa2 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -540,70 +540,70 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { case arrow.PrimitiveTypes.Int8: backing := a.(*arrowTensor.Int8).Int8Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Int16: backing := a.(*arrowTensor.Int16).Int16Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Int32: backing := a.(*arrowTensor.Int32).Int32Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Int64: backing := a.(*arrowTensor.Int64).Int64Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint8: backing := a.(*arrowTensor.Uint8).Uint8Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint16: backing := a.(*arrowTensor.Uint16).Uint16Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint32: backing := a.(*arrowTensor.Uint32).Uint32Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Uint64: backing := a.(*arrowTensor.Uint64).Uint64Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Float32: backing := a.(*arrowTensor.Float32).Float32Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) case arrow.PrimitiveTypes.Float64: backing := a.(*arrowTensor.Float64).Float64Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) diff --git a/dense_compat_test.go b/dense_compat_test.go index f872329..c641203 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -662,16 +662,16 @@ func TestFromArrowTensor(t *testing.T) { colMajorT = FromArrowTensor(colMajor) assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) - assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) for i, invalid := range rowMajorT.Mask() { - assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) } assert.True(colMajorT.Shape().Eq(taat.shape)) assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) - assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) for i, invalid := range colMajorT.Mask() { - assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) } assert.True(rowMajorT.Shape().Eq(taat.shape)) } diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 45dfa23..e3b5b52 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -333,7 +333,7 @@ func FromArrowTensor(a arrowTensor.Interface) *Dense { case arrow.PrimitiveTypes.{{.}}: backing := a.(*arrowTensor.{{.}}).{{.}}Values() if a.IsColMajor() { - return New(WithShape(shape...), WithMask(mask), AsFortran(backing)) + return New(WithShape(shape...), AsFortran(backing, mask)) } return New(WithShape(shape...), WithBacking(backing, mask)) diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index 384314f..d21831a 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -244,16 +244,16 @@ func TestFromArrowTensor(t *testing.T){ colMajorT = FromArrowTensor(colMajor) assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) - assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) for i, invalid := range rowMajorT.Mask() { - assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) } assert.True(colMajorT.Shape().Eq(taat.shape)) assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) - assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) for i, invalid := range colMajorT.Mask() { - assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v", i, taat.dt) + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) } assert.True(rowMajorT.Shape().Eq(taat.shape)) } From c10583cad498af000060e92a60097cc1bca08701 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 8 Sep 2020 14:50:38 +1000 Subject: [PATCH 054/154] Add complex support to linalg functionalities (#77) * Added Complex support to linalg functionality * Generate more clean iterators * Pulled these back from master * Oops missed one file * Go mod tidy * Fixed go mod and apache arrow go mod --- blas.go | 4 +- defaultengine_linalg.go | 102 ++++++++++++++++++++++++++++++++++++++-- dense_linalg.go | 4 +- dense_mapreduce.go | 2 +- go.mod | 8 ++-- go.sum | 23 ++++----- interfaces.go | 2 + tensor.go | 21 +++++++++ 8 files changed, 139 insertions(+), 27 deletions(-) diff --git a/blas.go b/blas.go index c708400..9bc170b 100644 --- a/blas.go +++ b/blas.go @@ -15,8 +15,8 @@ var whichblas BLAS type BLAS interface { blas.Float32 blas.Float64 - // blas.Complex64 - // blas.Complex128 + blas.Complex64 + blas.Complex128 } // only blastoise.Implementation() and cubone.Implementation() are batchedBLAS - diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 5e0ecd3..d9a16aa 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -317,7 +317,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { if t, ok = a.(*Dense); !ok { return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a) } - if !isFloat(t.Dtype()) { + if err = typeclassCheck(a.Dtype(), floatTypes); err != nil { return nil, nil, nil, errors.Errorf("StdEng can only perform SVDs for float64 and float32 type. Got tensor of %v instead", t.Dtype()) } @@ -373,7 +373,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // It returns a scalar value, wrapped in an interface{}, which is not quite nice. func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { var ad, bd DenseTensor - if ad, bd, err = e.checkTwoFloatTensors(a, b); err != nil { + if ad, bd, err = e.checkTwoFloatComplexTensors(a, b); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Inner") } @@ -384,6 +384,12 @@ func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { case []float64: B := bd.Float64s() retVal = whichblas.Ddot(len(A), A, 1, B, 1) + case []complex64: + B := bd.Complex64s() + retVal = whichblas.Cdotu(len(A), A, 1, B, 1) + case []complex128: + B := bd.Complex128s() + retVal = whichblas.Zdotu(len(A), A, 1, B, 1) } return } @@ -395,7 +401,7 @@ func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { // check all are DenseTensors var ad, bd, pd DenseTensor - if ad, bd, pd, err = e.checkThreeFloatTensors(a, b, prealloc); err != nil { + if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { return errors.Wrapf(err, opFail, "StdEng.MatVecMul") } @@ -443,6 +449,16 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { y := pd.Float32s() alpha, beta := float32(1), float32(0) whichblas.Sgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) + case []complex64: + x := bd.Complex64s() + y := pd.Complex64s() + var alpha, beta complex64 = complex(1, 0), complex(0, 0) + whichblas.Cgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) + case []complex128: + x := bd.Complex128s() + y := pd.Complex128s() + var alpha, beta complex128 = complex(1, 0), complex(0, 0) + whichblas.Zgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) default: return errors.Errorf(typeNYI, "matVecMul", bd.Data()) } @@ -457,7 +473,7 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { // check all are DenseTensors var ad, bd, pd DenseTensor - if ad, bd, pd, err = e.checkThreeFloatTensors(a, b, prealloc); err != nil { + if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { return errors.Wrapf(err, opFail, "StdEng.MatMul") } @@ -537,6 +553,24 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { } else { whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) } + case []complex64: + B := bd.Complex64s() + C := pd.Complex64s() + var alpha, beta complex64 = complex(1, 0), complex(0, 0) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Cgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Cgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } + case []complex128: + B := bd.Complex128s() + C := pd.Complex128s() + var alpha, beta complex128 = complex(1, 0), complex(0, 0) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Zgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Zgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } default: return errors.Errorf(typeNYI, "matMul", ad.Data()) } @@ -547,7 +581,7 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { // check all are DenseTensors var ad, bd, pd DenseTensor - if ad, bd, pd, err = e.checkThreeFloatTensors(a, b, prealloc); err != nil { + if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { return errors.Wrapf(err, opFail, "StdEng.Outer") } @@ -599,6 +633,16 @@ func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { A := pd.Float32s() alpha := float32(1) whichblas.Sger(m, n, alpha, x, incX, y, incY, A, lda) + case []complex64: + y := bd.Complex64s() + A := pd.Complex64s() + var alpha complex64 = complex(1, 0) + whichblas.Cgeru(m, n, alpha, x, incX, y, incY, A, lda) + case []complex128: + y := bd.Complex128s() + A := pd.Complex128s() + var alpha complex128 = complex(1, 0) + whichblas.Zgeru(m, n, alpha, x, incX, y, incY, A, lda) default: return errors.Errorf(typeNYI, "outer", b.Data()) } @@ -654,3 +698,51 @@ func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTe } return } + +func (e StdEng) checkTwoFloatComplexTensors(a, b Tensor) (ad, bd DenseTensor, err error) { + if err = e.checkAccessible(a); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + } + if err = e.checkAccessible(b); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + } + + if a.Dtype() != b.Dtype() { + return nil, nil, errors.New("Expected a and b to have the same Dtype") + } + + if ad, err = getFloatComplexDenseTensor(a); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") + } + if bd, err = getFloatComplexDenseTensor(b); err != nil { + return nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") + } + return +} + +func (e StdEng) checkThreeFloatComplexTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { + if err = e.checkAccessible(a); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") + } + if err = e.checkAccessible(b); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") + } + if err = e.checkAccessible(ret); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") + } + + if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { + return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype") + } + + if ad, err = getFloatComplexDenseTensor(a); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") + } + if bd, err = getFloatComplexDenseTensor(b); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") + } + if retVal, err = getFloatComplexDenseTensor(ret); err != nil { + return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects retVal to be be a DenseTensor") + } + return +} diff --git a/dense_linalg.go b/dense_linalg.go index c5362c5..7478cae 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -17,8 +17,8 @@ func (t *Dense) Trace() (retVal interface{}, err error) { // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { // check that the data is a float - if !isFloat(t.t) { - return nil, errors.Errorf(unsupportedDtype, t.t, "Inner") + if err = typeclassCheck(t.t, floatcmplxTypes); err != nil { + return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") } // check both are vectors diff --git a/dense_mapreduce.go b/dense_mapreduce.go index 6072614..2677fe7 100644 --- a/dense_mapreduce.go +++ b/dense_mapreduce.go @@ -11,7 +11,7 @@ func (t *Dense) Apply(fn interface{}, opts ...FuncOpt) (retVal Tensor, err error if m, ok := e.(Mapper); ok { return m.Map(fn, t, opts...) } - return nil, errors.Errorf("Execution engine for %v not a mapper", t) + return nil, errors.Errorf("Execution engine %T for %v not a mapper", e, t) } // Reduce applies a reduction function and reduces the values along the given axis. diff --git a/go.mod b/go.mod index a2b2ad0..eba9e51 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module gorgonia.org/tensor go 1.13 require ( - github.com/apache/arrow/go/arrow v0.0.0-20200720215425-c09a82a388e7 + github.com/apache/arrow/go/arrow latest github.com/chewxy/hm v1.0.0 - github.com/chewxy/math32 v1.0.4 + github.com/chewxy/math32 v1.0.6 github.com/gogo/protobuf v1.3.0 github.com/golang/protobuf v1.3.2 github.com/google/flatbuffers v1.11.0 - github.com/pkg/errors v0.8.1 - github.com/stretchr/testify v1.4.0 + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.6.0 github.com/xtgo/set v1.0.0 // indirect gonum.org/v1/gonum v0.7.0 gorgonia.org/vecf32 v0.9.0 diff --git a/go.sum b/go.sum index 1d7bc0c..8a00eb2 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,11 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/apache/arrow v0.0.0-20200720164908-23b19f65e1eb h1:/guPTo4KRiOQnB4UX0Sn9kk5k7kCC00eSKsoykKc0tU= -github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb h1:vBEPOeLNZ2RUgG/e+G2tOIucgCojRKRPorB3STXC+xw= -github.com/apache/arrow/go/arrow v0.0.0-20200720164908-23b19f65e1eb/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= -github.com/apache/arrow/go/arrow v0.0.0-20200720215425-c09a82a388e7 h1:vStV77omxrTbs1UsDogumkL0+TU28S6Ebp0LmKCM7KE= -github.com/apache/arrow/go/arrow v0.0.0-20200720215425-c09a82a388e7/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= +github.com/apache/arrow/go/arrow v0.0.0-20200907212344-3e3e18b1450c h1:jgTtVgKyVsVk+Voxktyl5YJbNrRFq8hty0Rvou6p5ps= +github.com/apache/arrow/go/arrow v0.0.0-20200907212344-3e3e18b1450c/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= -github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo= -github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/chewxy/math32 v1.0.6 h1:JWZYUNl2rtgVVui6z8JBsDgkOG2DYmfSODyo95yKfx4= +github.com/chewxy/math32 v1.0.6/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= @@ -22,15 +19,15 @@ github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho= +github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -51,8 +48,8 @@ gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6d gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= diff --git a/interfaces.go b/interfaces.go index c4a11c2..c0fd7e3 100644 --- a/interfaces.go +++ b/interfaces.go @@ -146,4 +146,6 @@ type unsafeMem interface { GetF32(i int) float32 Float64s() []float64 Float32s() []float32 + Complex64s() []complex64 + Complex128s() []complex128 } diff --git a/tensor.go b/tensor.go index b06066a..ff7e347 100644 --- a/tensor.go +++ b/tensor.go @@ -143,6 +143,27 @@ func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) { return } +// getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float +func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { + if t == nil { + return + } + if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil { + err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype()) + return + } + + if retVal, err = getDenseTensor(t); err != nil { + err = errors.Wrapf(err, opFail, "getFloatDense") + return + } + if retVal == nil { + return + } + + return +} + func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) { var sliced Tensor if sliced, err = t.Slice(slices...); err != nil { From 11dd72007671cc220c0b208043de204292462003 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 9 Sep 2020 13:12:43 +1000 Subject: [PATCH 055/154] Gomod (#80) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index eba9e51..5f4bbe2 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module gorgonia.org/tensor go 1.13 require ( - github.com/apache/arrow/go/arrow latest + github.com/apache/arrow/go/arrow v0.0.0-20200909005831-30143fc493df github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.6 github.com/gogo/protobuf v1.3.0 diff --git a/go.sum b/go.sum index 8a00eb2..7bbd20d 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/apache/arrow/go/arrow v0.0.0-20200907212344-3e3e18b1450c h1:jgTtVgKyVsVk+Voxktyl5YJbNrRFq8hty0Rvou6p5ps= -github.com/apache/arrow/go/arrow v0.0.0-20200907212344-3e3e18b1450c/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= +github.com/apache/arrow/go/arrow v0.0.0-20200909005831-30143fc493df h1:iXnL0pMIR/RDUWl0kCbc0CQ3UyehlyV+t/DYCLJTbFc= +github.com/apache/arrow/go/arrow v0.0.0-20200909005831-30143fc493df/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= From 714c7e5b8424bbe4f42acf488c9a3d4caab4bb3d Mon Sep 17 00:00:00 2001 From: Strigi-Form <43497306+strigi-form@users.noreply.github.com> Date: Mon, 19 Oct 2020 03:09:31 +0200 Subject: [PATCH 056/154] fix: permute strides (#84) * fix: permute strides * Added known issues test * Added IsVectorLike to Shape (not used). Fixed the logic as well * Cleaned up iterator even more Co-authored-by: chewxy --- ap.go | 7 ++++- dense_matop_test.go | 2 +- iterator.go | 69 +++++++++++++++++--------------------------- known_issues_test.go | 41 ++++++++++++++++++++++++-- shape.go | 13 +++++++++ utils.go | 9 ++++++ 6 files changed, 94 insertions(+), 47 deletions(-) diff --git a/ap.go b/ap.go index 22c7d9e..aab0a50 100644 --- a/ap.go +++ b/ap.go @@ -110,6 +110,11 @@ func (ap *AP) Format(state fmt.State, c rune) { // row vector func (ap *AP) IsVector() bool { return ap.shape.IsVector() } +// IsVectorLike returns true if the shape is vector-like (i.e. the shape only has one dim that is a non-1). +func (ap *AP) IsVectorLike() bool { + return ap.shape.IsVectorLike() && allones(ap.strides) +} + // IsColVec returns true when the access pattern has the shape (x, 1) func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() } @@ -324,7 +329,7 @@ func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { if axes[0] == 0 { return } - copy(strides, currentStride) + strides[0], strides[1] = 1, 1 shape[0], shape[1] = currentShape[1], currentShape[0] default: copy(shape, currentShape) diff --git a/dense_matop_test.go b/dense_matop_test.go index 5e5486d..5bf9533 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -138,7 +138,7 @@ var transposeTests = []struct { Shape{1, 4}, []int{1, 1}, []int{4, 1}, []float64{0, 1, 2, 3}}, {"r.T()", Shape{1, 4}, nil, []float32{0, 1, 2, 3}, - Shape{4, 1}, []int{4, 1}, []int{1, 1}, []float32{0, 1, 2, 3}}, + Shape{4, 1}, []int{1, 1}, []int{1, 1}, []float32{0, 1, 2, 3}}, {"v.T()", Shape{4}, nil, []int{0, 1, 2, 3}, Shape{4}, []int{1}, []int{1}, []int{0, 1, 2, 3}}, diff --git a/iterator.go b/iterator.go index 0db1158..0fe0a5d 100644 --- a/iterator.go +++ b/iterator.go @@ -121,13 +121,13 @@ type FlatIterator struct { *AP //state - track []int - nextIndex int - lastIndex int - strides0 int - size int - done bool - reverse bool // if true, iterator starts at end of array and runs backwards + track []int + nextIndex int + lastIndex int + size int + done bool + veclikeDim int // the dimension of a vectorlike shape that is not a 1. + reverse bool // if true, iterator starts at end of array and runs backwards isScalar bool isVector bool @@ -137,23 +137,24 @@ type FlatIterator struct { // newFlatIterator creates a new FlatIterator. func newFlatIterator(ap *AP) *FlatIterator { - var strides0 int - - if len(ap.strides) == 1 { - strides0 = ap.strides[0] + var dim int + if ap.IsVectorLike() { + for d, i := range ap.shape { + if i != 1 { + dim = d + break + } + } } - // else if ap.o.isColMajor() { - // strides0 = ap.strides[len(ap.strides)-1] - // } return &FlatIterator{ - AP: ap, - track: make([]int, len(ap.shape)), - size: ap.shape.TotalSize(), - strides0: strides0, + AP: ap, + track: make([]int, len(ap.shape)), + size: ap.shape.TotalSize(), + veclikeDim: dim, isScalar: ap.IsScalar(), - isVector: len(ap.strides) == 1, + isVector: ap.IsVectorLike(), } } @@ -265,20 +266,11 @@ func (it *FlatIterator) NextInvalid() (int, int, error) { func (it *FlatIterator) singleNext() (int, error) { it.lastIndex = it.nextIndex - // it.lastIndex += it.strides[0] - it.nextIndex += it.strides0 + it.nextIndex++ var tracked int - switch { - case it.IsRowVec(): - it.track[1]++ - tracked = it.track[1] - case it.IsColVec(), it.IsVector(): - it.track[0]++ - tracked = it.track[0] - default: - panic("This ain't supposed to happen") - } + it.track[it.veclikeDim]++ + tracked = it.track[it.veclikeDim] if tracked >= it.size { it.done = true @@ -289,20 +281,11 @@ func (it *FlatIterator) singleNext() (int, error) { func (it *FlatIterator) singlePrevious() (int, error) { it.lastIndex = it.nextIndex - // it.lastIndex += it.strides[0] - it.nextIndex -= it.strides0 + it.nextIndex-- var tracked int - switch { - case it.IsRowVec(): - it.track[1]-- - tracked = it.track[1] - case it.IsColVec(), it.IsVector(): - it.track[0]-- - tracked = it.track[0] - default: - panic("This ain't supposed to happen") - } + it.track[it.veclikeDim]-- + tracked = it.track[it.veclikeDim] if tracked < 0 { it.done = true diff --git a/known_issues_test.go b/known_issues_test.go index 20d8717..9f0d714 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -46,19 +46,21 @@ func TestIssue72(t *testing.T) { we, willFailEq := willerr(a, numberTypes, unsignedTypes) _, ok := q.Engine().(Suber) we = we || !ok - //log.Printf("b-a(r) | b:%v, a %v, r %v", b, a, reuse) + //log.Printf("b-a(r) | b:%v, a %v, r %v", b, a.Shape(), reuse.Shape()) ret, err := Sub(b, a, WithReuse(reuse)) if err, retEarly := qcErrCheck(t, "SubSV", a, b, we, err); retEarly { if err != nil { + t.Logf("err %v", err) return false } return true } - //log.Printf("b-a(r) | b:%v, a %v, r %v, ret %v", b, a, reuse, ret) + //log.Printf("b-a(r) | b:%v, a %v, r %v, ret %v", b, a.Shape(), reuse.Shape(), ret.Shape()) ret, err = Sub(b, ret, UseUnsafe()) if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + t.Errorf("a %v ", a.Shape()) return false } if reuse != ret { @@ -71,4 +73,39 @@ func TestIssue72(t *testing.T) { if err := quick.Check(invReuseScalar, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + +} + +func TestIssue83(t *testing.T) { + backing := []float64{-1, 0, 1} + var TT Tensor + TT = New( + WithShape(1, 3), + WithBacking(backing)) + TT, _ = T(TT) + + it := IteratorFromDense(TT.(*Dense)) + for i, ierr := it.Next(); ierr == nil; i, ierr = it.Next() { + if ierr != nil { + t.Error(ierr) + } + if i >= len(backing) { + t.Errorf("Iterator should not return an `i` greater than %v", i) + } + } + + backing = []float64{1, 2, 3, 4, 5, 5, 4, 3, 2, 1} + TT = New(WithShape(10, 1, 1, 1), WithBacking(backing)) + it = IteratorFromDense(TT.(*Dense)) + + var vals []float64 + for i, ierr := it.Next(); ierr == nil; i, ierr = it.Next() { + if ierr != nil { + t.Error(ierr) + } + v := TT.Data().([]float64)[i] + vals = append(vals, v) + } + t.Logf("%v", vals) + } diff --git a/shape.go b/shape.go index cecb41d..c448d0b 100644 --- a/shape.go +++ b/shape.go @@ -184,6 +184,19 @@ func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) } // IsRowVec returns true when the access pattern has the shape (1, x) func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) } +// IsVectorLike returns true when the shape looks like a vector +// e.g. a number that is surrounded by 1s: +// (1, 1, ... 1, 10, 1, 1... 1) +func (s Shape) IsVectorLike() bool { + var nonOnes int + for _, i := range s { + if i != 1 { + nonOnes++ + } + } + return nonOnes == 1 || nonOnes == 0 // if there is only one non-one then it's a vector or a scalarlike. +} + // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices func (s Shape) IsMatrix() bool { return len(s) == 2 } diff --git a/utils.go b/utils.go index 8e62448..3936208 100644 --- a/utils.go +++ b/utils.go @@ -307,6 +307,15 @@ func memsetBools(a []bool, v bool) { } } +func allones(a []int) bool { + for i := range a { + if a[i] != 1 { + return false + } + } + return true +} + /* FOR ILLUSTRATIVE PURPOSES */ // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) From 54b092c12aeb7c1c461dcd62068b40e57477aa42 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 20 Oct 2020 22:38:00 +1100 Subject: [PATCH 057/154] Go1.15fix (#87) * Fixes #86 and #85 * Fixed the things that Go 1.15 complains about --- array.go | 36 ++------- defaultengine_matop_misc.go | 2 +- dense.go | 2 +- dense_maskcmp_methods_test.go | 113 ++++++++++++++------------- genlib2/dense_maskedmethods_tests.go | 30 ++++++- 5 files changed, 95 insertions(+), 88 deletions(-) diff --git a/array.go b/array.go index 40995ec..d6c07c6 100644 --- a/array.go +++ b/array.go @@ -9,26 +9,6 @@ import ( "gorgonia.org/tensor/internal/storage" ) -//go:notinheap -type rawdata []byte - -// array2 is a type that will not be allocated on the heap. This is useful for operational stuff - no unnecessary allocations required. - -//go:notinheap -type array2 struct { - storage.Header - t Dtype - v interface{} -} - -func (a array2) toarray() array { - return array{ - Header: a.Header, - t: a.t, - v: a.v, - } -} - // array is the underlying generic array. type array struct { storage.Header // the header - the Go representation (a slice) @@ -165,7 +145,7 @@ func (a *array) sliceInto(i, j int, res *array) { } // slice slices an array -func (a array) slice(start, end int) array2 { +func (a array) slice(start, end int) array { if end > a.L { panic("Index out of range") } @@ -189,7 +169,7 @@ func (a array) slice(start, end int) array2 { C: C, } - return array2{ + return array{ Header: hdr, t: a.t, v: nil, @@ -295,7 +275,7 @@ func (a *array) rtype() reflect.Type { return a.t.Type } // malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory func malloc(t Dtype, length int) unsafe.Pointer { size := int(calcMemSize(t, length)) - s := make(rawdata, size) + s := make([]byte, size) return unsafe.Pointer(&s[0]) } @@ -368,13 +348,11 @@ func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, if e := src.Engine(); e != nil { darr := dst.arr() sarr := src.arr() - d := darr.slice(dstart, dend) - s := sarr.slice(sstart, send) + da := darr.slice(dstart, dend) + sa := sarr.slice(sstart, send) switch e.(type) { case NonStdEngine: - da := d.toarray() - sa := s.toarray() if err := e.Memcpy(&da, &sa); err != nil { panic(err) } @@ -400,10 +378,10 @@ func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, // Typically this means `StdEng`. // // If so, we directly use storage.Copy instead of using the engine - storage.Copy(d.t.Type, &d.Header, &s.Header) + storage.Copy(da.t.Type, &da.Header, &sa.Header) } - return d.Len() + return da.Len() } return copyArraySliced(dst.arr(), dstart, dend, src.arr(), sstart, send) } diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 217a85e..0fa6b90 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -208,7 +208,7 @@ func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, for j := 0; j < size; j++ { var tmp int tmp = repeats[j] - var tSlice array2 + var tSlice array tSlice = sarr.slice(srcStart, src.len()) diff --git a/dense.go b/dense.go index e80b2f5..09847ef 100644 --- a/dense.go +++ b/dense.go @@ -605,7 +605,7 @@ func (t *Dense) SetMask(mask []bool) { } func (t *Dense) slice(start, end int) { - t.array = t.array.slice(start, end).toarray() + t.array = t.array.slice(start, end) } // RequiresIterator indicates if an iterator is required to read the data in *Dense in the correct fashion diff --git a/dense_maskcmp_methods_test.go b/dense_maskcmp_methods_test.go index 94e365c..d16a78d 100644 --- a/dense_maskcmp_methods_test.go +++ b/dense_maskcmp_methods_test.go @@ -3,6 +3,7 @@ package tensor import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -604,26 +605,26 @@ func TestDense_MaskedEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -1244,26 +1245,26 @@ func TestDense_MaskedNotEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -1985,26 +1986,26 @@ func TestDense_MaskedGreater_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -2625,26 +2626,26 @@ func TestDense_MaskedGreaterEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -3265,26 +3266,26 @@ func TestDense_MaskedLess_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -3905,26 +3906,26 @@ func TestDense_MaskedLessEqual_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -4545,26 +4546,26 @@ func TestDense_MaskedInside_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) @@ -5185,26 +5186,26 @@ func TestDense_MaskedOutside_Str(t *testing.T) { assert.False(T.IsMasked()) data := T.Strings() for i := range data { - data[i] = string(i) + data[i] = fmt.Sprint(i) } - T.MaskedEqual(string(0)) + T.MaskedEqual(fmt.Sprint(0)) assert.True(T.IsMasked()) - T.MaskedEqual(string(1)) + T.MaskedEqual(fmt.Sprint(1)) assert.True(T.mask[0] && T.mask[1]) - T.MaskedNotEqual(string(2)) + T.MaskedNotEqual(fmt.Sprint(2)) assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() - T.MaskedInside(string(1), string(22)) + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() - T.MaskedOutside(string(1), string(22)) + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { - T.MaskedEqual(string(i * 10)) + T.MaskedEqual(fmt.Sprint(i * 10)) } it := IteratorFromDense(T) diff --git a/genlib2/dense_maskedmethods_tests.go b/genlib2/dense_maskedmethods_tests.go index addefbd..9b53e1a 100644 --- a/genlib2/dense_maskedmethods_tests.go +++ b/genlib2/dense_maskedmethods_tests.go @@ -18,29 +18,57 @@ const testMaskCmpMethodRaw = `func TestDense_{{title .Name}}_{{short .Kind}}(t * assert.False(T.IsMasked()) data := T.{{sliceOf .Kind}} for i := range data { +{{if eq "string" (asType .Kind) -}} + data[i] = fmt.Sprint(i) +{{else -}} data[i] = {{asType .Kind}}(i) +{{end -}} } +{{if eq "string" (asType .Kind) -}} + T.MaskedEqual(fmt.Sprint(0)) +{{else -}} T.MaskedEqual({{asType .Kind}}(0)) +{{end -}} assert.True(T.IsMasked()) +{{if eq "string" (asType .Kind) -}} + T.MaskedEqual(fmt.Sprint(1)) +{{else -}} T.MaskedEqual({{asType .Kind}}(1)) +{{end -}} assert.True(T.mask[0] && T.mask[1]) +{{if eq "string" (asType .Kind) -}} + T.MaskedNotEqual(fmt.Sprint(2)) +{{else -}} T.MaskedNotEqual({{asType .Kind}}(2)) +{{end -}} assert.False(T.mask[2] && !(T.mask[0])) T.ResetMask() +{{if eq "string" (asType .Kind) -}} + T.MaskedInside(fmt.Sprint(1), fmt.Sprint(22)) +{{else -}} T.MaskedInside({{asType .Kind}}(1), {{asType .Kind}}(22)) +{{end -}} assert.True(!T.mask[0] && !T.mask[23] && T.mask[1] && T.mask[22]) T.ResetMask() +{{if eq "string" (asType .Kind) -}} + T.MaskedOutside(fmt.Sprint(1), fmt.Sprint(22)) +{{else -}} T.MaskedOutside({{asType .Kind}}(1), {{asType .Kind}}(22)) +{{end -}} assert.True(T.mask[0] && T.mask[23] && !T.mask[1] && !T.mask[22]) T.ResetMask() for i := 0; i < 5; i++ { +{{if eq "string" (asType .Kind) -}} + T.MaskedEqual(fmt.Sprint(i*10)) +{{else -}} T.MaskedEqual({{asType .Kind}}(i*10)) +{{end -}} } it := IteratorFromDense(T) - + j := 0 for _, err := it.Next(); err == nil; _, err = it.Next() { j++ From d52bdbe53f57050e582b681eba484ad6f21ebba1 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 21 Oct 2020 10:36:54 +1100 Subject: [PATCH 058/154] Fix#88 (#89) * Fixes #88 by introducing a scalarDtypeCheck() function. * Added generated tests for the wrong types --- api_arith_generated_test.go | 130 ++++++++++++++++++++++++++++++++++++ defaultengine_arith.go | 24 +++++++ defaultengine_cmp.go | 24 +++++++ defaultengine_prep.go | 19 ++++++ dense_arith_test.go | 130 ++++++++++++++++++++++++++++++++++++ genlib2/agg2_body.go | 4 ++ genlib2/agg3_body.go | 36 +++++++++- genlib2/arith_tests.go | 18 +++++ known_issues_test.go | 11 +++ 9 files changed, 394 insertions(+), 2 deletions(-) diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index 30345d6..1120fba 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -662,6 +662,32 @@ func TestAddScalar(t *testing.T) { t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Add (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Add (tensor as right, scalar as left) failed: %v", err) + } } func TestSubScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -716,6 +742,32 @@ func TestSubScalar(t *testing.T) { if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Sub (tensor as right, scalar as left) failed: %v", err) + } } func TestMulScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -770,6 +822,32 @@ func TestMulScalar(t *testing.T) { t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Mul (tensor as right, scalar as left) failed: %v", err) + } } func TestDivScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -799,6 +877,32 @@ func TestDivScalar(t *testing.T) { t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Div (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Div (tensor as right, scalar as left) failed: %v", err) + } } func TestPowScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -828,6 +932,32 @@ func TestPowScalar(t *testing.T) { t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Pow (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Pow (tensor as right, scalar as left) failed: %v", err) + } } func TestAddScalar_unsafe(t *testing.T) { iden1 := func(q *Dense) bool { diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 5779897..72b171d 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -416,6 +416,10 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Add failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Add failed") + } + var reuse DenseTensor var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { @@ -509,6 +513,10 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Sub failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Sub failed") + } + var reuse DenseTensor var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { @@ -602,6 +610,10 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Mul failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Mul failed") + } + var reuse DenseTensor var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { @@ -695,6 +707,10 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Div failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Div failed") + } + var reuse DenseTensor var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { @@ -788,6 +804,10 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Pow failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Pow failed") + } + var reuse DenseTensor var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { @@ -881,6 +901,10 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Mod failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Mod failed") + } + var reuse DenseTensor var safe, toReuse, incr bool if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index b3651d7..8c2b919 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -496,6 +496,10 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO return nil, errors.Wrapf(err, "Gt failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Gt failed") + } + var reuse DenseTensor var safe, same bool if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { @@ -608,6 +612,10 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Gte failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Gte failed") + } + var reuse DenseTensor var safe, same bool if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { @@ -720,6 +728,10 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO return nil, errors.Wrapf(err, "Lt failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Lt failed") + } + var reuse DenseTensor var safe, same bool if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { @@ -832,6 +844,10 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func return nil, errors.Wrapf(err, "Lte failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Lte failed") + } + var reuse DenseTensor var safe, same bool if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { @@ -940,6 +956,10 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO return nil, errors.Wrapf(err, "Eq failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Eq failed") + } + var reuse DenseTensor var safe, same bool if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { @@ -1048,6 +1068,10 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO return nil, errors.Wrapf(err, "Ne failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "Ne failed") + } + var reuse DenseTensor var safe, same bool if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { diff --git a/defaultengine_prep.go b/defaultengine_prep.go index fca9848..ea2b2f5 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -1,6 +1,8 @@ package tensor import ( + "reflect" + "github.com/pkg/errors" "gorgonia.org/tensor/internal/storage" // "log" @@ -102,6 +104,23 @@ func unaryCheck(a Tensor, tc *typeclass) error { return nil } +// scalarDtypeCheck checks that a scalar value has the same dtype as the dtype of a given tensor. +func scalarDtypeCheck(a Tensor, b interface{}) error { + var dt Dtype + switch bt := b.(type) { + case Dtyper: + dt = bt.Dtype() + default: + t := reflect.TypeOf(b) + dt = Dtype{t} + } + + if a.Dtype() != dt { + return errors.Errorf("Expected scalar to have the same Dtype as the tensor (%v). Got %T instead ", a.Dtype(), b) + } + return nil +} + // prepDataVV prepares the data given the input and reuse tensors. It also retruns several indicators // // useIter indicates that the iterator methods should be used. diff --git a/dense_arith_test.go b/dense_arith_test.go index 039db67..d414dd2 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -662,6 +662,32 @@ func TestDense_AddScalar(t *testing.T) { t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Add (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Add(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Add (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_SubScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -716,6 +742,32 @@ func TestDense_SubScalar(t *testing.T) { if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Sub(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Sub (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_MulScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -770,6 +822,32 @@ func TestDense_MulScalar(t *testing.T) { t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Mul(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Mul (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_DivScalar(t *testing.T) { inv1 := func(q *Dense) bool { @@ -799,6 +877,32 @@ func TestDense_DivScalar(t *testing.T) { t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Div (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Div(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Div (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_PowScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -828,6 +932,32 @@ func TestDense_PowScalar(t *testing.T) { t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) } + type Foo int + wt1 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(a, b) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for Pow (tensor as left, scalar as right) failed: %v", err) + } + + wt2 := func(a *Dense) bool { + b := Foo(0) + ret, err := Pow(b, a) + if err == nil { + return false + } + _ = ret + return true + } + if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for Pow (tensor as right, scalar as left) failed: %v", err) + } } func TestDense_AddScalar_unsafe(t *testing.T) { iden1 := func(q *Dense) bool { diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 29c73ee..cf87bb0 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -40,6 +40,10 @@ const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); return nil, errors.Wrapf(err, "{{.Name}} failed") } + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "{{.Name}} failed") + } + var reuse DenseTensor {{template "prep" . -}} diff --git a/genlib2/agg3_body.go b/genlib2/agg3_body.go index 0eb274c..e7b6592 100644 --- a/genlib2/agg3_body.go +++ b/genlib2/agg3_body.go @@ -72,7 +72,7 @@ const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { } return true } - + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } @@ -167,7 +167,7 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { return true } {{template "callInv" .}} - + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } @@ -246,12 +246,42 @@ if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks {{end -}} ` +const denseArithScalarWrongTypeTestRaw = `type Foo int +wt1 := func(a *Dense) bool{ + b := Foo(0) + {{template "call0" .}} + if err == nil { + return false + } + _ = ret + return true +} +if err := quick.Check(wt1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongType test for {{.Name}} (tensor as left, scalar as right) failed: %v", err) +} + +wt2 := func(a *Dense) bool{ + b := Foo(0) + {{template "call1" .}} + if err == nil { + return false + } + _ = ret + return true +} +if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("WrongTYpe test for {{.Name}} (tensor as right, scalar as left) failed: %v", err) +} +` + var ( denseArithBody *template.Template denseArithScalarBody *template.Template denseIdentityArithTest *template.Template denseIdentityArithScalarTest *template.Template + + denseArithScalarWrongTypeTest *template.Template ) func init() { @@ -260,4 +290,6 @@ func init() { denseIdentityArithTest = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithTestBodyRaw)) denseIdentityArithScalarTest = template.Must(template.New("dense scalar identity test").Funcs(funcs).Parse(denseIdentityArithScalarTestRaw)) + + denseArithScalarWrongTypeTest = template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw)) } diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index e53f859..369cc0f 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -64,6 +64,7 @@ func (fn *ArithTest) WriteBody(w io.Writer) { if fn.IsInv { fn.writeInv(w) } + fn.WriteScalarWrongType(w) } func (fn *ArithTest) canWrite() bool { @@ -142,6 +143,23 @@ func (fn *ArithTest) writeInv(w io.Writer) { t.Execute(w, fn) } +func (fn *ArithTest) WriteScalarWrongType(w io.Writer) { + if !fn.scalars { + return + } + if fn.FuncOpt != "" { + return + } + t := template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw)) + template.Must(t.New("call0").Parse(APICallVSRaw)) + template.Must(t.New("call1").Parse(APICallSVRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} func (fn *ArithTest) Write(w io.Writer) { sig := fn.Signature() diff --git a/known_issues_test.go b/known_issues_test.go index 9f0d714..19bea52 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -109,3 +109,14 @@ func TestIssue83(t *testing.T) { t.Logf("%v", vals) } + +func TestIssue88(t *testing.T) { + a := New(WithShape(4, 2), WithBacking([]float64{1, 1, 1, 1, 1, 1, 1, 1})) + b := New(WithShape(2, 4), WithBacking([]float64{0, 1, 0, 1, 0, 1, 0, 1})) + c, _ := a.MatMul(b) + _, err := Div(c, 2) + if err == nil { + t.Fatal("Expected an error") + } + +} From 58c18c3ff9f8793928967856f56366c4b4d1b010 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 26 Oct 2020 10:54:22 +1100 Subject: [PATCH 059/154] Fixed #90 --- flags.go | 3 +++ known_issues_test.go | 30 ++++++++++++++++++++++++++++++ utils.go | 33 ++++++++++++++++++--------------- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/flags.go b/flags.go index 547136e..22fed67 100644 --- a/flags.go +++ b/flags.go @@ -98,6 +98,8 @@ const ( // ManuallyManaged indicates that the memory is managed by something else. Any Tensor with // manually managed memory will not be returned to the pool. ManuallyManaged + // IsOverallocated indicates that the memory for a given tensor is overallocated (i.e. the size-in-use is smaller than the size allocated) + IsOverallocated ) func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { @@ -113,6 +115,7 @@ func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { func (f MemoryFlag) nativelyAccessible() bool { return !((f & NativelyInaccessible) != 0) } func (f MemoryFlag) manuallyManaged() bool { return (f & ManuallyManaged) != 0 } +func (f MemoryFlag) isOverallocated() bool { return (f & IsOverallocated) != 0 } // OpOpt are the options used to call ops type OpOpt struct { diff --git a/known_issues_test.go b/known_issues_test.go index 19bea52..766056c 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -118,5 +118,35 @@ func TestIssue88(t *testing.T) { if err == nil { t.Fatal("Expected an error") } +} + +var ltoiTestCases = []struct { + name string + shape Shape + strides []int + coordinates []int + correct int + willErr bool +}{ + {"\"scalar\" - scalarshape", Shape{}, nil, []int{0}, 0, false}, + {"\"scalar\" - scalarshape, non empty strides", Shape{}, []int{1}, []int{0}, 0, false}, + {"\"scalar\" - scalarlike", Shape{1, 1, 1}, []int{1, 1, 1}, []int{0, 0, 0}, 0, false}, + {"vector", Shape{10}, []int{1}, []int{1}, 1, false}, + {"rowvec", Shape{1, 10}, []int{10, 1}, []int{0, 1}, 1, false}, + {"colvec", Shape{10, 1}, []int{1, 1}, []int{1, 0}, 1, false}, + {"rowvec- funny strides", Shape{1, 10}, []int{1}, []int{0, 1}, 1, false}, + {"colvec - funny strides", Shape{10, 1}, []int{1}, []int{1, 0}, 1, false}, +} + +func TestIssue90(t *testing.T) { + for i, c := range ltoiTestCases { + at, err := Ltoi(c.shape, c.strides, c.coordinates...) + if !checkErr(t, c.willErr, err, c.name, i) { + continue + } + if at != c.correct { + t.Errorf("Expected Ltoi(%v, %v, %v) to be %v. Got %v instead", c.shape, c.strides, c.coordinates, c.correct, at) + } + } } diff --git a/utils.go b/utils.go index 3936208..42ef597 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) const AllAxes int = -1 @@ -93,6 +95,14 @@ func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) { // Ltoi is Location to Index. Provide a shape, a strides, and a list of integers as coordinates, and returns the index at which the element is. func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { + if shape.IsScalarEquiv() { + for _, v := range coords { + if v != 0 { + return -1, errors.Errorf("Scalar shape only allows 0 as an index") + } + } + return 0, nil + } for i, coord := range coords { if i >= len(shape) { err = errors.Errorf(dimMismatch, len(shape), i) @@ -107,23 +117,16 @@ func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { } var stride int - if shape.IsRowVec() { - if i == 0 && len(coords) == 2 { - continue - } - stride = strides[0] - } else if shape.IsColVec() { - if i == 1 && len(coords) == 2 { - continue - } + switch { + case shape.IsVector() && len(strides) == 1: stride = strides[0] - } else { - if i >= len(strides) { - err = errors.Errorf(dimMismatch, len(strides), i) - return - } + case i >= len(strides): + err = errors.Errorf(dimMismatch, len(strides), i) + return + default: stride = strides[i] } + at += stride * coord } return at, nil From cf15c86e1a4e4721881e7d0cd2018d3898927025 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Mon, 26 Oct 2020 11:06:03 +1100 Subject: [PATCH 060/154] Fixed #90 (#91) --- flags.go | 3 +++ known_issues_test.go | 30 ++++++++++++++++++++++++++++++ utils.go | 33 ++++++++++++++++++--------------- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/flags.go b/flags.go index 547136e..22fed67 100644 --- a/flags.go +++ b/flags.go @@ -98,6 +98,8 @@ const ( // ManuallyManaged indicates that the memory is managed by something else. Any Tensor with // manually managed memory will not be returned to the pool. ManuallyManaged + // IsOverallocated indicates that the memory for a given tensor is overallocated (i.e. the size-in-use is smaller than the size allocated) + IsOverallocated ) func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { @@ -113,6 +115,7 @@ func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { func (f MemoryFlag) nativelyAccessible() bool { return !((f & NativelyInaccessible) != 0) } func (f MemoryFlag) manuallyManaged() bool { return (f & ManuallyManaged) != 0 } +func (f MemoryFlag) isOverallocated() bool { return (f & IsOverallocated) != 0 } // OpOpt are the options used to call ops type OpOpt struct { diff --git a/known_issues_test.go b/known_issues_test.go index 19bea52..766056c 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -118,5 +118,35 @@ func TestIssue88(t *testing.T) { if err == nil { t.Fatal("Expected an error") } +} + +var ltoiTestCases = []struct { + name string + shape Shape + strides []int + coordinates []int + correct int + willErr bool +}{ + {"\"scalar\" - scalarshape", Shape{}, nil, []int{0}, 0, false}, + {"\"scalar\" - scalarshape, non empty strides", Shape{}, []int{1}, []int{0}, 0, false}, + {"\"scalar\" - scalarlike", Shape{1, 1, 1}, []int{1, 1, 1}, []int{0, 0, 0}, 0, false}, + {"vector", Shape{10}, []int{1}, []int{1}, 1, false}, + {"rowvec", Shape{1, 10}, []int{10, 1}, []int{0, 1}, 1, false}, + {"colvec", Shape{10, 1}, []int{1, 1}, []int{1, 0}, 1, false}, + {"rowvec- funny strides", Shape{1, 10}, []int{1}, []int{0, 1}, 1, false}, + {"colvec - funny strides", Shape{10, 1}, []int{1}, []int{1, 0}, 1, false}, +} + +func TestIssue90(t *testing.T) { + for i, c := range ltoiTestCases { + at, err := Ltoi(c.shape, c.strides, c.coordinates...) + if !checkErr(t, c.willErr, err, c.name, i) { + continue + } + if at != c.correct { + t.Errorf("Expected Ltoi(%v, %v, %v) to be %v. Got %v instead", c.shape, c.strides, c.coordinates, c.correct, at) + } + } } diff --git a/utils.go b/utils.go index 3936208..42ef597 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) const AllAxes int = -1 @@ -93,6 +95,14 @@ func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) { // Ltoi is Location to Index. Provide a shape, a strides, and a list of integers as coordinates, and returns the index at which the element is. func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { + if shape.IsScalarEquiv() { + for _, v := range coords { + if v != 0 { + return -1, errors.Errorf("Scalar shape only allows 0 as an index") + } + } + return 0, nil + } for i, coord := range coords { if i >= len(shape) { err = errors.Errorf(dimMismatch, len(shape), i) @@ -107,23 +117,16 @@ func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { } var stride int - if shape.IsRowVec() { - if i == 0 && len(coords) == 2 { - continue - } - stride = strides[0] - } else if shape.IsColVec() { - if i == 1 && len(coords) == 2 { - continue - } + switch { + case shape.IsVector() && len(strides) == 1: stride = strides[0] - } else { - if i >= len(strides) { - err = errors.Errorf(dimMismatch, len(strides), i) - return - } + case i >= len(strides): + err = errors.Errorf(dimMismatch, len(strides), i) + return + default: stride = strides[i] } + at += stride * coord } return at, nil From ff7c2168573ad7226668924a863bc0663d335878 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 28 Oct 2020 12:55:53 +1100 Subject: [PATCH 061/154] Fix#57 (#92) * Fixed #90 * Added unsqueeze, which is used in concat (instead of reshape). unsqueeze works like reshape but is privileged - reshape cannot be used on non-contiguous views. unsqueeze can * Updated tests as well. * Updated travis * updated go mod --- .travis.yml | 1 + defaultengine_matop_misc.go | 5 ++ dense.go | 16 +++++- dense_assign.go | 4 +- dense_matop.go | 4 +- dense_matop_test.go | 3 +- dense_test.go | 43 ++++++++++++++++ go.mod | 8 +-- go.sum | 98 ++++++++++++++++++++++++++++++++++--- known_issues_test.go | 1 - 10 files changed, 167 insertions(+), 16 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2e9a505..7933b1a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ branches: go: - 1.13.x - 1.14.x + - 1.15.x - tip env: diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 0fa6b90..bd9ee18 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -291,6 +291,7 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") } + // keep dims after slicing switch { case v.IsVector() && T.IsMatrix() && axis == 0: v.reshape(v.shape[0], 1) @@ -320,6 +321,10 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen } v.shape = newShape v.strides = newStrides + } else if T.Shape()[axis] == 1 { + if err := v.unsqueeze(axis); err != nil { + return nil, errors.Wrapf(err, "Unable to keep dims after slicing a shape %v on axis %d where the size is 1", T.Shape(), axis) + } } } diff --git a/dense.go b/dense.go index 09847ef..7604056 100644 --- a/dense.go +++ b/dense.go @@ -79,7 +79,6 @@ func (t *Dense) addMask(mask []bool) { } func (t *Dense) makeArray(size int) { - switch te := t.e.(type) { case NonStdEngine: t.flag = MakeMemoryFlag(t.flag, ManuallyManaged) @@ -99,7 +98,6 @@ func (t *Dense) makeArray(size int) { t.array.C = size t.array.fix() return - } // Info returns the access pattern which explains how the data in the underlying array is accessed. This is mostly used for debugging. @@ -158,6 +156,20 @@ func (t *Dense) reshape(dims ...int) error { return t.sanity() } +func (t *Dense) unsqueeze(axis int) error { + if axis > t.shape.Dims()+1 { + return errors.Errorf("Cannot unsqueeze on axis %d when the tensor has shape %v", axis, t.shape) + } + t.shape = append(t.shape, 1) + copy(t.shape[axis+1:], t.shape[axis:]) + t.shape[axis] = 1 + + t.strides = append(t.strides, 1) + copy(t.strides[axis+1:], t.strides[axis:]) + + return nil +} + // ScalarValue returns the scalar value of a *Tensor, // IF and ONLY IF it's a Tensor representation of a scalar value. // This is required because operations like a (vec · vec) would return a scalar value. diff --git a/dense_assign.go b/dense_assign.go index 8b2783e..bd8bceb 100644 --- a/dense_assign.go +++ b/dense_assign.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) func overlaps(a, b DenseTensor) bool { if a.cap() == 0 || b.cap() == 0 { diff --git a/dense_matop.go b/dense_matop.go index 43e1967..5ce693b 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides // Usually this is more than enough, as BLAS will handle the rest of the transpose diff --git a/dense_matop_test.go b/dense_matop_test.go index 5bf9533..755309f 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -741,7 +741,7 @@ var concatTests = []struct { {"3tensor; axis 0", Float64, nil, nil, Shape{2, 3, 2}, Shape{1, 3, 2}, 0, Shape{3, 3, 2}, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}}, {"3tensor; axis 2", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 3, 1}, 2, Shape{2, 3, 3}, []float64{0, 1, 0, 2, 3, 1, 4, 5, 2, 6, 7, 3, 8, 9, 4, 10, 11, 5}}, - // {"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{222}}, + {"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{0, 1, 2, 3, 4, 5, 0, 1, 6, 7, 8, 9, 10, 11, 2, 3}}, } func TestDense_Concat(t *testing.T) { @@ -772,6 +772,7 @@ func TestDense_Concat(t *testing.T) { t.Errorf("Test %v failed: %v", cts.name, err) continue } + assert.True(cts.correctShape.Eq(T2.Shape())) assert.Equal(cts.correctData, T2.Data()) } diff --git a/dense_test.go b/dense_test.go index d3f43a3..d7e81e6 100644 --- a/dense_test.go +++ b/dense_test.go @@ -95,3 +95,46 @@ func Test_recycledDense(t *testing.T) { assert.Equal(t, StdEng{}, T.e) assert.Equal(t, StdEng{}, T.oe) } + +func TestDense_unsqueeze(t *testing.T) { + assert := assert.New(t) + T := New(WithShape(3, 3, 2), WithBacking([]float64{ + 1, 2, 3, 4, 5, 6, + 60, 50, 40, 30, 20, 10, + 100, 200, 300, 400, 500, 600, + })) + + if err := T.unsqueeze(0); err != nil { + t.Fatal(err) + } + + assert.True(T.Shape().Eq(Shape{1, 3, 3, 2})) + assert.Equal([]int{6, 6, 2, 1}, T.Strides()) // if you do shapes.CalcStrides() it'd be {18,6,2,1} + + // reset + T.Reshape(3, 3, 2) + + if err := T.unsqueeze(1); err != nil { + t.Fatal(err) + } + assert.True(T.Shape().Eq(Shape{3, 1, 3, 2})) + assert.Equal([]int{6, 2, 2, 1}, T.Strides()) + + // reset + T.Reshape(3, 3, 2) + if err := T.unsqueeze(2); err != nil { + t.Fatal(err) + } + t.Logf("%v", T) + assert.True(T.Shape().Eq(Shape{3, 3, 1, 2})) + assert.Equal([]int{6, 2, 1, 1}, T.Strides()) + + // reset + T.Reshape(3, 3, 2) + if err := T.unsqueeze(3); err != nil { + t.Fatal(err) + } + t.Logf("%v", T) + assert.True(T.Shape().Eq(Shape{3, 3, 2, 1})) + assert.Equal([]int{6, 2, 1, 1}, T.Strides()) +} diff --git a/go.mod b/go.mod index 5f4bbe2..7c1db9f 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,16 @@ module gorgonia.org/tensor go 1.13 require ( - github.com/apache/arrow/go/arrow v0.0.0-20200909005831-30143fc493df + github.com/apache/arrow/go/arrow v0.0.0-20201027203332-c3091dd3f8ca github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.6 - github.com/gogo/protobuf v1.3.0 - github.com/golang/protobuf v1.3.2 + github.com/gogo/protobuf v1.3.1 + github.com/golang/protobuf v1.4.2 github.com/google/flatbuffers v1.11.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.6.0 github.com/xtgo/set v1.0.0 // indirect - gonum.org/v1/gonum v0.7.0 + gonum.org/v1/gonum v0.8.1 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) diff --git a/go.sum b/go.sum index 7bbd20d..d1d71b5 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,48 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/apache/arrow/go/arrow v0.0.0-20200909005831-30143fc493df h1:iXnL0pMIR/RDUWl0kCbc0CQ3UyehlyV+t/DYCLJTbFc= -github.com/apache/arrow/go/arrow v0.0.0-20200909005831-30143fc493df/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= +github.com/apache/arrow/go/arrow v0.0.0-20201027203332-c3091dd3f8ca h1:OYqlohQ0r1GB7SeG03ct5Xox668iVXgThaNyKLeC01E= +github.com/apache/arrow/go/arrow v0.0.0-20201027203332-c3091dd3f8ca/go.mod h1:c9sxoIT3YgLxH4UhLOCKaBlEojuMhVYpk4Ntv3opUTQ= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= github.com/chewxy/math32 v1.0.6 h1:JWZYUNl2rtgVVui6z8JBsDgkOG2DYmfSODyo95yKfx4= github.com/chewxy/math32 v1.0.6/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE= -github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -23,6 +50,7 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= @@ -30,22 +58,78 @@ github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgh github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2 h1:y102fOLFqhV41b+4GPiJoa0k/x+pJcEi2/HB1Y5T6fU= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA= +golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200909081042-eff7692f9009 h1:W0lCpv29Hv0UaM1LXb9QlBHLNP8UFfcKjblhVCWftOM= +golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw= -gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM= +gonum.org/v1/gonum v0.8.1 h1:wGtP3yGpc5mCLOLeTeBdjeui9oZSz5De0eOjMLC/QuQ= +gonum.org/v1/gonum v0.8.1/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f h1:Yv4xsIx7HZOoyUGSJ2ksDyWE2qIBXROsZKt2ny3hCGM= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= +google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v0.0.0-20200910201057-6591123024b3/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= @@ -54,4 +138,6 @@ gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/known_issues_test.go b/known_issues_test.go index 766056c..36d4125 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -139,7 +139,6 @@ var ltoiTestCases = []struct { } func TestIssue90(t *testing.T) { - for i, c := range ltoiTestCases { at, err := Ltoi(c.shape, c.strides, c.coordinates...) if !checkErr(t, c.willErr, err, c.name, i) { From 912c427f1a69d15a599860a3ee5574e10325c4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mira=C3=A7=20G=C3=BClg=C3=B6n=C3=BCl?= Date: Wed, 28 Oct 2020 08:07:45 +0300 Subject: [PATCH 062/154] fixed numpy shape issue for 1D vector (#93) * fixed numpy shape issue for 1D vector * Added the numpy header fix to genlib2 so that next time the dense_io.go file is generated with the correct fix Co-authored-by: Chewxy --- dense_io.go | 11 +++++++++-- dense_io_test.go | 40 +++++++++++++++++++++++++++++++++++++++- genlib2/dense_io.go | 11 +++++++++-- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/dense_io.go b/dense_io.go index c55c66a..e4717f8 100644 --- a/dense_io.go +++ b/dense_io.go @@ -167,8 +167,15 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { return } - header := "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" - header = fmt.Sprintf(header, npdt, t.Shape()) + var header string + if t.Dims() == 1 { + // when t is a 1D vector, numpy expects "(N,)" instead of "(N)" which t.Shape() returns. + header = "{'descr': '<%v', 'fortran_order': False, 'shape': (%d,)}" + header = fmt.Sprintf(header, npdt, t.Shape()[0]) + } else { + header = "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" + header = fmt.Sprintf(header, npdt, t.Shape()) + } padding := 16 - ((10 + len(header)) % 16) if padding > 0 { header = header + strings.Repeat(" ", padding) diff --git a/dense_io_test.go b/dense_io_test.go index cdbe610..99afa43 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -18,12 +18,19 @@ func TestSaveLoadNumpy(t *testing.T) { assert := assert.New(t) T := New(WithShape(2, 2), WithBacking([]float64{1, 5, 10, -1})) + // also checks the 1D Vector. + T1D := New(WithShape(4), WithBacking([]float64{1, 5, 10, -1})) + f, _ := os.OpenFile("test.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) + f1D, _ := os.OpenFile("test1D.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) + T.WriteNpy(f) f.Close() - script := "import numpy as np\nx = np.load('test.npy')\nprint(x)" + T1D.WriteNpy(f1D) + f1D.Close() + script := "import numpy as np\nx = np.load('test.npy')\nprint(x)\nx = np.load('test1D.npy')\nprint(x)" // Configurable python command, in order to be able to use python or python3 pythonCommand := os.Getenv("PYTHON_COMMAND") if pythonCommand == "" { @@ -67,6 +74,11 @@ func TestSaveLoadNumpy(t *testing.T) { t.Error(err) } + err = os.Remove("test1D.npy") + if err != nil { + t.Error(err) + } + // ok now to test if it can read T2 := new(Dense) buf = new(bytes.Buffer) @@ -78,6 +90,17 @@ func TestSaveLoadNumpy(t *testing.T) { assert.Equal(T.Strides(), T2.Strides()) assert.Equal(T.Data(), T2.Data()) + // ok now to test if it can read 1D + T1D2 := new(Dense) + buf = new(bytes.Buffer) + T1D.WriteNpy(buf) + if err = T1D2.ReadNpy(buf); err != nil { + t.Fatal(err) + } + assert.Equal(T1D.Shape(), T1D2.Shape()) + assert.Equal(T1D.Strides(), T1D2.Strides()) + assert.Equal(T1D.Data(), T1D2.Data()) + // try with masked array. masked elements should be filled with default value T.ResetMask(false) T.mask[0] = true @@ -92,6 +115,21 @@ func TestSaveLoadNumpy(t *testing.T) { data := T.Float64s() data[0] = T.FillValue().(float64) assert.Equal(data, T3.Data()) + + // try with 1D masked array. masked elements should be filled with default value + T1D.ResetMask(false) + T1D.mask[0] = true + T1D3 := new(Dense) + buf = new(bytes.Buffer) + T1D.WriteNpy(buf) + if err = T1D3.ReadNpy(buf); err != nil { + t.Fatal(err) + } + assert.Equal(T1D.Shape(), T1D3.Shape()) + assert.Equal(T1D.Strides(), T1D3.Strides()) + data = T1D.Float64s() + data[0] = T1D.FillValue().(float64) + assert.Equal(data, T1D3.Data()) } func TestSaveLoadCSV(t *testing.T) { diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 73754df..e6e4b0f 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -67,8 +67,15 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { return } - header := "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" - header = fmt.Sprintf(header, npdt, t.Shape()) + var header string + if t.Dims() == 1 { + // when t is a 1D vector, numpy expects "(N,)" instead of "(N)" which t.Shape() returns. + header = "{'descr': '<%v', 'fortran_order': False, 'shape': (%d,)}" + header = fmt.Sprintf(header, npdt, t.Shape()[0]) + } else { + header = "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" + header = fmt.Sprintf(header, npdt, t.Shape()) + } padding := 16 - ((10 + len(header)) % 16) if padding > 0 { header = header + strings.Repeat(" ", padding) From ac5bb6d1626ddf9367a9610f6985dbd4ff742616 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sat, 31 Oct 2020 15:56:13 +1100 Subject: [PATCH 063/154] Reshapefix (#94) * Fixed #90 * Added a check for valid reshapes --- dense.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dense.go b/dense.go index 7604056..fa9693d 100644 --- a/dense.go +++ b/dense.go @@ -140,6 +140,10 @@ func (t *Dense) Engine() Engine { return t.e } // Reshape reshapes a *Dense. If the tensors need to be materialized (either it's a view or transpose), it will be materialized before the reshape happens func (t *Dense) Reshape(dims ...int) error { + if t.Shape().TotalSize() != Shape(dims).TotalSize() { + return errors.Errorf("Cannot reshape %v into %v", t.Shape(), dims) + } + if t.viewOf != 0 && t.o.IsNotContiguous() { return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") } @@ -335,6 +339,7 @@ func (t *Dense) sanity() error { if t.viewOf == 0 && size != expected && !t.IsScalar() { return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed") } + // TODO: sanity check for views return nil } From eff3be2aeec0f7b92d45e908c7d9272616f6878b Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 24 Nov 2020 22:21:24 +1100 Subject: [PATCH 064/154] Fmt h (#95) * Fixed #90 * Added a Headers only formatting option --- dense_format.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dense_format.go b/dense_format.go index 3f4f5e7..d20133d 100644 --- a/dense_format.go +++ b/dense_format.go @@ -153,6 +153,10 @@ func (f *fmtState) populate(t *Dense) { } func (f *fmtState) acceptableRune(d *Dense) { + if f.c == 'H' { + f.meta = true + return // accept H as header only + } switch d.t.Kind() { case reflect.Float64: switch f.c { @@ -277,6 +281,9 @@ func (t *Dense) Format(s fmt.State, c rune) { } fmt.Fprintf(f, " %v %v\n", t.Shape(), t.Strides()) } + if f.c == 'H' { + return + } if !t.IsNativelyAccessible() { fmt.Fprintf(f, "Inaccesible data") From 19dd23830a48ef20b14220a0926ee7f4bc175214 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 2 Dec 2020 09:26:53 +1100 Subject: [PATCH 065/154] Indexselect (#96) Added the `IndexSelect` function which allows users to select values of a tensor based on an index tensor (i.e. a tensor of ints) --- api_matop.go | 16 ++ defaultengine_selbyidx.go | 237 ++++++++++++++++++++++++ dense_selbyidx_test.go | 126 +++++++++++++ engine.go | 8 + internal/execution/eng_arith_manual.go | 247 +++++++++++++++++++++++++ 5 files changed, 634 insertions(+) create mode 100644 defaultengine_selbyidx.go create mode 100644 dense_selbyidx_test.go create mode 100644 internal/execution/eng_arith_manual.go diff --git a/api_matop.go b/api_matop.go index 2bc616c..fe3cad9 100644 --- a/api_matop.go +++ b/api_matop.go @@ -123,3 +123,19 @@ func Diag(t Tensor) (retVal Tensor, err error) { } return nil, errors.Errorf("Unable to perform diagonalization of tensor ") } + +// ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor. +// The `indices` tensor has to be a vector-like tensor of ints. +func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sbi, ok := a.Engine().(ByIndiceser); ok { + return sbi.SelectByIndices(a, indices, axis, opts...) + } + return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) +} + +func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sbi, ok := a.Engine().(ByIndiceser); ok { + return sbi.SelectByIndicesB(a, b, indices, axis, opts...) + } + return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) +} diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go new file mode 100644 index 0000000..4200b5b --- /dev/null +++ b/defaultengine_selbyidx.go @@ -0,0 +1,237 @@ +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" + + "reflect" +) + +func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !b.Shape().IsVectorLike() { + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape()) + } + if b.Dtype() != Int { + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype()) + } + + // if b is a scalar, then use Slice + if a.Shape().IsScalarEquiv() { + slices := make([]Slice, a.Shape().Dims()) + slices[axis] = ss(b.Data().([]int)[0]) + return a.Slice(slices...) + } + + expectedShape := a.Shape().Clone() + expectedShape[axis] = b.Shape().TotalSize() + + var reuse DenseTensor + var safe, toReuse, _ bool + if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(a.Dtype())) + } + + if !safe { + if a.Shape()[axis] != b.Shape().TotalSize() { + expected := a.Shape().Clone() + expected[axis] = b.Shape().TotalSize() + return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape()) + } + + reuse = a.(DenseTensor) + } + + typ := a.Dtype().Type + var dataA, dataB, dataReuse *storage.Header + var ait, bit, iit Iterator + var useIter bool + if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.Add") + } + + if useIter { + e.iterSelectByIdx(axis, dataA, dataB, dataReuse, ait, bit, iit) + //TODO + return + } + + e.selectByIdx(axis, dataB.Ints(), typ, dataA, dataReuse, a.(*Dense).AP, reuse.(*Dense).AP) + return reuse, nil +} + +func (e StdEng) iterSelectByIdx(axis int, dataA, dataB, dataReuse *storage.Header, ait, bit, iit Iterator) { + panic("iterSelectByIdx is not yet implemented") +} + +func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, dataRetVal *storage.Header, apA, apRet AP) { + isInnermost := axis == apA.shape.Dims()-1 + + outer := ProdInts(apA.shape[:axis]) + + axStride := apA.strides[axis] + retStride := apRet.strides[axis] + var outerRetStride int + if outer == 0 { + // then it's the outermost + outer = 1 + outerRetStride = apRet.strides[axis] * 2 + } else { + outerRetStride = apRet.strides[axis-1] + } + + srcCoord := make([]int, apA.shape.Dims()) + dstCoord := make([]int, apRet.shape.Dims()) + + if isInnermost { + prevStride := apA.strides[axis-1] + retPrevStride := apRet.strides[axis-1] + for i, idx := range indices { + srcCoord[axis] = idx + dstCoord[axis] = i + start, _ := Ltoi(apA.shape, apA.strides, srcCoord...) + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + for o := 0; o < outer; o++ { + end := start + axStride + dstEnd := dstStart + retStride + + storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) + + start += prevStride + dstStart += retPrevStride + + } + } + return + } + + for i, idx := range indices { + srcCoord[axis] = idx + dstCoord[axis] = i + start, _ := Ltoi(apA.shape, apA.strides, srcCoord...) + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + + for o := 0; o < outer; o++ { + end := start + axStride + dstEnd := dstStart + retStride + + storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) + + start = end + axStride + dstStart = dstEnd + (outerRetStride - retStride) + } + } +} + +// SelectByIndicesB is the backwards function of SelectByIndices. +func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !indices.Shape().IsVectorLike() { + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape()) + } + if indices.Dtype() != Int { + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype()) + } + + // if b is a scalar, then use Slice + if a.Shape().IsScalarEquiv() { + slices := make([]Slice, a.Shape().Dims()) + slices[axis] = ss(b.Data().([]int)[0]) + return a.Slice(slices...) + } + + expectedShape := a.Shape().Clone() + + var reuse DenseTensor + var _, toReuse, _ bool + if reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if !toReuse && reuse == nil { + // create reuse + reuse = New(WithShape(expectedShape...), Of(a.Dtype())) + } + + typ := a.Dtype().Type + var _, dataB, dataReuse *storage.Header + var _, bit, iit Iterator + var useIter bool + if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB") + } + + if useIter { + e.iterSelectByIndicesB(axis, dataB, dataReuse, bit, iit) + //TODO + return + } + + e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, b.(*Dense).AP, reuse.(*Dense).AP) + + return reuse, nil +} + +func (e StdEng) iterSelectByIndicesB(axis int, dataB, dataGradA *storage.Header, bit, iit Iterator) { + panic("iterSelectByIndicesB not implemented yet") +} + +func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, dataB, dataGradA *storage.Header, apB, apRet AP) { + isInnermost := axis == apB.shape.Dims()-1 + + outer := ProdInts(apB.shape[:axis]) + + axStride := apB.strides[axis] + retStride := apRet.strides[axis] + var outerRetStride int + if outer == 0 { + // then it's the outermost + outer = 1 + outerRetStride = apRet.strides[axis] * 2 + } else { + outerRetStride = apRet.strides[axis-1] + } + + dstCoord := make([]int, apB.shape.Dims()) + srcCoord := make([]int, apRet.shape.Dims()) + + if isInnermost { + retPrevStride := apB.strides[axis-1] + prevStride := apRet.strides[axis-1] + for i, idx := range indices { + dstCoord[axis] = idx + srcCoord[axis] = i + dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...) + start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...) + for o := 0; o < outer; o++ { + dstEnd := dstStart + axStride + end := start + retStride + + e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end) + + dstStart += prevStride + start += retPrevStride + + } + } + return + } + + for i, idx := range indices { + dstCoord[axis] = idx + srcCoord[axis] = i + dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...) + start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...) + + for o := 0; o < outer; o++ { + dstEnd := dstStart + axStride + end := start + retStride + + e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end) + + dstStart = dstEnd + axStride + start = end + (outerRetStride - retStride) + } + } +} diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go new file mode 100644 index 0000000..ca6b34f --- /dev/null +++ b/dense_selbyidx_test.go @@ -0,0 +1,126 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDense_SelectByIndices(t *testing.T) { + assert := assert.New(t) + + a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4)) + indices := New(WithBacking([]int{1, 1})) + + e := StdEng{} + + a1, err := e.SelectByIndices(a, indices, 1) + if err != nil { + t.Errorf("%v", err) + } + correct1 := []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23} + assert.Equal(correct1, a1.Data()) + + a0, err := e.SelectByIndices(a, indices, 0) + if err != nil { + t.Errorf("%v", err) + } + correct0 := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} + assert.Equal(correct0, a0.Data()) + + a2, err := e.SelectByIndices(a, indices, 2) + if err != nil { + t.Errorf("%v", err) + } + correct2 := []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21} + assert.Equal(correct2, a2.Data()) + + // !safe + aUnsafe := a.Clone().(*Dense) + indices = New(WithBacking([]int{1, 1, 1})) + aUnsafeSelect, err := e.SelectByIndices(aUnsafe, indices, 0, UseUnsafe()) + if err != nil { + t.Errorf("%v", err) + } + correctUnsafe := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} + assert.Equal(correctUnsafe, aUnsafeSelect.Data()) + + // 3 indices, just to make sure the sanity of the algorithm + indices = New(WithBacking([]int{1, 1, 1})) + a1, err = e.SelectByIndices(a, indices, 1) + if err != nil { + t.Errorf("%v", err) + } + correct1 = []float64{ + 4, 5, 6, 7, + 4, 5, 6, 7, + 4, 5, 6, 7, + + 12, 13, 14, 15, + 12, 13, 14, 15, + 12, 13, 14, 15, + + 20, 21, 22, 23, + 20, 21, 22, 23, + 20, 21, 22, 23, + } + assert.Equal(correct1, a1.Data()) + + a0, err = e.SelectByIndices(a, indices, 0) + if err != nil { + t.Errorf("%v", err) + } + correct0 = []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} + assert.Equal(correct0, a0.Data()) + + a2, err = e.SelectByIndices(a, indices, 2) + if err != nil { + t.Errorf("%v", err) + } + correct2 = []float64{1, 1, 1, 5, 5, 5, 9, 9, 9, 13, 13, 13, 17, 17, 17, 21, 21, 21} + assert.Equal(correct2, a2.Data()) +} + +func TestDense_SelectByIndicesB(t *testing.T) { + + a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4)) + indices := New(WithBacking([]int{1, 1})) + + t.Logf("a\n%v", a) + + e := StdEng{} + + a1, err := e.SelectByIndices(a, indices, 1) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("a1\n%v", a1) + + a1Grad, err := e.SelectByIndicesB(a, a1, indices, 1) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("a1Grad \n%v", a1Grad) + + a0, err := e.SelectByIndices(a, indices, 0) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("a0\n%v", a0) + a0Grad, err := e.SelectByIndicesB(a, a0, indices, 0) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("a0Grad\n%v", a0Grad) + + a2, err := e.SelectByIndices(a, indices, 2) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("\n%v", a2) + a2Grad, err := e.SelectByIndicesB(a, a2, indices, 2) + if err != nil { + t.Errorf("%v", err) + } + t.Logf("a2Grad\n%v", a2Grad) +} diff --git a/engine.go b/engine.go index ae508a8..a8ec63c 100644 --- a/engine.go +++ b/engine.go @@ -394,6 +394,14 @@ type InfChecker interface { HasInf(t Tensor) (bool, error) } +/* Advanced Indexing */ + +// ByIndiceser allows for values in tensor `a` to be selected by the indices listed in the `indices` tensor. +type ByIndiceser interface { + SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) +} + /* Internal interfaces for faster shit */ type denseArgmaxer interface { diff --git a/internal/execution/eng_arith_manual.go b/internal/execution/eng_arith_manual.go new file mode 100644 index 0000000..941d644 --- /dev/null +++ b/internal/execution/eng_arith_manual.go @@ -0,0 +1,247 @@ +package execution + +import ( + "reflect" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e E) AddSliced(t reflect.Type, dataA *storage.Header, dstStart, dstEnd int, dataB *storage.Header, srcStart, srcEnd int) (err error) { + a := &storage.Header{ + Ptr: storage.ElementAt(dstStart, dataA.Ptr, t.Size()), + L: dstEnd - dstStart, + C: dataA.C - dstStart, + } + if a.C == 0 { + a.C = 1 + } + + b := &storage.Header{ + Ptr: storage.ElementAt(srcStart, dataB.Ptr, t.Size()), + L: srcEnd - srcStart, + C: dataB.C - srcStart, + } + if b.C == 0 { + b.C = 1 + } + + as := isScalar(a) + bs := isScalar(b) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + + switch { + case as && bs: + VecAddI(at, bt) + case as && !bs: + AddSVI(at[0], bt) + case !as && bs: + AddVSI(at, bt[0]) + default: + VecAddI(at, bt) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecAddI8(at, bt) + case as && !bs: + AddSVI8(at[0], bt) + case !as && bs: + AddVSI8(at, bt[0]) + default: + VecAddI8(at, bt) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecAddI16(at, bt) + case as && !bs: + AddSVI16(at[0], bt) + case !as && bs: + AddVSI16(at, bt[0]) + default: + VecAddI16(at, bt) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecAddI32(at, bt) + case as && !bs: + AddSVI32(at[0], bt) + case !as && bs: + AddVSI32(at, bt[0]) + default: + VecAddI32(at, bt) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecAddI64(at, bt) + case as && !bs: + AddSVI64(at[0], bt) + case !as && bs: + AddVSI64(at, bt[0]) + default: + VecAddI64(at, bt) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecAddU(at, bt) + case as && !bs: + AddSVU(at[0], bt) + case !as && bs: + AddVSU(at, bt[0]) + default: + VecAddU(at, bt) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecAddU8(at, bt) + case as && !bs: + AddSVU8(at[0], bt) + case !as && bs: + AddVSU8(at, bt[0]) + default: + VecAddU8(at, bt) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecAddU16(at, bt) + case as && !bs: + AddSVU16(at[0], bt) + case !as && bs: + AddVSU16(at, bt[0]) + default: + VecAddU16(at, bt) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecAddU32(at, bt) + case as && !bs: + AddSVU32(at[0], bt) + case !as && bs: + AddVSU32(at, bt[0]) + default: + VecAddU32(at, bt) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecAddU64(at, bt) + case as && !bs: + AddSVU64(at[0], bt) + case !as && bs: + AddVSU64(at, bt[0]) + default: + VecAddU64(at, bt) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecAddF32(at, bt) + case as && !bs: + AddSVF32(at[0], bt) + case !as && bs: + AddVSF32(at, bt[0]) + default: + VecAddF32(at, bt) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecAddF64(at, bt) + case as && !bs: + AddSVF64(at[0], bt) + case !as && bs: + AddVSF64(at, bt[0]) + default: + VecAddF64(at, bt) + } + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + switch { + case as && bs: + VecAddC64(at, bt) + case as && !bs: + AddSVC64(at[0], bt) + case !as && bs: + AddVSC64(at, bt[0]) + default: + VecAddC64(at, bt) + } + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + switch { + case as && bs: + VecAddC128(at, bt) + case as && !bs: + AddSVC128(at[0], bt) + case !as && bs: + AddVSC128(at, bt[0]) + default: + VecAddC128(at, bt) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecAddStr(at, bt) + case as && !bs: + AddSVStr(at[0], bt) + case !as && bs: + AddVSStr(at, bt[0]) + default: + VecAddStr(at, bt) + } + return + default: + return errors.Errorf("Unsupported type %v for Add", t) + } +} From 4c3771b119367ec6f119aeed034a44c6b5b71a20 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 9 Dec 2020 06:11:29 +1100 Subject: [PATCH 066/154] Clarify Semantics of how Shape and Data works. (#97) * Fixed #90 * Starting to clarify some semantic * With the semantics clarified, the consopts need to change a bit too * Updated the semantics to make it more clear * Added an example to Dense.Data() to clarify the semantics. Added tests for certain consopts that may be breaking -race * Added mmap example for FromMemory * Fixed ap.T and clarified shapes better in ap.go Added an example for T * Fixes SelectByIndices * Unnecessary checks for 0 removed, given that ProdInts have changed in function --- ap.go | 17 +++++-- ap_test.go | 4 +- consopt.go | 10 +++- consopt_test.go | 94 ++++++++++++++++++++++++++++++++++++ defaultengine_matop_misc.go | 3 -- defaultengine_selbyidx.go | 7 +-- example_dense_basics_test.go | 45 +++++++++++++++++ example_dense_matop_test.go | 77 +++++++++++++++++++++++++++++ shape.go | 8 +-- shape_test.go | 12 ++++- utils.go | 21 +------- 11 files changed, 258 insertions(+), 40 deletions(-) create mode 100644 consopt_test.go create mode 100644 example_dense_basics_test.go diff --git a/ap.go b/ap.go index aab0a50..145af0a 100644 --- a/ap.go +++ b/ap.go @@ -8,9 +8,12 @@ import ( // An AP is an access pattern. It tells the various ndarrays how to access their data through the use of strides // Through the AP, there are several definitions of things, most notably there are two very specific "special cases": -// Scalar has Dims() of 0. However, its shape can take several forms: -// - (1, 1) +// Scalar has Dims() of 0. // - (1) +// Scalarlikes are higher order tensors, but each with a size of 1. The Dims() are not 0. +// - (1, 1) +// - (1, 1, 1) +// - (1, 1, 1, 1), etc // Vector has Dims() of 1, but its shape can take several forms: // - (x, 1) // - (1, x) @@ -121,9 +124,12 @@ func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() } // IsRowVec returns true when the access pattern has the shape (1, x) func (ap *AP) IsRowVec() bool { return ap.shape.IsRowVec() } -// IsScalar returns true if the access pattern indicates it's a scalar value +// IsScalar returns true if the access pattern indicates it's a scalar value. func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() } +// IsScalarEquiv returns true if the access pattern is equivalent to a scalar shape. +func (ap *AP) IsScalarEquiv() bool { return ap.shape.IsScalarEquiv() } + // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 } @@ -297,6 +303,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err er // T returns the transposed metadata based on the given input func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { + // prep axes if len(axes) > 0 && len(axes) != ap.Dims() { err = errors.Errorf(dimMismatch, ap.Dims(), len(axes)) @@ -312,6 +319,10 @@ func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { } a = axes + if ap.shape.IsScalarEquiv() { + return ap.Clone(), a, noopError{} + } + // if axes is 0, 1, 2, 3... then no op if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 { return ap.Clone(), a, noopError{} diff --git a/ap_test.go b/ap_test.go index 091d6e6..f5a230e 100644 --- a/ap_test.go +++ b/ap_test.go @@ -112,12 +112,14 @@ func TestAccessPatternIsX(t *testing.T) { ap = dummyScalar1() assert.True(ap.IsScalar()) + assert.True(ap.IsScalarEquiv()) assert.False(ap.IsVector()) assert.False(ap.IsColVec()) assert.False(ap.IsRowVec()) ap = dummyScalar2() - assert.True(ap.IsScalar()) + assert.False(ap.IsScalar()) + assert.True(ap.IsScalarEquiv()) assert.False(ap.IsVector()) assert.False(ap.IsColVec()) assert.False(ap.IsRowVec()) diff --git a/consopt.go b/consopt.go index c0c57c7..2134896 100644 --- a/consopt.go +++ b/consopt.go @@ -112,10 +112,17 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { xvi.Set(reflect.ValueOf(x)) uptr := unsafe.Pointer(xv.Pointer()) + var v interface{} + if !tt.Shape().IsScalar() { + sl := reflect.MakeSlice(reflect.SliceOf(xt), 1, 1) + zeroth := sl.Index(0) + zeroth.Set(reflect.ValueOf(x)) + v = sl.Interface() + } tt.array.Ptr = uptr tt.array.L = 1 tt.array.C = 1 - tt.v = x + tt.v = v tt.t = Dtype{xt} tt.mask = mask @@ -146,7 +153,6 @@ func FromMemory(ptr uintptr, memsize uintptr) ConsOpt { switch tt := t.(type) { case *Dense: tt.v = nil // if there were any underlying slices it should be GC'd - tt.array.Ptr = unsafe.Pointer(ptr) tt.array.L = int(memsize / tt.t.Size()) tt.array.C = int(memsize / tt.t.Size()) diff --git a/consopt_test.go b/consopt_test.go new file mode 100644 index 0000000..65d5396 --- /dev/null +++ b/consopt_test.go @@ -0,0 +1,94 @@ +package tensor + +import ( + "fmt" + "io/ioutil" + "os" + "syscall" + "testing" + "testing/quick" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +type F64 float64 + +func newF64(f float64) *F64 { r := F64(f); return &r } + +func (f *F64) Uintptr() uintptr { return uintptr(unsafe.Pointer(f)) } + +func (f *F64) MemSize() uintptr { return 8 } + +func (f *F64) Pointer() unsafe.Pointer { return unsafe.Pointer(f) } + +func Test_FromMemory(t *testing.T) { + fn := func(F float64) bool { + f := newF64(F) + T := New(WithShape(), Of(Float64), FromMemory(f.Uintptr(), f.MemSize())) + data := T.Data().(float64) + + if data != F { + return false + } + return true + } + if err := quick.Check(fn, &quick.Config{MaxCount: 1000000}); err != nil { + t.Logf("%v", err) + } + + f, err := ioutil.TempFile("", "test") + if err != nil { + t.Fatal(err) + } + // fill in with fake data + backing := make([]byte, 8*1024*1024) // 1024*1024 matrix of float64 + asFloats := *(*[]float64)(unsafe.Pointer(&backing)) + asFloats = asFloats[: 1024*1024 : 1024*1024] + asFloats[0] = 3.14 + asFloats[2] = 6.28 + asFloats[1024*1024-1] = 3.14 + asFloats[1024*1024-3] = 6.28 + f.Write(backing) + + // defer cleanup + defer os.Remove(f.Name()) + + // do the mmap stuff + stat, err := f.Stat() + if err != nil { + t.Fatal(err) + } + + size := int(stat.Size()) + fd := int(f.Fd()) + bs, err := syscall.Mmap(fd, 0, size, syscall.PROT_READ, syscall.MAP_SHARED) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := syscall.Munmap(bs); err != nil { + t.Error(err) + } + }() + T := New(WithShape(1024, 1024), Of(Float64), FromMemory(uintptr(unsafe.Pointer(&bs[0])), uintptr(size))) + + s := fmt.Sprintf("%v", T) + expected := `⎡3.14 0 6.28 0 ... 0 0 0 0⎤ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +. +. +. +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎢ 0 0 0 0 ... 0 0 0 0⎥ +⎣ 0 0 0 0 ... 0 6.28 0 3.14⎦ +` + if s != expected { + t.Errorf("Expected mmap'd tensor to be exactly the same.") + } + + assert.True(t, T.IsManuallyManaged()) +} diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index bd9ee18..f45c9bb 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -72,9 +72,6 @@ func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, outers = 1 } else { outers = ProdInts(t.Shape()[0:axis]) - if outers == 0 { - outers = 1 - } } var stride, newStride int diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index 4200b5b..ab7f4f1 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -75,9 +75,8 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da axStride := apA.strides[axis] retStride := apRet.strides[axis] var outerRetStride int - if outer == 0 { + if axis == 0 { // then it's the outermost - outer = 1 outerRetStride = apRet.strides[axis] * 2 } else { outerRetStride = apRet.strides[axis-1] @@ -185,9 +184,7 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data axStride := apB.strides[axis] retStride := apRet.strides[axis] var outerRetStride int - if outer == 0 { - // then it's the outermost - outer = 1 + if axis == 0 { outerRetStride = apRet.strides[axis] * 2 } else { outerRetStride = apRet.strides[axis-1] diff --git a/example_dense_basics_test.go b/example_dense_basics_test.go new file mode 100644 index 0000000..1b08ca0 --- /dev/null +++ b/example_dense_basics_test.go @@ -0,0 +1,45 @@ +package tensor + +import ( + "fmt" +) + +// Data shows how the shape of the *Dense actually affects the return value of .Data(). +func ExampleDense_Data() { + T := New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})) + fmt.Printf("Basics:\n======\nAny kind of arrays: %v\n", T.Data()) + + fmt.Printf("\nScalar-like\n===========\n") + T = New(WithShape(), FromScalar(3.14)) + fmt.Printf("WithShape(), FromScalar: %v\n", T.Data()) + + T = New(WithShape(), WithBacking([]float64{3.14})) + fmt.Printf("WithShape(), With a slice of 1 as backing: %v\n", T.Data()) + + T = New(WithShape(1), FromScalar(3.14)) + fmt.Printf("WithShape(1), With an initial scalar: %v\n", T.Data()) + + T = New(WithShape(1, 1), WithBacking([]float64{3.14})) + fmt.Printf("WithShape(1, 1), With an initial scalar: %v\n", T.Data()) + + T = New(WithShape(1, 1), FromScalar(3.14)) + fmt.Printf("WithShape(1, 1), With an initial scalar: %v\n", T.Data()) + + T.Reshape() + fmt.Printf("After reshaping to (): %v\n", T.Data()) + + // Output: + // Basics: + // ====== + // Any kind of arrays: [1 2 3 4] + // + // Scalar-like + // =========== + // WithShape(), FromScalar: 3.14 + // WithShape(), With a slice of 1 as backing: 3.14 + // WithShape(1), With an initial scalar: [3.14] + // WithShape(1, 1), With an initial scalar: [3.14] + // WithShape(1, 1), With an initial scalar: [3.14] + // After reshaping to (): 3.14 + +} diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 08ab492..30df83c 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -289,3 +289,80 @@ func ExampleRepeat_uncommonUses() { // Once again, observe that the 1st element ([2 5]) has been repeated 3 times, while the rest have been repeated twice } + +func ExampleT() { + // Usual example of 2D matrix being transposed: + M := New(WithBacking([]int{1, 2, 3, 4, 5, 6}), WithShape(2, 3)) + M2, err := T(M) + if err != nil { + fmt.Printf("Err: %v\n", err) + } + fmt.Printf("M:\n%v\nM2:\n%v\n", M, M2) + + // T accepts optional parameters describing the permutation of axes. + // In a 2D case, there are only two options: (0, 1) or (1, 0). + // The latter is default if no parameters are passed in. + // The former is a no-op as rearranging a matrix so that the 0th axis becomes the 0th axis + // and the first axis becomes the first axis is not going to do anything. + // + // However, note that M3 is a different result. + M3, err := T(M, 0, 1) + if err != nil { + fmt.Printf("Err: %v\n", err) + } + fmt.Printf("M3:\n%v\nM == M3: %t", M3, M == M3) + + // Output: + // M: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // + // M2: + // ⎡1 4⎤ + // ⎢2 5⎥ + // ⎣3 6⎦ + // + // M3: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // + // M == M3: false + +} + +func ExampleT_scalarlike() { + // Be aware when dealing with scalarlike tensors + // scalar/scalarlikes have no effect when calling T() + // but the result is put into a new tensor + S := New(WithBacking([]float32{3.14}), WithShape()) + S2, err := T(S) + if err != nil { + fmt.Printf("Err %v", err) + } + fmt.Printf("S: %v S2 %v S == S2: %t\n", S, S2, S == S2) + + // however do note that scalars and scalarlikes are not the same thing. + // for example, consider this: + _, err = T(S, 1, 0) + fmt.Printf("error when the axes are more than the shape's dims: %v\n", err) + + // but if you have a tensor that is a scalar-like: + S.Reshape(1, 1) + S2, err = T(S, 1, 0) + if err != nil { + fmt.Printf("Err: %v\n", err) + } + fmt.Printf("S:\n%v\nS2:\n%v\nS == S2: %t\n", S, S2, S == S2) + + // Output: + // S: 3.14 S2 3.14 S == S2: false + // error when the axes are more than the shape's dims: Dimension mismatch. Expected 0, got 2 + // S: + // ⎡3.14⎤ + // + // S2: + // ⎡3.14⎤ + // + // S == S2: false + +} diff --git a/shape.go b/shape.go index c448d0b..a0396bb 100644 --- a/shape.go +++ b/shape.go @@ -26,7 +26,7 @@ func (s Shape) TotalSize() int { // CalcStrides calculates the default strides for a shape func (s Shape) CalcStrides() []int { - if s.IsScalar() { + if s.IsScalarEquiv() { return nil } @@ -52,7 +52,7 @@ func (s Shape) CalcStrides() []int { // CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions // during calculation of stride func (s Shape) CalcStridesWithMask(mask []bool) []int { - if s.IsScalar() { + if s.IsScalarEquiv() { return nil } @@ -87,7 +87,7 @@ func (s Shape) CalcStridesWithMask(mask []bool) []int { // CalcStridesColMajor is like CalcStrides, but assumes a col major layout func (s Shape) CalcStridesColMajor() []int { - if s.IsScalar() { + if s.IsScalarEquiv() { return nil } @@ -155,7 +155,7 @@ func (s Shape) Clone() Shape { // IsScalar returns true if the access pattern indicates it's a scalar value func (s Shape) IsScalar() bool { - return len(s) == 0 || (len(s) == 1 && s[0] == 1) + return len(s) == 0 } // IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value diff --git a/shape_test.go b/shape_test.go index aa9e1de..cb21bed 100644 --- a/shape_test.go +++ b/shape_test.go @@ -47,12 +47,14 @@ func TestShapeIsX(t *testing.T) { // scalar shape s = Shape{} assert.True(s.IsScalar()) + assert.True(s.IsScalarEquiv()) assert.False(s.IsVector()) assert.False(s.IsColVec()) assert.False(s.IsRowVec()) s = Shape{1} - assert.True(s.IsScalar()) + assert.False(s.IsScalar()) + assert.True(s.IsScalarEquiv()) assert.False(s.IsVector()) assert.False(s.IsColVec()) assert.False(s.IsRowVec()) @@ -129,11 +131,17 @@ func TestShapeEquality(t *testing.T) { var s1, s2 Shape // scalar - s1 = Shape{1} + s1 = Shape{} s2 = Shape{} assert.True(s1.Eq(s2)) assert.True(s2.Eq(s1)) + // scalars and scalar equiv are not the same! + s1 = Shape{1} + s2 = Shape{} + assert.False(s1.Eq(s2)) + assert.False(s2.Eq(s1)) + // vector s1 = Shape{3} s2 = Shape{5} diff --git a/utils.go b/utils.go index 42ef597..064c812 100644 --- a/utils.go +++ b/utils.go @@ -42,35 +42,16 @@ func SumInts(a []int) (retVal int) { // ProdInts returns the internal product of an int slice func ProdInts(a []int) (retVal int) { + retVal = 1 if len(a) == 0 { return } - retVal = 1 for _, v := range a { retVal *= v } return } -// EqInts returns true if slices have same value -// func EqInts(a, b []int) bool { -// if len(a) != len(b) { -// return false -// } - -// if (a == nil) != (b == nil) { -// return false -// } - -// b = b[:len(a)] -// for i, v := range a { -// if v != b[i] { -// return false -// } -// } -// return true -// } - // IsMonotonicInts returns true if the slice of ints is monotonically increasing. It also returns true for incr1 if every succession is a succession of 1 func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) { var prev int From e325ad90803a29b1745512345d43cc1436deec80 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sat, 12 Dec 2020 09:40:52 +1100 Subject: [PATCH 067/154] Cuda11 (#99) * Fixed #90 * Removed everything that is unsafe.Pointer to use uintptr. Any thing that requires a unsafe.Pointer to remain will have to use a refcounter. * genlib'd the RC stuff into scalarHeaders * Fixed so -race will not complain * Updated travis to make sure things test with race as well * Added some tests for Float64Engine and Float32Engine * Moved to using raw byte slices as per Bryan C Mills' suggestion * More fixed from just moving to raw byte slices * Fixed more things for array * Fixed tests * Fixed all syntax errors * removed .v from array * Fixed some off that scalar business * Fixed the slice bits * tests pass * Added benchmark script * Fixed eng_arith_manual * Fixed inplace transpose as well --- .travis/test.sh | 3 +- api_arith_test.go | 1148 +++++++++++----------- api_matop.go | 2 +- array.go | 316 +++--- array_getset.go | 164 ++-- bench.sh | 23 + consopt.go | 40 +- defaultengine.go | 145 ++- defaultengine_arith.go | 84 +- defaultengine_cmp.go | 96 +- defaultengine_matop_misc.go | 4 +- defaultengine_matop_stack.go | 6 +- defaultengine_matop_transpose.go | 3 +- defaultengine_matop_transpose_inplace.go | 3 +- defaultengine_prep.go | 8 +- defaultenginefloat32.go | 11 +- defaultenginefloat32_test.go | 42 + defaultenginefloat64.go | 8 +- defaultenginefloat64_test.go | 42 + dense.go | 65 +- dense_assign.go | 10 +- dense_io.go | 2 +- dense_matop.go | 1 - dense_matop_test.go | 2 + engine.go | 5 - genlib2/agg1_body.go | 66 +- genlib2/agg2_body.go | 36 +- genlib2/array_getset.go | 81 +- genlib2/dense_io.go | 2 +- genlib2/main.go | 1 + internal/execution/e.go | 2 +- internal/execution/eng_arith.go | 144 +-- internal/execution/eng_arith_manual.go | 22 +- internal/execution/eng_cmp.go | 144 +-- internal/execution/eng_map.go | 2 +- internal/storage/consts.go | 29 + internal/storage/getset.go | 72 +- internal/storage/header.go | 67 +- known_race_test.go | 1 + perf.go | 9 +- sparse.go | 14 +- tensor.go | 5 +- testutils_test.go | 13 +- 43 files changed, 1534 insertions(+), 1409 deletions(-) create mode 100755 bench.sh create mode 100644 defaultenginefloat32_test.go create mode 100644 defaultenginefloat64_test.go create mode 100644 internal/storage/consts.go diff --git a/.travis/test.sh b/.travis/test.sh index 37fdd87..381a409 100644 --- a/.travis/test.sh +++ b/.travis/test.sh @@ -6,6 +6,7 @@ go test -v -a -covermode=atomic -coverprofile=test.cover . go test -tags='avx' -a -covermode=atomic -coverprofile=avx.cover . go test -tags='sse' -a -covermode=atomic -coverprofile=sse.cover . go test -tags='inplacetranspose' -a -covermode=atomic -coverprofile=inplacetranspose.cover . +go test -race -a . go test -a -covermode=atomic -coverprofile=native.cover ./native/. # because coveralls only accepts one coverage file at one time... we combine them into one gigantic one @@ -14,4 +15,4 @@ echo "mode: set" > ./final.cover tail -q -n +2 "${covers[@]}" >> ./final.cover goveralls -coverprofile=./final.cover -service=travis-ci -set +ex \ No newline at end of file +set +ex diff --git a/api_arith_test.go b/api_arith_test.go index ca45f8f..75a4838 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -1,574 +1,574 @@ -package tensor - -import ( - "log" - "math/rand" - "testing" - "testing/quick" - "time" - - "github.com/stretchr/testify/assert" -) - -// This file contains the tests for API functions that aren't generated by genlib - -func TestMod(t *testing.T) { - a := New(WithBacking([]float64{1, 2, 3, 4})) - b := New(WithBacking([]float64{1, 1, 1, 1})) - var correct interface{} = []float64{0, 0, 0, 0} - - // vec-vec - res, err := Mod(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar - if res, err = Mod(a, 1.0); err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestFMA(t *testing.T) { - same := func(q *Dense) bool { - a := q.Clone().(*Dense) - x := q.Clone().(*Dense) - y := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...)) - y.Memset(identityVal(100, q.Dtype())) - WithEngine(q.Engine())(y) - y2 := y.Clone().(*Dense) - - we, willFailEq := willerr(a, numberTypes, nil) - _, ok1 := q.Engine().(FMAer) - _, ok2 := q.Engine().(Muler) - _, ok3 := q.Engine().(Adder) - we = we || (!ok1 && (!ok2 || !ok3)) - - f, err := FMA(a, x, y) - if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { - if err != nil { - log.Printf("q.Engine() %T", q.Engine()) - return false - } - return true - } - - we, _ = willerr(a, numberTypes, nil) - _, ok := a.Engine().(Muler) - we = we || !ok - wi, err := Mul(a, x, WithIncr(y2)) - if err, retEarly := qcErrCheck(t, "FMA#2", a, x, we, err); retEarly { - if err != nil { - return false - } - return true - } - return qcEqCheck(t, q.Dtype(), willFailEq, wi, f) - } - r := rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(same, &quick.Config{Rand: r}); err != nil { - t.Error(err) - } - - // specific engines - var eng Engine - - // FLOAT64 ENGINE - - // vec-vec - eng = Float64Engine{} - a := New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) - x := New(WithBacking(Range(Float64, 1, 101)), WithEngine(eng)) - y := New(Of(Float64), WithShape(100), WithEngine(eng)) - - f, err := FMA(a, x, y) - if err != nil { - t.Fatal(err) - } - - a2 := New(WithBacking(Range(Float64, 0, 100))) - x2 := New(WithBacking(Range(Float64, 1, 101))) - y2 := New(Of(Float64), WithShape(100)) - f2, err := Mul(a2, x2, WithIncr(y2)) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - - // vec-scalar - a = New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) - y = New(Of(Float64), WithShape(100)) - - if f, err = FMA(a, 2.0, y); err != nil { - t.Fatal(err) - } - - a2 = New(WithBacking(Range(Float64, 0, 100))) - y2 = New(Of(Float64), WithShape(100)) - if f2, err = Mul(a2, 2.0, WithIncr(y2)); err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - - // FLOAT32 engine - eng = Float32Engine{} - a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) - x = New(WithBacking(Range(Float32, 1, 101)), WithEngine(eng)) - y = New(Of(Float32), WithShape(100), WithEngine(eng)) - - f, err = FMA(a, x, y) - if err != nil { - t.Fatal(err) - } - - a2 = New(WithBacking(Range(Float32, 0, 100))) - x2 = New(WithBacking(Range(Float32, 1, 101))) - y2 = New(Of(Float32), WithShape(100)) - f2, err = Mul(a2, x2, WithIncr(y2)) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - - // vec-scalar - a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) - y = New(Of(Float32), WithShape(100)) - - if f, err = FMA(a, float32(2), y); err != nil { - t.Fatal(err) - } - - a2 = New(WithBacking(Range(Float32, 0, 100))) - y2 = New(Of(Float32), WithShape(100)) - if f2, err = Mul(a2, float32(2), WithIncr(y2)); err != nil { - t.Fatal(err) - } - - assert.Equal(t, f.Data(), f2.Data()) - -} - -func TestMulScalarScalar(t *testing.T) { - // scalar-scalar - a := New(WithBacking([]float64{2})) - b := New(WithBacking([]float64{3})) - var correct interface{} = 6.0 - - res, err := Mul(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Mul(b, a) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-tensor - a = New(WithBacking([]float64{3, 2})) - b = New(WithBacking([]float64{2})) - correct = []float64{6, 4} - - res, err = Mul(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Mul(b, a) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor - tensor - a = New(WithBacking([]float64{3, 5})) - b = New(WithBacking([]float64{7, 2})) - correct = []float64{21, 10} - - res, err = Mul(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Mul(b, a) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Interface - tensor - ai := 2.0 - b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) - correct = []float64{6.0} - - res, err = Mul(ai, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Commutativity - res, err = Mul(b, ai) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestDivScalarScalar(t *testing.T) { - // scalar-scalar - a := New(WithBacking([]float64{6})) - b := New(WithBacking([]float64{2})) - var correct interface{} = 3.0 - - res, err := Div(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-tensor - a = New(WithBacking([]float64{6, 4})) - b = New(WithBacking([]float64{2})) - correct = []float64{3, 2} - - res, err = Div(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor-scalar - a = New(WithBacking([]float64{6})) - b = New(WithBacking([]float64{3, 2})) - correct = []float64{2, 3} - - res, err = Div(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor - tensor - a = New(WithBacking([]float64{21, 10})) - b = New(WithBacking([]float64{7, 2})) - correct = []float64{3, 5} - - res, err = Div(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // interface-scalar - ai := 6.0 - b = New(WithBacking([]float64{2})) - correct = 3.0 - - res, err = Div(ai, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-interface - a = New(WithBacking([]float64{6})) - bi := 2.0 - correct = 3.0 - - res, err = Div(a, bi) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestAddScalarScalar(t *testing.T) { - // scalar-scalar - a := New(WithBacking([]float64{2})) - b := New(WithBacking([]float64{3})) - var correct interface{} = 5.0 - - res, err := Add(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Add(b, a) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-tensor - a = New(WithBacking([]float64{3, 2})) - b = New(WithBacking([]float64{2})) - correct = []float64{5, 4} - - res, err = Add(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Add(b, a) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor - tensor - a = New(WithBacking([]float64{3, 5})) - b = New(WithBacking([]float64{7, 2})) - correct = []float64{10, 7} - - res, err = Add(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Add(b, a) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // interface-scalar - ai := 2.0 - b = New(WithBacking([]float64{3})) - correct = 5.0 - - res, err = Add(ai, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // Test commutativity - res, err = Add(b, ai) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestSubScalarScalar(t *testing.T) { - // scalar-scalar - a := New(WithBacking([]float64{6})) - b := New(WithBacking([]float64{2})) - var correct interface{} = 4.0 - - res, err := Sub(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-tensor - a = New(WithBacking([]float64{6, 4})) - b = New(WithBacking([]float64{2})) - correct = []float64{4, 2} - - res, err = Sub(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor-scalar - a = New(WithBacking([]float64{6})) - b = New(WithBacking([]float64{3, 2})) - correct = []float64{3, 4} - - res, err = Sub(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor - tensor - a = New(WithBacking([]float64{21, 10})) - b = New(WithBacking([]float64{7, 2})) - correct = []float64{14, 8} - - res, err = Sub(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // interface-scalar - ai := 6.0 - b = New(WithBacking([]float64{2})) - correct = 4.0 - - res, err = Sub(ai, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-interface - a = New(WithBacking([]float64{6})) - bi := 2.0 - correct = 4.0 - - res, err = Sub(a, bi) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestModScalarScalar(t *testing.T) { - // scalar-scalar - a := New(WithBacking([]float64{5})) - b := New(WithBacking([]float64{2})) - var correct interface{} = 1.0 - - res, err := Mod(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-tensor - a = New(WithBacking([]float64{5, 4})) - b = New(WithBacking([]float64{2})) - correct = []float64{1, 0} - - res, err = Mod(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor-scalar - a = New(WithBacking([]float64{5})) - b = New(WithBacking([]float64{3, 2})) - correct = []float64{2, 1} - - res, err = Mod(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor - tensor - a = New(WithBacking([]float64{22, 10})) - b = New(WithBacking([]float64{7, 2})) - correct = []float64{1, 0} - - res, err = Mod(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // interface-scalar - ai := 5.0 - b = New(WithBacking([]float64{2})) - correct = 1.0 - - res, err = Mod(ai, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-interface - a = New(WithBacking([]float64{5})) - bi := 2.0 - correct = 1.0 - - res, err = Mod(a, bi) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} - -func TestPowScalarScalar(t *testing.T) { - // scalar-scalar - a := New(WithBacking([]float64{6})) - b := New(WithBacking([]float64{2})) - var correct interface{} = 36.0 - - res, err := Pow(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-tensor - a = New(WithBacking([]float64{6, 4})) - b = New(WithBacking([]float64{2})) - correct = []float64{36, 16} - - res, err = Pow(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor-scalar - a = New(WithBacking([]float64{6})) - b = New(WithBacking([]float64{3, 2})) - correct = []float64{216, 36} - - res, err = Pow(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // tensor - tensor - a = New(WithBacking([]float64{3, 10})) - b = New(WithBacking([]float64{7, 2})) - correct = []float64{2187, 100} - - res, err = Pow(a, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // interface-scalar - ai := 6.0 - b = New(WithBacking([]float64{2})) - correct = 36.0 - - res, err = Pow(ai, b) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) - - // scalar-interface - a = New(WithBacking([]float64{6})) - bi := 2.0 - correct = 36.0 - - res, err = Pow(a, bi) - if err != nil { - t.Fatalf("Error: %v", err) - } - assert.Equal(t, correct, res.Data()) -} +package tensor + +import ( + "log" + "math/rand" + "testing" + "testing/quick" + "time" + + "github.com/stretchr/testify/assert" +) + +// This file contains the tests for API functions that aren't generated by genlib + +func TestMod(t *testing.T) { + a := New(WithBacking([]float64{1, 2, 3, 4})) + b := New(WithBacking([]float64{1, 1, 1, 1})) + var correct interface{} = []float64{0, 0, 0, 0} + + // vec-vec + res, err := Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar + if res, err = Mod(a, 1.0); err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestFMA(t *testing.T) { + same := func(q *Dense) bool { + a := q.Clone().(*Dense) + x := q.Clone().(*Dense) + y := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...)) + y.Memset(identityVal(100, q.Dtype())) + WithEngine(q.Engine())(y) + y2 := y.Clone().(*Dense) + + we, willFailEq := willerr(a, numberTypes, nil) + _, ok1 := q.Engine().(FMAer) + _, ok2 := q.Engine().(Muler) + _, ok3 := q.Engine().(Adder) + we = we || (!ok1 && (!ok2 || !ok3)) + + f, err := FMA(a, x, y) + if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { + if err != nil { + log.Printf("q.Engine() %T", q.Engine()) + return false + } + return true + } + + we, _ = willerr(a, numberTypes, nil) + _, ok := a.Engine().(Muler) + we = we || !ok + wi, err := Mul(a, x, WithIncr(y2)) + if err, retEarly := qcErrCheck(t, "FMA#2", a, x, we, err); retEarly { + if err != nil { + return false + } + return true + } + return qcEqCheck(t, q.Dtype(), willFailEq, wi, f) + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(same, &quick.Config{Rand: r}); err != nil { + t.Error(err) + } + + // specific engines + var eng Engine + + // FLOAT64 ENGINE + + // vec-vec + eng = Float64Engine{} + a := New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) + x := New(WithBacking(Range(Float64, 1, 101)), WithEngine(eng)) + y := New(Of(Float64), WithShape(100), WithEngine(eng)) + + f, err := FMA(a, x, y) + if err != nil { + t.Fatal(err) + } + + a2 := New(WithBacking(Range(Float64, 0, 100))) + x2 := New(WithBacking(Range(Float64, 1, 101))) + y2 := New(Of(Float64), WithShape(100)) + f2, err := Mul(a2, x2, WithIncr(y2)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + + // vec-scalar + a = New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) + y = New(Of(Float64), WithShape(100)) + + if f, err = FMA(a, 2.0, y); err != nil { + t.Fatal(err) + } + + a2 = New(WithBacking(Range(Float64, 0, 100))) + y2 = New(Of(Float64), WithShape(100)) + if f2, err = Mul(a2, 2.0, WithIncr(y2)); err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + + // FLOAT32 engine + eng = Float32Engine{} + a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) + x = New(WithBacking(Range(Float32, 1, 101)), WithEngine(eng)) + y = New(Of(Float32), WithShape(100), WithEngine(eng)) + + f, err = FMA(a, x, y) + if err != nil { + t.Fatal(err) + } + + a2 = New(WithBacking(Range(Float32, 0, 100))) + x2 = New(WithBacking(Range(Float32, 1, 101))) + y2 = New(Of(Float32), WithShape(100)) + f2, err = Mul(a2, x2, WithIncr(y2)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + + // vec-scalar + a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) + y = New(Of(Float32), WithShape(100)) + + if f, err = FMA(a, float32(2), y); err != nil { + t.Fatal(err) + } + + a2 = New(WithBacking(Range(Float32, 0, 100))) + y2 = New(Of(Float32), WithShape(100)) + if f2, err = Mul(a2, float32(2), WithIncr(y2)); err != nil { + t.Fatal(err) + } + + assert.Equal(t, f.Data(), f2.Data()) + +} + +func TestMulScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{2})) + b := New(WithBacking([]float64{3})) + var correct interface{} = 6.0 + + res, err := Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{3, 2})) + b = New(WithBacking([]float64{2})) + correct = []float64{6, 4} + + res, err = Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 5})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{21, 10} + + res, err = Mul(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Mul(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Interface - tensor + ai := 2.0 + b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) + correct = []float64{6.0} + + res, err = Mul(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Commutativity + res, err = Mul(b, ai) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestDivScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 3.0 + + res, err := Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{3, 2} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{2, 3} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{3, 5} + + res, err = Div(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 3.0 + + res, err = Div(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 3.0 + + res, err = Div(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestAddScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{2})) + b := New(WithBacking([]float64{3})) + var correct interface{} = 5.0 + + res, err := Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{3, 2})) + b = New(WithBacking([]float64{2})) + correct = []float64{5, 4} + + res, err = Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 5})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{10, 7} + + res, err = Add(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, a) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 2.0 + b = New(WithBacking([]float64{3})) + correct = 5.0 + + res, err = Add(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // Test commutativity + res, err = Add(b, ai) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestSubScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 4.0 + + res, err := Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{4, 2} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{3, 4} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{21, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{14, 8} + + res, err = Sub(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 4.0 + + res, err = Sub(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 4.0 + + res, err = Sub(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestModScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{5})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 1.0 + + res, err := Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{5, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{1, 0} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{5})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{2, 1} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{22, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{1, 0} + + res, err = Mod(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 5.0 + b = New(WithBacking([]float64{2})) + correct = 1.0 + + res, err = Mod(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{5})) + bi := 2.0 + correct = 1.0 + + res, err = Mod(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} + +func TestPowScalarScalar(t *testing.T) { + // scalar-scalar + a := New(WithBacking([]float64{6})) + b := New(WithBacking([]float64{2})) + var correct interface{} = 36.0 + + res, err := Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-tensor + a = New(WithBacking([]float64{6, 4})) + b = New(WithBacking([]float64{2})) + correct = []float64{36, 16} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor-scalar + a = New(WithBacking([]float64{6})) + b = New(WithBacking([]float64{3, 2})) + correct = []float64{216, 36} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // tensor - tensor + a = New(WithBacking([]float64{3, 10})) + b = New(WithBacking([]float64{7, 2})) + correct = []float64{2187, 100} + + res, err = Pow(a, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // interface-scalar + ai := 6.0 + b = New(WithBacking([]float64{2})) + correct = 36.0 + + res, err = Pow(ai, b) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) + + // scalar-interface + a = New(WithBacking([]float64{6})) + bi := 2.0 + correct = 36.0 + + res, err = Pow(a, bi) + if err != nil { + t.Fatalf("Error: %v", err) + } + assert.Equal(t, correct, res.Data()) +} diff --git a/api_matop.go b/api_matop.go index fe3cad9..75c2452 100644 --- a/api_matop.go +++ b/api_matop.go @@ -13,7 +13,7 @@ func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { return nil, errors.New("Engine does not support Repeat") } -// RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given // ???? , but the results will still be valid. +// RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid. func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) { if r, ok := t.Engine().(Repeater); ok { return r.RepeatReuse(t, reuse, axis, repeats...) diff --git a/array.go b/array.go index d6c07c6..ca948d6 100644 --- a/array.go +++ b/array.go @@ -3,6 +3,7 @@ package tensor import ( "fmt" "reflect" + "sync" "unsafe" "github.com/pkg/errors" @@ -11,33 +12,21 @@ import ( // array is the underlying generic array. type array struct { - storage.Header // the header - the Go representation (a slice) - t Dtype // the element type - v interface{} // an additional reference to the underlying slice. This is not strictly necessary, but does improve upon anything that calls .Data() -} - -// makeHeader makes a array Header -func makeHeader(t Dtype, length int) storage.Header { - return storage.Header{ - Ptr: malloc(t, length), - L: length, - C: length, - } + storage.Header // the header - the Go representation (a slice) + t Dtype // the element type } // makeArray makes an array. The memory allocation is handled by Go func makeArray(t Dtype, length int) array { - hdr := makeHeader(t, length) - return makeArrayFromHeader(hdr, t) -} - -// makeArrayFromHeader makes an array given a header -func makeArrayFromHeader(hdr storage.Header, t Dtype) array { + v := malloc(t, length) + hdr := storage.Header{ + Raw: v, + } return array{ Header: hdr, t: t, - v: nil, } + } // arrayFromSlice creates an array from a slice. If x is not a slice, it will panic. @@ -48,20 +37,18 @@ func arrayFromSlice(x interface{}) array { } elT := xT.Elem() - xV := reflect.ValueOf(x) - uptr := unsafe.Pointer(xV.Pointer()) - return array{ Header: storage.Header{ - Ptr: uptr, - L: xV.Len(), - C: xV.Cap(), + Raw: storage.AsByteSlice(x), }, t: Dtype{elT}, - v: x, } } +func (a *array) Len() int { return a.Header.TypedLen(a.t.Type) } + +func (a *array) Cap() int { return a.Header.TypedLen(a.t.Type) } + // fromSlice populates the value from a slice func (a *array) fromSlice(x interface{}) { xT := reflect.TypeOf(x) @@ -69,14 +56,8 @@ func (a *array) fromSlice(x interface{}) { panic("Expected a slice") } elT := xT.Elem() - xV := reflect.ValueOf(x) - uptr := unsafe.Pointer(xV.Pointer()) - - a.Ptr = uptr - a.L = xV.Len() - a.C = xV.Cap() + a.Raw = storage.AsByteSlice(x) a.t = Dtype{elT} - a.v = x } // fromSliceOrTensor populates the value from a slice or anything that can form an array @@ -85,94 +66,52 @@ func (a *array) fromSliceOrArrayer(x interface{}) { xp := T.arrPtr() // if the underlying array hasn't been allocated, or not enough has been allocated - if a.Ptr == nil || a.L < xp.L || a.C < xp.C { - a.t = xp.t - a.L = xp.L - a.C = xp.C - a.Ptr = malloc(a.t, a.L) + if a.Header.Raw == nil { + a.Header.Raw = malloc(xp.t, xp.Len()) } a.t = xp.t - a.L = xp.L - a.C = xp.C copyArray(a, T.arrPtr()) - a.v = nil // tell the GC to release whatever a.v may hold - a.forcefix() // fix it such that a.v has a value and is not nil return } a.fromSlice(x) } -// fix fills the a.v empty interface{} if it's not nil -func (a *array) fix() { - if a.v == nil { - a.forcefix() - } -} - -// forcefix fills the a.v empty interface{}. No checks are made if the thing is empty -func (a *array) forcefix() { - sliceT := reflect.SliceOf(a.t.Type) - ptr := unsafe.Pointer(&a.Header) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - a.v = val.Interface() -} - // byteSlice casts the underlying slice into a byte slice. Useful for copying and zeroing, but not much else -func (a array) byteSlice() []byte { - return storage.AsByteSlice(&a.Header, a.t.Type) -} +func (a array) byteSlice() []byte { return a.Header.Raw } // sliceInto creates a slice. Instead of returning an array, which would cause a lot of reallocations, sliceInto expects a array to // already have been created. This allows repetitive actions to be done without having to have many pointless allocation func (a *array) sliceInto(i, j int, res *array) { - c := a.C + c := a.Cap() if i < 0 || j < i || j > c { panic(fmt.Sprintf("Cannot slice %v - index %d:%d is out of bounds", a, i, j)) } - res.L = j - i - res.C = c - i + s := i * int(a.t.Size()) + e := j * int(a.t.Size()) + c = c - i + + res.Raw = a.Raw[s:e] - if c-1 > 0 { - res.Ptr = storage.ElementAt(i, a.Ptr, a.t.Size()) - } else { - // don't advance pointer - res.Ptr = a.Ptr - } - res.fix() } // slice slices an array func (a array) slice(start, end int) array { - if end > a.L { + if end > a.Len() { panic("Index out of range") } if end < start { panic("Index out of range") } - L := end - start - C := a.C - start - - var startptr unsafe.Pointer - if a.C-start > 0 { - startptr = storage.ElementAt(start, a.Ptr, a.t.Size()) - } else { - startptr = a.Ptr - } - - hdr := storage.Header{ - Ptr: startptr, - L: L, - C: C, - } + s := start * int(a.t.Size()) + e := end * int(a.t.Size()) return array{ - Header: hdr, + Header: storage.Header{Raw: a.Raw[s:e]}, t: a.t, - v: nil, } } @@ -216,30 +155,24 @@ func (a *array) swap(i, j int) { /* *Array is a Memory */ // Uintptr returns the pointer of the first value of the slab -func (a *array) Uintptr() uintptr { return uintptr(a.Ptr) } +func (a *array) Uintptr() uintptr { return uintptr(unsafe.Pointer(&a.Header.Raw[0])) } // MemSize returns how big the slice is in bytes -func (a *array) MemSize() uintptr { return uintptr(a.L) * a.t.Size() } - -// Pointer returns the pointer of the first value of the slab, as an unsafe.Pointer -func (a *array) Pointer() unsafe.Pointer { return a.Ptr } +func (a *array) MemSize() uintptr { return uintptr(len(a.Header.Raw)) } // Data returns the representation of a slice. func (a array) Data() interface{} { - if a.v == nil { - // build a type of []T - shdr := reflect.SliceHeader{ - Data: uintptr(a.Header.Ptr), - Len: a.Header.L, - Cap: a.Header.C, - } - sliceT := reflect.SliceOf(a.t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - a.v = val.Interface() - + // build a type of []T + shdr := reflect.SliceHeader{ + Data: a.Uintptr(), + Len: a.Len(), + Cap: a.Cap(), } - return a.v + sliceT := reflect.SliceOf(a.t.Type) + ptr := unsafe.Pointer(&shdr) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + return val.Interface() + } // Zero zeroes out the underlying array of the *Dense tensor. @@ -258,10 +191,10 @@ func (a array) Zero() { } return } - ptr := uintptr(a.Ptr) - for i := 0; i < a.L; i++ { - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) + + l := a.Len() + for i := 0; i < l; i++ { + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(reflect.Zero(a.t)) } @@ -273,10 +206,9 @@ func (a *array) rtype() reflect.Type { return a.t.Type } /* MEMORY MOVEMENT STUFF */ // malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory -func malloc(t Dtype, length int) unsafe.Pointer { +func malloc(t Dtype, length int) []byte { size := int(calcMemSize(t, length)) - s := make([]byte, size) - return unsafe.Pointer(&s[0]) + return make([]byte, size) } // calcMemSize calulates the memory size of an array (given its size) @@ -427,95 +359,79 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { return storage.CopyIter(dst.rtype(), dst.hdr(), src.hdr(), diter, siter), nil } -func getPointer(a interface{}) unsafe.Pointer { +type scalarPtrCount struct { + Ptr unsafe.Pointer + Count int +} + +// scalarRCLock is a lock for the reference counting list. +var scalarRCLock sync.Mutex + +// scalarRC is a bunch of reference counted pointers to scalar values +var scalarRC = make(map[uintptr]*sync.Pool) // uintptr is the size, the pool stores []byte + +func scalarPool(size uintptr) *sync.Pool { + scalarRCLock.Lock() + pool, ok := scalarRC[size] + if !ok { + pool = &sync.Pool{ + New: func() interface{} { return make([]byte, size) }, + } + scalarRC[size] = pool + } + scalarRCLock.Unlock() + return pool +} + +func allocScalar(a interface{}) []byte { + atype := reflect.TypeOf(a) + size := atype.Size() + pool := scalarPool(size) + return pool.Get().([]byte) +} + +func freeScalar(bs []byte) { + if bs == nil { + return + } + + // zero out + for i := range bs { + bs[i] = 0 + } + + size := uintptr(len(bs)) + + // put it back into pool + pool := scalarPool(size) + pool.Put(bs) +} + +// scalarToHeader creates a Header from a scalar value +func scalarToHeader(a interface{}) (hdr *storage.Header, newAlloc bool) { + var raw []byte switch at := a.(type) { case Memory: - return at.Pointer() - case bool: - return unsafe.Pointer(&at) - case int: - return unsafe.Pointer(&at) - case int8: - return unsafe.Pointer(&at) - case int16: - return unsafe.Pointer(&at) - case int32: - return unsafe.Pointer(&at) - case int64: - return unsafe.Pointer(&at) - case uint: - return unsafe.Pointer(&at) - case uint8: - return unsafe.Pointer(&at) - case uint16: - return unsafe.Pointer(&at) - case uint32: - return unsafe.Pointer(&at) - case uint64: - return unsafe.Pointer(&at) - case float32: - return unsafe.Pointer(&at) - case float64: - return unsafe.Pointer(&at) - case complex64: - return unsafe.Pointer(&at) - case complex128: - return unsafe.Pointer(&at) - case string: - return unsafe.Pointer(&at) - case uintptr: - return unsafe.Pointer(at) - case unsafe.Pointer: - return at - - // POINTERS - - case *bool: - return unsafe.Pointer(at) - case *int: - return unsafe.Pointer(at) - case *int8: - return unsafe.Pointer(at) - case *int16: - return unsafe.Pointer(at) - case *int32: - return unsafe.Pointer(at) - case *int64: - return unsafe.Pointer(at) - case *uint: - return unsafe.Pointer(at) - case *uint8: - return unsafe.Pointer(at) - case *uint16: - return unsafe.Pointer(at) - case *uint32: - return unsafe.Pointer(at) - case *uint64: - return unsafe.Pointer(at) - case *float32: - return unsafe.Pointer(at) - case *float64: - return unsafe.Pointer(at) - case *complex64: - return unsafe.Pointer(at) - case *complex128: - return unsafe.Pointer(at) - case *string: - return unsafe.Pointer(at) - case *uintptr: - return unsafe.Pointer(*at) - case *unsafe.Pointer: - return *at - } - - panic("Cannot get pointer") + raw = storage.FromMemory(at.Uintptr(), at.MemSize()) + default: + raw = allocScalar(a) + newAlloc = true + } + hdr = borrowHeader() + hdr.Raw = raw + if newAlloc { + copyScalarToPrealloc(a, hdr.Raw) + } + + return hdr, newAlloc } -// scalarToHeader creates a Header from a scalar value -func scalarToHeader(a interface{}) *storage.Header { - hdr := borrowHeader() - hdr.Ptr = getPointer(a) - hdr.L = 1 - hdr.C = 1 - return hdr +func copyScalarToPrealloc(a interface{}, bs []byte) { + xV := reflect.ValueOf(a) + xT := reflect.TypeOf(a) + + p := unsafe.Pointer(&bs[0]) + v := reflect.NewAt(xT, p) + reflect.Indirect(v).Set(xV) + return } diff --git a/array_getset.go b/array_getset.go index 1f71afd..c19fe68 100644 --- a/array_getset.go +++ b/array_getset.go @@ -7,6 +7,7 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" ) // Set sets the value of the underlying array at the index i. @@ -68,8 +69,7 @@ func (a *array) Set(i int, x interface{}) { a.SetUnsafePointer(i, xv) default: xv := reflect.ValueOf(x) - want := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i)*a.t.Size()) - val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -80,43 +80,60 @@ func (a *array) Get(i int) interface{} { switch a.t.Kind() { case reflect.Bool: return a.GetB(i) + case reflect.Int: return a.GetI(i) + case reflect.Int8: return a.GetI8(i) + case reflect.Int16: return a.GetI16(i) + case reflect.Int32: return a.GetI32(i) + case reflect.Int64: return a.GetI64(i) + case reflect.Uint: return a.GetU(i) + case reflect.Uint8: return a.GetU8(i) + case reflect.Uint16: return a.GetU16(i) + case reflect.Uint32: return a.GetU32(i) + case reflect.Uint64: return a.GetU64(i) + case reflect.Uintptr: return a.GetUintptr(i) + case reflect.Float32: return a.GetF32(i) + case reflect.Float64: return a.GetF64(i) + case reflect.Complex64: return a.GetC64(i) + case reflect.Complex128: return a.GetC128(i) + case reflect.String: return a.GetStr(i) + case reflect.UnsafePointer: return a.GetUnsafePointer(i) + default: - at := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i)*a.t.Size()) - val := reflect.NewAt(a.t.Type, at) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) return val.Interface() } @@ -290,25 +307,24 @@ func (a *array) Memset(x interface{}) error { } xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - for i := 0; i < a.L; i++ { - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) + l := a.Len() + for i := 0; i < l; i++ { + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } return nil } -func (t *array) memsetIter(x interface{}, it Iterator) (err error) { +func (a *array) memsetIter(x interface{}, it Iterator) (err error) { var i int - switch t.t { + switch a.t { case Bool: xv, ok := x.(bool) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Bools() + data := a.Bools() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -316,9 +332,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int: xv, ok := x.(int) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Ints() + data := a.Ints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -326,9 +342,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int8: xv, ok := x.(int8) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int8s() + data := a.Int8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -336,9 +352,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int16: xv, ok := x.(int16) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int16s() + data := a.Int16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -346,9 +362,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int32: xv, ok := x.(int32) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int32s() + data := a.Int32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -356,9 +372,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Int64: xv, ok := x.(int64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Int64s() + data := a.Int64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -366,9 +382,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint: xv, ok := x.(uint) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uints() + data := a.Uints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -376,9 +392,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint8: xv, ok := x.(uint8) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint8s() + data := a.Uint8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -386,9 +402,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint16: xv, ok := x.(uint16) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint16s() + data := a.Uint16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -396,9 +412,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint32: xv, ok := x.(uint32) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint32s() + data := a.Uint32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -406,9 +422,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uint64: xv, ok := x.(uint64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uint64s() + data := a.Uint64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -416,9 +432,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Uintptr: xv, ok := x.(uintptr) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Uintptrs() + data := a.Uintptrs() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -426,9 +442,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Float32: xv, ok := x.(float32) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Float32s() + data := a.Float32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -436,9 +452,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Float64: xv, ok := x.(float64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Float64s() + data := a.Float64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -446,9 +462,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Complex64: xv, ok := x.(complex64) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Complex64s() + data := a.Complex64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -456,9 +472,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case Complex128: xv, ok := x.(complex128) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Complex128s() + data := a.Complex128s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -466,9 +482,9 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case String: xv, ok := x.(string) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.Strings() + data := a.Strings() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } @@ -476,19 +492,17 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { case UnsafePointer: xv, ok := x.(unsafe.Pointer) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.UnsafePointers() + data := a.UnsafePointers() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = xv } err = handleNoOp(err) default: xv := reflect.ValueOf(x) - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next() { - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -504,7 +518,7 @@ func (a array) Eq(other interface{}) bool { return false } - if oa.L != a.L { + if oa.Len() != a.Len() { return false } /* @@ -514,7 +528,7 @@ func (a array) Eq(other interface{}) bool { */ // same exact thing - if uintptr(oa.Ptr) == uintptr(a.Ptr) { + if uintptr(unsafe.Pointer(&oa.Header.Raw[0])) == uintptr(unsafe.Pointer(&a.Header.Raw[0])) { return true } @@ -628,7 +642,7 @@ func (a array) Eq(other interface{}) bool { } } default: - for i := 0; i < a.L; i++ { + for i := 0; i < a.Len(); i++ { if !reflect.DeepEqual(a.Get(i), oa.Get(i)) { return false } @@ -639,124 +653,122 @@ func (a array) Eq(other interface{}) bool { return false } -func (t *array) zeroIter(it Iterator) (err error) { +func (a *array) zeroIter(it Iterator) (err error) { var i int - switch t.t { + switch a.t { case Bool: - data := t.Bools() + data := a.Bools() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = false } err = handleNoOp(err) case Int: - data := t.Ints() + data := a.Ints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int8: - data := t.Int8s() + data := a.Int8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int16: - data := t.Int16s() + data := a.Int16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int32: - data := t.Int32s() + data := a.Int32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Int64: - data := t.Int64s() + data := a.Int64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint: - data := t.Uints() + data := a.Uints() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint8: - data := t.Uint8s() + data := a.Uint8s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint16: - data := t.Uint16s() + data := a.Uint16s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint32: - data := t.Uint32s() + data := a.Uint32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uint64: - data := t.Uint64s() + data := a.Uint64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Uintptr: - data := t.Uintptrs() + data := a.Uintptrs() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Float32: - data := t.Float32s() + data := a.Float32s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Float64: - data := t.Float64s() + data := a.Float64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Complex64: - data := t.Complex64s() + data := a.Complex64s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case Complex128: - data := t.Complex128s() + data := a.Complex128s() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = 0 } err = handleNoOp(err) case String: - data := t.Strings() + data := a.Strings() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = "" } err = handleNoOp(err) case UnsafePointer: - data := t.UnsafePointers() + data := a.UnsafePointers() for i, err = it.Next(); err == nil; i, err = it.Next() { data[i] = nil } err = handleNoOp(err) default: - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next() { - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) - val.Set(reflect.Zero(t.t)) + val.Set(reflect.Zero(a.t)) } err = handleNoOp(err) } diff --git a/bench.sh b/bench.sh new file mode 100755 index 0000000..8523853 --- /dev/null +++ b/bench.sh @@ -0,0 +1,23 @@ +#!/bin/sh + +old=$1; +new=$2; + +git checkout $old +# https://stackoverflow.com/a/2111099 +branch=$(git symbolic-ref HEAD | sed -e 's,.*/\(.*\),\1,') +echo "Benchmarking $branch (old)" +go test -run=$^ -bench=. > ${branch}.bench +for i in {1..10} + do + go test -run=$^ -bench=. >> ${branch}.bench + done + +git checkout $new +branch=$(git symbolic-ref HEAD | sed -e 's,.*/\(.*\),\1,') +echo "Benchmarking $branch (new)" +go test -run=$^ -bench=. > ${branch}.bench +for i in {1..10} + do + go test -run=$^ -bench=. >> ${branch}.bench + done diff --git a/consopt.go b/consopt.go index 2134896..8cbc54f 100644 --- a/consopt.go +++ b/consopt.go @@ -2,7 +2,8 @@ package tensor import ( "reflect" - "unsafe" + + "gorgonia.org/tensor/internal/storage" ) // ConsOpt is a tensor construction option. @@ -106,24 +107,15 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - xt := reflect.TypeOf(x) - xv := reflect.New(xt) - xvi := reflect.Indirect(xv) - xvi.Set(reflect.ValueOf(x)) - uptr := unsafe.Pointer(xv.Pointer()) - var v interface{} - if !tt.Shape().IsScalar() { - sl := reflect.MakeSlice(reflect.SliceOf(xt), 1, 1) - zeroth := sl.Index(0) - zeroth.Set(reflect.ValueOf(x)) - v = sl.Interface() - } - tt.array.Ptr = uptr - tt.array.L = 1 - tt.array.C = 1 - tt.v = v - tt.t = Dtype{xt} + xT := reflect.TypeOf(x) + sxT := reflect.SliceOf(xT) + xv := reflect.MakeSlice(sxT, 1, 1) // []T + xv0 := xv.Index(0) // xv[0] + xv0.Set(reflect.ValueOf(x)) + tt.array.Header.Raw = storage.AsByteSlice(xv.Interface()) + tt.t = Dtype{xT} + tt.mask = mask default: @@ -152,17 +144,11 @@ func FromMemory(ptr uintptr, memsize uintptr) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - tt.v = nil // if there were any underlying slices it should be GC'd - tt.array.Ptr = unsafe.Pointer(ptr) - tt.array.L = int(memsize / tt.t.Size()) - tt.array.C = int(memsize / tt.t.Size()) - tt.flag = MakeMemoryFlag(tt.flag, ManuallyManaged) - - if tt.IsNativelyAccessible() { - tt.array.fix() - } + tt.Header.Raw = nil // GC anything if needed + tt.Header.Raw = storage.FromMemory(ptr, memsize) + tt.flag = MakeMemoryFlag(tt.flag, ManuallyManaged) default: panic("Unsupported Tensor type") } diff --git a/defaultengine.go b/defaultengine.go index bc92e8c..d9138ae 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -1,76 +1,69 @@ -package tensor - -import ( - "unsafe" - - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/execution" -) - -// StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. -type StdEng struct { - execution.E -} - -// makeArray allocates a slice for the array -func (e StdEng) makeArray(arr *array, t Dtype, size int) { - memsize := calcMemSize(t, size) - s := make([]byte, memsize) - arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() -} - -func (e StdEng) AllocAccessible() bool { return true } -func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } -func (e StdEng) Free(mem Memory, size int64) error { return nil } -func (e StdEng) Memset(mem Memory, val interface{}) error { - if ms, ok := mem.(MemSetter); ok { - return ms.Memset(val) - } - return errors.Errorf("Cannot memset %v with StdEng", mem) -} - -func (e StdEng) Memclr(mem Memory) { - if z, ok := mem.(Zeroer); ok { - z.Zero() - } - return -} - -func (e StdEng) Memcpy(dst, src Memory) error { - switch dt := dst.(type) { - case *array: - switch st := src.(type) { - case *array: - copyArray(dt, st) - return nil - case arrayer: - copyArray(dt, st.arrPtr()) - return nil - } - case arrayer: - switch st := src.(type) { - case *array: - copyArray(dt.arrPtr(), st) - return nil - case arrayer: - copyArray(dt.arrPtr(), st.arrPtr()) - return nil - } - } - return errors.Errorf("Failed to copy %T %T", dst, src) -} - -func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } - -func (e StdEng) WorksWith(order DataOrder) bool { return true } - -func (e StdEng) checkAccessible(t Tensor) error { - if !t.IsNativelyAccessible() { - return errors.Errorf(inaccessibleData, t) - } - return nil -} +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/execution" +) + +// StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. +type StdEng struct { + execution.E +} + +// makeArray allocates a slice for the array +func (e StdEng) makeArray(arr *array, t Dtype, size int) { + arr.Raw = malloc(t, size) + arr.t = t +} + +func (e StdEng) AllocAccessible() bool { return true } +func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } +func (e StdEng) Free(mem Memory, size int64) error { return nil } +func (e StdEng) Memset(mem Memory, val interface{}) error { + if ms, ok := mem.(MemSetter); ok { + return ms.Memset(val) + } + return errors.Errorf("Cannot memset %v with StdEng", mem) +} + +func (e StdEng) Memclr(mem Memory) { + if z, ok := mem.(Zeroer); ok { + z.Zero() + } + return +} + +func (e StdEng) Memcpy(dst, src Memory) error { + switch dt := dst.(type) { + case *array: + switch st := src.(type) { + case *array: + copyArray(dt, st) + return nil + case arrayer: + copyArray(dt, st.arrPtr()) + return nil + } + case arrayer: + switch st := src.(type) { + case *array: + copyArray(dt.arrPtr(), st) + return nil + case arrayer: + copyArray(dt.arrPtr(), st.arrPtr()) + return nil + } + } + return errors.Errorf("Failed to copy %T %T", dst, src) +} + +func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } + +func (e StdEng) WorksWith(order DataOrder) bool { return true } + +func (e StdEng) checkAccessible(t Tensor) error { + if !t.IsNativelyAccessible() { + return errors.Errorf(inaccessibleData, t) + } + return nil +} diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 72b171d..918e1ca 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -48,7 +48,6 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.AddIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -70,7 +69,6 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Add(typ, retVal.hdr(), dataB) } - return } @@ -115,7 +113,6 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.SubIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -137,7 +134,6 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Sub(typ, retVal.hdr(), dataB) } - return } @@ -182,7 +178,6 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.MulIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -204,7 +199,6 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Mul(typ, retVal.hdr(), dataB) } - return } @@ -249,7 +243,6 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.DivIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -271,7 +264,6 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Div(typ, retVal.hdr(), dataB) } - return } @@ -316,7 +308,6 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.PowIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -338,7 +329,6 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Pow(typ, retVal.hdr(), dataB) } - return } @@ -383,7 +373,6 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.ModIter(typ, retVal.hdr(), dataB, ait, bit) } - return } switch { @@ -405,7 +394,6 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err } err = e.E.Mod(typ, retVal.hdr(), dataB) } - return } @@ -429,15 +417,15 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Add") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Add") } scalarHeader = dataA @@ -471,6 +459,9 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.AddIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -502,6 +493,9 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } err = e.E.Add(typ, retVal.hdr(), dataB) } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -526,15 +520,15 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Sub") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Sub") } scalarHeader = dataA @@ -568,6 +562,9 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.SubIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -599,6 +596,9 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } err = e.E.Sub(typ, retVal.hdr(), dataB) } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -623,15 +623,15 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mul") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mul") } scalarHeader = dataA @@ -665,6 +665,9 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.MulIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -696,6 +699,9 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } err = e.E.Mul(typ, retVal.hdr(), dataB) } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -720,15 +726,15 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Div") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Div") } scalarHeader = dataA @@ -762,6 +768,9 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.DivIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -793,6 +802,9 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } err = e.E.Div(typ, retVal.hdr(), dataB) } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -817,15 +829,15 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Pow") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Pow") } scalarHeader = dataA @@ -859,6 +871,9 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.PowIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -890,6 +905,9 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } err = e.E.Pow(typ, retVal.hdr(), dataB) } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -914,15 +932,15 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mod") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Mod") } scalarHeader = dataA @@ -956,6 +974,9 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.ModIter(typ, dataA, retVal.hdr(), ait, bit) } } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -987,6 +1008,9 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } err = e.E.Mod(typ, retVal.hdr(), dataB) } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 8c2b919..1d6ff48 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -66,7 +66,6 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.GtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -83,7 +82,6 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.Gt(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -146,7 +144,6 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.GteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -163,7 +160,6 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.Gte(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -226,7 +222,6 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.LtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -243,7 +238,6 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro err = e.E.Lt(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -306,7 +300,6 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.LteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -323,7 +316,6 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.Lte(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -386,7 +378,6 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.EqIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -403,7 +394,6 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.Eq(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -466,7 +456,6 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.NeIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - return } @@ -483,7 +472,6 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er err = e.E.Ne(typ, dataA, dataB, dataReuse) retVal = reuse } - return } @@ -512,15 +500,15 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gt") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gt") } scalarHeader = dataA @@ -563,12 +551,15 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.GtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -599,6 +590,9 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Gt(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -628,15 +622,15 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gte") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Gte") } scalarHeader = dataA @@ -679,12 +673,15 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.GteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -715,6 +712,9 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.Gte(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -744,15 +744,15 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lt") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lt") } scalarHeader = dataA @@ -795,12 +795,15 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.LtIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -831,6 +834,9 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Lt(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -860,15 +866,15 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lte") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Lte") } scalarHeader = dataA @@ -911,12 +917,15 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.LteIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -947,6 +956,9 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func err = e.E.Lte(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -972,15 +984,15 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Eq") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Eq") } scalarHeader = dataA @@ -1023,12 +1035,15 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.EqIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -1059,6 +1074,9 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Eq(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } @@ -1084,15 +1102,15 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Ne") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Ne") } scalarHeader = dataA @@ -1135,12 +1153,15 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.NeIter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -1171,6 +1192,9 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO err = e.E.Ne(typ, dataA, dataB, dataReuse) retVal = reuse } + if newAlloc { + freeScalar(scalarHeader.Raw) + } returnHeader(scalarHeader) return } diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index f45c9bb..59c3b69 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -180,8 +180,8 @@ func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, bDestEnd := (destStart + tmp) * int(darr.t.Size()) // then we get the data as a slice of raw bytes - sBS := storage.AsByteSlice(&sarr.Header, sarr.t.Type) - dBS := storage.AsByteSlice(&darr.Header, darr.t.Type) + sBS := sarr.Header.Raw + dBS := darr.Header.Raw // recall that len(src) < len(dest) // it's easier to understand if we define the ranges. diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index 368ddb5..879ca28 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -2,7 +2,6 @@ package tensor import ( "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" ) // This file contains code for the execution engine to stack tensors @@ -366,7 +365,7 @@ func (e StdEng) doViewStack8(t, retVal DenseTensor, axisStride, batches int, it func (e StdEng) doViewStackArbitrary(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { dt := t.Dtype() - data := storage.AsByteSlice(retVal.hdr(), dt.Type)[:0] + data := retVal.hdr().Raw[:0] // truncate to 0 size := int(dt.Size()) var mask []bool var retIsMasked bool @@ -385,8 +384,7 @@ func (e StdEng) doViewStackArbitrary(t, retVal DenseTensor, axisStride, batches tmask = mt.Mask() isMasked = mt.IsMasked() } - dt := t.Dtype() - bs := storage.AsByteSlice(t.hdr(), dt.Type) + bs := t.hdr().Raw for last = 0; last < axisStride; last++ { id, err := it.Next() diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index 76a2f0a..cef220e 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -4,7 +4,6 @@ package tensor import ( "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" ) func (e StdEng) Transpose(a Tensor, expStrides []int) error { @@ -140,7 +139,7 @@ func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { // arbs := storage.AsByteSlice(tmpArr.hdr(), rtype) arbs := tmpArr.byteSlice() - orig := storage.AsByteSlice(a.hdr(), rtype) + orig := a.hdr().Raw it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index 612e1cc..8627927 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -4,7 +4,6 @@ package tensor import ( "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" ) func (e StdEng) Transpose(a Tensor, expStrides []int) error { @@ -290,7 +289,7 @@ func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { saved := make([]byte, typeSize, typeSize) tmp := make([]byte, typeSize, typeSize) var i int - data := storage.AsByteSlice(a.hdr(), rtype) + data := a.arr().Raw if len(data) < 4*typeSize { return } diff --git a/defaultengine_prep.go b/defaultengine_prep.go index ea2b2f5..261367a 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -159,10 +159,10 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea return } -func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) { +func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, iit Iterator, useIter bool, newAlloc bool, err error) { // get data dataA = a.hdr() - dataB = scalarToHeader(b) + dataB, newAlloc = scalarToHeader(b) if reuse != nil { dataReuse = reuse.hdr() } @@ -182,9 +182,9 @@ func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse return } -func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, bit, iit Iterator, useIter bool, err error) { +func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, bit, iit Iterator, useIter bool, newAlloc bool, err error) { // get data - dataA = scalarToHeader(a) + dataA, newAlloc = scalarToHeader(a) dataB = b.hdr() if reuse != nil { dataReuse = reuse.hdr() diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index 82d48f2..45859a4 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -1,8 +1,6 @@ package tensor import ( - "unsafe" - "github.com/pkg/errors" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" @@ -118,12 +116,11 @@ func (e Float32Engine) makeArray(arr *array, t Dtype, size int) { if t != Float32 { panic("Float32Engine only creates float32s") } - s := make([]float32, size) + if size < 0 { + panic("Cannot have negative sizes when making array") + } + arr.Header.Raw = make([]byte, size*4) arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() } func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { diff --git a/defaultenginefloat32_test.go b/defaultenginefloat32_test.go new file mode 100644 index 0000000..0ebd016 --- /dev/null +++ b/defaultenginefloat32_test.go @@ -0,0 +1,42 @@ +package tensor + +import ( + "testing" + "testing/quick" +) + +func TestFloat32Engine_makeArray(t *testing.T) { + + // the uint16 is just to make sure that tests are correctly run. + // we don't want the quicktest to randomly generate a size that is so large + // that Go takes a long time just to allocate. We'll test the other sizes (like negative numbers) + // after the quick test. + f := func(sz uint16) bool { + size := int(sz) + e := Float32Engine{StdEng{}} + dt := Float32 + arr := array{} + + e.makeArray(&arr, dt, size) + + if len(arr.Raw) != size*4 { + t.Errorf("Expected raw to be size*4. Got %v instead", len(arr.Raw)) + return false + } + v, ok := arr.Data().([]float32) + if !ok { + t.Errorf("Expected v to be []float32. Got %T instead", arr.Data()) + return false + } + + if len(v) != size { + return false + } + return true + } + + if err := quick.Check(f, nil); err != nil { + t.Errorf("Quick test failed %v", err) + } + +} diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index b0d9466..21bba43 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -1,8 +1,6 @@ package tensor import ( - "unsafe" - "github.com/pkg/errors" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" @@ -118,12 +116,8 @@ func (e Float64Engine) makeArray(arr *array, t Dtype, size int) { if t != Float64 { panic("Float64Engine only creates float64s") } - s := make([]float64, size) + arr.Header.Raw = make([]byte, size*8) arr.t = t - arr.L = size - arr.C = size - arr.Ptr = unsafe.Pointer(&s[0]) - arr.fix() } func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { diff --git a/defaultenginefloat64_test.go b/defaultenginefloat64_test.go new file mode 100644 index 0000000..2d9391a --- /dev/null +++ b/defaultenginefloat64_test.go @@ -0,0 +1,42 @@ +package tensor + +import ( + "testing" + "testing/quick" +) + +func TestFloat64Engine_makeArray(t *testing.T) { + + // the uint16 is just to make sure that tests are correctly run. + // we don't want the quicktest to randomly generate a size that is so large + // that Go takes a long time just to allocate. We'll test the other sizes (like negative numbers) + // after the quick test. + f := func(sz uint16) bool { + size := int(sz) + e := Float64Engine{StdEng{}} + dt := Float64 + arr := array{} + + e.makeArray(&arr, dt, size) + + if len(arr.Raw) != size*8 { + t.Errorf("Expected raw to be size*8. Got %v instead", len(arr.Raw)) + return false + } + v, ok := arr.Data().([]float64) + if !ok { + t.Errorf("Expected v to be []float32. Got %T instead", arr.Data()) + return false + } + + if len(v) != size { + return false + } + return true + } + + if err := quick.Check(f, nil); err != nil { + t.Errorf("Quick test failed %v", err) + } + +} diff --git a/dense.go b/dense.go index fa9693d..d647ab5 100644 --- a/dense.go +++ b/dense.go @@ -6,6 +6,7 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" ) const ( @@ -47,14 +48,12 @@ func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { } func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { - size := shape.TotalSize() - if shape.IsScalar() { - size = 1 - } + // size := shape.TotalSize() + //if shape.IsScalar() { + // size = 1 + //} retVal = borrowDense() retVal.array.t = dt - retVal.array.L = size - retVal.array.C = size retVal.AP.zeroWithDims(shape.Dims()) for _, opt := range opts { @@ -65,8 +64,7 @@ func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) } func (t *Dense) fromSlice(x interface{}) { - t.array.Ptr = nil - t.array.v = nil + t.array.Header.Raw = nil // GC anything else t.array.fromSlice(x) } @@ -88,15 +86,13 @@ func (t *Dense) makeArray(size int) { default: } - mem, err := t.e.Alloc(calcMemSize(t.t, size)) + memsize := calcMemSize(t.t, size) + mem, err := t.e.Alloc(memsize) if err != nil { panic(err) } - t.array.Ptr = mem.Pointer() - t.array.L = size - t.array.C = size - t.array.fix() + t.array.Raw = storage.FromMemory(mem.Uintptr(), uintptr(memsize)) return } @@ -111,28 +107,25 @@ func (t *Dense) Data() interface{} { if t.IsScalar() { return t.Get(0) } - if t.v == nil { - // build a type of []T - shdr := reflect.SliceHeader{ - Data: uintptr(t.Header.Ptr), - Len: t.Header.L, - Cap: t.Header.C, - } - sliceT := reflect.SliceOf(t.t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - t.v = val.Interface() + // build a type of []T + shdr := reflect.SliceHeader{ + Data: t.array.Uintptr(), + Len: t.array.Len(), + Cap: t.array.Cap(), } - return t.v + sliceT := reflect.SliceOf(t.t.Type) + ptr := unsafe.Pointer(&shdr) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + return val.Interface() } // DataSize returns the size of the underlying array. Typically t.DataSize() == t.Shape().TotalSize() func (t *Dense) DataSize() int { if t.IsScalar() { - return 0 + return 0 // DOUBLE CHECK } - return t.L + return t.array.Len() } // Engine returns the execution engine associated with this Tensor @@ -212,7 +205,7 @@ func (t *Dense) Clone() interface{} { retVal.e = t.e retVal.oe = t.oe retVal.flag = t.flag - retVal.makeArray(t.L) + retVal.makeArray(t.Len()) if !t.old.IsZero() { retVal.old = t.old.Clone() @@ -270,8 +263,8 @@ func (t *Dense) MaskFromDense(tts ...*Dense) { // Private methods -func (t *Dense) cap() int { return t.array.C } -func (t *Dense) len() int { return t.array.L } // exactly the same as DataSize +func (t *Dense) cap() int { return t.array.Cap() } +func (t *Dense) len() int { return t.array.Len() } // exactly the same as DataSize func (t *Dense) arr() array { return t.array } func (t *Dense) arrPtr() *array { return &t.array } @@ -294,16 +287,16 @@ func (t *Dense) fix() { } switch { - case t.IsScalar() && t.array.Ptr == nil: + case t.IsScalar() && t.array.Header.Raw == nil: t.makeArray(1) - case t.Shape() == nil && t.array.Ptr != nil: - size := t.L + case t.Shape() == nil && t.array.Header.Raw != nil: + size := t.Len() if size == 1 { t.SetShape() // scalar } else { t.SetShape(size) // vector } - case t.array.Ptr == nil && t.t != Dtype{}: + case t.array.Header.Raw == nil && t.t != Dtype{}: size := t.Shape().TotalSize() t.makeArray(size) @@ -330,11 +323,11 @@ func (t *Dense) makeMask() { // sanity is a function that sanity checks that a tensor is correct. func (t *Dense) sanity() error { - if !t.AP.IsZero() && t.Shape() == nil && t.array.Ptr == nil { + if !t.AP.IsZero() && t.Shape() == nil && t.array.Header.Raw == nil { return errors.New(emptyTensor) } - size := t.L + size := t.Len() expected := t.Size() if t.viewOf == 0 && size != expected && !t.IsScalar() { return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed") diff --git a/dense_assign.go b/dense_assign.go index bd8bceb..5f44897 100644 --- a/dense_assign.go +++ b/dense_assign.go @@ -10,14 +10,14 @@ func overlaps(a, b DenseTensor) bool { } aarr := a.arr() barr := b.arr() - if aarr.Ptr == barr.Ptr { + if aarr.Uintptr() == barr.Uintptr() { return true } - aptr := uintptr(aarr.Ptr) - bptr := uintptr(barr.Ptr) + aptr := aarr.Uintptr() + bptr := barr.Uintptr() - capA := aptr + uintptr(aarr.C)*a.Dtype().Size() - capB := bptr + uintptr(barr.C)*b.Dtype().Size() + capA := aptr + uintptr(cap(aarr.Header.Raw)) + capB := bptr + uintptr(cap(barr.Header.Raw)) switch { case aptr < bptr: diff --git a/dense_io.go b/dense_io.go index e4717f8..7bb9608 100644 --- a/dense_io.go +++ b/dense_io.go @@ -808,7 +808,7 @@ func (t *Dense) FBDecode(buf []byte) error { // allocated data. Now time to actually copy over the data db := t.byteSlice() copy(db, serialized.DataBytes()) - t.forcefix() + t.fix() return t.sanity() } diff --git a/dense_matop.go b/dense_matop.go index 5ce693b..7e81419 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -243,7 +243,6 @@ func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) } view.AP.zero() - view.array.v = nil // reset view.t = t.t view.e = t.e diff --git a/dense_matop_test.go b/dense_matop_test.go index 755309f..cf2ce7a 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -607,7 +607,9 @@ func TestDense_Slice(t *testing.T) { assert.True(V.(*Dense).old.IsZero()) // slice a sliced + t.Logf("%v", V) V, err = V.Slice(makeRS(1, 2)) + t.Logf("%v", V) assert.True(ScalarShape().Eq(V.Shape())) assert.Equal(float32(3), V.Data()) diff --git a/engine.go b/engine.go index a8ec63c..1ac8400 100644 --- a/engine.go +++ b/engine.go @@ -1,9 +1,5 @@ package tensor -import ( - "unsafe" -) - // Memory is a representation of memory of the value. // // The main reason for requiring both Uintptr() and Pointer() methods is because while Go currently does not have a compacting @@ -13,7 +9,6 @@ import ( type Memory interface { Uintptr() uintptr MemSize() uintptr - Pointer() unsafe.Pointer } // Engine is a representation of an execution engine. diff --git a/genlib2/agg1_body.go b/genlib2/agg1_body.go index 2ca9d96..f738d0c 100644 --- a/genlib2/agg1_body.go +++ b/genlib2/agg1_body.go @@ -5,8 +5,8 @@ import "text/template" // level 1 aggregation (internal.E) templates const ( - eArithRaw = `as := isScalar(a) - bs := isScalar(b) + eArithRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -25,18 +25,18 @@ const ( default: {{if and $isDiv $p}} err = {{end}} Vec{{$name}}{{short .}}(at, bt) } - return + return {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` - eArithIncrRaw = `as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + eArithIncrRaw = `as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} switch t { @@ -60,14 +60,14 @@ const ( default: {{$name}}Incr{{short .}}(at, bt,it) } - return + return {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` - eArithIterRaw = `as := isScalar(a) - bs := isScalar(b) + eArithIterRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -91,12 +91,12 @@ const ( } ` - eArithIterIncrRaw = `as :=isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + eArithIterIncrRaw = `as :=isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} switch t { @@ -126,7 +126,7 @@ const ( } ` - eMapRaw = `as := isScalar(a) + eMapRaw = `as := isScalar(a, t) switch t { {{range .Kinds -}} case {{reflectKind .}}: @@ -181,11 +181,11 @@ const ( Map{{short .}}(f0, at) } {{end -}} - + {{end -}} default: return errors.Errorf("Cannot map t of %v", t) - + } ` @@ -233,8 +233,8 @@ const ( } ` - eCmpSameRaw = `as := isScalar(a) - bs := isScalar(b) + eCmpSameRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -252,20 +252,20 @@ const ( default: {{$name}}Same{{short .}}(at, bt) } - return + return {{end -}} {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) }` - eCmpBoolRaw = `as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + eCmpBoolRaw = `as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} @@ -285,15 +285,15 @@ const ( default: {{$name}}{{short .}}(at, bt, rt) } - return + return {{end -}} default: return errors.Errorf("Unsupported type %v for {{$name}}", t) } ` - eCmpSameIterRaw = `as := isScalar(a) - bs := isScalar(b) + eCmpSameIterRaw = `as := isScalar(a, t) + bs := isScalar(b, t) {{$name := .Name}} switch t { {{range .Kinds -}} @@ -319,13 +319,13 @@ const ( } ` - eCmpBoolIterRaw = `as :=isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + eCmpBoolIterRaw = `as :=isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } {{$name := .Name}} @@ -478,7 +478,7 @@ const ( return errors.Wrap(errors.Errorf(typeMismatch, max, maxVal), "Clamp() max") } Clamp{{short .}}(a.{{sliceOf .}}, min, max) - return nil + return nil {{end -}} default: return errors.Errorf("Unsupported type %v for Clamp", t) @@ -553,7 +553,7 @@ const ( if _, ok := err.(NoOpError); ok { err = nil } - return + return {{end -}} default: return nil, errors.Errorf("Unsupported type %v for Arg{{.Name}}", t) diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index cf87bb0..1e16123 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -51,15 +51,15 @@ const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); typ := t.Dtype().Type var ait, bit, iit Iterator var dataA, dataB, dataReuse, scalarHeader *storage.Header - var useIter bool + var useIter, newAlloc bool if leftTensor { - if dataA, dataB, dataReuse, ait, iit, useIter, err = prepDataVS(t, s, reuse); err != nil { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}") } scalarHeader = dataB } else { - if dataA, dataB, dataReuse, bit, iit, useIter, err = prepDataSV(s, t, reuse); err != nil { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}") } scalarHeader = dataA @@ -133,7 +133,12 @@ const agg2BodyRaw = `if useIter { } {{end -}} } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return } switch { @@ -184,7 +189,12 @@ const agg2BodyRaw = `if useIter { err = e.E.{{.Name}}(typ, retVal.hdr(), dataB) {{end -}} } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return ` @@ -242,13 +252,18 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created err = e.E.{{.Name}}Iter(typ, dataA, dataB, dataReuse, ait, bit, iit) retVal = reuse } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return } {{if not .VV -}} // handle special case where A and B have both len 1 - if dataA.L == 1 && dataB.L == 1 { + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ,dataReuse,dataA) @@ -288,7 +303,12 @@ const agg2CmpBodyRaw = `// check to see if anything needs to be created err = e.E.{{.Name}}(typ, dataA, dataB, dataReuse) retVal = reuse } - {{if not .VV -}}returnHeader(scalarHeader){{end}} + {{if not .VV -}} + if newAlloc{ + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + {{end -}} return ` diff --git a/genlib2/array_getset.go b/genlib2/array_getset.go index 73a686b..a75edd5 100644 --- a/genlib2/array_getset.go +++ b/genlib2/array_getset.go @@ -6,7 +6,7 @@ import ( "text/template" ) -const asSliceRaw = `func (h *Header) {{asType . | strip | title}}s() []{{asType .}} { return *(*[]{{asType .}})(unsafe.Pointer(h)) } +const asSliceRaw = `func (h *Header) {{asType . | strip | title}}s() []{{asType .}} {return (*(*[]{{asType .}})(unsafe.Pointer(&h.Raw)))[:h.TypedLen({{short . | unexport}}Type):h.TypedLen({{short . | unexport}}Type)]} ` const setBasicRaw = `func (h *Header) Set{{short . }}(i int, x {{asType . }}) { h.{{sliceOf .}}[i] = x } @@ -23,11 +23,10 @@ func (a *array) Get(i int) interface{} { {{else -}} case reflect.{{reflectKind .}}: return a.{{getOne .}}(i) - {{end -}} + {{end -}}; {{end -}} default: - at := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i) * a.t.Size()) - val := reflect.NewAt(a.t.Type, at) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) return val.Interface() } @@ -47,8 +46,7 @@ func (a *array) Set(i int, x interface{}) { {{end -}} default: xv := reflect.ValueOf(x) - want := unsafe.Pointer(uintptr(a.Ptr) + uintptr(i)*a.t.Size()) - val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -76,10 +74,9 @@ func (a *array) Memset(x interface{}) error { } xv := reflect.ValueOf(x) - ptr := uintptr(a.Ptr) - for i := 0; i < a.L; i++ { - want := ptr + uintptr(i)*a.t.Size() - val := reflect.NewAt(a.t.Type, unsafe.Pointer(want)) + l := a.Len() + for i := 0; i < l; i++ { + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -94,7 +91,7 @@ func (a array) Eq(other interface{}) bool { return false } - if oa.L != a.L { + if oa.Len() != a.Len() { return false } /* @@ -104,7 +101,7 @@ func (a array) Eq(other interface{}) bool { */ // same exact thing - if uintptr(oa.Ptr) == uintptr(a.Ptr){ + if uintptr(unsafe.Pointer(&oa.Header.Raw[0])) == uintptr(unsafe.Pointer(&a.Header.Raw[0])){ return true } @@ -121,7 +118,7 @@ func (a array) Eq(other interface{}) bool { {{end -}} {{end -}} default: - for i := 0; i < a.L; i++{ + for i := 0; i < a.Len(); i++{ if !reflect.DeepEqual(a.Get(i), oa.Get(i)){ return false } @@ -179,18 +176,18 @@ const copyArrayIterRaw = `func copyArrayIter(dst, src array, diter, siter Iterat ` const memsetIterRaw = ` -func (t *array) memsetIter(x interface{}, it Iterator) (err error) { +func (a *array) memsetIter(x interface{}, it Iterator) (err error) { var i int - switch t.t{ + switch a.t{ {{range .Kinds -}} {{if isParameterized . -}} {{else -}} case {{reflectKind .}}: xv, ok := x.({{asType .}}) if !ok { - return errors.Errorf(dtypeMismatch, t.t, x) + return errors.Errorf(dtypeMismatch, a.t, x) } - data := t.{{sliceOf .}} + data := a.{{sliceOf .}} for i, err = it.Next(); err == nil; i, err = it.Next(){ data[i] = xv } @@ -199,10 +196,8 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { {{end -}} default: xv := reflect.ValueOf(x) - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next(){ - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) val.Set(xv) } @@ -213,14 +208,14 @@ func (t *array) memsetIter(x interface{}, it Iterator) (err error) { ` -const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ +const zeroIterRaw = `func (a *array) zeroIter(it Iterator) (err error){ var i int - switch t.t { + switch a.t { {{range .Kinds -}} {{if isParameterized . -}} {{else -}} case {{reflectKind .}}: - data := t.{{sliceOf .}} + data := a.{{sliceOf .}} for i, err = it.Next(); err == nil; i, err = it.Next(){ data[i] = {{if eq .String "bool" -}} false @@ -232,12 +227,10 @@ const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ {{end -}} {{end -}} default: - ptr := uintptr(t.Ptr) for i, err = it.Next(); err == nil; i, err = it.Next(){ - want := ptr + uintptr(i)*t.t.Size() - val := reflect.NewAt(t.t.Type, unsafe.Pointer(want)) + val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) val = reflect.Indirect(val) - val.Set(reflect.Zero(t.t)) + val.Set(reflect.Zero(a.t)) } err = handleNoOp(err) } @@ -245,16 +238,26 @@ const zeroIterRaw = `func (t *array) zeroIter(it Iterator) (err error){ } ` +const reflectConstTemplateRaw = `var ( + {{range .Kinds -}} + {{if isParameterized . -}} + {{else -}} + {{short . | unexport}}Type = reflect.TypeOf({{asType .}}({{if eq .String "bool" -}} false {{else if eq .String "string" -}}"" {{else if eq .String "unsafe.Pointer" -}}nil {{else -}}0{{end -}})) + {{end -}} + {{end -}} +)` + var ( - AsSlice *template.Template - SimpleSet *template.Template - SimpleGet *template.Template - Get *template.Template - Set *template.Template - Memset *template.Template - MemsetIter *template.Template - Eq *template.Template - ZeroIter *template.Template + AsSlice *template.Template + SimpleSet *template.Template + SimpleGet *template.Template + Get *template.Template + Set *template.Template + Memset *template.Template + MemsetIter *template.Template + Eq *template.Template + ZeroIter *template.Template + ReflectType *template.Template ) func init() { @@ -267,6 +270,7 @@ func init() { MemsetIter = template.Must(template.New("MemsetIter").Funcs(funcs).Parse(memsetIterRaw)) Eq = template.Must(template.New("ArrayEq").Funcs(funcs).Parse(arrayEqRaw)) ZeroIter = template.Must(template.New("Zero").Funcs(funcs).Parse(zeroIterRaw)) + ReflectType = template.Must(template.New("ReflectType").Funcs(funcs).Parse(reflectConstTemplateRaw)) } func generateArrayMethods(f io.Writer, ak Kinds) { @@ -295,3 +299,8 @@ func generateHeaderGetSet(f io.Writer, ak Kinds) { } } } + +func generateReflectTypes(f io.Writer, ak Kinds) { + ReflectType.Execute(f, ak) + fmt.Fprintf(f, "\n\n\n") +} diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index e6e4b0f..814067f 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -560,7 +560,7 @@ func (t *Dense) FBDecode(buf []byte) error { // allocated data. Now time to actually copy over the data db := t.byteSlice() copy(db, serialized.DataBytes()) - t.forcefix() + t.fix() return t.sanity() } ` diff --git a/genlib2/main.go b/genlib2/main.go index fafd74c..328cd19 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -52,6 +52,7 @@ func main() { pregenerate() // storage + pipeline(storageLoc, "consts.go", Kinds{allKinds}, generateReflectTypes) pipeline(storageLoc, "getset.go", Kinds{allKinds}, generateHeaderGetSet) pipeline(tensorPkgLoc, "array_getset.go", Kinds{allKinds}, generateArrayMethods) diff --git a/internal/execution/e.go b/internal/execution/e.go index 670ae0b..83fcc1f 100644 --- a/internal/execution/e.go +++ b/internal/execution/e.go @@ -38,7 +38,7 @@ var ( UnsafePointer = reflect.TypeOf(unsafe.Pointer(&Uintptr)) ) -func isScalar(a *storage.Header) bool { return a.L == 1 } +func isScalar(a *storage.Header, t reflect.Type) bool { return a.TypedLen(t) == 1 } type errorIndices []int diff --git a/internal/execution/eng_arith.go b/internal/execution/eng_arith.go index f626a3d..f3de110 100644 --- a/internal/execution/eng_arith.go +++ b/internal/execution/eng_arith.go @@ -10,8 +10,8 @@ import ( ) func (e E) Add(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -230,8 +230,8 @@ func (e E) Add(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Sub(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -436,8 +436,8 @@ func (e E) Sub(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Mul(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -642,8 +642,8 @@ func (e E) Mul(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Div(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -848,8 +848,8 @@ func (e E) Div(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Pow(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Float32: @@ -914,8 +914,8 @@ func (e E) Pow(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) Mod(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -1092,11 +1092,11 @@ func (e E) Mod(t reflect.Type, a *storage.Header, b *storage.Header) (err error) } func (e E) AddIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1406,11 +1406,11 @@ func (e E) AddIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) SubIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1700,11 +1700,11 @@ func (e E) SubIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) MulIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1994,11 +1994,11 @@ func (e E) MulIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) DivIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2288,11 +2288,11 @@ func (e E) DivIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) PowIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2382,11 +2382,11 @@ func (e E) PowIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) ModIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on scalar increment. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2636,8 +2636,8 @@ func (e E) ModIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *s } func (e E) AddIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -2856,8 +2856,8 @@ func (e E) AddIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } func (e E) SubIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3062,8 +3062,8 @@ func (e E) SubIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } func (e E) MulIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3268,8 +3268,8 @@ func (e E) MulIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } func (e E) DivIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3474,8 +3474,8 @@ func (e E) DivIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } func (e E) PowIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Float32: @@ -3540,8 +3540,8 @@ func (e E) PowIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } func (e E) ModIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -3718,12 +3718,12 @@ func (e E) ModIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } func (e E) AddIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4018,12 +4018,12 @@ func (e E) AddIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc } func (e E) SubIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4299,12 +4299,12 @@ func (e E) SubIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc } func (e E) MulIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4580,12 +4580,12 @@ func (e E) MulIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc } func (e E) DivIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4861,12 +4861,12 @@ func (e E) DivIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc } func (e E) PowIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4952,12 +4952,12 @@ func (e E) PowIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc } func (e E) ModIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, incr *storage.Header, ait Iterator, bit Iterator, iit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - is := isScalar(incr) + as := isScalar(a, t) + bs := isScalar(b, t) + is := isScalar(incr, t) if ((as && !bs) || (bs && !as)) && is { - return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { diff --git a/internal/execution/eng_arith_manual.go b/internal/execution/eng_arith_manual.go index 941d644..3a620e6 100644 --- a/internal/execution/eng_arith_manual.go +++ b/internal/execution/eng_arith_manual.go @@ -8,26 +8,20 @@ import ( ) func (e E) AddSliced(t reflect.Type, dataA *storage.Header, dstStart, dstEnd int, dataB *storage.Header, srcStart, srcEnd int) (err error) { + ds := dstStart * int(t.Size()) + de := dstEnd * int(t.Size()) a := &storage.Header{ - Ptr: storage.ElementAt(dstStart, dataA.Ptr, t.Size()), - L: dstEnd - dstStart, - C: dataA.C - dstStart, - } - if a.C == 0 { - a.C = 1 + Raw: dataA.Raw[ds:de], } + ss := srcStart * int(t.Size()) + se := srcEnd * int(t.Size()) b := &storage.Header{ - Ptr: storage.ElementAt(srcStart, dataB.Ptr, t.Size()), - L: srcEnd - srcStart, - C: dataB.C - srcStart, - } - if b.C == 0 { - b.C = 1 + Raw: dataB.Raw[ss:se], } - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: diff --git a/internal/execution/eng_cmp.go b/internal/execution/eng_cmp.go index 9514f61..b2c4ece 100644 --- a/internal/execution/eng_cmp.go +++ b/internal/execution/eng_cmp.go @@ -10,13 +10,13 @@ import ( ) func (e E) Gt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -221,13 +221,13 @@ func (e E) Gt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) Gte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -432,13 +432,13 @@ func (e E) Gte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *sto } func (e E) Lt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -643,13 +643,13 @@ func (e E) Lt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) Lte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -854,13 +854,13 @@ func (e E) Lte(t reflect.Type, a *storage.Header, b *storage.Header, retVal *sto } func (e E) Eq(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1140,13 +1140,13 @@ func (e E) Eq(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) Ne(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is a scalar. a: %d, b %d", a.Len(), b.Len()) + return errors.Errorf("retVal is a scalar. a: %d, b %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -1426,8 +1426,8 @@ func (e E) Ne(t reflect.Type, a *storage.Header, b *storage.Header, retVal *stor } func (e E) GtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -1618,8 +1618,8 @@ func (e E) GtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) GteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -1810,8 +1810,8 @@ func (e E) GteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err er } func (e E) LtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -2002,8 +2002,8 @@ func (e E) LtSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) LteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -2194,8 +2194,8 @@ func (e E) LteSame(t reflect.Type, a *storage.Header, b *storage.Header) (err er } func (e E) EqSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: @@ -2442,8 +2442,8 @@ func (e E) EqSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) NeSame(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: @@ -2690,13 +2690,13 @@ func (e E) NeSame(t reflect.Type, a *storage.Header, b *storage.Header) (err err } func (e E) GtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -2888,13 +2888,13 @@ func (e E) GtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) GteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3086,13 +3086,13 @@ func (e E) GteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal } func (e E) LtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3284,13 +3284,13 @@ func (e E) LtIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) LteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3482,13 +3482,13 @@ func (e E) LteIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal } func (e E) EqIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -3750,13 +3750,13 @@ func (e E) EqIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) NeIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header, ait Iterator, bit Iterator, rit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) - rs := isScalar(retVal) + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(retVal, Bool) rt := retVal.Bools() if ((as && !bs) || (bs && !as)) && rs { - return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.Len(), b.Len()) + return errors.Errorf("retVal is scalar while len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) } switch t { @@ -4018,8 +4018,8 @@ func (e E) NeIter(t reflect.Type, a *storage.Header, b *storage.Header, retVal * } func (e E) GtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4210,8 +4210,8 @@ func (e E) GtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) GteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4402,8 +4402,8 @@ func (e E) GteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) LtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4594,8 +4594,8 @@ func (e E) LtSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) LteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Int: @@ -4786,8 +4786,8 @@ func (e E) LteSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) EqSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: @@ -5034,8 +5034,8 @@ func (e E) EqSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait } func (e E) NeSameIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { - as := isScalar(a) - bs := isScalar(b) + as := isScalar(a, t) + bs := isScalar(b, t) switch t { case Bool: diff --git a/internal/execution/eng_map.go b/internal/execution/eng_map.go index 17ca682..81cb2c4 100644 --- a/internal/execution/eng_map.go +++ b/internal/execution/eng_map.go @@ -11,7 +11,7 @@ import ( ) func (e E) Map(t reflect.Type, fn interface{}, a *storage.Header, incr bool) (err error) { - as := isScalar(a) + as := isScalar(a, t) switch t { case Bool: var f0 func(bool) bool diff --git a/internal/storage/consts.go b/internal/storage/consts.go new file mode 100644 index 0000000..7304ac5 --- /dev/null +++ b/internal/storage/consts.go @@ -0,0 +1,29 @@ +// Code generated by genlib2. DO NOT EDIT. + +package storage + +import ( + "reflect" + "unsafe" +) + +var ( + bType = reflect.TypeOf(bool(false)) + iType = reflect.TypeOf(int(0)) + i8Type = reflect.TypeOf(int8(0)) + i16Type = reflect.TypeOf(int16(0)) + i32Type = reflect.TypeOf(int32(0)) + i64Type = reflect.TypeOf(int64(0)) + uType = reflect.TypeOf(uint(0)) + u8Type = reflect.TypeOf(uint8(0)) + u16Type = reflect.TypeOf(uint16(0)) + u32Type = reflect.TypeOf(uint32(0)) + u64Type = reflect.TypeOf(uint64(0)) + uintptrType = reflect.TypeOf(uintptr(0)) + f32Type = reflect.TypeOf(float32(0)) + f64Type = reflect.TypeOf(float64(0)) + c64Type = reflect.TypeOf(complex64(0)) + c128Type = reflect.TypeOf(complex128(0)) + strType = reflect.TypeOf(string("")) + unsafePointerType = reflect.TypeOf(unsafe.Pointer(nil)) +) diff --git a/internal/storage/getset.go b/internal/storage/getset.go index 879a5e3..c60d61c 100644 --- a/internal/storage/getset.go +++ b/internal/storage/getset.go @@ -6,108 +6,144 @@ import "unsafe" /* bool */ -func (h *Header) Bools() []bool { return *(*[]bool)(unsafe.Pointer(h)) } +func (h *Header) Bools() []bool { + return (*(*[]bool)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(bType):h.TypedLen(bType)] +} func (h *Header) SetB(i int, x bool) { h.Bools()[i] = x } func (h *Header) GetB(i int) bool { return h.Bools()[i] } /* int */ -func (h *Header) Ints() []int { return *(*[]int)(unsafe.Pointer(h)) } +func (h *Header) Ints() []int { + return (*(*[]int)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(iType):h.TypedLen(iType)] +} func (h *Header) SetI(i int, x int) { h.Ints()[i] = x } func (h *Header) GetI(i int) int { return h.Ints()[i] } /* int8 */ -func (h *Header) Int8s() []int8 { return *(*[]int8)(unsafe.Pointer(h)) } +func (h *Header) Int8s() []int8 { + return (*(*[]int8)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i8Type):h.TypedLen(i8Type)] +} func (h *Header) SetI8(i int, x int8) { h.Int8s()[i] = x } func (h *Header) GetI8(i int) int8 { return h.Int8s()[i] } /* int16 */ -func (h *Header) Int16s() []int16 { return *(*[]int16)(unsafe.Pointer(h)) } +func (h *Header) Int16s() []int16 { + return (*(*[]int16)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i16Type):h.TypedLen(i16Type)] +} func (h *Header) SetI16(i int, x int16) { h.Int16s()[i] = x } func (h *Header) GetI16(i int) int16 { return h.Int16s()[i] } /* int32 */ -func (h *Header) Int32s() []int32 { return *(*[]int32)(unsafe.Pointer(h)) } +func (h *Header) Int32s() []int32 { + return (*(*[]int32)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i32Type):h.TypedLen(i32Type)] +} func (h *Header) SetI32(i int, x int32) { h.Int32s()[i] = x } func (h *Header) GetI32(i int) int32 { return h.Int32s()[i] } /* int64 */ -func (h *Header) Int64s() []int64 { return *(*[]int64)(unsafe.Pointer(h)) } +func (h *Header) Int64s() []int64 { + return (*(*[]int64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(i64Type):h.TypedLen(i64Type)] +} func (h *Header) SetI64(i int, x int64) { h.Int64s()[i] = x } func (h *Header) GetI64(i int) int64 { return h.Int64s()[i] } /* uint */ -func (h *Header) Uints() []uint { return *(*[]uint)(unsafe.Pointer(h)) } +func (h *Header) Uints() []uint { + return (*(*[]uint)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(uType):h.TypedLen(uType)] +} func (h *Header) SetU(i int, x uint) { h.Uints()[i] = x } func (h *Header) GetU(i int) uint { return h.Uints()[i] } /* uint8 */ -func (h *Header) Uint8s() []uint8 { return *(*[]uint8)(unsafe.Pointer(h)) } +func (h *Header) Uint8s() []uint8 { + return (*(*[]uint8)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u8Type):h.TypedLen(u8Type)] +} func (h *Header) SetU8(i int, x uint8) { h.Uint8s()[i] = x } func (h *Header) GetU8(i int) uint8 { return h.Uint8s()[i] } /* uint16 */ -func (h *Header) Uint16s() []uint16 { return *(*[]uint16)(unsafe.Pointer(h)) } +func (h *Header) Uint16s() []uint16 { + return (*(*[]uint16)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u16Type):h.TypedLen(u16Type)] +} func (h *Header) SetU16(i int, x uint16) { h.Uint16s()[i] = x } func (h *Header) GetU16(i int) uint16 { return h.Uint16s()[i] } /* uint32 */ -func (h *Header) Uint32s() []uint32 { return *(*[]uint32)(unsafe.Pointer(h)) } +func (h *Header) Uint32s() []uint32 { + return (*(*[]uint32)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u32Type):h.TypedLen(u32Type)] +} func (h *Header) SetU32(i int, x uint32) { h.Uint32s()[i] = x } func (h *Header) GetU32(i int) uint32 { return h.Uint32s()[i] } /* uint64 */ -func (h *Header) Uint64s() []uint64 { return *(*[]uint64)(unsafe.Pointer(h)) } +func (h *Header) Uint64s() []uint64 { + return (*(*[]uint64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(u64Type):h.TypedLen(u64Type)] +} func (h *Header) SetU64(i int, x uint64) { h.Uint64s()[i] = x } func (h *Header) GetU64(i int) uint64 { return h.Uint64s()[i] } /* uintptr */ -func (h *Header) Uintptrs() []uintptr { return *(*[]uintptr)(unsafe.Pointer(h)) } +func (h *Header) Uintptrs() []uintptr { + return (*(*[]uintptr)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(uintptrType):h.TypedLen(uintptrType)] +} func (h *Header) SetUintptr(i int, x uintptr) { h.Uintptrs()[i] = x } func (h *Header) GetUintptr(i int) uintptr { return h.Uintptrs()[i] } /* float32 */ -func (h *Header) Float32s() []float32 { return *(*[]float32)(unsafe.Pointer(h)) } +func (h *Header) Float32s() []float32 { + return (*(*[]float32)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(f32Type):h.TypedLen(f32Type)] +} func (h *Header) SetF32(i int, x float32) { h.Float32s()[i] = x } func (h *Header) GetF32(i int) float32 { return h.Float32s()[i] } /* float64 */ -func (h *Header) Float64s() []float64 { return *(*[]float64)(unsafe.Pointer(h)) } +func (h *Header) Float64s() []float64 { + return (*(*[]float64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(f64Type):h.TypedLen(f64Type)] +} func (h *Header) SetF64(i int, x float64) { h.Float64s()[i] = x } func (h *Header) GetF64(i int) float64 { return h.Float64s()[i] } /* complex64 */ -func (h *Header) Complex64s() []complex64 { return *(*[]complex64)(unsafe.Pointer(h)) } +func (h *Header) Complex64s() []complex64 { + return (*(*[]complex64)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(c64Type):h.TypedLen(c64Type)] +} func (h *Header) SetC64(i int, x complex64) { h.Complex64s()[i] = x } func (h *Header) GetC64(i int) complex64 { return h.Complex64s()[i] } /* complex128 */ -func (h *Header) Complex128s() []complex128 { return *(*[]complex128)(unsafe.Pointer(h)) } +func (h *Header) Complex128s() []complex128 { + return (*(*[]complex128)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(c128Type):h.TypedLen(c128Type)] +} func (h *Header) SetC128(i int, x complex128) { h.Complex128s()[i] = x } func (h *Header) GetC128(i int) complex128 { return h.Complex128s()[i] } /* string */ -func (h *Header) Strings() []string { return *(*[]string)(unsafe.Pointer(h)) } +func (h *Header) Strings() []string { + return (*(*[]string)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(strType):h.TypedLen(strType)] +} func (h *Header) SetStr(i int, x string) { h.Strings()[i] = x } func (h *Header) GetStr(i int) string { return h.Strings()[i] } /* unsafe.Pointer */ -func (h *Header) UnsafePointers() []unsafe.Pointer { return *(*[]unsafe.Pointer)(unsafe.Pointer(h)) } +func (h *Header) UnsafePointers() []unsafe.Pointer { + return (*(*[]unsafe.Pointer)(unsafe.Pointer(&h.Raw)))[:h.TypedLen(unsafePointerType):h.TypedLen(unsafePointerType)] +} func (h *Header) SetUnsafePointer(i int, x unsafe.Pointer) { h.UnsafePointers()[i] = x } func (h *Header) GetUnsafePointer(i int) unsafe.Pointer { return h.UnsafePointers()[i] } diff --git a/internal/storage/header.go b/internal/storage/header.go index 249f2fc..99414a2 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -9,22 +9,22 @@ import ( // With this, we wouldn't need to keep the uintptr. // This usually means additional pressure for the GC though, especially when passing around Headers type Header struct { - Ptr unsafe.Pointer - L int - C int + Raw []byte } -func (h *Header) Pointer() unsafe.Pointer { return h.Ptr } -func (h *Header) Len() int { return h.L } +func (h *Header) TypedLen(t reflect.Type) int { + sz := int(t.Size()) + return len(h.Raw) / sz +} func Copy(t reflect.Type, dst, src *Header) int { - if dst.L == 0 || src.L == 0 { + if len(dst.Raw) == 0 || len(src.Raw) == 0 { return 0 } - n := src.L - if dst.L < n { - n = dst.L + n := src.TypedLen(t) + if len(dst.Raw) < n { + n = dst.TypedLen(t) } // handle struct{} type @@ -37,15 +37,15 @@ func Copy(t reflect.Type, dst, src *Header) int { // otherwise, just copy bytes. // FUTURE: implement memmove - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) + dstBA := dst.Raw + srcBA := src.Raw copied := copy(dstBA, srcBA) return copied / int(t.Size()) } func CopySliced(t reflect.Type, dst *Header, dstart, dend int, src *Header, sstart, send int) int { - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) + dstBA := dst.Raw + srcBA := src.Raw size := int(t.Size()) ds := dstart * size @@ -57,8 +57,8 @@ func CopySliced(t reflect.Type, dst *Header, dstart, dend int, src *Header, ssta } func Fill(t reflect.Type, dst, src *Header) int { - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) + dstBA := dst.Raw + srcBA := src.Raw size := int(t.Size()) lenSrc := len(srcBA) @@ -74,8 +74,8 @@ func Fill(t reflect.Type, dst, src *Header) int { } func CopyIter(t reflect.Type, dst, src *Header, diter, siter Iterator) int { - dstBA := AsByteSlice(dst, t) - srcBA := AsByteSlice(src, t) + dstBA := dst.Raw + srcBA := src.Raw size := int(t.Size()) var idx, jdx, i, j, count int @@ -102,17 +102,30 @@ func CopyIter(t reflect.Type, dst, src *Header, diter, siter Iterator) int { return count } -func AsByteSlice(a *Header, t reflect.Type) []byte { - size := a.L * int(t.Size()) - b := make([]byte, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - hdr.Data = uintptr(a.Ptr) - hdr.Cap = size - hdr.Len = size - return b -} - // Element gets the pointer of ith element func ElementAt(i int, base unsafe.Pointer, typeSize uintptr) unsafe.Pointer { return unsafe.Pointer(uintptr(base) + uintptr(i)*typeSize) } + +// AsByteSlice takes a slice of anything and returns a casted-as-byte-slice view of it. +// This function panics if input is not a slice. +func AsByteSlice(x interface{}) []byte { + xV := reflect.ValueOf(x) + xT := reflect.TypeOf(x).Elem() // expects a []T + + hdr := reflect.SliceHeader{ + Data: xV.Pointer(), + Len: xV.Len() * int(xT.Size()), + Cap: xV.Cap() * int(xT.Size()), + } + return *(*[]byte)(unsafe.Pointer(&hdr)) +} + +func FromMemory(ptr uintptr, memsize uintptr) []byte { + hdr := reflect.SliceHeader{ + Data: ptr, + Len: int(memsize), + Cap: int(memsize), + } + return *(*[]byte)(unsafe.Pointer(&hdr)) +} diff --git a/known_race_test.go b/known_race_test.go index cb9e265..f6d5616 100644 --- a/known_race_test.go +++ b/known_race_test.go @@ -1,3 +1,4 @@ +// +build ignore // +build !race package tensor diff --git a/perf.go b/perf.go index 2d20df2..bc5c3aa 100644 --- a/perf.go +++ b/perf.go @@ -56,9 +56,7 @@ func returnHeader(hdr *storage.Header) { } func destroyHeader(hdr *storage.Header) { - hdr.Ptr = nil - hdr.L = 0 - hdr.C = 0 + hdr.Raw = nil } var densePool = make(chan *Dense, PoolSize) @@ -92,10 +90,7 @@ func ReturnTensor(t Tensor) { // array reset tt.t = Dtype{} - tt.array.Ptr = nil - tt.array.L = 0 - tt.array.C = 0 - tt.array.v = nil + tt.array.Header.Raw = nil // engine and flag reset tt.e = StdEng{} diff --git a/sparse.go b/sparse.go index 3126843..1a9da7c 100644 --- a/sparse.go +++ b/sparse.go @@ -2,7 +2,6 @@ package tensor import ( "reflect" - "unsafe" "sort" @@ -29,7 +28,7 @@ type coo struct { data array } -func (c *coo) Len() int { return c.data.L } +func (c *coo) Len() int { return c.data.Len() } func (c *coo) Less(i, j int) bool { if c.o.IsColMajor() { return c.colMajorLess(i, j) @@ -187,7 +186,7 @@ func (t *CS) Strides() []int { return nil } func (t *CS) Dtype() Dtype { return t.t } func (t *CS) Dims() int { return 2 } func (t *CS) Size() int { return t.s.TotalSize() } -func (t *CS) DataSize() int { return t.L } +func (t *CS) DataSize() int { return t.Len() } func (t *CS) Engine() Engine { return t.e } func (t *CS) DataOrder() DataOrder { return t.o } @@ -289,7 +288,7 @@ func (t *CS) Clone() interface{} { retVal.indptr = make([]int, len(t.indptr)) copy(retVal.indices, t.indices) copy(retVal.indptr, t.indptr) - retVal.array = makeArray(t.t, t.array.L) + retVal.array = makeArray(t.t, t.array.Len()) copyArray(&retVal.array, &t.array) retVal.e = t.e return retVal @@ -298,12 +297,11 @@ func (t *CS) Clone() interface{} { func (t *CS) IsScalar() bool { return false } func (t *CS) ScalarValue() interface{} { panic("Sparse Matrices cannot represent Scalar Values") } -func (t *CS) MemSize() uintptr { return uintptr(calcMemSize(t.t, t.array.L)) } -func (t *CS) Uintptr() uintptr { return uintptr(t.array.Ptr) } -func (t *CS) Pointer() unsafe.Pointer { return t.array.Ptr } +func (t *CS) MemSize() uintptr { return uintptr(calcMemSize(t.t, t.array.Len())) } +func (t *CS) Uintptr() uintptr { return t.array.Uintptr() } // NonZeroes returns the nonzeroes. In academic literature this is often written as NNZ. -func (t *CS) NonZeroes() int { return t.L } +func (t *CS) NonZeroes() int { return t.Len() } func (t *CS) RequiresIterator() bool { return true } func (t *CS) Iterator() Iterator { return NewFlatSparseIterator(t) } diff --git a/tensor.go b/tensor.go index ff7e347..071ca67 100644 --- a/tensor.go +++ b/tensor.go @@ -6,7 +6,6 @@ import ( "encoding/gob" "fmt" "io" - "unsafe" "github.com/pkg/errors" ) @@ -62,10 +61,8 @@ type Tensor interface { // engine/memory related stuff // all Tensors should be able to be expressed of as a slab of memory // Note: the size of each element can be acquired by T.Dtype().Size() + Memory // Tensors all implement Memory Engine() Engine // Engine can be nil - MemSize() uintptr // the size in memory - Uintptr() uintptr // the pointer to the first element, as a uintptr - Pointer() unsafe.Pointer // the pointer to the first elemment as a unsafe.Ponter IsNativelyAccessible() bool // Can Go access the memory IsManuallyManaged() bool // Must Go manage the memory diff --git a/testutils_test.go b/testutils_test.go index 20cdda9..3a0d466 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -477,17 +477,10 @@ func (e dummyEngine) Memclr(mem Memory) {} func (e dummyEngine) Memcpy(dst, src Memory) error { if e { var a, b storage.Header - a.Ptr = src.Pointer() - a.L = int(src.MemSize()) - a.C = int(src.MemSize()) + a.Raw = storage.FromMemory(src.Uintptr(), src.MemSize()) + b.Raw = storage.FromMemory(dst.Uintptr(), dst.MemSize()) - b.Ptr = dst.Pointer() - b.L = int(dst.MemSize()) - b.C = int(dst.MemSize()) - - abs := *(*[]byte)(unsafe.Pointer(&a)) - bbs := *(*[]byte)(unsafe.Pointer(&b)) - copy(bbs, abs) + copy(b.Raw, a.Raw) return nil } return errors.New("Unable to copy ") From 0364c45520f04451db1c848de9393a9e69aac57d Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 30 Dec 2020 08:18:48 +1100 Subject: [PATCH 068/154] Added an unsafe checker. --- go.mod | 1 + go.sum | 2 ++ internal/storage/header.go | 2 ++ unsafe.go | 3 +++ 4 files changed, 8 insertions(+) create mode 100644 unsafe.go diff --git a/go.mod b/go.mod index 7c1db9f..2a99f4b 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.6.0 github.com/xtgo/set v1.0.0 // indirect + go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 gonum.org/v1/gonum v0.8.1 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 diff --git a/go.sum b/go.sum index d1d71b5..5728d89 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgh github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= diff --git a/internal/storage/header.go b/internal/storage/header.go index 99414a2..4137f01 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -3,6 +3,8 @@ package storage // import "gorgonia.org/tensor/internal/storage" import ( "reflect" "unsafe" + + _ "go4.org/unsafe/assume-no-moving-gc" ) // Header is runtime representation of a slice. It's a cleaner version of reflect.SliceHeader. diff --git a/unsafe.go b/unsafe.go new file mode 100644 index 0000000..5260605 --- /dev/null +++ b/unsafe.go @@ -0,0 +1,3 @@ +package tensor + +import _ "go4.org/unsafe/assume-no-moving-gc" From 098c4c377dcf4bec0854d2d4f11f373731d373f2 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 30 Dec 2020 11:00:57 +1100 Subject: [PATCH 069/154] Use Github Actions for testing instead of travis (#101) * Fixed #90 * Added gihub action * Removed travis * Updated tests for numpy loading to be skipped * Updated Test_FromMemory to skip if OS is not linux * added pip dependency for coverage * Ugh s/action/actions --- .github/FUNDING.yml | 12 +++++ .github/workflows/.go.yml | 108 ++++++++++++++++++++++++++++++++++++++ .travis.yml | 31 ----------- .travis/test.sh | 18 ------- consopt_test.go | 2 + dense_io_test.go | 4 +- 6 files changed, 124 insertions(+), 51 deletions(-) create mode 100644 .github/FUNDING.yml create mode 100644 .github/workflows/.go.yml delete mode 100644 .travis.yml delete mode 100644 .travis/test.sh diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..efb5abf --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: [chewxy, owulveryck, dcu] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/workflows/.go.yml b/.github/workflows/.go.yml new file mode 100644 index 0000000..c323ab7 --- /dev/null +++ b/.github/workflows/.go.yml @@ -0,0 +1,108 @@ +on: + push: + branches: [ master ] + pull_request: +name: test and build +env: + GOPROXY: "https://proxy.golang.org" + CI_NO_PYTHON: "true" +jobs: + test: + strategy: + matrix: + go: [1.13.x, 1.14.x, 1.15.x] + os: [ubuntu-latest, macos-latest, windows-latest] + tags: [avx, sse] + allowfail: [false] + include: + - go: tip + os: ubuntu-latest + allowfail: true + runs-on: ${{ matrix.os }} + continue-on-error: ${{ matrix.allowfail }} + steps: + - name: Install Go ${{ matrix.go }} on ${{ matrix.os }} + if: matrix.go != 'tip' + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} +# tempoary hack: +# https://github.com/actions/setup-go/issues/21#issuecomment-565704236 + - name: Install Go ${{ matrix.go }} on ${{ matrix.os }} + if: matrix.go == 'tip' + run: | + git clone --depth=1 https://go.googlesource.com/go $HOME/gotip + cd $HOME/gotip/src + ./make.bash + echo "::set-env name=GOROOT::$HOME/gotip" + echo "::add-path::$HOME/gotip/bin" + - name: Checkout code + uses: actions/checkout@v2 + - name: Run tests + run: | + go test -v -race + go test -race -tags=${{ matrix.tags }} + + coverage: + env: + CI_NO_PYTHON: "false" + PYTHON_COMMAND: python + strategy: + matrix: + tags: [avx, sse] + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: 1.14.x + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + architecture: 'x64' + - name: Install Pip + uses: BSFishy/pip-action@v1 + with: + packages: numpy + - name: Checkout code + uses: actions/checkout@v2 + - name: Calc coverage + run: | + export PATH=$PATH:$(go env GOPATH)/bin + go test -v -covermode=atomic -coverprofile=coverage.out + - name: Convert coverage to lcov + uses: jandelgado/gcov2lcov-action@v1.0.0 + with: + infile: coverage.out + outfile: coverage.lcov + - name: Coveralls + uses: coverallsapp/github-action@v1.0.1 + with: + github-token: ${{ secrets.github_token }} + path-to-lcov: coverage.lcov + + build: + strategy: + matrix: + go: [1.13, 1.14] + goos: [linux, darwin] + goarch: [amd64, arm] + exclude: + # windows/386 and darwin/386 seems useless + - goarch: "arm" + goos: darwin + runs-on: ubuntu-latest + needs: [test] + steps: + - name: Install Go ${{ matrix.go }} + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Checkout code + uses: actions/checkout@v2 + - name: build + run: go build . + env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 7933b1a..0000000 --- a/.travis.yml +++ /dev/null @@ -1,31 +0,0 @@ -sudo: required -language: go - -branches: - only: - - master -go: - - 1.13.x - - 1.14.x - - 1.15.x - - tip - -env: - global: - - GOARCH=amd64 - - BLAS_LIB=OpenBLAS - - TRAVISTEST=true - - CUDA=9.1.85-1 - -before_install: - - go get github.com/mattn/goveralls - -go_import_path: gorgonia.org/tensor - -script: - source ${TRAVIS_BUILD_DIR}/.travis/test.sh - - $HOME/gopath/bin/goveralls -service=travis-ci -package=gorgonia.org/tensor -covermode=atomic - -matrix: - allow_failures: - - go: tip diff --git a/.travis/test.sh b/.travis/test.sh deleted file mode 100644 index 381a409..0000000 --- a/.travis/test.sh +++ /dev/null @@ -1,18 +0,0 @@ -set -ex - -go env - -go test -v -a -covermode=atomic -coverprofile=test.cover . -go test -tags='avx' -a -covermode=atomic -coverprofile=avx.cover . -go test -tags='sse' -a -covermode=atomic -coverprofile=sse.cover . -go test -tags='inplacetranspose' -a -covermode=atomic -coverprofile=inplacetranspose.cover . -go test -race -a . -go test -a -covermode=atomic -coverprofile=native.cover ./native/. - -# because coveralls only accepts one coverage file at one time... we combine them into one gigantic one -covers=(./test.cover ./avx.cover ./sse.cover ./inplacetranspose.cover ./native.cover) -echo "mode: set" > ./final.cover -tail -q -n +2 "${covers[@]}" >> ./final.cover -goveralls -coverprofile=./final.cover -service=travis-ci - -set +ex diff --git a/consopt_test.go b/consopt_test.go index 65d5396..67ad664 100644 --- a/consopt_test.go +++ b/consopt_test.go @@ -1,3 +1,5 @@ +// +build linux + package tensor import ( diff --git a/dense_io_test.go b/dense_io_test.go index 99afa43..0d65884 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -12,8 +12,8 @@ import ( ) func TestSaveLoadNumpy(t *testing.T) { - if os.Getenv("TRAVISTEST") == "true" { - t.Skip("skipping test; This is being run on TravisCI") + if os.Getenv("CI_NO_PYTHON") == "true" { + t.Skip("skipping test; This is being run on a CI tool that does not have Python") } assert := assert.New(t) From ff9c22f0e63dbb96a11e26ac3dfa872914f49976 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Wed, 30 Dec 2020 14:08:57 +1100 Subject: [PATCH 070/154] Clarifysemantics 2 (#100) * Added a parentTensor check, as well as unsafe.go which will go into v0.10.0 as well * Some weirdness from the master branch is fixed * Added more tests for dense_format * Added some more formatting niceness --- ap_test.go | 3 ++- consopt.go | 4 ---- dense.go | 3 +++ dense_format.go | 19 ++++++++++++++++--- dense_format_test.go | 16 ++++++++++++++-- example_dense_matop_test.go | 6 ++---- go.mod | 11 ++++++----- go.sum | 18 ++++++++++++------ shape.go | 2 +- shape_test.go | 13 ++++++++++++- unsafe.go | 3 +++ 11 files changed, 71 insertions(+), 27 deletions(-) create mode 100644 unsafe.go diff --git a/ap_test.go b/ap_test.go index f5a230e..b813d1f 100644 --- a/ap_test.go +++ b/ap_test.go @@ -120,7 +120,8 @@ func TestAccessPatternIsX(t *testing.T) { ap = dummyScalar2() assert.False(ap.IsScalar()) assert.True(ap.IsScalarEquiv()) - assert.False(ap.IsVector()) + assert.True(ap.IsVectorLike()) + assert.True(ap.IsVector()) assert.False(ap.IsColVec()) assert.False(ap.IsRowVec()) diff --git a/consopt.go b/consopt.go index 8cbc54f..25c157a 100644 --- a/consopt.go +++ b/consopt.go @@ -107,7 +107,6 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - xT := reflect.TypeOf(x) sxT := reflect.SliceOf(xT) xv := reflect.MakeSlice(sxT, 1, 1) // []T @@ -115,7 +114,6 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { xv0.Set(reflect.ValueOf(x)) tt.array.Header.Raw = storage.AsByteSlice(xv.Interface()) tt.t = Dtype{xT} - tt.mask = mask default: @@ -144,10 +142,8 @@ func FromMemory(ptr uintptr, memsize uintptr) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - tt.Header.Raw = nil // GC anything if needed tt.Header.Raw = storage.FromMemory(ptr, memsize) - tt.flag = MakeMemoryFlag(tt.flag, ManuallyManaged) default: panic("Unsupported Tensor type") diff --git a/dense.go b/dense.go index d647ab5..0e8e684 100644 --- a/dense.go +++ b/dense.go @@ -376,12 +376,15 @@ func (t *Dense) ShallowClone() *Dense { func (t *Dense) oldAP() *AP { return &t.old } func (t *Dense) setOldAP(ap *AP) { t.old = *ap } func (t *Dense) transposeAxes() []int { return t.transposeWith } + +//go:nocheckptr func (t *Dense) parentTensor() *Dense { if t.viewOf != 0 { return (*Dense)(unsafe.Pointer(t.viewOf)) } return nil } + func (t *Dense) setParentTensor(d *Dense) { if d == nil { t.viewOf = 0 diff --git a/dense_format.go b/dense_format.go index d20133d..b7f8611 100644 --- a/dense_format.go +++ b/dense_format.go @@ -121,10 +121,14 @@ func (f *fmtState) cleanFmt() string { // does the calculation for metadata func (f *fmtState) populate(t *Dense) { - if t.IsVector() { + switch { + case t.IsVector(): f.rows = 1 f.cols = t.Size() - } else { + case t.IsScalarEquiv(): + f.rows = 1 + f.cols = 1 + default: f.rows = t.Shape()[t.Dims()-2] f.cols = t.Shape()[t.Dims()-1] } @@ -281,6 +285,7 @@ func (t *Dense) Format(s fmt.State, c rune) { } fmt.Fprintf(f, " %v %v\n", t.Shape(), t.Strides()) } + if f.c == 'H' { return } @@ -367,7 +372,6 @@ func (t *Dense) Format(s fmt.State, c rune) { firstVal := true var lastRow, lastCol int var expected int - for next, err := it.Next(); err == nil; next, err = it.Next() { if next < expected { continue @@ -389,6 +393,10 @@ func (t *Dense) Format(s fmt.State, c rune) { f.Write(rowVecStart) case t.IsVector(): f.Write(vecStart) + case t.IsScalarEquiv(): + for i := 0; i < t.Dims(); i++ { + f.Write(vecStart) + } default: f.Write(matFirstStart) } @@ -439,6 +447,11 @@ func (t *Dense) Format(s fmt.State, c rune) { case t.IsVector(): f.Write(vecEnd) return + case t.IsScalarEquiv(): + for i := 0; i < t.Dims(); i++ { + f.Write(vecEnd) + } + return case firstRow: f.Write(matFirstEnd) case eom: diff --git a/dense_format_test.go b/dense_format_test.go index b4b230c..50d4acb 100644 --- a/dense_format_test.go +++ b/dense_format_test.go @@ -21,6 +21,16 @@ func TestDense_Format(t *testing.T) { res = fmt.Sprintf("%3.3f", T) assert.Equal("3.140", res) + // Scalar-equiv (vector) + T = New(WithBacking([]float64{3.14}), WithShape(1)) + res = fmt.Sprintf("%3.3f", T) + assert.Equal("[3.140]", res) + + // Scalar-equiv (n-dimensional) + T = New(WithBacking([]float64{3.14}), WithShape(1, 1, 1, 1)) + res = fmt.Sprintf("%3.3f", T) + assert.Equal("[[[[3.140]]]]", res) + // short vector T = New(Of(Float64), WithShape(4)) res = fmt.Sprintf("%v", T) @@ -73,11 +83,13 @@ Matrix (2, 2) [2 1] // many cols, rows, compressed T = New(WithShape(16, 14), WithBacking(Range(Float64, 0, 16*14))) res = fmt.Sprintf("\n%s", T) + // this clunky string addition thing is because some editors like to trim whitespace. + // There should be two spaces after ` ⋮` . expected = ` ⎡ 0 1 ⋯ 12 13⎤ ⎢ 14 15 ⋯ 26 27⎥ - ⋮ -⎢196 197 ⋯ 208 209⎥ +` + ` ⋮ ` + ` +` + `⎢196 197 ⋯ 208 209⎥ ⎣210 211 ⋯ 222 223⎦ ` assert.Equal(expected, res, "expected %v. Got %v", expected, res) diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 30df83c..91b819e 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -358,11 +358,9 @@ func ExampleT_scalarlike() { // S: 3.14 S2 3.14 S == S2: false // error when the axes are more than the shape's dims: Dimension mismatch. Expected 0, got 2 // S: - // ⎡3.14⎤ - // + // [[3.14]] // S2: - // ⎡3.14⎤ - // + // [[3.14]] // S == S2: false } diff --git a/go.mod b/go.mod index 7c1db9f..7106ca9 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,17 @@ module gorgonia.org/tensor go 1.13 require ( - github.com/apache/arrow/go/arrow v0.0.0-20201027203332-c3091dd3f8ca + github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.6 github.com/gogo/protobuf v1.3.1 - github.com/golang/protobuf v1.4.2 - github.com/google/flatbuffers v1.11.0 + github.com/golang/protobuf v1.4.3 + github.com/google/flatbuffers v1.12.0 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.0 + github.com/stretchr/testify v1.6.1 github.com/xtgo/set v1.0.0 // indirect - gonum.org/v1/gonum v0.8.1 + go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 + gonum.org/v1/gonum v0.8.2 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) diff --git a/go.sum b/go.sum index d1d71b5..21b3359 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/apache/arrow/go/arrow v0.0.0-20201027203332-c3091dd3f8ca h1:OYqlohQ0r1GB7SeG03ct5Xox668iVXgThaNyKLeC01E= -github.com/apache/arrow/go/arrow v0.0.0-20201027203332-c3091dd3f8ca/go.mod h1:c9sxoIT3YgLxH4UhLOCKaBlEojuMhVYpk4Ntv3opUTQ= +github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc h1:zvQ6w7KwtQWgMQiewOF9tFtundRMVZFSAksNV6ogzuY= +github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc/go.mod h1:c9sxoIT3YgLxH4UhLOCKaBlEojuMhVYpk4Ntv3opUTQ= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= @@ -35,8 +35,12 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/flatbuffers v1.12.0 h1:/PtAHvnBY4Kqnx/xCQ3OIV9uYcSFGScBsWI3Oogeh6w= +github.com/google/flatbuffers v1.12.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -54,10 +58,12 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho= -github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -100,8 +106,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.8.1 h1:wGtP3yGpc5mCLOLeTeBdjeui9oZSz5De0eOjMLC/QuQ= -gonum.org/v1/gonum v0.8.1/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= +gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= +gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= diff --git a/shape.go b/shape.go index a0396bb..24c1897 100644 --- a/shape.go +++ b/shape.go @@ -176,7 +176,7 @@ func (s Shape) IsScalarEquiv() bool { // vanilla vector (not a row or a col) // column vector // row vector -func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1 && s[0] > 1) } +func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1) } // IsColVec returns true when the access pattern has the shape (x, 1) func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) } diff --git a/shape_test.go b/shape_test.go index cb21bed..4fcf75f 100644 --- a/shape_test.go +++ b/shape_test.go @@ -55,7 +55,10 @@ func TestShapeIsX(t *testing.T) { s = Shape{1} assert.False(s.IsScalar()) assert.True(s.IsScalarEquiv()) - assert.False(s.IsVector()) + assert.True(s.IsVector()) + assert.True(s.IsVectorLike()) + assert.True(s.IsVector()) + assert.False(s.IsColVec()) assert.False(s.IsRowVec()) @@ -69,12 +72,14 @@ func TestShapeIsX(t *testing.T) { s = Shape{2, 1} assert.False(s.IsScalar()) assert.True(s.IsVector()) + assert.True(s.IsVectorLike()) assert.True(s.IsColVec()) assert.False(s.IsRowVec()) s = Shape{1, 2} assert.False(s.IsScalar()) assert.True(s.IsVector()) + assert.True(s.IsVectorLike()) assert.False(s.IsColVec()) assert.True(s.IsRowVec()) @@ -84,6 +89,12 @@ func TestShapeIsX(t *testing.T) { assert.False(s.IsVector()) assert.False(s.IsColVec()) assert.False(s.IsRowVec()) + + s = Shape{1, 1} + assert.False(s.IsScalar()) + assert.True(s.IsScalarEquiv()) + assert.True(s.IsVectorLike()) + assert.False(s.IsVector()) } func TestShapeCalcStride(t *testing.T) { diff --git a/unsafe.go b/unsafe.go new file mode 100644 index 0000000..5260605 --- /dev/null +++ b/unsafe.go @@ -0,0 +1,3 @@ +package tensor + +import _ "go4.org/unsafe/assume-no-moving-gc" From a53dcb2b7b10ac879dee56d787693879b5950677 Mon Sep 17 00:00:00 2001 From: David Cuadrado Date: Sun, 3 Jan 2021 21:47:40 -0500 Subject: [PATCH 071/154] Fix slicing vectors with shape 1 (#103) Fixes #102 --- dense_matop_test.go | 6 +++++- shape.go | 2 +- shape_test.go | 13 +++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/dense_matop_test.go b/dense_matop_test.go index cf2ce7a..e53c8b8 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -531,6 +531,11 @@ var denseSliceTests = []struct { correctStride []int correctData interface{} }{ + // scalar-equiv vector (issue 102) + {"a[0], a is scalar-equiv", []float64{2}, + Shape{1}, []Slice{ss(0)}, ScalarShape(), nil, 2.0}, + + // vector {"a[0]", []bool{true, true, false, false, false}, Shape{5}, []Slice{ss(0)}, ScalarShape(), nil, true}, {"a[0:2]", Range(Byte, 0, 5), Shape{5}, []Slice{makeRS(0, 2)}, Shape{2}, []int{1}, []byte{0, 1}}, @@ -632,7 +637,6 @@ func TestDense_Slice(t *testing.T) { if err == nil { t.Error("Expected a IndexError") } - } func TestDense_SliceInto(t *testing.T) { diff --git a/shape.go b/shape.go index 24c1897..c1347b4 100644 --- a/shape.go +++ b/shape.go @@ -26,7 +26,7 @@ func (s Shape) TotalSize() int { // CalcStrides calculates the default strides for a shape func (s Shape) CalcStrides() []int { - if s.IsScalarEquiv() { + if s.IsScalar() { return nil } diff --git a/shape_test.go b/shape_test.go index 4fcf75f..9cbc370 100644 --- a/shape_test.go +++ b/shape_test.go @@ -52,23 +52,26 @@ func TestShapeIsX(t *testing.T) { assert.False(s.IsColVec()) assert.False(s.IsRowVec()) + // vectors + + // scalar-equiv vector s = Shape{1} assert.False(s.IsScalar()) assert.True(s.IsScalarEquiv()) assert.True(s.IsVector()) assert.True(s.IsVectorLike()) assert.True(s.IsVector()) - assert.False(s.IsColVec()) assert.False(s.IsRowVec()) - // vector + // vanila vector s = Shape{2} assert.False(s.IsScalar()) assert.True(s.IsVector()) assert.False(s.IsColVec()) assert.False(s.IsRowVec()) + // col vec s = Shape{2, 1} assert.False(s.IsScalar()) assert.True(s.IsVector()) @@ -76,6 +79,7 @@ func TestShapeIsX(t *testing.T) { assert.True(s.IsColVec()) assert.False(s.IsRowVec()) + // row vec s = Shape{1, 2} assert.False(s.IsScalar()) assert.True(s.IsVector()) @@ -90,6 +94,7 @@ func TestShapeIsX(t *testing.T) { assert.False(s.IsColVec()) assert.False(s.IsRowVec()) + // scalar equiv matrix s = Shape{1, 1} assert.False(s.IsScalar()) assert.True(s.IsScalarEquiv()) @@ -105,10 +110,10 @@ func TestShapeCalcStride(t *testing.T) { s = Shape{} assert.Nil(s.CalcStrides()) + // vector shape s = Shape{1} - assert.Nil(s.CalcStrides()) + assert.Equal([]int{1}, s.CalcStrides()) - // vector shape s = Shape{2, 1} assert.Equal([]int{1, 1}, s.CalcStrides()) From 537c38912e63864df7f3839e130fee7c2aec46e8 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 5 Jan 2021 13:30:31 +1100 Subject: [PATCH 072/154] Clarifysemantics2 (#104) * Fixed a bug found by @dcu on concat with scalarlikes * Added test --- defaultengine_matop_misc.go | 1 + dense_matop_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 59c3b69..0ab392a 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -299,6 +299,7 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen if mt, ok := T.(MaskedTensor); ok { copy(v.mask, mt.Mask()) } + start = end continue default: diff := retVal.Shape().Dims() - v.Shape().Dims() diff --git a/dense_matop_test.go b/dense_matop_test.go index e53c8b8..d9de697 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -823,6 +823,33 @@ func TestDense_Concat(t *testing.T) { } } +func TestDense_Concat_sliced(t *testing.T) { + v := New( + WithShape(1, 5), + WithBacking([]float64{0, 1, 2, 3, 4}), + ) + cols := make([]Tensor, v.Shape().TotalSize()) + for i := 0; i < v.Shape().TotalSize(); i++ { + sliced, err := v.Slice(nil, ss(i)) + if err != nil { + t.Fatalf("Failed to slice %d. Error: %v", i, err) + } + if err = sliced.Reshape(sliced.Shape().TotalSize(), 1); err != nil { + t.Fatalf("Failed to reshape %d. Error %v", i, err) + } + cols[i] = sliced + } + result, err := Concat(1, cols[0], cols[1:]...) + if err != nil { + t.Error(err) + } + assert.Equal(t, v.Data(), result.Data()) + if v.Uintptr() == result.Uintptr() { + t.Error("They should not share the same backing data!") + } + +} + var simpleStackTests = []struct { name string dt Dtype From e30cbd9b83ab4c62d2d0a88465a51084452f5891 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Thu, 14 Jan 2021 15:30:06 +1100 Subject: [PATCH 073/154] Fixed #90 (#105) From 41c1060f61a2974ade42d89059a61fe9a56a238d Mon Sep 17 00:00:00 2001 From: Chewxy Date: Fri, 15 Jan 2021 06:32:06 +1100 Subject: [PATCH 074/154] Fix by indices bug (#106) There was a subtle bug in `ByIndices`. The tests have also been updated to detect a wider class of bugs. --- api_matop.go | 7 ++ defaultengine_selbyidx.go | 18 +++- dense_selbyidx_test.go | 193 ++++++++++++++++++-------------------- 3 files changed, 111 insertions(+), 107 deletions(-) diff --git a/api_matop.go b/api_matop.go index 75c2452..bf412ea 100644 --- a/api_matop.go +++ b/api_matop.go @@ -127,13 +127,20 @@ func Diag(t Tensor) (retVal Tensor, err error) { // ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor. // The `indices` tensor has to be a vector-like tensor of ints. func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if axis >= a.Shape().Dims() { + return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) + } if sbi, ok := a.Engine().(ByIndiceser); ok { return sbi.SelectByIndices(a, indices, axis, opts...) } return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) } +// ByIndicesB is the backpropagation of ByIndices. func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if axis >= a.Shape().Dims() { + return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) + } if sbi, ok := a.Engine().(ByIndiceser); ok { return sbi.SelectByIndicesB(a, b, indices, axis, opts...) } diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index ab7f4f1..cdcc318 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -86,8 +86,13 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da dstCoord := make([]int, apRet.shape.Dims()) if isInnermost { - prevStride := apA.strides[axis-1] - retPrevStride := apRet.strides[axis-1] + prevAxis := axis - 1 + if prevAxis < 0 { + // this may be the case if input is a vector + prevAxis = 0 + } + prevStride := apA.strides[prevAxis] + retPrevStride := apRet.strides[prevAxis] for i, idx := range indices { srcCoord[axis] = idx dstCoord[axis] = i @@ -194,8 +199,13 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data srcCoord := make([]int, apRet.shape.Dims()) if isInnermost { - retPrevStride := apB.strides[axis-1] - prevStride := apRet.strides[axis-1] + prevAxis := axis - 1 + if prevAxis < 0 { + // this may be the case if input is a vector + prevAxis = 0 + } + retPrevStride := apB.strides[prevAxis] + prevStride := apRet.strides[prevAxis] for i, idx := range indices { dstCoord[axis] = idx srcCoord[axis] = i diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go index ca6b34f..86369be 100644 --- a/dense_selbyidx_test.go +++ b/dense_selbyidx_test.go @@ -6,121 +6,108 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDense_SelectByIndices(t *testing.T) { - assert := assert.New(t) - - a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4)) - indices := New(WithBacking([]int{1, 1})) - - e := StdEng{} - - a1, err := e.SelectByIndices(a, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - correct1 := []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23} - assert.Equal(correct1, a1.Data()) - - a0, err := e.SelectByIndices(a, indices, 0) - if err != nil { - t.Errorf("%v", err) - } - correct0 := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} - assert.Equal(correct0, a0.Data()) +type selByIndicesTest struct { + Name string + Data interface{} + Shape Shape + Indices []int + Axis int + WillErr bool + + Correct interface{} + CorrectShape Shape +} - a2, err := e.SelectByIndices(a, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - correct2 := []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21} - assert.Equal(correct2, a2.Data()) +var selByIndicesTests = []selByIndicesTest{ + {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, - // !safe - aUnsafe := a.Clone().(*Dense) - indices = New(WithBacking([]int{1, 1, 1})) - aUnsafeSelect, err := e.SelectByIndices(aUnsafe, indices, 0, UseUnsafe()) - if err != nil { - t.Errorf("%v", err) - } - correctUnsafe := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} - assert.Equal(correctUnsafe, aUnsafeSelect.Data()) + {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, + Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, - // 3 indices, just to make sure the sanity of the algorithm - indices = New(WithBacking([]int{1, 1, 1})) - a1, err = e.SelectByIndices(a, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - correct1 = []float64{ - 4, 5, 6, 7, - 4, 5, 6, 7, - 4, 5, 6, 7, + {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, + Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, - 12, 13, 14, 15, - 12, 13, 14, 15, - 12, 13, 14, 15, + {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + Correct: []int{1, 1}, CorrectShape: Shape{2}}, - 20, 21, 22, 23, - 20, 21, 22, 23, - 20, 21, 22, 23, - } - assert.Equal(correct1, a1.Data()) + {Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true, + Correct: []int{1, 1}, CorrectShape: Shape{2}}, + {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, + Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, + {Name: "(2,1) Matrx (colvec)m with (10) indies", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, + }, +} - a0, err = e.SelectByIndices(a, indices, 0) - if err != nil { - t.Errorf("%v", err) +func TestDense_SelectByIndices(t *testing.T) { + assert := assert.New(t) + for i, tc := range selByIndicesTests { + T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) + indices := New(WithBacking(tc.Indices)) + ret, err := ByIndices(T, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + assert.Equal(tc.Correct, ret.Data()) + assert.True(tc.CorrectShape.Eq(ret.Shape())) } - correct0 = []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} - assert.Equal(correct0, a0.Data()) +} - a2, err = e.SelectByIndices(a, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - correct2 = []float64{1, 1, 1, 5, 5, 5, 9, 9, 9, 13, 13, 13, 17, 17, 17, 21, 21, 21} - assert.Equal(correct2, a2.Data()) +var selByIndicesBTests = []struct { + selByIndicesTest + + CorrectGrad interface{} + CorrectGradShape Shape +}{ + { + selByIndicesTest: selByIndicesTests[0], + CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 16, 18, 20, 22, 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0}, + CorrectGradShape: Shape{3, 2, 4}, + }, + { + selByIndicesTest: selByIndicesTests[1], + CorrectGrad: []float64{0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 0, 24, 26, 28, 30, 0, 0, 0, 0, 40, 42, 44, 46}, + CorrectGradShape: Shape{3, 2, 4}, + }, + { + selByIndicesTest: selByIndicesTests[2], + CorrectGrad: []float64{0, 2, 0, 0, 0, 10, 0, 0, 0, 18, 0, 0, 0, 26, 0, 0, 0, 34, 0, 0, 0, 42, 0, 0}, + CorrectGradShape: Shape{3, 2, 4}, + }, + { + selByIndicesTest: selByIndicesTests[3], + CorrectGrad: []int{0, 2, 0, 0, 0}, + CorrectGradShape: Shape{5}, + }, + { + selByIndicesTest: selByIndicesTests[5], + CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, + CorrectGradShape: Shape{4, 2}, + }, + { + selByIndicesTest: selByIndicesTests[6], + CorrectGrad: []float64{0, 10}, + CorrectGradShape: Shape{2, 1}, + }, } func TestDense_SelectByIndicesB(t *testing.T) { - a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4)) - indices := New(WithBacking([]int{1, 1})) - - t.Logf("a\n%v", a) - - e := StdEng{} - - a1, err := e.SelectByIndices(a, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a1\n%v", a1) - - a1Grad, err := e.SelectByIndicesB(a, a1, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a1Grad \n%v", a1Grad) - - a0, err := e.SelectByIndices(a, indices, 0) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a0\n%v", a0) - a0Grad, err := e.SelectByIndicesB(a, a0, indices, 0) - if err != nil { - t.Errorf("%v", err) + assert := assert.New(t) + for i, tc := range selByIndicesBTests { + T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) + indices := New(WithBacking(tc.Indices)) + ret, err := ByIndices(T, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + grad, err := ByIndicesB(T, ret, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + assert.Equal(tc.CorrectGrad, grad.Data(), "%v", tc.Name) + assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead", tc.Name, tc.CorrectGradShape, grad.Shape()) } - t.Logf("a0Grad\n%v", a0Grad) - a2, err := e.SelectByIndices(a, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("\n%v", a2) - a2Grad, err := e.SelectByIndicesB(a, a2, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a2Grad\n%v", a2Grad) } From 0523aa4835f75ff8bf78271f246dc37956bf9cb9 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 18 Jan 2021 11:56:37 +1100 Subject: [PATCH 075/154] Removed `standardEngine` requirement from the Tensor definition --- api_arith.go | 90 +++++++++++++++++------------------------ defaultenginefloat32.go | 7 +++- defaultenginefloat64.go | 7 +++- interfaces.go | 8 ++++ tensor.go | 13 +++--- 5 files changed, 62 insertions(+), 63 deletions(-) diff --git a/api_arith.go b/api_arith.go index 8ef78db..2d7546d 100644 --- a/api_arith.go +++ b/api_arith.go @@ -23,14 +23,14 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition if oe != nil { return oe.Add(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Add(at, bt, opts...) } if adder, ok = at.Engine().(Adder); ok { @@ -55,7 +55,7 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.AddScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.AddScalar(at, bt, leftTensor, opts...) } if adder, ok = at.Engine().(Adder); ok { @@ -79,7 +79,7 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.AddScalar(bt, at, false, opts...) } if adder, ok = bt.Engine().(Adder); ok { @@ -104,14 +104,14 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor substraction if oe != nil { return oe.Sub(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Sub(at, bt, opts...) } if suber, ok = at.Engine().(Suber); ok { @@ -136,7 +136,7 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.SubScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.SubScalar(at, bt, leftTensor, opts...) } if suber, ok = at.Engine().(Suber); ok { @@ -160,7 +160,7 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.SubScalar(bt, at, false, opts...) } if suber, ok = bt.Engine().(Suber); ok { @@ -185,14 +185,14 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor multiplication if oe != nil { return oe.Mul(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Mul(at, bt, opts...) } if muler, ok = at.Engine().(Muler); ok { @@ -217,7 +217,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MulScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MulScalar(at, bt, leftTensor, opts...) } if muler, ok = at.Engine().(Muler); ok { @@ -242,7 +242,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: // b Tensor * a interface - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MulScalar(bt, at, false, opts...) } if muler, ok = bt.Engine().(Muler); ok { @@ -268,14 +268,14 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor division if oe != nil { return oe.Div(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Div(at, bt, opts...) } if diver, ok = at.Engine().(Diver); ok { @@ -300,7 +300,7 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.DivScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.DivScalar(at, bt, leftTensor, opts...) } if diver, ok = at.Engine().(Diver); ok { @@ -324,7 +324,7 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.DivScalar(bt, at, false, opts...) } if diver, ok = bt.Engine().(Diver); ok { @@ -349,14 +349,14 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor exponentiation if oe != nil { return oe.Pow(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Pow(at, bt, opts...) } if power, ok = at.Engine().(Power); ok { @@ -381,7 +381,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.PowScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.PowScalar(at, bt, leftTensor, opts...) } if power, ok = at.Engine().(Power); ok { @@ -405,7 +405,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.PowScalar(bt, at, false, opts...) } if power, ok = bt.Engine().(Power); ok { @@ -430,14 +430,14 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor modulo if oe != nil { return oe.Mod(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Mod(at, bt, opts...) } if moder, ok = at.Engine().(Moder); ok { @@ -462,7 +462,7 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.ModScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.ModScalar(at, bt, leftTensor, opts...) } if moder, ok = at.Engine().(Moder); ok { @@ -486,7 +486,7 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.ModScalar(bt, at, false, opts...) } if moder, ok = bt.Engine().(Moder); ok { @@ -526,41 +526,23 @@ func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { // FMA performs Y = A * X + Y. func FMA(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + var e FMAer if xTensor, ok := x.(Tensor); ok { - if oe := a.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - if oe := xTensor.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - if oe := y.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - - if e, ok := a.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) - } - if e, ok := xTensor.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) - } - if e, ok := y.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) + for _, T := range [3]Tensor{a, xTensor, y} { + e, ok = T.Engine().(FMAer) + if ok { + return e.FMA(a, xTensor, y) + } } } else { - if oe := a.standardEngine(); oe != nil { - return oe.FMAScalar(a, x, y) - } - if oe := y.standardEngine(); oe != nil { - return oe.FMAScalar(a, x, y) - } - - if e, ok := a.Engine().(FMAer); ok { - return e.FMAScalar(a, x, y) - } - if e, ok := y.Engine().(FMAer); ok { - return e.FMAScalar(a, x, y) + for _, T := range [2]Tensor{a, y} { + e, ok = T.Engine().(FMAer) + if ok { + return e.FMAScalar(a, x, y) + } } } + return Mul(a, x, WithIncr(y)) } diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index 45859a4..1618a4f 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -209,8 +209,11 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, vecf32.Add(dataA, dataB) retVal = a default: - ret := a.Clone().(headerer) - vecf32.Add(ret.hdr().Float32s(), dataB) + ret, ok := a.Clone().(float32ser) + if !ok { + return nil, errors.Errorf("Unable to get the Float32 data from `a`, of %T", a) + } + vecf32.Add(ret.Float32s(), dataB) retVal = ret.(Tensor) } return diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 21bba43..3186408 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -206,8 +206,11 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, vecf64.Add(dataA, dataB) retVal = a default: - ret := a.Clone().(headerer) - vecf64.Add(ret.hdr().Float64s(), dataB) + ret, ok := a.Clone().(float64ser) + if !ok { + return nil, errors.Errorf("Unable to get the Float64 data from `a`, of %T", a) + } + vecf64.Add(ret.Float64s(), dataB) retVal = ret.(Tensor) } return diff --git a/interfaces.go b/interfaces.go index c0fd7e3..ed2f400 100644 --- a/interfaces.go +++ b/interfaces.go @@ -149,3 +149,11 @@ type unsafeMem interface { Complex64s() []complex64 Complex128s() []complex128 } + +type float64ser interface { + Float64s() []float64 +} + +type float32ser interface { + Float32s() []float32 +} diff --git a/tensor.go b/tensor.go index eb5ca01..addf65d 100644 --- a/tensor.go +++ b/tensor.go @@ -5,7 +5,6 @@ package tensor // import "gorgonia.org/tensor" import ( "encoding/gob" "fmt" - "unsafe" "github.com/pkg/errors" ) @@ -21,9 +20,8 @@ func init() { gob.Register(&CS{}) } -// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. -// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. -type Tensor interface { +// Desc is a description of a tensor. It does not actually deal with data. +type Desc interface { // info about the ndarray Shape() Shape Strides() []int @@ -31,6 +29,12 @@ type Tensor interface { Dims() int Size() int DataSize() int +} + +// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. +// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. +type Tensor interface { + Desc // Data access related RequiresIterator() bool @@ -72,7 +76,6 @@ type Tensor interface { //gob.GobEncoder //gob.GobDecoder - standardEngine() StandardEngine headerer arrayer From 094bda9af68208bc4c37d7c0cc869f7efab35581 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 18 Jan 2021 12:12:06 +1100 Subject: [PATCH 076/154] Removed requirement for fmt.Stringer and fmt.Formatter --- tensor.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensor.go b/tensor.go index addf65d..6d1b60a 100644 --- a/tensor.go +++ b/tensor.go @@ -4,7 +4,6 @@ package tensor // import "gorgonia.org/tensor" import ( "encoding/gob" - "fmt" "github.com/pkg/errors" ) @@ -67,8 +66,8 @@ type Tensor interface { IsManuallyManaged() bool // Must Go manage the memory // formatters - fmt.Formatter - fmt.Stringer + // fmt.Formatter + // fmt.Stringer // all Tensors are serializable to these formats //WriteNpy(io.Writer) error From 3b5b6f4493cbf35df002913d9627763879d0fc7b Mon Sep 17 00:00:00 2001 From: Guillaume Simonneau <2980507+khezen@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:43:29 +0100 Subject: [PATCH 077/154] handleReuse: add `safe` flag to skip expensive call to BorrowInt (#107) * handleReuse: add unsafe flag to skip expensive call to BorrowInt * handleReuse: add safe flag to skip expensive call to BorrowInt --- dense_linalg.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dense_linalg.go b/dense_linalg.go index 7478cae..3caa8e7 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -82,7 +82,7 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err // check whether retVal has the same size as the resulting matrix would be: mx1 fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { err = errors.Wrapf(err, opFail, "MatVecMul") return } @@ -131,7 +131,7 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { err = errors.Wrapf(err, opFail, "MatMul") return } @@ -170,7 +170,7 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { err = errors.Wrapf(err, opFail, "Outer") return } @@ -380,13 +380,15 @@ func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) { /* UTILITY FUNCTIONS */ // handleReuse extracts a *Dense from Tensor, and checks the shape of the reuse Tensor -func handleReuse(reuse Tensor, expectedShape Shape) (retVal *Dense, err error) { +func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, err error) { if reuse != nil { if retVal, err = assertDense(reuse); err != nil { err = errors.Wrapf(err, opFail, "handling reuse") return } - + if !safe { + return + } if err = reuseCheckShape(retVal, expectedShape); err != nil { err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.") return From 572a2255f0dd5c1855cebeb11368f27e34d03fa7 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Mon, 29 Mar 2021 10:59:06 +1100 Subject: [PATCH 078/154] Fix#111 (#112) * Fixed #90 * Fixed #111 * Boyscout commit to fix python/numpy testing --- defaultengine_mapreduce.go | 18 +++++++++-- dense_io_test.go | 40 +++++++++++++++--------- dense_reduction_test.go | 14 +++++++++ example_dense_matop_test.go | 62 +++++++++++++++++++++++++++++++++++++ example_mapreduce_test.go | 43 +++++++++++++++++++++++++ 5 files changed, 159 insertions(+), 18 deletions(-) diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 03b5c0e..9c1443c 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -178,15 +178,27 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, } func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { - return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a, along...) + a2 := a + if v, ok := a.(View); ok && v.IsMaterializable() { + a2 = v.Materialize() + } + return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...) } func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { - return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a, along...) + a2 := a + if v, ok := a.(View); ok && v.IsMaterializable() { + a2 = v.Materialize() + } + return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...) } func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { - return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a, along...) + a2 := a + if v, ok := a.(View); ok && v.IsMaterializable() { + a2 = v.Materialize() + } + return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a2, along...) } func (e StdEng) reduce( diff --git a/dense_io_test.go b/dense_io_test.go index 0d65884..d2a6548 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -3,6 +3,7 @@ package tensor import ( "bytes" "encoding/gob" + "io/ioutil" "os" "os/exec" "regexp" @@ -30,6 +31,19 @@ func TestSaveLoadNumpy(t *testing.T) { T1D.WriteNpy(f1D) f1D.Close() + defer func() { + // cleanup + err := os.Remove("test.npy") + if err != nil { + t.Error(err) + } + + err = os.Remove("test1D.npy") + if err != nil { + t.Error(err) + } + }() + script := "import numpy as np\nx = np.load('test.npy')\nprint(x)\nx = np.load('test1D.npy')\nprint(x)" // Configurable python command, in order to be able to use python or python3 pythonCommand := os.Getenv("PYTHON_COMMAND") @@ -42,6 +56,10 @@ func TestSaveLoadNumpy(t *testing.T) { if err != nil { t.Error(err) } + stderr, err := cmd.StderrPipe() + if err != nil { + t.Error(err) + } go func() { defer stdin.Close() @@ -56,8 +74,14 @@ func TestSaveLoadNumpy(t *testing.T) { t.Logf("Do you have a python with numpy installed? You can change the python interpreter by setting the environment variable PYTHON_COMMAND. Current value: PYTHON_COMMAND=%s", pythonCommand) } + importError := `ImportError: No module named numpy` + slurpErr, _ := ioutil.ReadAll(stderr) + if ok, _ := regexp.Match(importError, slurpErr); ok { + t.Skipf("Skipping numpy test. It would appear that you do not have Numpy installed.") + } + if err := cmd.Wait(); err != nil { - t.Error(err) + t.Errorf("%q", err.Error()) } expected := `\[\[\s*1\.\s*5\.\]\n \[\s*10\.\s*-1\.\]\]\n` @@ -65,20 +89,6 @@ func TestSaveLoadNumpy(t *testing.T) { t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected) } - if buf.String() != expected { - } - - // cleanup - err = os.Remove("test.npy") - if err != nil { - t.Error(err) - } - - err = os.Remove("test1D.npy") - if err != nil { - t.Error(err) - } - // ok now to test if it can read T2 := new(Dense) buf = new(bytes.Buffer) diff --git a/dense_reduction_test.go b/dense_reduction_test.go index b10e3ac..f83d0c6 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -547,3 +547,17 @@ func TestDense_Min(t *testing.T) { _, err = T.Min(1000) assert.NotNil(err) } + +func TestSlicedSum(t *testing.T) { + T := New(WithShape(4, 4), WithBacking([]int{ + 1, 2, 3, 4, + 5, 6, 7, 8, + 1, 2, 3, 4, + 5, 6, 7, 8, + })) + s, _ := T.Slice(sli(1, 3), sli(1, 3)) + sum, _ := Sum(s) + if sum.Data().(int) != 18 { + t.Errorf("Expected the sum of %v to be 18. Got %v instead", s, sum) + } +} diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 91b819e..497e9d1 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -85,6 +85,68 @@ func ExampleDense_Slice_viewMutation() { // } +func ExampleView() { + // Slicing creates a "view" on the original tensor + T := New(WithBacking(Range(Int, 0, 16)), WithShape(4, 4)) + fmt.Printf("T:\n%v\n", T) + V, _ := T.Slice(makeRS(1, 3), makeRS(1, 3)) + fmt.Printf("V:\n%v\n", V) + + // Now we modify V's 0th value + V.(*Dense).Set(0, 1000) + fmt.Printf("V[0] = 1000:\n%v\n", V) + fmt.Printf("T is also mutated:\n%v\n", T) + + // Now we materialize the views + fmt.Printf("V is Materializable: %v\n", V.IsMaterializable()) + T2 := V.Materialize() + fmt.Printf("T2 == V:\n%v\n", T2) + + // Once materialized, it is decoupled from the original tensor + T2.(*Dense).Set(0, 999) + fmt.Printf("T2 is mutated:\n%v\nBut T is not mutated:\n%v\nNeither is V:\n%v", T2, T, V) + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎢ 8 9 10 11⎥ + // ⎣12 13 14 15⎦ + // + // V: + // ⎡ 5 6⎤ + // ⎣ 9 10⎦ + // + // V[0] = 1000: + // ⎡1000 6⎤ + // ⎣ 9 10⎦ + // + // T is also mutated: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 1000 6 7⎥ + // ⎢ 8 9 10 11⎥ + // ⎣ 12 13 14 15⎦ + // + // V is Materializable: true + // T2 == V: + // ⎡1000 6⎤ + // ⎣ 9 10⎦ + // + // T2 is mutated: + // ⎡999 6⎤ + // ⎣ 9 10⎦ + // + // But T is not mutated: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 1000 6 7⎥ + // ⎢ 8 9 10 11⎥ + // ⎣ 12 13 14 15⎦ + // + // Neither is V: + // ⎡1000 6⎤ + // ⎣ 9 10⎦ + +} + func ExampleDense_Hstack() { var T, T1, T2, T3 *Dense var err error diff --git a/example_mapreduce_test.go b/example_mapreduce_test.go index 4f42a72..e08c6da 100644 --- a/example_mapreduce_test.go +++ b/example_mapreduce_test.go @@ -31,6 +31,27 @@ func ExampleSum() { // Summed along (1, 0): 6 } +func ExampleSum_sliced() { + T := New(WithBacking([]float64{0, 1, 2, 3}), WithShape(2, 2)) + fmt.Printf("T:\n%v\n", T) + + V, _ := T.Slice(nil, sli(1)) + fmt.Printf("V:\n%v\n", V) + + Σ, _ := Sum(V) + fmt.Printf("Σ: %v", Σ) + + // Output: + // T: + // ⎡0 1⎤ + // ⎣2 3⎦ + // + // V: + // [1 3] + // Σ: 4 + +} + func ExampleArgmax() { T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2)) fmt.Printf("T:\n%v\n", T) @@ -49,6 +70,28 @@ func ExampleArgmax() { // Argmax is *tensor.Dense of int } +func ExampleArgmax_sliced() { + T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2)) + fmt.Printf("T:\n%v\n", T) + + // slice creates a view + V, _ := T.Slice(nil, sli(1)) + + // argmax along the x-axis + am, _ := Argmax(V, 0) + fmt.Printf("Argmax: %v\n", am) + fmt.Printf("Argmax is %T of %v", am, am.Dtype()) + + // Output: + // T: + // ⎡ 0 100⎤ + // ⎣200 3⎦ + // + // Argmax: 0 + // Argmax is *tensor.Dense of int + +} + func ExampleArgmin() { T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2)) fmt.Printf("T:\n%v\n", T) From 2a53de8bb7938826062f1d116ec9d90398bbaaaa Mon Sep 17 00:00:00 2001 From: Mark Kremer Date: Mon, 29 Mar 2021 18:58:11 +0200 Subject: [PATCH 079/154] Add subdirectories to tested folders in Github workflow (#108) * Add subdirectories to tested folders in Github workflow * Fix Header test * Fix Fprint redundant newline errors * Increase Github workflow timeout for test job * Update set-env and add-path in Github workflow https://github.blog/changelog/2020-10-01-github-actions-deprecating-set-env-and-add-path-commands/ --- .github/workflows/.go.yml | 11 ++++++----- genlib2/dense_io.go | 10 +++++----- internal/storage/header_test.go | 17 ++++++----------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/.github/workflows/.go.yml b/.github/workflows/.go.yml index c323ab7..f658e9a 100644 --- a/.github/workflows/.go.yml +++ b/.github/workflows/.go.yml @@ -20,6 +20,7 @@ jobs: allowfail: true runs-on: ${{ matrix.os }} continue-on-error: ${{ matrix.allowfail }} + timeout-minutes: 5 steps: - name: Install Go ${{ matrix.go }} on ${{ matrix.os }} if: matrix.go != 'tip' @@ -34,14 +35,14 @@ jobs: git clone --depth=1 https://go.googlesource.com/go $HOME/gotip cd $HOME/gotip/src ./make.bash - echo "::set-env name=GOROOT::$HOME/gotip" - echo "::add-path::$HOME/gotip/bin" + echo "GOROOT=$HOME/gotip" >> $GITHUB_ENV + echo "$HOME/gotip/bin" >> $GITHUB_PATH - name: Checkout code uses: actions/checkout@v2 - name: Run tests run: | - go test -v -race - go test -race -tags=${{ matrix.tags }} + go test ./... -v -race + go test ./... -race -tags=${{ matrix.tags }} coverage: env: @@ -70,7 +71,7 @@ jobs: - name: Calc coverage run: | export PATH=$PATH:$(go env GOPATH)/bin - go test -v -covermode=atomic -coverprofile=coverage.out + go test ./... -v -covermode=atomic -coverprofile=coverage.out - name: Convert coverage to lcov uses: jandelgado/gcov2lcov-action@v1.0.0 with: diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 814067f..4a63ddd 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -657,12 +657,12 @@ func init() { func generateDenseIO(f io.Writer, generic Kinds) { mk := Kinds{Kinds: filter(generic.Kinds, isNumber)} - fmt.Fprintln(f, "/* GOB SERIALIZATION */\n") + fmt.Fprint(f, "/* GOB SERIALIZATION */\n\n") gobEncode.Execute(f, mk) gobDecode.Execute(f, mk) fmt.Fprint(f, "\n") - fmt.Fprintln(f, "/* NPY SERIALIZATION */\n") + fmt.Fprint(f, "/* NPY SERIALIZATION */\n\n") fmt.Fprintln(f, npyDescRE) fmt.Fprintln(f, rowOrderRE) fmt.Fprintln(f, shapeRE) @@ -670,16 +670,16 @@ func generateDenseIO(f io.Writer, generic Kinds) { readNpy.Execute(f, mk) fmt.Fprint(f, "\n") - fmt.Fprintln(f, "/* CSV SERIALIZATION */\n") + fmt.Fprint(f, "/* CSV SERIALIZATION */\n\n") fmt.Fprintln(f, writeCSVRaw) readCSV.Execute(f, mk) fmt.Fprint(f, "\n") - fmt.Fprintln(f, "/* FB SERIALIZATION */\n") + fmt.Fprint(f, "/* FB SERIALIZATION */\n\n") fmt.Fprintln(f, fbEncodeDecodeRaw) fmt.Fprint(f, "\n") - fmt.Fprintln(f, "/* PB SERIALIZATION */\n") + fmt.Fprint(f, "/* PB SERIALIZATION */\n\n") fmt.Fprintln(f, pbEncodeDecodeRaw) fmt.Fprint(f, "\n") diff --git a/internal/storage/header_test.go b/internal/storage/header_test.go index c59fe58..bf6ed73 100644 --- a/internal/storage/header_test.go +++ b/internal/storage/header_test.go @@ -3,7 +3,6 @@ package storage import ( "reflect" "testing" - "unsafe" "github.com/stretchr/testify/assert" ) @@ -14,16 +13,16 @@ func TestFill(t *testing.T) { b := headerFromSlice([]int{10, 11}) copied := Fill(reflect.TypeOf(1), &a, &b) - assert.Equal(t, copied, 5) - assert.Equal(t, a.Ints(), []int{10, 11, 10, 11, 10}) + assert.Equal(t, 5, copied) + assert.Equal(t, []int{10, 11, 10, 11, 10}, a.Ints()) // B longer than A a = headerFromSlice([]int{10, 11}) b = headerFromSlice([]int{0, 1, 2, 3, 4}) copied = Fill(reflect.TypeOf(1), &a, &b) - assert.Equal(t, copied, 2) - assert.Equal(t, a.Ints(), []int{0, 1}) + assert.Equal(t, 2, copied) + assert.Equal(t, []int{0, 1}, a.Ints()) } func headerFromSlice(x interface{}) Header { @@ -31,13 +30,9 @@ func headerFromSlice(x interface{}) Header { if xT.Kind() != reflect.Slice { panic("Expected a slice") } - xV := reflect.ValueOf(x) - uptr := unsafe.Pointer(xV.Pointer()) - + size := uintptr(xV.Len()) * xT.Elem().Size() return Header{ - Ptr: uptr, - L: xV.Len(), - C: xV.Cap(), + Raw: FromMemory(xV.Pointer(), size), } } From 58db8c4e987c498a68936f3cb996e9a2bfd4f0cd Mon Sep 17 00:00:00 2001 From: Mark Kremer Date: Tue, 30 Mar 2021 01:01:24 +0200 Subject: [PATCH 080/154] Remove redundant code from header copy func (#109) * Add subdirectories to tested folders in Github workflow * Fix Header test * Fix Fprint redundant newline errors * Remove redundant code from Header copy func Co-authored-by: Chewxy --- internal/storage/header.go | 27 +++---------------------- internal/storage/header_test.go | 36 +++++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/internal/storage/header.go b/internal/storage/header.go index 99414a2..0e05a1d 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -12,34 +12,13 @@ type Header struct { Raw []byte } +// TypedLen returns the length of data as if it was a slice of type t func (h *Header) TypedLen(t reflect.Type) int { - sz := int(t.Size()) - return len(h.Raw) / sz + return len(h.Raw) / int(t.Size()) } func Copy(t reflect.Type, dst, src *Header) int { - if len(dst.Raw) == 0 || len(src.Raw) == 0 { - return 0 - } - - n := src.TypedLen(t) - if len(dst.Raw) < n { - n = dst.TypedLen(t) - } - - // handle struct{} type - if t.Size() == 0 { - return n - } - - // memmove(dst.Pointer(), src.Pointer(), t.Size()) - // return n - - // otherwise, just copy bytes. - // FUTURE: implement memmove - dstBA := dst.Raw - srcBA := src.Raw - copied := copy(dstBA, srcBA) + copied := copy(dst.Raw, src.Raw) return copied / int(t.Size()) } diff --git a/internal/storage/header_test.go b/internal/storage/header_test.go index bf6ed73..ab28ac5 100644 --- a/internal/storage/header_test.go +++ b/internal/storage/header_test.go @@ -1,12 +1,44 @@ package storage import ( + "github.com/stretchr/testify/assert" "reflect" "testing" - - "github.com/stretchr/testify/assert" ) +func TestCopy(t *testing.T) { + // A longer than B + a := headerFromSlice([]int{0, 1, 2, 3, 4}) + b := headerFromSlice([]int{10, 11}) + copied := Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 2, copied) + assert.Equal(t, []int{10, 11, 2, 3, 4}, a.Ints()) + + // B longer than A + a = headerFromSlice([]int{10, 11}) + b = headerFromSlice([]int{0, 1, 2, 3, 4}) + copied = Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 2, copied) + assert.Equal(t, []int{0, 1}, a.Ints()) + + // A is empty + a = headerFromSlice([]int{}) + b = headerFromSlice([]int{0, 1, 2, 3, 4}) + copied = Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 0, copied) + + // B is empty + a = headerFromSlice([]int{0, 1, 2, 3, 4}) + b = headerFromSlice([]int{}) + copied = Copy(reflect.TypeOf(1), &a, &b) + + assert.Equal(t, 0, copied) + assert.Equal(t, []int{0, 1, 2, 3, 4}, a.Ints()) +} + func TestFill(t *testing.T) { // A longer than B a := headerFromSlice([]int{0, 1, 2, 3, 4}) From 26ad71333f720bae4dc4aa2762173542b21431d9 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 5 Apr 2021 08:23:51 +1000 Subject: [PATCH 081/154] Initial work to separate out View into its own type... --- api_arith.go | 2 +- api_matop.go | 18 ++++++++++++++++-- dense.go | 14 ++------------ dense_compat.go | 5 ++--- dense_mask_filling.go | 4 ++-- dense_matop.go | 5 ++--- dense_views.go | 28 +++++++++++++++++++++++++++- interfaces.go | 9 +++++++++ tensor.go | 26 +++++++++++++++++++------- 9 files changed, 80 insertions(+), 31 deletions(-) diff --git a/api_arith.go b/api_arith.go index 2d7546d..9aa86a8 100644 --- a/api_arith.go +++ b/api_arith.go @@ -597,7 +597,7 @@ func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { var reuse *Dense fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if reuse, err = handleReuse(fo.Reuse(), expectedShape); err != nil { + if reuse, err = handleReuse(fo.Reuse(), expectedShape, true); err != nil { err = errors.Wrapf(err, opFail, "MatMul") return } diff --git a/api_matop.go b/api_matop.go index bf412ea..6b609ae 100644 --- a/api_matop.go +++ b/api_matop.go @@ -53,15 +53,29 @@ func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) { } switch T := t.(type) { case *Dense: + // IF YOU UPDATE THIS, UPDATE THE DENSE VIEW CASE TOO. ts := make([]*Dense, len(others)) for i, o := range others { - if ot, ok := o.(*Dense); ok { + ot, err := assertDense(o) + if err == nil { ts[i] = ot continue } - return nil, errors.Errorf("Expected all Tensors to be *Dense") + return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) } return T.Concat(axis, ts...) + case DenseView: + ts := make([]*Dense, len(others)) + for i, o := range others { + ot, err := assertDense(o) + if err == nil { + ts[i] = ot + continue + } + return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) + } + return T.Concat(axis, ts...) + } panic("Unreachable") } diff --git a/dense.go b/dense.go index eedb9d7..482d9c2 100644 --- a/dense.go +++ b/dense.go @@ -182,16 +182,6 @@ func (t *Dense) ScalarValue() interface{} { return t.Get(0) } -// IsView indicates if the Tensor is a view of another (typically from slicing) -func (t *Dense) IsView() bool { - return t.viewOf != 0 -} - -// IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing -func (t *Dense) IsMaterializable() bool { - return t.viewOf != 0 || !t.old.IsZero() -} - // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user) func (t *Dense) IsManuallyManaged() bool { return t.flag.manuallyManaged() } @@ -575,7 +565,7 @@ func (t *Dense) Memset(x interface{}) error { if !t.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, t) } - if t.IsMaterializable() { + if t.RequiresIterator() { it := newFlatIterator(&t.AP) return t.array.memsetIter(x, it) } @@ -598,7 +588,7 @@ func (t *Dense) Eq(other interface{}) bool { } func (t *Dense) Zero() { - if t.IsMaterializable() { + if t.RequiresIterator() { it := newFlatIterator(&t.AP) if err := t.zeroIter(it); err != nil { panic(err) diff --git a/dense_compat.go b/dense_compat.go index dcbefa2..3fc923b 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -416,10 +416,10 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.IsMaterializable(): + case t.t == Float64 && toCopy && !t.RequiresIterator(): data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.IsMaterializable(): + case !t.RequiresIterator(): data = convToFloat64s(t) default: it := newFlatIterator(&t.AP) @@ -431,7 +431,6 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { data = append(data, convToFloat64(t.Get(next))) } err = nil - } retVal = mat.NewDense(r, c, data) diff --git a/dense_mask_filling.go b/dense_mask_filling.go index f5d45c7..a31b5aa 100644 --- a/dense_mask_filling.go +++ b/dense_mask_filling.go @@ -72,7 +72,7 @@ func (t *Dense) Filled(val ...interface{}) (interface{}, error) { for i := range sliceList { tt, err := tc.Slice(nil, sliceList[i]) if err != nil { - ts := tt.(*Dense) + ts := tt.(DenseView) ts.Memset(fillval) } } @@ -107,7 +107,7 @@ func (t *Dense) FilledInplace(val ...interface{}) (interface{}, error) { for i := range sliceList { tt, err := t.Slice(nil, sliceList[i]) if err != nil { - ts := tt.(*Dense) + ts := tt.(DenseView) ts.Memset(fillval) } } diff --git a/dense_matop.go b/dense_matop.go index 7e81419..cb976cc 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -228,7 +228,7 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { view.mask = t.mask[ndStart:ndEnd] } - return view, err + return DenseView{view}, err } // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view. @@ -256,8 +256,7 @@ func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) view.mask = t.mask[ndStart:ndEnd] } - return view, err - + return DenseView{view}, err } // RollAxis rolls the axis backwards until it lies in the given position. diff --git a/dense_views.go b/dense_views.go index 201ff20..ec8fe8b 100644 --- a/dense_views.go +++ b/dense_views.go @@ -3,8 +3,34 @@ package tensor // a View is a *Tensor with customized strides. The reason for not splitting them up into different types is complicated // this file contains all the methods that deals with Views +type DenseView struct { + *Dense +} + +// RequiresIterator returns true if an iterator is required to read the data in the correct fashion. +func (t DenseView) RequiresIterator() bool { + if t.len() == 1 { + return false + } + // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required + if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { + return true + } + return false +} + +// IsView indicates if the Tensor is a view of another (typically from slicing) +func (t DenseView) IsView() bool { + return t.viewOf != 0 +} + +// IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing +func (t DenseView) IsMaterializable() bool { + return t.viewOf != 0 || !t.old.IsZero() +} + // Materialize takes a view, copies its data and puts it in a new *Tensor. -func (t *Dense) Materialize() Tensor { +func (t DenseView) Materialize() Tensor { if !t.IsMaterializable() { return t } diff --git a/interfaces.go b/interfaces.go index ed2f400..345698a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -71,6 +71,15 @@ type Slicer interface { Slice(...Slice) (View, error) } +// Reslicer is any tensor that can reslice. +// To reslice is to reuse the container (*Dense, *CS) etc, but with new `Slice`s applied to it. +// +// e.g: A is a (3,3) matrix that has been sliced at [1:3, 1:3]. Call it B. So now B's shape is (2,2). +// B.Reslice(S(0,2), S(0,2)) would reslice the original tensor (A) with the new slices. +type Reslicer interface { + Reslice(...Slice) (View, error) +} + // DenseTensor is the interface for any Dense tensor. type DenseTensor interface { Tensor diff --git a/tensor.go b/tensor.go index 6d1b60a..24be135 100644 --- a/tensor.go +++ b/tensor.go @@ -11,7 +11,7 @@ import ( var ( _ Tensor = &Dense{} _ Tensor = &CS{} - _ View = &Dense{} + _ View = &DenseView{} ) func init() { @@ -96,15 +96,26 @@ func New(opts ...ConsOpt) *Dense { return d } +// MustGetDense gets a *Dense from a given Tensor. Panics otherwise. +func MustGetDense(T Tensor) *Dense { + d, err := assertDense(T) + if err != nil { + panic(err) + } + return d +} + func assertDense(t Tensor) (*Dense, error) { if t == nil { return nil, errors.New("nil is not a *Dense") } - if retVal, ok := t.(*Dense); ok { - return retVal, nil - } - if retVal, ok := t.(Densor); ok { - return retVal.Dense(), nil + switch tt := t.(type) { + case *Dense: + return tt, nil + case DenseView: + return tt.Dense, nil + case Densor: + return tt.Dense(), nil } return nil, errors.Errorf("%T is not *Dense", t) } @@ -162,10 +173,11 @@ func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { return } +// sliceDense returns a *Dense. func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) { var sliced Tensor if sliced, err = t.Slice(slices...); err != nil { return nil, err } - return sliced.(*Dense), nil + return sliced.(DenseView).Dense, nil } From 781e213d1d3b7c7a3d8a21822dab56b05f630d82 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 5 Apr 2021 14:10:47 +1000 Subject: [PATCH 082/154] Finished moving View out. Now time to deal with the transposey views --- defaultengine_mapreduce.go | 13 ++- dense.go | 1 - dense_compat.go | 5 +- dense_mask_inspection.go | 4 +- dense_mask_inspection_test.go | 8 +- dense_matop_test.go | 24 +++--- dense_norms_test.go | 20 ++--- dense_views.go | 5 +- example_dense_arith_test.go | 144 +++++++++++++++++----------------- example_dense_cmp_test.go | 60 +++++++------- example_dense_linalg_test.go | 4 +- example_dense_matop_test.go | 4 +- genlib2/dense_compat.go | 6 +- 13 files changed, 149 insertions(+), 149 deletions(-) diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 9c1443c..a70af74 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -15,6 +15,9 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e err = errors.Wrap(err, "Failed Map()") return } + if _, ok := a.(DenseTensor); !ok { + return nil, errors.Errorf("StdEng's Map method only supports dense tensors for now. Please put in a Pull Request to support other forms of Tensors. The file is: defaultengine_mapreduce.go") + } var reuse DenseTensor var safe, _, incr bool @@ -24,14 +27,10 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e switch { case safe && reuse == nil: // create reuse - if v, ok := a.(View); ok { - if v.IsMaterializable() { - reuse = v.Materialize().(DenseTensor) - } else { - reuse = v.Clone().(DenseTensor) - } + if v, ok := a.(View); ok && v.IsMaterializable() { + reuse = v.Materialize().(DenseTensor) } else { - reuse = New(Of(a.Dtype()), WithShape(a.Shape().Clone()...)) + reuse = a.Clone().(DenseTensor) } case reuse != nil: if !reuse.IsNativelyAccessible() { diff --git a/dense.go b/dense.go index 482d9c2..92a535c 100644 --- a/dense.go +++ b/dense.go @@ -205,7 +205,6 @@ func (t *Dense) Clone() interface{} { } copyDense(retVal, t) retVal.lock() - return retVal } panic("Unreachable: No engine") diff --git a/dense_compat.go b/dense_compat.go index 3fc923b..0d6073a 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -416,10 +416,10 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.RequiresIterator(): + case t.t == Float64 && toCopy && !t.RequiresIterator() && t.viewOf == 0: data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.RequiresIterator(): + case !t.RequiresIterator() && t.viewOf == 0: data = convToFloat64s(t) default: it := newFlatIterator(&t.AP) @@ -432,7 +432,6 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { } err = nil } - retVal = mat.NewDense(r, c, data) return } diff --git a/dense_mask_inspection.go b/dense_mask_inspection.go index d2e7843..39be72f 100644 --- a/dense_mask_inspection.go +++ b/dense_mask_inspection.go @@ -18,7 +18,7 @@ func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) inter // calculate shape of tensor to be returned slices[ax] = makeRS(0, 0) tt, _ := t.Slice(slices...) - ts := tt.(*Dense) + ts := MustGetDense(tt) retVal := NewDense(retType, ts.shape) //retVal is array to be returned it := NewIterator(retVal.Info()) @@ -37,7 +37,7 @@ func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) inter } } tt, _ = t.Slice(slices...) - ts = tt.(*Dense) + ts = MustGetDense(tt) retVal.SetAt(fn(ts), coord...) } diff --git a/dense_mask_inspection_test.go b/dense_mask_inspection_test.go index 7bd118f..ea3574f 100644 --- a/dense_mask_inspection_test.go +++ b/dense_mask_inspection_test.go @@ -124,7 +124,7 @@ func TestMaskedFindContiguous(t *testing.T) { T.ResetMask(true) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(false) } retSL = T.FlatNotMaskedContiguous() @@ -137,7 +137,7 @@ func TestMaskedFindContiguous(t *testing.T) { T.ResetMask(false) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(true) } retSL = T.FlatMaskedContiguous() @@ -158,7 +158,7 @@ func TestMaskedFindEdges(t *testing.T) { T.ResetMask(false) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(true) } start, end := T.FlatNotMaskedEdges() @@ -169,7 +169,7 @@ func TestMaskedFindEdges(t *testing.T) { T.ResetMask(true) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(false) } start, end = T.FlatMaskedEdges() diff --git a/dense_matop_test.go b/dense_matop_test.go index d9de697..abab26e 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -503,7 +503,7 @@ func TestDense_CopyTo(t *testing.T) { T = New(Of(Byte), WithShape(3, 3)) T2 = New(Of(Byte), WithShape(2, 2)) T3, _ = T.Slice(makeRS(0, 2), makeRS(0, 2)) // T[0:2, 0:2], shape == (2,2) - if err = T2.CopyTo(T3.(*Dense)); err != nil { + if err = T2.CopyTo(MustGetDense(T3)); err != nil { t.Log(err) // for now it's a not yet implemented error. TODO: FIX THIS } @@ -609,7 +609,7 @@ func TestDense_Slice(t *testing.T) { assert.True(Shape{2}.Eq(V.Shape())) assert.Equal([]int{3}, V.Strides()) assert.Equal([]float32{0, 1, 2, 3}, V.Data()) - assert.True(V.(*Dense).old.IsZero()) + assert.True(MustGetDense(V).old.IsZero()) // slice a sliced t.Logf("%v", V) @@ -960,12 +960,12 @@ func TestDense_Stack(t *testing.T) { T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T.Slice(sts.slices...); err != nil { t.Error(err) continue } - T = sliced.(*Dense) + T = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T.T(sts.transform...) } @@ -976,12 +976,12 @@ func TestDense_Stack(t *testing.T) { T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T1.Slice(sts.slices...); err != nil { t.Error(err) continue } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T1.T(sts.transform...) } @@ -1027,12 +1027,12 @@ func TestDense_Stack(t *testing.T) { T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T.Slice(sts.slices...); err != nil { t.Error(err) continue } - T = sliced.(*Dense) + T = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T.T(sts.transform...) } @@ -1044,12 +1044,12 @@ func TestDense_Stack(t *testing.T) { T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T1.Slice(sts.slices...); err != nil { t.Error(err) continue } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T1.T(sts.transform...) } @@ -1077,12 +1077,12 @@ func TestDense_Stack(t *testing.T) { var stacked []*Dense for i := 0; i < 1; i++ { T1 := New(WithShape(2, 2), WithBacking([]string{"blah1", "blah2", "blah3", "blah4"})) - var sliced Tensor + var sliced View if sliced, err = T1.Slice(nil, nil); err != nil { t.Error(err) break } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) stacked = append(stacked, T1) } T2, err := T.Stack(0, stacked...) diff --git a/dense_norms_test.go b/dense_norms_test.go index 316b32a..69879ee 100644 --- a/dense_norms_test.go +++ b/dense_norms_test.go @@ -120,12 +120,13 @@ func TestTensor_Norm(t *testing.T) { t.Error(err) } } + } func TestTensor_Norm_Axis(t *testing.T) { assert := assert.New(t) var T, s, expected, retVal *Dense - var sliced Tensor + var sliced View var err error var backing []float64 var ords []NormOrder @@ -149,7 +150,7 @@ func TestTensor_Norm_Axis(t *testing.T) { var expecteds []*Dense for k := 0; k < T.Shape()[1]; k++ { sliced, _ = T.Slice(nil, ss(k)) - s = sliced.(View).Materialize().(*Dense) + s = sliced.Materialize().(*Dense) expected, _ = s.Norm(ord) expecteds = append(expecteds, expected) } @@ -162,8 +163,8 @@ func TestTensor_Norm_Axis(t *testing.T) { assert.Equal(len(expecteds), retVal.Shape()[0]) for i, e := range expecteds { sliced, _ = retVal.Slice(ss(i)) - sliced = sliced.(View).Materialize() - if !allClose(e.Data(), sliced.Data()) { + mat := sliced.Materialize() + if !allClose(e.Data(), mat.Data()) { t.Errorf("Axis = 0; Ord = %v; Expected %v. Got %v instead. ret %v, i: %d", ord, e.Data(), sliced.Data(), retVal, i) } } @@ -173,7 +174,7 @@ func TestTensor_Norm_Axis(t *testing.T) { expecteds = expecteds[:0] for k := 0; k < T.Shape()[0]; k++ { sliced, _ = T.Slice(ss(k)) - s = sliced.(*Dense) + s = MustGetDense(sliced) expected, _ = s.Norm(ord) expecteds = append(expecteds, expected) } @@ -185,8 +186,8 @@ func TestTensor_Norm_Axis(t *testing.T) { assert.Equal(len(expecteds), retVal.Shape()[0]) for i, e := range expecteds { sliced, _ = retVal.Slice(ss(i)) - sliced = sliced.(View).Materialize().(*Dense) - if !allClose(e.Data(), sliced.Data()) { + mat := sliced.Materialize() + if !allClose(e.Data(), mat.Data()) { t.Errorf("Axis = 1; Ord = %v; Expected %v. Got %v instead", ord, e.Data(), sliced.Data()) } } @@ -249,9 +250,8 @@ func TestTensor_Norm_Axis(t *testing.T) { if rowAxis > colAxis { sliced.T() } - sliced = sliced.(View).Materialize().(*Dense) - s = sliced.(*Dense) - expected, _ = s.Norm(ord) + mat := sliced.Materialize().(*Dense) + expected, _ = mat.Norm(ord) expecteds = append(expecteds, expected) } diff --git a/dense_views.go b/dense_views.go index ec8fe8b..ab3c537 100644 --- a/dense_views.go +++ b/dense_views.go @@ -3,6 +3,9 @@ package tensor // a View is a *Tensor with customized strides. The reason for not splitting them up into different types is complicated // this file contains all the methods that deals with Views +var _ View = DenseView{} + +// Dense type DenseView struct { *Dense } @@ -32,7 +35,7 @@ func (t DenseView) IsMaterializable() bool { // Materialize takes a view, copies its data and puts it in a new *Tensor. func (t DenseView) Materialize() Tensor { if !t.IsMaterializable() { - return t + return t.Dense } retVal := recycledDense(t.t, t.shape.Clone(), WithEngine(t.e)) diff --git a/example_dense_arith_test.go b/example_dense_arith_test.go index 1ea0c1d..4c17d1a 100644 --- a/example_dense_arith_test.go +++ b/example_dense_arith_test.go @@ -13,7 +13,7 @@ func ExampleDense_Add_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Add(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] + T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -57,7 +57,7 @@ func ExampleDense_Add_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Add(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -100,7 +100,7 @@ func ExampleDense_Add_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Add(T2, WithReuse(Reuse)) @@ -137,7 +137,7 @@ func ExampleDense_Add_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Add(T2, WithIncr(Incr)) @@ -175,7 +175,7 @@ func ExampleDense_Sub_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Sub(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] + T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -219,7 +219,7 @@ func ExampleDense_Sub_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Sub(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -262,7 +262,7 @@ func ExampleDense_Sub_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Sub(T2, WithReuse(Reuse)) @@ -299,7 +299,7 @@ func ExampleDense_Sub_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Sub(T2, WithIncr(Incr)) @@ -337,7 +337,7 @@ func ExampleDense_Mul_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Mul(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] × T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -381,7 +381,7 @@ func ExampleDense_Mul_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Mul(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -424,7 +424,7 @@ func ExampleDense_Mul_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Mul(T2, WithReuse(Reuse)) @@ -461,7 +461,7 @@ func ExampleDense_Mul_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Mul(T2, WithIncr(Incr)) @@ -499,7 +499,7 @@ func ExampleDense_Div_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Div(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] ÷ T2\nT3:\n%1.1v\nT1 is unchanged:\n%1.1v\n", T3, T1) @@ -543,7 +543,7 @@ func ExampleDense_Div_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Div(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -586,7 +586,7 @@ func ExampleDense_Div_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Div(T2, WithReuse(Reuse)) @@ -623,7 +623,7 @@ func ExampleDense_Div_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Div(T2, WithIncr(Incr)) @@ -661,7 +661,7 @@ func ExampleDense_Pow_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Pow(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] ^ T2\nT3:\n%1.1v\nT1 is unchanged:\n%v\n", T3, T1) @@ -705,7 +705,7 @@ func ExampleDense_Pow_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Pow(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -749,7 +749,7 @@ func ExampleDense_Pow_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Pow(T2, WithReuse(Reuse)) @@ -786,7 +786,7 @@ func ExampleDense_Pow_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Pow(T2, WithIncr(Incr)) @@ -824,7 +824,7 @@ func ExampleDense_Mod_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Mod(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] %% T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -868,7 +868,7 @@ func ExampleDense_Mod_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Mod(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -912,7 +912,7 @@ func ExampleDense_Mod_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Mod(T2, WithReuse(Reuse)) @@ -949,7 +949,7 @@ func ExampleDense_Mod_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Mod(T2, WithIncr(Incr)) @@ -991,13 +991,13 @@ func ExampleDense_AddScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 + T1[:, 1:3]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -1067,15 +1067,15 @@ func ExampleDense_AddScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 + T1[:, 0:2]\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 + T1[:, 0:2]\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1114,7 +1114,7 @@ func ExampleDense_AddScalar_unsafe() { // ⎢ 8 9⎥ // ⎣11 12⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 5 6 2⎤ // ⎢ 8 9 5⎥ @@ -1128,7 +1128,7 @@ func ExampleDense_AddScalar_unsafe() { // ⎢ 8 9⎥ // ⎣11 12⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 5 6 2⎤ // ⎢ 8 9 5⎥ @@ -1155,7 +1155,7 @@ func ExampleDense_AddScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.AddScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v\n", T3 == Reuse, T3) @@ -1163,7 +1163,7 @@ func ExampleDense_AddScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.AddScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v", T3 == Reuse, T3) @@ -1213,7 +1213,7 @@ func ExampleDense_AddScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.AddScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 + T2\nIncr == T3: %t\nT3:\n%v\n", Incr == T3, T3) @@ -1250,13 +1250,13 @@ func ExampleDense_SubScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 - T1[:, 1:3]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -1327,15 +1327,15 @@ func ExampleDense_SubScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 - T1[:, 0:2]\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 - T1[:, 0:2]\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1374,7 +1374,7 @@ func ExampleDense_SubScalar_unsafe() { // ⎢-2 -1⎥ // ⎣ 1 2⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡-5 -4 2⎤ // ⎢-2 -1 5⎥ @@ -1388,7 +1388,7 @@ func ExampleDense_SubScalar_unsafe() { // ⎢ 2 1⎥ // ⎣-1 -2⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 5 4 2⎤ // ⎢ 2 1 5⎥ @@ -1415,7 +1415,7 @@ func ExampleDense_SubScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.SubScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v\n", T3 == Reuse, T3) @@ -1423,7 +1423,7 @@ func ExampleDense_SubScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.SubScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v", T3 == Reuse, T3) @@ -1473,7 +1473,7 @@ func ExampleDense_SubScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.SubScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 - T2\nIncr == T3: %t\nT3:\n%v\n", Incr == T3, T3) @@ -1512,13 +1512,13 @@ func ExampleDense_MulScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 * T1[:, 1:3]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -1588,15 +1588,15 @@ func ExampleDense_MulScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 * T1[:, 0:2]\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 * T1[:, 0:2]\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1635,7 +1635,7 @@ func ExampleDense_MulScalar_unsafe() { // ⎢15 20⎥ // ⎣30 35⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 0 5 2⎤ // ⎢15 20 5⎥ @@ -1649,7 +1649,7 @@ func ExampleDense_MulScalar_unsafe() { // ⎢15 20⎥ // ⎣30 35⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 0 5 2⎤ // ⎢15 20 5⎥ @@ -1676,7 +1676,7 @@ func ExampleDense_MulScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.MulScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v\n", T3 == Reuse, T3) @@ -1684,7 +1684,7 @@ func ExampleDense_MulScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.MulScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v", T3 == Reuse, T3) @@ -1734,7 +1734,7 @@ func ExampleDense_MulScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.MulScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 * T2\nIncr == T3: %t\nT3:\n%v\n", Incr == T3, T3) @@ -1771,13 +1771,13 @@ func ExampleDense_DivScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%1.1v\nT1 is unchanged:\n%1.1v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 / T1[:, 1:3]\nT3:\n%1.1v\nT1 is unchanged:\n%1.1v\n", T3, T1) @@ -1847,15 +1847,15 @@ func ExampleDense_DivScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%1.1v\nsliced == T3: %t\nT1 is changed:\n%1.1v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%1.1v\nV == T3: %t\nT1 is changed:\n%1.1v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 / T1[:, 0:2]\nT3:\n%1.1v\nsliced == T3: %t\nT1 is changed:\n%1.1v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 / T1[:, 0:2]\nT3:\n%1.1v\nV == T3: %t\nT1 is changed:\n%1.1v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1894,7 +1894,7 @@ func ExampleDense_DivScalar_unsafe() { // ⎢0.6 0.8⎥ // ⎣ 1 1⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 0 0.2 2⎤ // ⎢0.6 0.8 5⎥ @@ -1908,7 +1908,7 @@ func ExampleDense_DivScalar_unsafe() { // ⎢ 2 1⎥ // ⎣ 0.8 0.7⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡+Inf 5 2⎤ // ⎢ 2 1 5⎥ @@ -1935,7 +1935,7 @@ func ExampleDense_DivScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.DivScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%1.1v\n", T3 == Reuse, T3) @@ -1943,7 +1943,7 @@ func ExampleDense_DivScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.DivScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%1.1v", T3 == Reuse, T3) @@ -1993,7 +1993,7 @@ func ExampleDense_DivScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.DivScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 / T2\nIncr == T3: %t\nT3:\n%3.1v\n", Incr == T3, T3) @@ -2030,13 +2030,13 @@ func ExampleDense_PowScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.PowScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[0:2, 0:2] ^ 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.PowScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 ^ T1[0:2, 0:2]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -2105,13 +2105,13 @@ func ExampleDense_ModScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.ModScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[0:2, 0:2] %% 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.ModScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 %% T1[0:2, 0:2]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) diff --git a/example_dense_cmp_test.go b/example_dense_cmp_test.go index 6d72c4d..9166821 100644 --- a/example_dense_cmp_test.go +++ b/example_dense_cmp_test.go @@ -20,7 +20,7 @@ func ExampleDense_Gt_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gt(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -28,7 +28,7 @@ func ExampleDense_Gt_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gt(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -83,7 +83,7 @@ func ExampleDense_Gt_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Gt(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -129,7 +129,7 @@ func ExampleDense_Gt_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Gt(T2, WithReuse(T3)) @@ -138,7 +138,7 @@ func ExampleDense_Gt_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Gt(T2, WithReuse(T3), AsSameType()) @@ -192,7 +192,7 @@ func ExampleDense_Gte_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gte(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -200,7 +200,7 @@ func ExampleDense_Gte_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gte(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -255,7 +255,7 @@ func ExampleDense_Gte_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Gte(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -301,7 +301,7 @@ func ExampleDense_Gte_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Gte(T2, WithReuse(T3)) @@ -310,7 +310,7 @@ func ExampleDense_Gte_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Gte(T2, WithReuse(T3), AsSameType()) @@ -364,7 +364,7 @@ func ExampleDense_Lt_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lt(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -372,7 +372,7 @@ func ExampleDense_Lt_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lt(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -427,7 +427,7 @@ func ExampleDense_Lt_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Lt(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -473,7 +473,7 @@ func ExampleDense_Lt_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Lt(T2, WithReuse(T3)) @@ -482,7 +482,7 @@ func ExampleDense_Lt_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Lt(T2, WithReuse(T3), AsSameType()) @@ -535,7 +535,7 @@ func ExampleDense_Lte_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lte(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -543,7 +543,7 @@ func ExampleDense_Lte_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lte(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -598,7 +598,7 @@ func ExampleDense_Lte_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Lte(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -644,7 +644,7 @@ func ExampleDense_Lte_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Lte(T2, WithReuse(T3)) @@ -653,7 +653,7 @@ func ExampleDense_Lte_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Lte(T2, WithReuse(T3), AsSameType()) @@ -707,7 +707,7 @@ func ExampleDense_ElEq_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElEq(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -715,7 +715,7 @@ func ExampleDense_ElEq_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElEq(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -770,7 +770,7 @@ func ExampleDense_ElEq_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.ElEq(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -817,7 +817,7 @@ func ExampleDense_ElEq_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.ElEq(T2, WithReuse(T3)) @@ -826,7 +826,7 @@ func ExampleDense_ElEq_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.ElEq(T2, WithReuse(T3), AsSameType()) @@ -880,7 +880,7 @@ func ExampleDense_ElNe_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElNe(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -888,7 +888,7 @@ func ExampleDense_ElNe_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElNe(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -943,7 +943,7 @@ func ExampleDense_ElNe_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.ElNe(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -990,7 +990,7 @@ func ExampleDense_ElNe_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.ElNe(T2, WithReuse(T3)) @@ -999,7 +999,7 @@ func ExampleDense_ElNe_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.ElNe(T2, WithReuse(T3), AsSameType()) diff --git a/example_dense_linalg_test.go b/example_dense_linalg_test.go index d558481..13c9dcf 100644 --- a/example_dense_linalg_test.go +++ b/example_dense_linalg_test.go @@ -76,7 +76,7 @@ func ExampleDense_MatVecMul_rowMajorSliced() { // here we print the underlying slice of T3 just to show that it's actually a much larger slice fmt.Printf("Underlying Slice: %v\n", T3.Data()) - T4, err := T2.(*Dense).MatVecMul(T3) + T4, err := MustGetDense(T2).MatVecMul(T3) handleErr(err) fmt.Printf("T4:\n%v\n", T4) @@ -120,7 +120,7 @@ func ExampleDense_MatMul_sliced() { handleErr(err) fmt.Printf("T4:\n%v", T4) - T5, err := T3.(*Dense).MatMul(T4) + T5, err := MustGetDense(T3).MatMul(T4) handleErr(err) fmt.Printf("T3xT4:\n%v", T5) diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 497e9d1..1e81271 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -58,7 +58,7 @@ func ExampleDense_Slice_viewMutation() { fmt.Printf("V:\n%v\n", V) // Now we modify V's 0th value - V.(*Dense).Set(0, 1000) + MustGetDense(V).Set(0, 1000) fmt.Printf("V[0] = 1000:\n%v\n", V) fmt.Printf("T is also mutated:\n%v", T) @@ -93,7 +93,7 @@ func ExampleView() { fmt.Printf("V:\n%v\n", V) // Now we modify V's 0th value - V.(*Dense).Set(0, 1000) + MustGetDense(V).Set(0, 1000) fmt.Printf("V[0] = 1000:\n%v\n", V) fmt.Printf("T is also mutated:\n%v\n", T) diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index e3b5b52..8ce7fe8 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -220,10 +220,10 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.IsMaterializable(): + case t.t == Float64 && toCopy && !t.RequiresIterator() && t.viewOf == 0: data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.IsMaterializable(): + case !t.RequiresIterator() && t.viewOf == 0: data = convToFloat64s(t) default: it := newFlatIterator(&t.AP) @@ -235,7 +235,7 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { data = append(data, convToFloat64(t.Get(next))) } err = nil - + } retVal = mat.NewDense(r, c, data) From 82b7912a13a4a478d376b7e09c5369e93130a3ab Mon Sep 17 00:00:00 2001 From: Mark Kremer Date: Mon, 5 Apr 2021 15:44:54 +0200 Subject: [PATCH 083/154] Move unused noopError to test. Simplify handleNoOp function (#110) * Move unused noopError to test. Simplify handleNoOp function * Revert removal of nil check in handleNoOp func * Add test for handleNoOp func in main package * Invert if statements in handleNoOp for consistency * Retrigger checks --- errors.go | 7 +++---- errors_test.go | 15 +++++++++++++++ internal/execution/keepsync.go | 12 +++--------- internal/execution/keepsync_test.go | 20 ++++++++++++++++++++ internal/storage/keepsync.go | 12 +++--------- internal/storage/keepsync_test.go | 20 ++++++++++++++++++++ 6 files changed, 64 insertions(+), 22 deletions(-) create mode 100644 errors_test.go create mode 100644 internal/execution/keepsync_test.go create mode 100644 internal/storage/keepsync_test.go diff --git a/errors.go b/errors.go index 461baf4..cd6a297 100644 --- a/errors.go +++ b/errors.go @@ -21,11 +21,10 @@ func handleNoOp(err error) error { if err == nil { return nil } - - if _, ok := err.(NoOpError); !ok { - return err + if _, ok := err.(NoOpError); ok { + return nil } - return nil + return err } type errorIndices []int diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..ec12185 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,15 @@ +package tensor + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestHandleNoOp(t *testing.T) { + otherErr := errors.New("other error") + + assert.Equal(t, nil, handleNoOp(noopError{})) + assert.Equal(t, nil, handleNoOp(nil)) + assert.Equal(t, otherErr, handleNoOp(otherErr)) +} diff --git a/internal/execution/keepsync.go b/internal/execution/keepsync.go index 5b49f7d..8921d1c 100644 --- a/internal/execution/keepsync.go +++ b/internal/execution/keepsync.go @@ -19,18 +19,12 @@ type NoOpError interface { NoOp() bool } -type noopError struct{} - -func (e noopError) NoOp() bool { return true } -func (e noopError) Error() string { return "NoOp" } - func handleNoOp(err error) error { if err == nil { return nil } - - if _, ok := err.(NoOpError); !ok { - return err + if _, ok := err.(NoOpError); ok { + return nil } - return nil + return err } diff --git a/internal/execution/keepsync_test.go b/internal/execution/keepsync_test.go new file mode 100644 index 0000000..2e2c693 --- /dev/null +++ b/internal/execution/keepsync_test.go @@ -0,0 +1,20 @@ +package execution + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +type noopError struct{} + +func (e noopError) NoOp() bool { return true } +func (e noopError) Error() string { return "NoOp" } + +func TestHandleNoOp(t *testing.T) { + otherErr := errors.New("other error") + + assert.Equal(t, nil, handleNoOp(noopError{})) + assert.Equal(t, nil, handleNoOp(nil)) + assert.Equal(t, otherErr, handleNoOp(otherErr)) +} diff --git a/internal/storage/keepsync.go b/internal/storage/keepsync.go index f008e2a..dde26cd 100644 --- a/internal/storage/keepsync.go +++ b/internal/storage/keepsync.go @@ -19,18 +19,12 @@ type NoOpError interface { NoOp() bool } -type noopError struct{} - -func (e noopError) NoOp() bool { return true } -func (e noopError) Error() string { return "NoOp" } - func handleNoOp(err error) error { if err == nil { return nil } - - if _, ok := err.(NoOpError); !ok { - return err + if _, ok := err.(NoOpError); ok { + return nil } - return nil + return err } diff --git a/internal/storage/keepsync_test.go b/internal/storage/keepsync_test.go new file mode 100644 index 0000000..00b2182 --- /dev/null +++ b/internal/storage/keepsync_test.go @@ -0,0 +1,20 @@ +package storage + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +type noopError struct{} + +func (e noopError) NoOp() bool { return true } +func (e noopError) Error() string { return "NoOp" } + +func TestHandleNoOp(t *testing.T) { + otherErr := errors.New("other error") + + assert.Equal(t, nil, handleNoOp(noopError{})) + assert.Equal(t, nil, handleNoOp(nil)) + assert.Equal(t, otherErr, handleNoOp(otherErr)) +} From d7be7e722b8b83805debe7991cebc9cc63c33299 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 6 Apr 2021 11:08:22 +1000 Subject: [PATCH 084/154] Example iterators (#114) * Fixed #90 * Added an example for SliceIterator --- api_matop.go | 4 +- example_iterator_test.go | 129 +++++++++++++++++++++++---------------- 2 files changed, 80 insertions(+), 53 deletions(-) diff --git a/api_matop.go b/api_matop.go index bf412ea..e0f479d 100644 --- a/api_matop.go +++ b/api_matop.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // this file handles matops. While by default most of these matops should already have been defined as part of the // Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions diff --git a/example_iterator_test.go b/example_iterator_test.go index a6b31da..aff34e3 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -1,52 +1,77 @@ -package tensor - -import "fmt" - -// This is an example of how to use `IteratorFromDense` from a row-major Dense tensor -func Example_iteratorRowmajor() { - T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) - it := IteratorFromDense(T) - fmt.Printf("T:\n%v\n", T) - - for i, err := it.Start(); err == nil; i, err = it.Next() { - fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) - } - - // Output: - // T: - // ⎡0 1 2⎤ - // ⎣3 4 5⎦ - // - // i: 0, coord: [0 1] - // i: 1, coord: [0 2] - // i: 2, coord: [1 0] - // i: 3, coord: [1 1] - // i: 4, coord: [1 2] - // i: 5, coord: [0 0] - -} - -// This is an example of using `IteratorFromDense` on a col-major Dense tensor. More importantly -// this example shows the order of the iteration. -func Example_iteratorcolMajor() { - T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) - it := IteratorFromDense(T) - fmt.Printf("T:\n%v\n", T) - - for i, err := it.Start(); err == nil; i, err = it.Next() { - fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) - } - - // Output: - // T: - // ⎡0 2 4⎤ - // ⎣1 3 5⎦ - // - // i: 0, coord: [0 1] - // i: 2, coord: [0 2] - // i: 4, coord: [1 0] - // i: 1, coord: [1 1] - // i: 3, coord: [1 2] - // i: 5, coord: [0 0] - -} +package tensor + +import "fmt" + +// This is an example of how to use `IteratorFromDense` from a row-major Dense tensor +func Example_iteratorRowmajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) + it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // + // i: 0, coord: [0 1] + // i: 1, coord: [0 2] + // i: 2, coord: [1 0] + // i: 3, coord: [1 1] + // i: 4, coord: [1 2] + // i: 5, coord: [0 0] + +} + +// This is an example of using `IteratorFromDense` on a col-major Dense tensor. More importantly +// this example shows the order of the iteration. +func Example_iteratorcolMajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) + it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // + // i: 0, coord: [0 1] + // i: 2, coord: [0 2] + // i: 4, coord: [1 0] + // i: 1, coord: [1 1] + // i: 3, coord: [1 2] + // i: 5, coord: [0 0] + +} + +func ExampleSliceIter() { + T := New(WithShape(3, 3), WithBacking(Range(Float64, 0, 9))) + S, err := T.Slice(makeRS(1, 3), makeRS(1, 3)) + if err != nil { + fmt.Printf("Err %v\n", err) + return + } + fmt.Printf("S (requires iterator? %t)\n%v\n", S.(*Dense).RequiresIterator(), S) + it := IteratorFromDense(S.(*Dense)) + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i %d, coord %v\n", i, it.Coord()) + } + + // Output: + // S (requires iterator? true) + // ⎡4 5⎤ + // ⎣7 8⎦ + // + // i 0, coord [0 1] + // i 1, coord [1 0] + // i 3, coord [1 1] + // i 4, coord [0 0] + +} From 1cfb599134e531220f393c410e43a90458972803 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 9 Jul 2021 10:57:05 +1000 Subject: [PATCH 085/154] Started to move out Dtype to its own package. --- api_unary_test.go | 13 +- array.go | 15 +- collections.go | 60 +-- consopt.go | 5 +- defaultengine.go | 3 +- defaultengine_misc.go | 162 +++---- defaultengine_prep.go | 7 +- defaultenginefloat32.go | 3 +- defaultenginefloat64.go | 3 +- dense.go | 11 +- dense_apply_test.go | 446 +++++++++--------- dense_compat.go | 5 +- dense_compat_test.go | 3 +- dense_generated.go | 10 +- dense_generated_test.go | 5 +- dense_getset_test.go | 7 +- dense_io.go | 3 +- dense_mask_inspection.go | 4 +- dense_matop_test.go | 9 +- dense_reduction_test.go | 23 +- engine.go | 4 +- example_iterator_test.go | 104 ++-- flags.go | 6 +- flags_test.go | 180 +++---- generic_utils.go | 5 +- genlib2/dense_compat.go | 2 +- genlib2/dense_compat_tests.go | 574 +++++++++++------------ genlib2/dense_cons.go | 4 +- genlib2/dense_cons_tests.go | 170 +++---- genlib2/dense_getset_tests.go | 18 +- genlib2/dense_io.go | 2 +- genlib2/dense_reduction_methods_tests.go | 328 ++++++------- genlib2/dense_reduction_tests.go | 2 +- genlib2/generic_utils.go | 6 +- genlib2/native_iterator.go | 2 +- genlib2/native_select.go | 2 +- genlib2/testutils.go | 10 +- go.mod | 7 +- go.sum | 7 +- interfaces.go | 3 +- native/iterator_native.go | 3 +- native/iterator_native2.go | 3 +- optimizations_test.go | 30 +- perf.go | 4 +- scalar.go | 2 +- sparse.go | 3 +- sparse_test.go | 210 ++++----- tensor.go | 4 +- test_test.go | 3 +- testutils_test.go | 8 +- type_test.go | 6 +- types.go | 333 +------------ 52 files changed, 1285 insertions(+), 1557 deletions(-) diff --git a/api_unary_test.go b/api_unary_test.go index 9c735e6..5f453a5 100644 --- a/api_unary_test.go +++ b/api_unary_test.go @@ -1,14 +1,14 @@ package tensor import ( + "math" "math/rand" "testing" "testing/quick" "time" - "math" - "github.com/stretchr/testify/assert" "github.com/chewxy/math32" + "github.com/stretchr/testify/assert" ) /* @@ -683,7 +683,6 @@ func TestLog10(t *testing.T) { t.Errorf("Inv tests for Log10 failed: %v", err) } - // unsafe invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -720,7 +719,6 @@ func TestLog10(t *testing.T) { t.Errorf("Inv tests using unsafe for Log10 failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -836,7 +834,6 @@ func TestAbs(t *testing.T) { } } - func TestTanh(t *testing.T) { var r *rand.Rand // default @@ -926,7 +923,6 @@ func TestTanh(t *testing.T) { t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -973,7 +969,6 @@ func TestTanh(t *testing.T) { t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) } - // incr invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -1062,7 +1057,6 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests for Log2 failed: %v", err) } - // unsafe invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -1099,7 +1093,6 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -1177,4 +1170,4 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) } -} \ No newline at end of file +} diff --git a/array.go b/array.go index ca948d6..e805405 100644 --- a/array.go +++ b/array.go @@ -7,17 +7,18 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // array is the underlying generic array. type array struct { - storage.Header // the header - the Go representation (a slice) - t Dtype // the element type + storage.Header // the header - the Go representation (a slice) + t dtype.Dtype // the element type } // makeArray makes an array. The memory allocation is handled by Go -func makeArray(t Dtype, length int) array { +func makeArray(t dtype.Dtype, length int) array { v := malloc(t, length) hdr := storage.Header{ Raw: v, @@ -41,7 +42,7 @@ func arrayFromSlice(x interface{}) array { Header: storage.Header{ Raw: storage.AsByteSlice(x), }, - t: Dtype{elT}, + t: dtype.Dtype{elT}, } } @@ -57,7 +58,7 @@ func (a *array) fromSlice(x interface{}) { } elT := xT.Elem() a.Raw = storage.AsByteSlice(x) - a.t = Dtype{elT} + a.t = dtype.Dtype{elT} } // fromSliceOrTensor populates the value from a slice or anything that can form an array @@ -206,13 +207,13 @@ func (a *array) rtype() reflect.Type { return a.t.Type } /* MEMORY MOVEMENT STUFF */ // malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory -func malloc(t Dtype, length int) []byte { +func malloc(t dtype.Dtype, length int) []byte { size := int(calcMemSize(t, length)) return make([]byte, size) } // calcMemSize calulates the memory size of an array (given its size) -func calcMemSize(dt Dtype, size int) int64 { +func calcMemSize(dt dtype.Dtype, size int) int64 { return int64(dt.Size()) * int64(size) } diff --git a/collections.go b/collections.go index 5f4d075..34e8284 100644 --- a/collections.go +++ b/collections.go @@ -1,30 +1,30 @@ -package tensor - -import "github.com/pkg/errors" - -func densesToTensors(a []*Dense) []Tensor { - retVal := make([]Tensor, len(a)) - for i, t := range a { - retVal[i] = t - } - return retVal -} - -func densesToDenseTensors(a []*Dense) []DenseTensor { - retVal := make([]DenseTensor, len(a)) - for i, t := range a { - retVal[i] = t - } - return retVal -} - -func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) { - retVal := make([]DenseTensor, len(a)) - var ok bool - for i, t := range a { - if retVal[i], ok = t.(DenseTensor); !ok { - return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i) - } - } - return retVal, nil -} +package tensor + +import "github.com/pkg/errors" + +func densesToTensors(a []*Dense) []Tensor { + retVal := make([]Tensor, len(a)) + for i, t := range a { + retVal[i] = t + } + return retVal +} + +func densesToDenseTensors(a []*Dense) []DenseTensor { + retVal := make([]DenseTensor, len(a)) + for i, t := range a { + retVal[i] = t + } + return retVal +} + +func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) { + retVal := make([]DenseTensor, len(a)) + var ok bool + for i, t := range a { + if retVal[i], ok = t.(DenseTensor); !ok { + return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i) + } + } + return retVal, nil +} diff --git a/consopt.go b/consopt.go index 7297a0f..ab4135d 100644 --- a/consopt.go +++ b/consopt.go @@ -3,6 +3,7 @@ package tensor import ( "reflect" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -10,7 +11,7 @@ import ( type ConsOpt func(Tensor) // Of is a construction option for a Tensor. -func Of(a Dtype) ConsOpt { +func Of(a dtype.Dtype) ConsOpt { Register(a) f := func(t Tensor) { switch tt := t.(type) { @@ -113,7 +114,7 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { xv0 := xv.Index(0) // xv[0] xv0.Set(reflect.ValueOf(x)) tt.array.Header.Raw = storage.AsByteSlice(xv.Interface()) - tt.t = Dtype{xT} + tt.t = dtype.Dtype{xT} tt.mask = mask default: diff --git a/defaultengine.go b/defaultengine.go index 91f674b..d1450c7 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -2,6 +2,7 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" ) @@ -11,7 +12,7 @@ type StdEng struct { } // makeArray allocates a slice for the array -func (e StdEng) makeArray(arr *array, t Dtype, size int) { +func (e StdEng) makeArray(arr *array, t dtype.Dtype, size int) { arr.Raw = malloc(t, size) arr.t = t diff --git a/defaultengine_misc.go b/defaultengine_misc.go index bb70e57..8ce04db 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -1,81 +1,81 @@ -package tensor - -import ( - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" -) - -func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nonComplexNumberTypes); err != nil { - return nil, errors.Wrap(err, "Clamp failed") - } - - var reuse DenseTensor - var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { - return nil, errors.Wrap(err, "Unable to handle funcOpts") - } - - typ := a.Dtype().Type - var ait, rit Iterator - var dataA, dataReuse *storage.Header - var useIter bool - - if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { - return nil, errors.Wrapf(err, opFail, "StdEng.Neg") - } - - if useIter { - switch { - case incr: - cloned := a.Clone().(Tensor) - if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { - return nil, errors.Wrapf(err, "Unable to perform Clamp") - } - ait.Reset() - err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) - retVal = reuse - case toReuse: - storage.CopyIter(typ, dataReuse, dataA, rit, ait) - rit.Reset() - err = e.E.ClampIter(typ, dataReuse, rit, min, max) - retVal = reuse - case !safe: - err = e.E.ClampIter(typ, dataA, ait, min, max) - retVal = a - default: - cloned := a.Clone().(Tensor) - err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) - retVal = cloned - } - return - } - switch { - case incr: - cloned := a.Clone().(Tensor) - if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { - return nil, errors.Wrapf(err, "Unable to perform Clamp") - } - err = e.E.Add(typ, dataReuse, cloned.hdr()) - retVal = reuse - case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Clamp(typ, dataReuse, min, max) - retVal = reuse - case !safe: - err = e.E.Clamp(typ, dataA, min, max) - retVal = a - default: - cloned := a.Clone().(Tensor) - err = e.E.Clamp(typ, cloned.hdr(), min, max) - retVal = cloned - } - return -} - -func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) { - return e.Mul(a, x, WithIncr(y)) -} -func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) { - return e.MulScalar(a, x, true, WithIncr(y)) -} +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(a, nonComplexNumberTypes); err != nil { + return nil, errors.Wrap(err, "Clamp failed") + } + + var reuse DenseTensor + var safe, toReuse, incr bool + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + + typ := a.Dtype().Type + var ait, rit Iterator + var dataA, dataReuse *storage.Header + var useIter bool + + if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.Neg") + } + + if useIter { + switch { + case incr: + cloned := a.Clone().(Tensor) + if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { + return nil, errors.Wrapf(err, "Unable to perform Clamp") + } + ait.Reset() + err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) + retVal = reuse + case toReuse: + storage.CopyIter(typ, dataReuse, dataA, rit, ait) + rit.Reset() + err = e.E.ClampIter(typ, dataReuse, rit, min, max) + retVal = reuse + case !safe: + err = e.E.ClampIter(typ, dataA, ait, min, max) + retVal = a + default: + cloned := a.Clone().(Tensor) + err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) + retVal = cloned + } + return + } + switch { + case incr: + cloned := a.Clone().(Tensor) + if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { + return nil, errors.Wrapf(err, "Unable to perform Clamp") + } + err = e.E.Add(typ, dataReuse, cloned.hdr()) + retVal = reuse + case toReuse: + storage.Copy(typ, dataReuse, dataA) + err = e.E.Clamp(typ, dataReuse, min, max) + retVal = reuse + case !safe: + err = e.E.Clamp(typ, dataA, min, max) + retVal = a + default: + cloned := a.Clone().(Tensor) + err = e.E.Clamp(typ, cloned.hdr(), min, max) + retVal = cloned + } + return +} + +func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) { + return e.Mul(a, x, WithIncr(y)) +} +func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) { + return e.MulScalar(a, x, true, WithIncr(y)) +} diff --git a/defaultengine_prep.go b/defaultengine_prep.go index 261367a..a3df181 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -4,11 +4,12 @@ import ( "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" // "log" ) -func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { +func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -106,13 +107,13 @@ func unaryCheck(a Tensor, tc *typeclass) error { // scalarDtypeCheck checks that a scalar value has the same dtype as the dtype of a given tensor. func scalarDtypeCheck(a Tensor, b interface{}) error { - var dt Dtype + var dt dtype.Dtype switch bt := b.(type) { case Dtyper: dt = bt.Dtype() default: t := reflect.TypeOf(b) - dt = Dtype{t} + dt = dtype.Dtype{t} } if a.Dtype() != dt { diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index 1618a4f..9f1ebf7 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -2,6 +2,7 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" @@ -112,7 +113,7 @@ type Float32Engine struct { } // makeArray allocates a slice for the array -func (e Float32Engine) makeArray(arr *array, t Dtype, size int) { +func (e Float32Engine) makeArray(arr *array, t dtype.Dtype, size int) { if t != Float32 { panic("Float32Engine only creates float32s") } diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 3186408..4e2167a 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -2,6 +2,7 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" @@ -112,7 +113,7 @@ type Float64Engine struct { } // makeArray allocates a slice for the array -func (e Float64Engine) makeArray(arr *array, t Dtype, size int) { +func (e Float64Engine) makeArray(arr *array, t dtype.Dtype, size int) { if t != Float64 { panic("Float64Engine only creates float64s") } diff --git a/dense.go b/dense.go index 92a535c..f873f61 100644 --- a/dense.go +++ b/dense.go @@ -6,6 +6,7 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -34,11 +35,11 @@ type Dense struct { } // NewDense creates a new *Dense. It tries its best to get from the tensor pool. -func NewDense(dt Dtype, shape Shape, opts ...ConsOpt) *Dense { +func NewDense(dt dtype.Dtype, shape Shape, opts ...ConsOpt) *Dense { return recycledDense(dt, shape, opts...) } -func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { +func recycledDense(dt dtype.Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { retVal = recycledDenseNoFix(dt, shape, opts...) retVal.fix() if err := retVal.sanity(); err != nil { @@ -47,7 +48,7 @@ func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { return } -func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { +func recycledDenseNoFix(dt dtype.Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { // size := shape.TotalSize() //if shape.IsScalar() { // size = 1 @@ -102,7 +103,7 @@ func (t *Dense) makeArray(size int) { func (t *Dense) Info() *AP { return &t.AP } // Dtype returns the data type of the *Dense tensor. -func (t *Dense) Dtype() Dtype { return t.t } +func (t *Dense) Dtype() dtype.Dtype { return t.t } // Data returns the underlying array. If the *Dense represents a scalar value, the scalar value is returned instead func (t *Dense) Data() interface{} { @@ -287,7 +288,7 @@ func (t *Dense) fix() { } else { t.SetShape(size) // vector } - case t.array.Header.Raw == nil && t.t != Dtype{}: + case t.array.Header.Raw == nil && t.t != dtype.Dtype{}: size := t.Shape().TotalSize() t.makeArray(size) diff --git a/dense_apply_test.go b/dense_apply_test.go index 5e8c23d..8d73631 100644 --- a/dense_apply_test.go +++ b/dense_apply_test.go @@ -1,222 +1,224 @@ -package tensor - -import ( - "math/rand" - "testing" - "testing/quick" - "time" - "unsafe" -) - -func getMutateVal(dt Dtype) interface{} { - switch dt { - case Int: - return int(1) - case Int8: - return int8(1) - case Int16: - return int16(1) - case Int32: - return int32(1) - case Int64: - return int64(1) - case Uint: - return uint(1) - case Uint8: - return uint8(1) - case Uint16: - return uint16(1) - case Uint32: - return uint32(1) - case Uint64: - return uint64(1) - case Float32: - return float32(1) - case Float64: - return float64(1) - case Complex64: - var c complex64 = 1 - return c - case Complex128: - var c complex128 = 1 - return c - case Bool: - return true - case String: - return "Hello World" - case Uintptr: - return uintptr(0xdeadbeef) - case UnsafePointer: - return unsafe.Pointer(uintptr(0xdeadbeef)) - } - return nil -} - -func getMutateFn(dt Dtype) interface{} { - switch dt { - case Int: - return mutateI - case Int8: - return mutateI8 - case Int16: - return mutateI16 - case Int32: - return mutateI32 - case Int64: - return mutateI64 - case Uint: - return mutateU - case Uint8: - return mutateU8 - case Uint16: - return mutateU16 - case Uint32: - return mutateU32 - case Uint64: - return mutateU64 - case Float32: - return mutateF32 - case Float64: - return mutateF64 - case Complex64: - return mutateC64 - case Complex128: - return mutateC128 - case Bool: - return mutateB - case String: - return mutateStr - case Uintptr: - return mutateUintptr - case UnsafePointer: - return mutateUnsafePointer - } - return nil -} - -func TestDense_Apply(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - - // wrong fn type/illogical values - if _, err = a.Apply(getMutateFn); err == nil { - t.Error("Expected an error") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} - -func TestDense_Apply_unsafe(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn, UseUnsafe()) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - if ret != a { - t.Error("Expected ret == correct (Unsafe option was used)") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} - -func TestDense_Apply_reuse(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - reuse := q.Clone().(*Dense) - reuse.Zero() - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn, WithReuse(reuse)) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - if ret != reuse { - t.Error("Expected ret == correct (Unsafe option was used)") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} +package tensor + +import ( + "math/rand" + "testing" + "testing/quick" + "time" + "unsafe" + + "gorgonia.org/dtype" +) + +func getMutateVal(dt dtype.Dtype) interface{} { + switch dt { + case Int: + return int(1) + case Int8: + return int8(1) + case Int16: + return int16(1) + case Int32: + return int32(1) + case Int64: + return int64(1) + case Uint: + return uint(1) + case Uint8: + return uint8(1) + case Uint16: + return uint16(1) + case Uint32: + return uint32(1) + case Uint64: + return uint64(1) + case Float32: + return float32(1) + case Float64: + return float64(1) + case Complex64: + var c complex64 = 1 + return c + case Complex128: + var c complex128 = 1 + return c + case Bool: + return true + case String: + return "Hello World" + case Uintptr: + return uintptr(0xdeadbeef) + case UnsafePointer: + return unsafe.Pointer(uintptr(0xdeadbeef)) + } + return nil +} + +func getMutateFn(dt dtype.Dtype) interface{} { + switch dt { + case Int: + return mutateI + case Int8: + return mutateI8 + case Int16: + return mutateI16 + case Int32: + return mutateI32 + case Int64: + return mutateI64 + case Uint: + return mutateU + case Uint8: + return mutateU8 + case Uint16: + return mutateU16 + case Uint32: + return mutateU32 + case Uint64: + return mutateU64 + case Float32: + return mutateF32 + case Float64: + return mutateF64 + case Complex64: + return mutateC64 + case Complex128: + return mutateC128 + case Bool: + return mutateB + case String: + return mutateStr + case Uintptr: + return mutateUintptr + case UnsafePointer: + return mutateUnsafePointer + } + return nil +} + +func TestDense_Apply(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nil, nil) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + + // wrong fn type/illogical values + if _, err = a.Apply(getMutateFn); err == nil { + t.Error("Expected an error") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} + +func TestDense_Apply_unsafe(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nil, nil) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn, UseUnsafe()) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + if ret != a { + t.Error("Expected ret == correct (Unsafe option was used)") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} + +func TestDense_Apply_reuse(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nil, nil) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + reuse := q.Clone().(*Dense) + reuse.Zero() + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn, WithReuse(reuse)) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + if ret != reuse { + t.Error("Expected ret == correct (Unsafe option was used)") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} diff --git a/dense_compat.go b/dense_compat.go index 0d6073a..1161cf1 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -15,9 +15,10 @@ import ( "github.com/chewxy/math32" "github.com/pkg/errors" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) -func convFromFloat64s(to Dtype, data []float64) interface{} { +func convFromFloat64s(to dtype.Dtype, data []float64) interface{} { switch to { case Int: retVal := make([]int, len(data)) @@ -431,7 +432,9 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { data = append(data, convToFloat64(t.Get(next))) } err = nil + } + retVal = mat.NewDense(r, c, data) return } diff --git a/dense_compat_test.go b/dense_compat_test.go index c641203..581563c 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -11,13 +11,14 @@ import ( arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/stretchr/testify/assert" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) var toMat64Tests = []struct { data interface{} sliced interface{} shape Shape - dt Dtype + dt dtype.Dtype }{ {Range(Int, 0, 6), []int{0, 1, 3, 4}, Shape{2, 3}, Int}, {Range(Int8, 0, 6), []int8{0, 1, 3, 4}, Shape{2, 3}, Int8}, diff --git a/dense_generated.go b/dense_generated.go index 6349bfb..5a44a10 100644 --- a/dense_generated.go +++ b/dense_generated.go @@ -2,10 +2,14 @@ package tensor -import "reflect" +import ( + "reflect" + + "gorgonia.org/dtype" +) // Ones creates a *Dense with the provided shape and type -func Ones(dt Dtype, shape ...int) *Dense { +func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) switch d.t.Kind() { case reflect.Int: @@ -68,7 +72,7 @@ func Ones(dt Dtype, shape ...int) *Dense { // ⎢1 0 0 0⎥ // ⎢0 1 0 0⎥ // ⎣0 0 1 0⎦ -func I(dt Dtype, r, c, k int) *Dense { +func I(dt dtype.Dtype, r, c, k int) *Dense { ret := New(Of(dt), WithShape(r, c)) i := k if k < 0 { diff --git a/dense_generated_test.go b/dense_generated_test.go index e87baa0..7332bf4 100644 --- a/dense_generated_test.go +++ b/dense_generated_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) var onesTests = []struct { - of Dtype + of dtype.Dtype shape Shape correct interface{} }{ @@ -56,7 +57,7 @@ func TestOnes(t *testing.T) { // yes, it's a pun on eye tests, stop asking and go see your optometrist var eyeTests = []struct { - E Dtype + E dtype.Dtype R, C, K int correct interface{} diff --git a/dense_getset_test.go b/dense_getset_test.go index 8ab8e44..cde0542 100644 --- a/dense_getset_test.go +++ b/dense_getset_test.go @@ -8,10 +8,11 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) var denseSetGetTests = []struct { - of Dtype + of dtype.Dtype data interface{} set interface{} @@ -48,7 +49,7 @@ func TestDense_setget(t *testing.T) { } var denseMemsetTests = []struct { - of Dtype + of dtype.Dtype data interface{} val interface{} shape Shape @@ -88,7 +89,7 @@ func TestDense_memset(t *testing.T) { } var denseZeroTests = []struct { - of Dtype + of dtype.Dtype data interface{} correct interface{} diff --git a/dense_io.go b/dense_io.go index 7bb9608..f5a1abd 100644 --- a/dense_io.go +++ b/dense_io.go @@ -16,6 +16,7 @@ import ( flatbuffers "github.com/google/flatbuffers/go" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/serialization/fb" "gorgonia.org/tensor/internal/serialization/pb" ) @@ -423,7 +424,7 @@ func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { // convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. // If into is nil, then a backing slice will be created. -func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { +func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { case reflect.Int: diff --git a/dense_mask_inspection.go b/dense_mask_inspection.go index 39be72f..7e1c30c 100644 --- a/dense_mask_inspection.go +++ b/dense_mask_inspection.go @@ -1,10 +1,12 @@ package tensor +import "gorgonia.org/dtype" + type maskedReduceFn func(Tensor) interface{} // MaskedReduce applies a reduction function of type maskedReduceFn to mask, and returns // either an int, or another array -func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) interface{} { +func MaskedReduce(t *Dense, retType dtype.Dtype, fn maskedReduceFn, axis ...int) interface{} { if len(axis) == 0 || t.IsVector() { return fn(t) } diff --git a/dense_matop_test.go b/dense_matop_test.go index abab26e..dd38660 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" "gorgonia.org/vecf64" ) @@ -41,7 +42,7 @@ func cloneArray(a interface{}) interface{} { return nil } -func castToDt(val float64, dt Dtype) interface{} { +func castToDt(val float64, dt dtype.Dtype) interface{} { switch dt { case Bool: return false @@ -694,7 +695,7 @@ func TestDense_RollAxis(t *testing.T) { var concatTests = []struct { name string - dt Dtype + dt dtype.Dtype a interface{} b interface{} shape Shape @@ -852,7 +853,7 @@ func TestDense_Concat_sliced(t *testing.T) { var simpleStackTests = []struct { name string - dt Dtype + dt dtype.Dtype shape Shape axis int stackCount int @@ -903,7 +904,7 @@ var simpleStackTests = []struct { var viewStackTests = []struct { name string - dt Dtype + dt dtype.Dtype shape Shape transform []int slices []Slice diff --git a/dense_reduction_test.go b/dense_reduction_test.go index f83d0c6..05de324 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -6,11 +6,12 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" ) var denseReductionTests = []struct { - of Dtype + of dtype.Dtype fn interface{} def interface{} axis int @@ -116,7 +117,7 @@ func TestDense_Reduce(t *testing.T) { var sumTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -273,7 +274,7 @@ func TestDense_Sum(t *testing.T) { var maxTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -411,7 +412,7 @@ func TestDense_Max(t *testing.T) { var minTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -547,17 +548,3 @@ func TestDense_Min(t *testing.T) { _, err = T.Min(1000) assert.NotNil(err) } - -func TestSlicedSum(t *testing.T) { - T := New(WithShape(4, 4), WithBacking([]int{ - 1, 2, 3, 4, - 5, 6, 7, 8, - 1, 2, 3, 4, - 5, 6, 7, 8, - })) - s, _ := T.Slice(sli(1, 3), sli(1, 3)) - sum, _ := Sum(s) - if sum.Data().(int) != 18 { - t.Errorf("Expected the sum of %v to be 18. Got %v instead", s, sum) - } -} diff --git a/engine.go b/engine.go index 88bd6e8..a1efa5b 100644 --- a/engine.go +++ b/engine.go @@ -1,5 +1,7 @@ package tensor +import "gorgonia.org/dtype" + // Memory is a representation of memory of the value. // // The main reason for requiring both Uintptr() and Pointer() methods is because while Go currently does not have a compacting @@ -51,7 +53,7 @@ type StandardEngine interface { } type arrayMaker interface { - makeArray(arr *array, t Dtype, size int) + makeArray(arr *array, t dtype.Dtype, size int) } // NonStdEngine are any engines that do not allocate using the default built in allocator diff --git a/example_iterator_test.go b/example_iterator_test.go index a6b31da..21932da 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -1,52 +1,52 @@ -package tensor - -import "fmt" - -// This is an example of how to use `IteratorFromDense` from a row-major Dense tensor -func Example_iteratorRowmajor() { - T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) - it := IteratorFromDense(T) - fmt.Printf("T:\n%v\n", T) - - for i, err := it.Start(); err == nil; i, err = it.Next() { - fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) - } - - // Output: - // T: - // ⎡0 1 2⎤ - // ⎣3 4 5⎦ - // - // i: 0, coord: [0 1] - // i: 1, coord: [0 2] - // i: 2, coord: [1 0] - // i: 3, coord: [1 1] - // i: 4, coord: [1 2] - // i: 5, coord: [0 0] - -} - -// This is an example of using `IteratorFromDense` on a col-major Dense tensor. More importantly -// this example shows the order of the iteration. -func Example_iteratorcolMajor() { - T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) - it := IteratorFromDense(T) - fmt.Printf("T:\n%v\n", T) - - for i, err := it.Start(); err == nil; i, err = it.Next() { - fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) - } - - // Output: - // T: - // ⎡0 2 4⎤ - // ⎣1 3 5⎦ - // - // i: 0, coord: [0 1] - // i: 2, coord: [0 2] - // i: 4, coord: [1 0] - // i: 1, coord: [1 1] - // i: 3, coord: [1 2] - // i: 5, coord: [0 0] - -} +package tensor + +import "fmt" + +// This is an example of how to use `IteratorFromDense` from a row-major Dense tensor +func Example_iteratorRowmajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) + it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // + // i: 0, coord: [0 1] + // i: 1, coord: [0 2] + // i: 2, coord: [1 0] + // i: 3, coord: [1 1] + // i: 4, coord: [1 2] + // i: 5, coord: [0 0] + +} + +// This is an example of using `IteratorFromDense` on a col-major Dense tensor. More importantly +// this example shows the order of the iteration. +func Example_iteratorcolMajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) + it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // + // i: 0, coord: [0 1] + // i: 2, coord: [0 2] + // i: 4, coord: [1 0] + // i: 1, coord: [1 1] + // i: 3, coord: [1 2] + // i: 5, coord: [0 0] + +} diff --git a/flags.go b/flags.go index 22fed67..c8000d1 100644 --- a/flags.go +++ b/flags.go @@ -1,5 +1,7 @@ package tensor +import "gorgonia.org/dtype" + // DataOrder is a flag that indicates the order of data. The default DataOrder (0) // is what this package uses by default. type DataOrder byte @@ -123,7 +125,7 @@ type OpOpt struct { incr Tensor unsafe bool same bool - t Dtype + t dtype.Dtype } // ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. @@ -164,4 +166,4 @@ func (fo *OpOpt) Same() bool { return fo.same } // a.Add(b, As(Int)) // indicates that the result of `Add()` should be converted to a Tensor of Int. // Note that this function is not yet supported in most operations. -func (fo *OpOpt) As() Dtype { return fo.t } +func (fo *OpOpt) As() dtype.Dtype { return fo.t } diff --git a/flags_test.go b/flags_test.go index 83dd3be..26d10e8 100644 --- a/flags_test.go +++ b/flags_test.go @@ -1,90 +1,90 @@ -package tensor - -import "testing" - -func TestMemoryFlag(t *testing.T) { - var defaultFlag MemoryFlag - if defaultFlag.manuallyManaged() || !defaultFlag.nativelyAccessible() { - t.Errorf("Something went wrong with the creation of flags") - } - - a := ManuallyManaged - if !a.manuallyManaged() { - t.Errorf("Expected ManuallyManaged to be true") - } - if !a.nativelyAccessible() { - t.Errorf("Expected ManuallyManaged to be nativelyAccessible") - } - - b := NativelyInaccessible - if b.manuallyManaged() { - t.Errorf("Expected NativelyInaccessible to not be manually managed") - } - if b.nativelyAccessible() { - t.Errorf("Expected NativelyInaccessible to be false %v", b.nativelyAccessible()) - } - - c := MakeMemoryFlag(ManuallyManaged, NativelyInaccessible) - if !c.manuallyManaged() { - t.Errorf("Expected c to be manually managed") - } - if c.nativelyAccessible() { - t.Errorf("Expected c to be natively inaccessible") - } -} - -func TestDataOrder(t *testing.T) { - var defaultFlag DataOrder - if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { - t.Error("Expected default flag to be row major and contiguous and not transposed") - } - if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { - t.Error("Expected default flag to be row major and contiguous") - } - if defaultFlag.String() != "Contiguous, RowMajor" { - t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) - } - - ncrm := MakeDataOrder(NonContiguous) - if ncrm.IsColMajor() || ncrm.IsContiguous() { - t.Error("Expected noncontiguous row major.") - } - if ncrm.String() != "NonContiguous, RowMajor" { - t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) - } - - cm := ColMajor - if cm.IsRowMajor() { - t.Error("colMajor cannot be rowMajor") - } - if cm.IsNotContiguous() { - t.Error("ColMajor by default is contiguous") - } - if cm.String() != "Contiguous, ColMajor" { - t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) - } - - // check toggle - rm := cm.toggleColMajor() - if rm.IsColMajor() { - t.Errorf("toggled cm should be rm") - } - - cm = rm.toggleColMajor() - if cm.IsRowMajor() { - t.Errorf("toggled rm should be cm") - } - - transposed := MakeDataOrder(Transposed) - if !transposed.IsTransposed() { - t.Error("Expected transposed flag to be set") - } - if transposed.String() != "Contiguous, RowMajorᵀ" { - t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) - } - untransposed := transposed.clearTransposed() - if untransposed != defaultFlag { - t.Error("Expected default flag after untransposing") - } - -} +package tensor + +import "testing" + +func TestMemoryFlag(t *testing.T) { + var defaultFlag MemoryFlag + if defaultFlag.manuallyManaged() || !defaultFlag.nativelyAccessible() { + t.Errorf("Something went wrong with the creation of flags") + } + + a := ManuallyManaged + if !a.manuallyManaged() { + t.Errorf("Expected ManuallyManaged to be true") + } + if !a.nativelyAccessible() { + t.Errorf("Expected ManuallyManaged to be nativelyAccessible") + } + + b := NativelyInaccessible + if b.manuallyManaged() { + t.Errorf("Expected NativelyInaccessible to not be manually managed") + } + if b.nativelyAccessible() { + t.Errorf("Expected NativelyInaccessible to be false %v", b.nativelyAccessible()) + } + + c := MakeMemoryFlag(ManuallyManaged, NativelyInaccessible) + if !c.manuallyManaged() { + t.Errorf("Expected c to be manually managed") + } + if c.nativelyAccessible() { + t.Errorf("Expected c to be natively inaccessible") + } +} + +func TestDataOrder(t *testing.T) { + var defaultFlag DataOrder + if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { + t.Error("Expected default flag to be row major and contiguous and not transposed") + } + if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { + t.Error("Expected default flag to be row major and contiguous") + } + if defaultFlag.String() != "Contiguous, RowMajor" { + t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + ncrm := MakeDataOrder(NonContiguous) + if ncrm.IsColMajor() || ncrm.IsContiguous() { + t.Error("Expected noncontiguous row major.") + } + if ncrm.String() != "NonContiguous, RowMajor" { + t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + cm := ColMajor + if cm.IsRowMajor() { + t.Error("colMajor cannot be rowMajor") + } + if cm.IsNotContiguous() { + t.Error("ColMajor by default is contiguous") + } + if cm.String() != "Contiguous, ColMajor" { + t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) + } + + // check toggle + rm := cm.toggleColMajor() + if rm.IsColMajor() { + t.Errorf("toggled cm should be rm") + } + + cm = rm.toggleColMajor() + if cm.IsRowMajor() { + t.Errorf("toggled rm should be cm") + } + + transposed := MakeDataOrder(Transposed) + if !transposed.IsTransposed() { + t.Error("Expected transposed flag to be set") + } + if transposed.String() != "Contiguous, RowMajorᵀ" { + t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) + } + untransposed := transposed.clearTransposed() + if untransposed != defaultFlag { + t.Error("Expected default flag after untransposing") + } + +} diff --git a/generic_utils.go b/generic_utils.go index 24310b5..ca00bd9 100644 --- a/generic_utils.go +++ b/generic_utils.go @@ -7,6 +7,7 @@ import ( "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/vecf32" "gorgonia.org/vecf64" ) @@ -14,7 +15,7 @@ import ( // Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i -func Range(dt Dtype, start, end int) interface{} { +func Range(dt dtype.Dtype, start, end int) interface{} { size := end - start incr := true if start > end { @@ -172,7 +173,7 @@ func Range(dt Dtype, start, end int) interface{} { // For complex Dtypes, the imaginary component will be 0. // // This function is only useful in cases where the randomness is not vital. -func Random(dt Dtype, size int) interface{} { +func Random(dt dtype.Dtype, size int) interface{} { r := rand.New(rand.NewSource(1337)) switch dt.Kind() { case reflect.Int: diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index 8ce7fe8..afb3fda 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -13,7 +13,7 @@ const importsArrowRaw = `import ( ) ` -const conversionsRaw = `func convFromFloat64s(to Dtype, data []float64) interface{} { +const conversionsRaw = `func convFromFloat64s(to dtype.Dtype, data []float64) interface{} { switch to { {{range .Kinds -}} {{if isNumber . -}} diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index d21831a..334f2a8 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -1,287 +1,287 @@ -package main - -import ( - "io" - "text/template" -) - -const compatTestsRaw = `var toMat64Tests = []struct{ - data interface{} - sliced interface{} - shape Shape - dt Dtype -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, - {{end -}} - {{end -}} -} -func TestToMat64(t *testing.T){ - assert := assert.New(t) - for i, tmt := range toMat64Tests { - T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) - var m *mat.Dense - var err error - if m, err = ToMat64(T); err != nil { - t.Errorf("ToMat basic test %d failed : %v", i, err) - continue - } - conv := anyToFloat64s(tmt.data) - assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) - - if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ - t.Errorf("Slice failed %v", err) - continue - } - if m, err = ToMat64(T); err != nil { - t.Errorf("ToMat of slice test %d failed : %v", i, err) - continue - } - conv = anyToFloat64s(tmt.sliced) - assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) - t.Logf("Done") - - if tmt.dt == Float64 { - T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) - if m, err = ToMat64(T, UseUnsafe()); err != nil { - t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) - } - conv = anyToFloat64s(tmt.data) - assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) - conv[0] = 1000 - assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) - conv[0] = 0 // reset for future tests that use the same backing - } - } - // idiocy test - T := New(Of(Float64), WithShape(2,3,4)) - _, err := ToMat64(T) - if err == nil { - t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") - } -} - -func TestFromMat64(t *testing.T){ - assert := assert.New(t) - var m *mat.Dense - var T *Dense - var backing []float64 - - - for i, tmt := range toMat64Tests { - backing = Range(Float64, 0, 6).([]float64) - m = mat.NewDense(2, 3, backing) - T = FromMat64(m) - conv := anyToFloat64s(tmt.data) - assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) - assert.True(T.Shape().Eq(tmt.shape)) - - T = FromMat64(m, As(tmt.dt)) - assert.Equal(tmt.data, T.Data()) - assert.True(T.Shape().Eq(tmt.shape)) - - if tmt.dt == Float64{ - backing = Range(Float64, 0, 6).([]float64) - m = mat.NewDense(2, 3, backing) - T = FromMat64(m, UseUnsafe()) - assert.Equal(backing, T.Float64s()) - assert.True(T.Shape().Eq(tmt.shape)) - backing[0] = 1000 - assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) - } - } -} -` - -const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ - data interface{} - valid []bool - dt arrow.DataType - shape Shape -}{ - {{range .PrimitiveTypes -}} - { - data: Range({{.}}, 0, 6), - valid: []bool{true, true, true, false, true, true}, - dt: arrow.PrimitiveTypes.{{ . }}, - shape: Shape{6,1}, - }, - {{end -}} -} -func TestFromArrowArray(t *testing.T){ - assert := assert.New(t) - var T *Dense - pool := memory.NewGoAllocator() - - for i, taat := range toArrowArrayTests { - var m arrowArray.Interface - - switch taat.dt { - {{range .BinaryTypes -}} - case arrow.BinaryTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - {{if eq . "String" -}} - []string{"0", "1", "2", "3", "4", "5"}, - {{else -}} - Range({{ . }}, 0, 6).([]{{lower . }}), - {{end -}} - taat.valid, - ) - m = b.NewArray() - defer m.Release() - {{end -}} - {{range .FixedWidthTypes -}} - case arrow.FixedWidthTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - {{if eq . "Boolean" -}} - []bool{true, false, true, false, true, false}, - {{else -}} - Range({{ . }}, 0, 6).([]{{lower . }}), - {{end -}} - taat.valid, - ) - m = b.NewArray() - defer m.Release() - {{end -}} - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - Range({{ . }}, 0, 6).([]{{lower . }}), - taat.valid, - ) - m = b.NewArray() - defer m.Release() - {{end -}} - default: - t.Errorf("DataType not supported in tests: %v", taat.dt) - } - - T = FromArrowArray(m) - switch taat.dt { - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{ . }}: - conv := taat.data.([]{{lower . }}) - assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) - {{end -}} - default: - t.Errorf("DataType not supported in tests: %v", taat.dt) - } - for i, invalid := range T.Mask() { - assert.Equal(taat.valid[i], !invalid) - } - assert.True(T.Shape().Eq(taat.shape)) - } -} -` - -const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ - rowMajorData interface{} - colMajorData interface{} - rowMajorValid []bool - colMajorValid []bool - dt arrow.DataType - shape Shape -}{ - {{range .PrimitiveTypes -}} - { - rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, - colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, - dt: arrow.PrimitiveTypes.{{ . }}, - shape: Shape{2,5}, - }, - {{end -}} -} -func TestFromArrowTensor(t *testing.T){ - assert := assert.New(t) - var rowMajorT *Dense - var colMajorT *Dense - pool := memory.NewGoAllocator() - - for i, taat := range toArrowTensorTests { - var rowMajorArr arrowArray.Interface - var colMajorArr arrowArray.Interface - var rowMajor arrowTensor.Interface - var colMajor arrowTensor.Interface - - switch taat.dt { - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - taat.rowMajorValid, - ) - rowMajorArr = b.NewArray() - defer rowMajorArr.Release() - - b.AppendValues( - []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - taat.rowMajorValid, - ) - colMajorArr = b.NewArray() - defer colMajorArr.Release() - - rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) - defer rowMajor.Release() - colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) - defer colMajor.Release() - {{end -}} - default: - t.Errorf("DataType not supported in tests: %v", taat.dt) - } - - rowMajorT = FromArrowTensor(rowMajor) - colMajorT = FromArrowTensor(colMajor) - - assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) - assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) - for i, invalid := range rowMajorT.Mask() { - assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) - } - assert.True(colMajorT.Shape().Eq(taat.shape)) - - assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) - assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) - for i, invalid := range colMajorT.Mask() { - assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) - } - assert.True(rowMajorT.Shape().Eq(taat.shape)) - } -} -` - -var ( - compatTests *template.Template - compatArrowArrayTests *template.Template - compatArrowTensorTests *template.Template -) - -func init() { - compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) - compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) - compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) -} - -func generateDenseCompatTests(f io.Writer, generic Kinds) { - // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming - // collisions - importsArrow.Execute(f, generic) - compatTests.Execute(f, generic) - arrowData := ArrowData{ - BinaryTypes: arrowBinaryTypes, - FixedWidthTypes: arrowFixedWidthTypes, - PrimitiveTypes: arrowPrimitiveTypes, - } - compatArrowArrayTests.Execute(f, arrowData) - compatArrowTensorTests.Execute(f, arrowData) -} +package main + +import ( + "io" + "text/template" +) + +const compatTestsRaw = `var toMat64Tests = []struct{ + data interface{} + sliced interface{} + shape Shape + dt dtype.Dtype +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, + {{end -}} + {{end -}} +} +func TestToMat64(t *testing.T){ + assert := assert.New(t) + for i, tmt := range toMat64Tests { + T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) + var m *mat.Dense + var err error + if m, err = ToMat64(T); err != nil { + t.Errorf("ToMat basic test %d failed : %v", i, err) + continue + } + conv := anyToFloat64s(tmt.data) + assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) + + if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ + t.Errorf("Slice failed %v", err) + continue + } + if m, err = ToMat64(T); err != nil { + t.Errorf("ToMat of slice test %d failed : %v", i, err) + continue + } + conv = anyToFloat64s(tmt.sliced) + assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) + t.Logf("Done") + + if tmt.dt == Float64 { + T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) + if m, err = ToMat64(T, UseUnsafe()); err != nil { + t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) + } + conv = anyToFloat64s(tmt.data) + assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) + conv[0] = 1000 + assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) + conv[0] = 0 // reset for future tests that use the same backing + } + } + // idiocy test + T := New(Of(Float64), WithShape(2,3,4)) + _, err := ToMat64(T) + if err == nil { + t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") + } +} + +func TestFromMat64(t *testing.T){ + assert := assert.New(t) + var m *mat.Dense + var T *Dense + var backing []float64 + + + for i, tmt := range toMat64Tests { + backing = Range(Float64, 0, 6).([]float64) + m = mat.NewDense(2, 3, backing) + T = FromMat64(m) + conv := anyToFloat64s(tmt.data) + assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) + assert.True(T.Shape().Eq(tmt.shape)) + + T = FromMat64(m, As(tmt.dt)) + assert.Equal(tmt.data, T.Data()) + assert.True(T.Shape().Eq(tmt.shape)) + + if tmt.dt == Float64{ + backing = Range(Float64, 0, 6).([]float64) + m = mat.NewDense(2, 3, backing) + T = FromMat64(m, UseUnsafe()) + assert.Equal(backing, T.Float64s()) + assert.True(T.Shape().Eq(tmt.shape)) + backing[0] = 1000 + assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) + } + } +} +` + +const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ + data interface{} + valid []bool + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + data: Range({{.}}, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{6,1}, + }, + {{end -}} +} +func TestFromArrowArray(t *testing.T){ + assert := assert.New(t) + var T *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowArrayTests { + var m arrowArray.Interface + + switch taat.dt { + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "String" -}} + []string{"0", "1", "2", "3", "4", "5"}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "Boolean" -}} + []bool{true, false, true, false, true, false}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + Range({{ . }}, 0, 6).([]{{lower . }}), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + T = FromArrowArray(m) + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + conv := taat.data.([]{{lower . }}) + assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + for i, invalid := range T.Mask() { + assert.Equal(taat.valid[i], !invalid) + } + assert.True(T.Shape().Eq(taat.shape)) + } +} +` + +const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ + rowMajorData interface{} + colMajorData interface{} + rowMajorValid []bool + colMajorValid []bool + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{2,5}, + }, + {{end -}} +} +func TestFromArrowTensor(t *testing.T){ + assert := assert.New(t) + var rowMajorT *Dense + var colMajorT *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowTensorTests { + var rowMajorArr arrowArray.Interface + var colMajorArr arrowArray.Interface + var rowMajor arrowTensor.Interface + var colMajor arrowTensor.Interface + + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + rowMajorT = FromArrowTensor(rowMajor) + colMajorT = FromArrowTensor(colMajor) + + assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) + for i, invalid := range rowMajorT.Mask() { + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) + } + assert.True(colMajorT.Shape().Eq(taat.shape)) + + assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) + for i, invalid := range colMajorT.Mask() { + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) + } + assert.True(rowMajorT.Shape().Eq(taat.shape)) + } +} +` + +var ( + compatTests *template.Template + compatArrowArrayTests *template.Template + compatArrowTensorTests *template.Template +) + +func init() { + compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) + compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) + compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) +} + +func generateDenseCompatTests(f io.Writer, generic Kinds) { + // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming + // collisions + importsArrow.Execute(f, generic) + compatTests.Execute(f, generic) + arrowData := ArrowData{ + BinaryTypes: arrowBinaryTypes, + FixedWidthTypes: arrowFixedWidthTypes, + PrimitiveTypes: arrowPrimitiveTypes, + } + compatArrowArrayTests.Execute(f, arrowData) + compatArrowTensorTests.Execute(f, arrowData) +} diff --git a/genlib2/dense_cons.go b/genlib2/dense_cons.go index fee0df5..aa6bab8 100644 --- a/genlib2/dense_cons.go +++ b/genlib2/dense_cons.go @@ -6,7 +6,7 @@ import ( ) const onesRaw = `// Ones creates a *Dense with the provided shape and type -func Ones(dt Dtype, shape ...int) *Dense { +func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) switch d.t.Kind() { {{range .Kinds -}} @@ -48,7 +48,7 @@ const Iraw = `// I creates the identity matrix (usually a square) matrix with 1s // ⎢1 0 0 0⎥ // ⎢0 1 0 0⎥ // ⎣0 0 1 0⎦ -func I(dt Dtype, r, c, k int) *Dense{ +func I(dt dtype.Dtype, r, c, k int) *Dense{ ret := New(Of(dt), WithShape(r,c)) i := k if k < 0 { diff --git a/genlib2/dense_cons_tests.go b/genlib2/dense_cons_tests.go index 938d6fa..29d1366 100644 --- a/genlib2/dense_cons_tests.go +++ b/genlib2/dense_cons_tests.go @@ -1,85 +1,85 @@ -package main - -import ( - "io" - "text/template" -) - -const onesTestsRaw = `var onesTests = []struct { - of Dtype - shape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { {{asType . | title | strip}}, ScalarShape(), {{asType .}}(1)}, - { {{asType . | title | strip}}, Shape{2,2}, []{{asType .}}{1,1,1,1}}, - {{end -}} - {{end -}} - {Bool, ScalarShape(), true}, - {Bool, Shape{2,2}, []bool{true, true, true, true}}, -} - -func TestOnes(t *testing.T){ - assert := assert.New(t) - for _, ot := range onesTests{ - T := Ones(ot.of, ot.shape...) - assert.True(ot.shape.Eq(T.Shape())) - assert.Equal(ot.correct, T.Data()) - } -} -` - -const eyeTestsRaw = `// yes, it's a pun on eye tests, stop asking and go see your optometrist -var eyeTests = []struct{ - E Dtype - R, C, K int - - - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { {{asType . | title | strip}}, 4,4, 0, []{{asType .}}{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}}, - { {{asType . | title | strip}}, 4,4, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 2, []{{asType .}}{0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 3, []{{asType .}}{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -1, []{{asType .}}{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}}, - { {{asType . | title | strip}}, 4,4, -2, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -3, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,5, 0, []{{asType .}}{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}}, - { {{asType . | title | strip}}, 4,5, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1}}, - { {{asType . | title | strip}}, 4,5, -1, []{{asType .}}{0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}}, - {{end -}} - {{end -}} -} - -func TestI(t *testing.T){ - assert := assert.New(t) - var T Tensor - - for i, it := range eyeTests { - T = I(it.E, it.R, it.C, it.K) - assert.True(Shape{it.R, it.C}.Eq(T.Shape())) - assert.Equal(it.correct, T.Data(), "Test %d-R: %d, C: %d K: %d", i, it.R, it.C, it.K) - } - -} -` - -var ( - onesTests *template.Template - eyeTests *template.Template -) - -func init() { - onesTests = template.Must(template.New("onesTest").Funcs(funcs).Parse(onesTestsRaw)) - eyeTests = template.Must(template.New("eyeTest").Funcs(funcs).Parse(eyeTestsRaw)) -} - -func generateDenseConsTests(f io.Writer, generic Kinds) { - onesTests.Execute(f, generic) - eyeTests.Execute(f, generic) -} +package main + +import ( + "io" + "text/template" +) + +const onesTestsRaw = `var onesTests = []struct { + of dtype.Dtype + shape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { {{asType . | title | strip}}, ScalarShape(), {{asType .}}(1)}, + { {{asType . | title | strip}}, Shape{2,2}, []{{asType .}}{1,1,1,1}}, + {{end -}} + {{end -}} + {Bool, ScalarShape(), true}, + {Bool, Shape{2,2}, []bool{true, true, true, true}}, +} + +func TestOnes(t *testing.T){ + assert := assert.New(t) + for _, ot := range onesTests{ + T := Ones(ot.of, ot.shape...) + assert.True(ot.shape.Eq(T.Shape())) + assert.Equal(ot.correct, T.Data()) + } +} +` + +const eyeTestsRaw = `// yes, it's a pun on eye tests, stop asking and go see your optometrist +var eyeTests = []struct{ + E dtype.Dtype + R, C, K int + + + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { {{asType . | title | strip}}, 4,4, 0, []{{asType .}}{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}}, + { {{asType . | title | strip}}, 4,4, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 2, []{{asType .}}{0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 3, []{{asType .}}{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -1, []{{asType .}}{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}}, + { {{asType . | title | strip}}, 4,4, -2, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -3, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,5, 0, []{{asType .}}{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}}, + { {{asType . | title | strip}}, 4,5, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1}}, + { {{asType . | title | strip}}, 4,5, -1, []{{asType .}}{0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}}, + {{end -}} + {{end -}} +} + +func TestI(t *testing.T){ + assert := assert.New(t) + var T Tensor + + for i, it := range eyeTests { + T = I(it.E, it.R, it.C, it.K) + assert.True(Shape{it.R, it.C}.Eq(T.Shape())) + assert.Equal(it.correct, T.Data(), "Test %d-R: %d, C: %d K: %d", i, it.R, it.C, it.K) + } + +} +` + +var ( + onesTests *template.Template + eyeTests *template.Template +) + +func init() { + onesTests = template.Must(template.New("onesTest").Funcs(funcs).Parse(onesTestsRaw)) + eyeTests = template.Must(template.New("eyeTest").Funcs(funcs).Parse(eyeTestsRaw)) +} + +func generateDenseConsTests(f io.Writer, generic Kinds) { + onesTests.Execute(f, generic) + eyeTests.Execute(f, generic) +} diff --git a/genlib2/dense_getset_tests.go b/genlib2/dense_getset_tests.go index 15cc820..50bafb3 100644 --- a/genlib2/dense_getset_tests.go +++ b/genlib2/dense_getset_tests.go @@ -102,8 +102,8 @@ func makeZeroTests(generic Kinds) []testData { } const getTestRaw = `var denseSetGetTests = []struct { - of Dtype - data interface{} + of dtype.Dtype + data interface{} set interface{} correct []interface{} @@ -129,7 +129,7 @@ func TestDense_setget(t *testing.T) { ` const memsetTestRaw = `var denseMemsetTests = []struct{ - of Dtype + of dtype.Dtype data interface{} val interface{} shape Shape @@ -139,7 +139,7 @@ const memsetTestRaw = `var denseMemsetTests = []struct{ {{range . -}} {{$val := .Set -}} {{$k := .Kind -}} - { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, + { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, {{end -}} } @@ -159,7 +159,7 @@ func TestDense_memset(t *testing.T){ ` const zeroTestRaw = `var denseZeroTests = []struct{ - of Dtype + of dtype.Dtype data interface{} correct interface{} @@ -167,18 +167,18 @@ const zeroTestRaw = `var denseZeroTests = []struct{ {{range . -}} {{$val := .Set -}} {{$k := .Kind -}} - { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, + { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, {{end -}} } func TestDense_Zero(t *testing.T) { assert := assert.New(t) for _, mts := range denseZeroTests { - + typ := reflect.TypeOf(mts.data) val := reflect.ValueOf(mts.data) data := reflect.MakeSlice(typ, val.Len(), val.Cap()) - reflect.Copy(data, val) + reflect.Copy(data, val) T := New(Of(mts.of), WithBacking(data.Interface())) T.Zero() @@ -188,7 +188,7 @@ func TestDense_Zero(t *testing.T) { T2, _ := T.Slice(nil) T2.Zero() assert.Equal(mts.correct, T2.Data()) - } + } } ` diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 4a63ddd..513d657 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -348,7 +348,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. // If into is nil, then a backing slice will be created. -func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { +func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { {{range .Kinds -}} diff --git a/genlib2/dense_reduction_methods_tests.go b/genlib2/dense_reduction_methods_tests.go index 30defc9..342b1d4 100644 --- a/genlib2/dense_reduction_methods_tests.go +++ b/genlib2/dense_reduction_methods_tests.go @@ -1,164 +1,164 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const testDenseSumRaw = `var sumTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)}, - {"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}}, - {"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}}, - {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, - {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, - {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, - {"4T.Sum() for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)}, - {"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}}, - {"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}}, - {{end -}} - {{end -}} -} -func TestDense_Sum(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, sts := range sumTests { - T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize()))) - if T2, err = T.Sum(sts.along ...); err != nil { - t.Error(err) - continue - } - assert.True(sts.correctShape.Eq(T2.Shape())) - assert.Equal(sts.correct, T2.Data()) - } - - // idiots - _,err =T.Sum(1000) - assert.NotNil(err) -} -` - -const testDenseMaxRaw = `var maxTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - {"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)}, - {"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}}, - {"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}}, - {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, - {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, - {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, - {"4T.Max()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)}, - {"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}}, - {"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}}, - {{end -}} - {{end -}} - {{end -}} -} - -func TestDense_Max(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, mts := range maxTests { - T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) - if T2, err = T.Max(mts.along...); err != nil{ - t.Error(err) - continue - } - assert.True(mts.correctShape.Eq(T2.Shape())) - assert.Equal(mts.correct, T2.Data()) - } - /* IDIOT TESTING TIME */ - _, err = T.Max(1000) - assert.NotNil(err) -} -` - -const testDenseMinRaw = `var minTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - {"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)}, - {"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}}, - {"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}}, - {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, - {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, - {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, - {"4T.Min()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)}, - {"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}}, - {"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}}, - {{end -}} - {{end -}} - {{end -}} -} - -func TestDense_Min(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, mts := range minTests { - T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) - if T2, err = T.Min(mts.along...); err != nil{ - t.Error(err) - continue - } - assert.True(mts.correctShape.Eq(T2.Shape())) - assert.Equal(mts.correct, T2.Data()) - } - - /* IDIOT TESTING TIME */ - _, err = T.Min(1000) - assert.NotNil(err) -} -` - -var ( - testDenseSum *template.Template - testDenseMax *template.Template - testDenseMin *template.Template -) - -func init() { - testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw)) - testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw)) - testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw)) -} - -func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) { - testDenseSum.Execute(f, generic) - fmt.Fprint(f, "\n") - testDenseMax.Execute(f, generic) - fmt.Fprint(f, "\n") - testDenseMin.Execute(f, generic) -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const testDenseSumRaw = `var sumTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)}, + {"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}}, + {"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}}, + {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, + {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, + {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, + {"4T.Sum() for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)}, + {"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}}, + {{end -}} + {{end -}} +} +func TestDense_Sum(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, sts := range sumTests { + T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize()))) + if T2, err = T.Sum(sts.along ...); err != nil { + t.Error(err) + continue + } + assert.True(sts.correctShape.Eq(T2.Shape())) + assert.Equal(sts.correct, T2.Data()) + } + + // idiots + _,err =T.Sum(1000) + assert.NotNil(err) +} +` + +const testDenseMaxRaw = `var maxTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {{if isOrd . -}} + {"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)}, + {"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}}, + {"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}}, + {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, + {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, + {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, + {"4T.Max()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)}, + {"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}}, + {{end -}} + {{end -}} + {{end -}} +} + +func TestDense_Max(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, mts := range maxTests { + T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) + if T2, err = T.Max(mts.along...); err != nil{ + t.Error(err) + continue + } + assert.True(mts.correctShape.Eq(T2.Shape())) + assert.Equal(mts.correct, T2.Data()) + } + /* IDIOT TESTING TIME */ + _, err = T.Max(1000) + assert.NotNil(err) +} +` + +const testDenseMinRaw = `var minTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {{if isOrd . -}} + {"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)}, + {"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}}, + {"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}}, + {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, + {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, + {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, + {"4T.Min()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)}, + {"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}}, + {{end -}} + {{end -}} + {{end -}} +} + +func TestDense_Min(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, mts := range minTests { + T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) + if T2, err = T.Min(mts.along...); err != nil{ + t.Error(err) + continue + } + assert.True(mts.correctShape.Eq(T2.Shape())) + assert.Equal(mts.correct, T2.Data()) + } + + /* IDIOT TESTING TIME */ + _, err = T.Min(1000) + assert.NotNil(err) +} +` + +var ( + testDenseSum *template.Template + testDenseMax *template.Template + testDenseMin *template.Template +) + +func init() { + testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw)) + testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw)) + testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw)) +} + +func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) { + testDenseSum.Execute(f, generic) + fmt.Fprint(f, "\n") + testDenseMax.Execute(f, generic) + fmt.Fprint(f, "\n") + testDenseMin.Execute(f, generic) +} diff --git a/genlib2/dense_reduction_tests.go b/genlib2/dense_reduction_tests.go index 2c35efa..06f78c0 100644 --- a/genlib2/dense_reduction_tests.go +++ b/genlib2/dense_reduction_tests.go @@ -6,7 +6,7 @@ import ( ) const testDenseReduceRaw = `var denseReductionTests = []struct { - of Dtype + of dtype.Dtype fn interface{} def interface{} axis int diff --git a/genlib2/generic_utils.go b/genlib2/generic_utils.go index 7c207fa..8d5f87b 100644 --- a/genlib2/generic_utils.go +++ b/genlib2/generic_utils.go @@ -8,7 +8,7 @@ import ( const rangeRaw = `// Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i -func Range(dt Dtype, start, end int) interface{} { +func Range(dt dtype.Dtype, start, end int) interface{} { size := end - start incr := true if start > end { @@ -58,8 +58,8 @@ func Range(dt Dtype, start, end int) interface{} { const randomRaw = `// Random creates an array of random numbers of the given type. // For complex Dtypes, the imaginary component will be 0. // -// This function is only useful in cases where the randomness is not vital. -func Random(dt Dtype, size int) interface{} { +// This function is only useful in cases where the randomness is not vital. +func Random(dt dtype.Dtype, size int) interface{} { r := rand.New(rand.NewSource(1337)) switch dt.Kind() { {{range .Kinds -}} diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go index 565d9e9..a727253 100644 --- a/genlib2/native_iterator.go +++ b/genlib2/native_iterator.go @@ -6,7 +6,7 @@ import ( "text/template" ) -const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dtype) error { +const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { // checks: if !t.IsNativelyAccessible() { return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") diff --git a/genlib2/native_select.go b/genlib2/native_select.go index 6b1e277..a05ce18 100644 --- a/genlib2/native_select.go +++ b/genlib2/native_select.go @@ -6,7 +6,7 @@ import ( "text/template" ) -const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { +const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { if !t.IsNativelyAccessible() { return errors.New("Cannot select on non-natively accessible data") } diff --git a/genlib2/testutils.go b/genlib2/testutils.go index 177333f..c7dbe81 100644 --- a/genlib2/testutils.go +++ b/genlib2/testutils.go @@ -90,7 +90,7 @@ const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { s[i] = randomString() {{else if eq .String "unsafe.Pointer" -}} s[i] = nil - {{end -}} + {{end -}} } {{end -}} {{end -}} @@ -99,7 +99,7 @@ const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { ` const testQCRaw = `type QCDense{{short .}} struct { - *Dense + *Dense } func (*QCDense{{short .}}) Generate(r *rand.Rand, size int) reflect.Value { s := make([]{{asType .}}, size) @@ -137,11 +137,11 @@ const mutateFnsRaw = `func mutate{{short .}}(a {{asType . }}){{asType .}} { {{if {{else if eq .String "bool" -}}return true } {{else if eq .String "string" -}}return "Hello World"} {{else if eq .String "uintptr" -}}return 0xdeadbeef} -{{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} -{{end -}} +{{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} +{{end -}} ` -const identityValsRaw = `func identityVal(x int, dt Dtype) interface{} { +const identityValsRaw = `func identityVal(x int, dt dtype.Dtype) interface{} { switch dt { {{range .Kinds -}} case {{reflectKind .}}: diff --git a/go.mod b/go.mod index 7106ca9..0aea516 100644 --- a/go.mod +++ b/go.mod @@ -2,18 +2,19 @@ module gorgonia.org/tensor go 1.13 +replace gorgonia.org/dtype => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/dtype + require ( github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc - github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.6 github.com/gogo/protobuf v1.3.1 github.com/golang/protobuf v1.4.3 github.com/google/flatbuffers v1.12.0 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.1 - github.com/xtgo/set v1.0.0 // indirect + github.com/stretchr/testify v1.7.0 go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 gonum.org/v1/gonum v0.8.2 + gorgonia.org/dtype v0.0.0-00010101000000-000000000000 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) diff --git a/go.sum b/go.sum index 21b3359..1fbe637 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGw github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -33,11 +32,9 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/flatbuffers v1.12.0 h1:/PtAHvnBY4Kqnx/xCQ3OIV9uYcSFGScBsWI3Oogeh6w= github.com/google/flatbuffers v1.12.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= @@ -58,8 +55,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= diff --git a/interfaces.go b/interfaces.go index 345698a..5327583 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,12 +3,13 @@ package tensor import ( "reflect" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // Dtyper is any type that has a Dtype type Dtyper interface { - Dtype() Dtype + Dtype() dtype.Dtype } // Eq is any type where you can perform an equality test diff --git a/native/iterator_native.go b/native/iterator_native.go index d9727fe..b820159 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -7,10 +7,11 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" . "gorgonia.org/tensor" ) -func checkNativeIterable(t *Dense, dims int, dt Dtype) error { +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { // checks: if !t.IsNativelyAccessible() { return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") diff --git a/native/iterator_native2.go b/native/iterator_native2.go index 934863d..d47bfb3 100644 --- a/native/iterator_native2.go +++ b/native/iterator_native2.go @@ -7,10 +7,11 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" . "gorgonia.org/tensor" ) -func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { if !t.IsNativelyAccessible() { return errors.New("Cannot select on non-natively accessible data") } diff --git a/optimizations_test.go b/optimizations_test.go index 18bb677..9b8afcb 100644 --- a/optimizations_test.go +++ b/optimizations_test.go @@ -1,15 +1,15 @@ -package tensor - -import ( - "testing" -) - -// this file contains tests to make sure certain algorithms/optimizations aren't crazy - -func TestRequiresIterator(t *testing.T) { - T := New(Of(Int), WithBacking([]int{1, 2, 3, 4})) - sliced, _ := T.Slice(makeRS(1, 3)) - if sliced.RequiresIterator() { - t.Errorf("Slicing on rows should not require Iterator") - } -} +package tensor + +import ( + "testing" +) + +// this file contains tests to make sure certain algorithms/optimizations aren't crazy + +func TestRequiresIterator(t *testing.T) { + T := New(Of(Int), WithBacking([]int{1, 2, 3, 4})) + sliced, _ := T.Slice(makeRS(1, 3)) + if sliced.RequiresIterator() { + t.Errorf("Slicing on rows should not require Iterator") + } +} diff --git a/perf.go b/perf.go index bc5c3aa..200aeac 100644 --- a/perf.go +++ b/perf.go @@ -89,7 +89,7 @@ func ReturnTensor(t Tensor) { } // array reset - tt.t = Dtype{} + tt.t = dtype.Dtype{} tt.array.Header.Raw = nil // engine and flag reset @@ -262,7 +262,7 @@ func returnOpOpt(oo *OpOpt) { oo.incr = nil oo.unsafe = false oo.same = false - oo.t = Dtype{} + oo.t = dtype.Dtype{} // if len(optPool) < cap(optPool) { // optPool <- oo // } diff --git a/scalar.go b/scalar.go index 3721cbc..8c7c025 100644 --- a/scalar.go +++ b/scalar.go @@ -40,7 +40,7 @@ func MakeScalar(v interface{}) Scalar { func (s Scalar) Shape() Shape { return ScalarShape() } func (s Scalar) Strides() []int { return nil } -func (s Scalar) Dtype() Dtype { return Dtype{reflect.TypeOf(s.v)} } +func (s Scalar) Dtype() dtype.Dtype { return dtype.Dtype{reflect.TypeOf(s.v)} } func (s Scalar) Dims() int { return 0 } func (s Scalar) Size() int { return 0 } // TODO func (s Scalar) DataSize() int { return 0 } diff --git a/sparse.go b/sparse.go index 2710602..9d4884f 100644 --- a/sparse.go +++ b/sparse.go @@ -6,6 +6,7 @@ import ( "sort" "github.com/pkg/errors" + "gorgonia.org/dtype" ) var ( @@ -183,7 +184,7 @@ func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { func (t *CS) Shape() Shape { return t.s } func (t *CS) Strides() []int { return nil } -func (t *CS) Dtype() Dtype { return t.t } +func (t *CS) Dtype() dtype.Dtype { return t.t } func (t *CS) Dims() int { return 2 } func (t *CS) Size() int { return t.s.TotalSize() } func (t *CS) DataSize() int { return t.Len() } diff --git a/sparse_test.go b/sparse_test.go index 86cdad1..34b22dd 100644 --- a/sparse_test.go +++ b/sparse_test.go @@ -1,105 +1,105 @@ -package tensor - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCS_Basics(t *testing.T) { - assert := assert.New(t) - xs0 := []int{1, 2, 6, 8} - ys0 := []int{1, 2, 1, 6} - xs1 := []int{1, 2, 6, 8} - ys1 := []int{1, 2, 1, 6} - vals0 := []float64{3, 1, 4, 1} - vals1 := []float64{3, 1, 4, 1} - - var T0, T1 *CS - var d0, d1 *Dense - var dp0, dp1 *Dense - var err error - fails := func() { - CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) - } - assert.Panics(fails) - - // Test CSC - T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) - d0 = T0.Dense() - T0.T() - dp0 = T0.Dense() - T0.UT() // untranspose as Materialize() will be called below - - // Test CSR - fails = func() { - CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) - } - T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) - d1 = T1.Dense() - T1.T() - dp1 = T1.Dense() - T1.UT() - - t.Logf("%v %v", T0.indptr, T0.indices) - t.Logf("%v %v", T1.indptr, T1.indices) - - assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) - assert.True(dp0.Eq(dp1)) - assert.True(T1.Eq(T1)) - assert.False(T0.Eq(T1)) - - // At - var got interface{} - correct := float64(3.0) - if got, err = T0.At(1, 1); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) - } - if got, err = T1.At(1, 1); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) - } - - correct = 0.0 - if got, err = T0.At(3, 3); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) - } - - if got, err = T1.At(3, 3); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) - } - - // Test clone - T2 := T0.Clone() - assert.True(T0.Eq(T2)) - - // Scalar representation - assert.False(T0.IsScalar()) - fails = func() { - T0.ScalarValue() - } - assert.Panics(fails) - assert.Equal(len(vals0), T0.NonZeroes()) - - // Sparse Iterator - it := T0.Iterator() - var valids []int - correctValids := []int{0, 2, 1, 3} - for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { - if valid { - valids = append(valids, i) - } - } - assert.Equal(correctValids, valids) -} +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCS_Basics(t *testing.T) { + assert := assert.New(t) + xs0 := []int{1, 2, 6, 8} + ys0 := []int{1, 2, 1, 6} + xs1 := []int{1, 2, 6, 8} + ys1 := []int{1, 2, 1, 6} + vals0 := []float64{3, 1, 4, 1} + vals1 := []float64{3, 1, 4, 1} + + var T0, T1 *CS + var d0, d1 *Dense + var dp0, dp1 *Dense + var err error + fails := func() { + CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) + } + assert.Panics(fails) + + // Test CSC + T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) + d0 = T0.Dense() + T0.T() + dp0 = T0.Dense() + T0.UT() // untranspose as Materialize() will be called below + + // Test CSR + fails = func() { + CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) + } + T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) + d1 = T1.Dense() + T1.T() + dp1 = T1.Dense() + T1.UT() + + t.Logf("%v %v", T0.indptr, T0.indices) + t.Logf("%v %v", T1.indptr, T1.indices) + + assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) + assert.True(dp0.Eq(dp1)) + assert.True(T1.Eq(T1)) + assert.False(T0.Eq(T1)) + + // At + var got interface{} + correct := float64(3.0) + if got, err = T0.At(1, 1); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) + } + if got, err = T1.At(1, 1); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) + } + + correct = 0.0 + if got, err = T0.At(3, 3); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) + } + + if got, err = T1.At(3, 3); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) + } + + // Test clone + T2 := T0.Clone() + assert.True(T0.Eq(T2)) + + // Scalar representation + assert.False(T0.IsScalar()) + fails = func() { + T0.ScalarValue() + } + assert.Panics(fails) + assert.Equal(len(vals0), T0.NonZeroes()) + + // Sparse Iterator + it := T0.Iterator() + var valids []int + correctValids := []int{0, 2, 1, 3} + for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { + if valid { + valids = append(valids, i) + } + } + assert.Equal(correctValids, valids) +} diff --git a/tensor.go b/tensor.go index 24be135..5b60671 100644 --- a/tensor.go +++ b/tensor.go @@ -6,6 +6,7 @@ import ( "encoding/gob" "github.com/pkg/errors" + "gorgonia.org/dtype" ) var ( @@ -24,7 +25,8 @@ type Desc interface { // info about the ndarray Shape() Shape Strides() []int - Dtype() Dtype + Dtype() dtype.Dtype + Dims() int Size() int DataSize() int diff --git a/test_test.go b/test_test.go index 5f76d8a..772a71f 100644 --- a/test_test.go +++ b/test_test.go @@ -9,6 +9,7 @@ import ( "unsafe" "github.com/chewxy/math32" + "gorgonia.org/dtype" ) func anyToFloat64s(x interface{}) (retVal []float64) { @@ -120,7 +121,7 @@ func anyToFloat64s(x interface{}) (retVal []float64) { panic("Unreachable") } -func identityVal(x int, dt Dtype) interface{} { +func identityVal(x int, dt dtype.Dtype) interface{} { switch dt { case Int: return int(x) diff --git a/testutils_test.go b/testutils_test.go index 3a0d466..58814bf 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -14,6 +14,8 @@ import ( "github.com/chewxy/math32" "gorgonia.org/tensor/internal/storage" + + "gorgonia.org/dtype" ) func randomBool() bool { @@ -330,7 +332,7 @@ func shuffleInts(a []int, r *rand.Rand) { type TensorGenerator struct { ShapeConstraint Shape - DtypeConstraint Dtype + DtypeConstraint dtype.Dtype } func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { @@ -539,14 +541,14 @@ func qcErrCheck(t *testing.T, name string, a Dtyper, b interface{}, we bool, err return nil, false } -func qcIsFloat(dt Dtype) bool { +func qcIsFloat(dt dtype.Dtype) bool { if err := typeclassCheck(dt, floatcmplxTypes); err == nil { return true } return false } -func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{}) bool { +func qcEqCheck(t *testing.T, dt dtype.Dtype, willFailEq bool, correct, got interface{}) bool { isFloatTypes := qcIsFloat(dt) if !willFailEq && (isFloatTypes && !allClose(correct, got) || (!isFloatTypes && !reflect.DeepEqual(correct, got))) { t.Errorf("q.Dtype: %v", dt) diff --git a/type_test.go b/type_test.go index d616b8f..12c5678 100644 --- a/type_test.go +++ b/type_test.go @@ -3,12 +3,14 @@ package tensor import ( "reflect" "testing" + + "gorgonia.org/dtype" ) type Float16 uint16 func TestRegisterType(t *testing.T) { - dt := Dtype{reflect.TypeOf(Float16(0))} + dt := dtype.Dtype{reflect.TypeOf(Float16(0))} RegisterFloat(dt) if err := typeclassCheck(dt, floatTypes); err != nil { @@ -34,7 +36,7 @@ func TestDtypeConversions(t *testing.T) { t.Errorf("Error: %v", err) } } - dt := Dtype{reflect.TypeOf(Float16(0))} + dt := dtype.Dtype{reflect.TypeOf(Float16(0))} if _, err := dt.numpyDtype(); err == nil { t.Errorf("Expected an error when passing in type unknown to np") } diff --git a/types.go b/types.go index 69740cf..2af1b72 100644 --- a/types.go +++ b/types.go @@ -4,101 +4,10 @@ import ( "fmt" "math" "reflect" - "unsafe" - "github.com/chewxy/hm" - "github.com/pkg/errors" + "gorgonia.org/dtype" ) -// Dtype represents a data type of a Tensor. Concretely it's implemented as an embedded reflect.Type -// which allows for easy reflection operations. It also implements hm.Type, for type inference in Gorgonia -type Dtype struct { - reflect.Type -} - -// note: the Name() and String() methods are already defined in reflect.Type. Might as well use the composed methods - -func (dt Dtype) Apply(hm.Subs) hm.Substitutable { return dt } -func (dt Dtype) FreeTypeVar() hm.TypeVarSet { return nil } -func (dt Dtype) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { return dt, nil } -func (dt Dtype) Types() hm.Types { return nil } -func (dt Dtype) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%s", dt.Name()) } -func (dt Dtype) Eq(other hm.Type) bool { return other == dt } - -var numpyDtypes map[Dtype]string -var reverseNumpyDtypes map[string]Dtype - -func init() { - numpyDtypes = map[Dtype]string{ - Bool: "b1", - Int: fmt.Sprintf("i%d", Int.Size()), - Int8: "i1", - Int16: "i2", - Int32: "i4", - Int64: "i8", - Uint: fmt.Sprintf("u%d", Uint.Size()), - Uint8: "u1", - Uint16: "u2", - Uint32: "u4", - Uint64: "u8", - Float32: "f4", - Float64: "f8", - Complex64: "c8", - Complex128: "c16", - } - - reverseNumpyDtypes = map[string]Dtype{ - "b1": Bool, - "i1": Int8, - "i2": Int16, - "i4": Int32, - "i8": Int64, - "u1": Uint8, - "u2": Uint16, - "u4": Uint32, - "u8": Uint64, - "f4": Float32, - "f8": Float64, - "c8": Complex64, - "c16": Complex128, - } -} - -// NumpyDtype returns the Numpy's Dtype equivalent. This is predominantly used in converting a Tensor to a Numpy ndarray, -// however, not all Dtypes are supported -func (dt Dtype) numpyDtype() (string, error) { - retVal, ok := numpyDtypes[dt] - if !ok { - return "v", errors.Errorf("Unsupported Dtype conversion to Numpy Dtype: %v", dt) - } - return retVal, nil -} - -func fromNumpyDtype(t string) (Dtype, error) { - retVal, ok := reverseNumpyDtypes[t] - if !ok { - return Dtype{}, errors.Errorf("Unsupported Dtype conversion from %q to Dtype", t) - } - if t == "i4" && Int.Size() == 4 { - return Int, nil - } - if t == "i8" && Int.Size() == 8 { - return Int, nil - } - if t == "u4" && Uint.Size() == 4 { - return Uint, nil - } - if t == "u8" && Uint.Size() == 8 { - return Uint, nil - } - return retVal, nil -} - -type typeclass struct { - name string - set []Dtype -} - var parameterizedKinds = [...]reflect.Kind{ reflect.Array, reflect.Chan, @@ -119,227 +28,29 @@ func isParameterizedKind(k reflect.Kind) bool { return false } -// oh how nice it'd be if I could make them immutable +// type aliases var ( - Bool = Dtype{reflect.TypeOf(true)} - Int = Dtype{reflect.TypeOf(int(1))} - Int8 = Dtype{reflect.TypeOf(int8(1))} - Int16 = Dtype{reflect.TypeOf(int16(1))} - Int32 = Dtype{reflect.TypeOf(int32(1))} - Int64 = Dtype{reflect.TypeOf(int64(1))} - Uint = Dtype{reflect.TypeOf(uint(1))} - Uint8 = Dtype{reflect.TypeOf(uint8(1))} - Uint16 = Dtype{reflect.TypeOf(uint16(1))} - Uint32 = Dtype{reflect.TypeOf(uint32(1))} - Uint64 = Dtype{reflect.TypeOf(uint64(1))} - Float32 = Dtype{reflect.TypeOf(float32(1))} - Float64 = Dtype{reflect.TypeOf(float64(1))} - Complex64 = Dtype{reflect.TypeOf(complex64(1))} - Complex128 = Dtype{reflect.TypeOf(complex128(1))} - String = Dtype{reflect.TypeOf("")} - - // aliases - Byte = Uint8 - - // extras - Uintptr = Dtype{reflect.TypeOf(uintptr(0))} - UnsafePointer = Dtype{reflect.TypeOf(unsafe.Pointer(&Uintptr))} + Bool = dtype.Bool + Int = dtype.Int + Int8 = dtype.Int8 + Int16 = dtype.Int16 + Int32 = dtype.Int32 + Int64 = dtype.Int64 + Uint = dtype.Uint + Uint8 = dtype.Uint8 + Uint16 = dtype.Uint16 + Uint32 = dtype.Uint32 + Uint64 = dtype.Uint64 + Float32 = dtype.Float32 + Float64 = dtype.Float64 + Complex64 = dtype.Complex64 + Complex128 = dtype.Complex128 + String = dtype.String + Byte = dtype.Byte + Uintptr = dtype.Uintptr + UnsafePointer = dtype.UnsafePointer ) -// allTypes for indexing -var allTypes = &typeclass{ - name: "τ", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, - }, -} - -// specialized types indicate that there are specialized code generated for these types -var specializedTypes = &typeclass{ - name: "Specialized", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, - }, -} - -var addableTypes = &typeclass{ - name: "Addable", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, - }, -} - -var numberTypes = &typeclass{ - name: "Number", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, - }, -} - -var ordTypes = &typeclass{ - name: "Ord", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, - }, -} - -var eqTypes = &typeclass{ - name: "Eq", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, - }, -} - -var unsignedTypes = &typeclass{ - name: "Unsigned", - set: []Dtype{Uint, Uint8, Uint16, Uint32, Uint64}, -} - -var signedTypes = &typeclass{ - name: "Signed", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Float32, Float64, Complex64, Complex128, - }, -} - -// this typeclass is ever only used by Sub tests -var signedNonComplexTypes = &typeclass{ - name: "Signed NonComplex", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Float32, Float64, - }, -} - -var floatTypes = &typeclass{ - name: "Float", - set: []Dtype{ - Float32, Float64, - }, -} - -var complexTypes = &typeclass{ - name: "Complex Numbers", - set: []Dtype{Complex64, Complex128}, -} - -var floatcmplxTypes = &typeclass{ - name: "Real", - set: []Dtype{ - Float32, Float64, Complex64, Complex128, - }, -} - -var nonComplexNumberTypes = &typeclass{ - name: "Non complex numbers", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, - }, -} - -// this typeclass is ever only used by Pow tests -var generatableTypes = &typeclass{ - name: "Generatable types", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, - }, -} - -func isFloat(dt Dtype) bool { - return dt == Float64 || dt == Float32 -} - -func typeclassCheck(a Dtype, tc *typeclass) error { - if tc == nil { - return nil - } - for _, s := range tc.set { - if s == a { - return nil - } - } - return errors.Errorf("Type %v is not a member of %v", a, tc.name) -} - -// RegisterNumber is a function required to register a new numerical Dtype. -// This package provides the following Dtype: -// Int -// Int8 -// Int16 -// Int32 -// Int64 -// Uint -// Uint8 -// Uint16 -// Uint32 -// Uint64 -// Float32 -// Float64 -// Complex64 -// Complex128 -// -// If a Dtype that is registered already exists on the list, it will not be added to the list. -func RegisterNumber(a Dtype) { - for _, dt := range numberTypes.set { - if dt == a { - return - } - } - numberTypes.set = append(numberTypes.set, a) - RegisterEq(a) -} - -func RegisterFloat(a Dtype) { - for _, dt := range floatTypes.set { - if dt == a { - return - } - } - floatTypes.set = append(floatTypes.set, a) - RegisterNumber(a) - RegisterOrd(a) -} - -// RegisterOrd registers a dtype as a type that can be typed -func RegisterOrd(a Dtype) { - for _, dt := range ordTypes.set { - if dt == a { - return - } - } - ordTypes.set = append(ordTypes.set, a) - RegisterEq(a) -} - -// RegisterEq registers a dtype as a type that can be compared for equality -func RegisterEq(a Dtype) { - for _, dt := range eqTypes.set { - if dt == a { - return - } - } - eqTypes.set = append(eqTypes.set, a) - Register(a) -} - -// Register registers a new Dtype -func Register(a Dtype) { - for _, dt := range allTypes.set { - if a == dt { - return - } - } - allTypes.set = append(allTypes.set, a) -} - -func dtypeID(a Dtype) int { - for i, v := range allTypes.set { - if a == v { - return i - } - } - return -1 -} - // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte. // But there are norm types that are outside numerical types, such as nuclear norm and fobenius norm. // So it is internally represented by a float. If Go could use NaN and Inf as consts, it would have been best, @@ -455,7 +166,7 @@ func AsSameType() FuncOpt { } // As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 -func As(t Dtype) FuncOpt { +func As(t dtype.Dtype) FuncOpt { f := func(opt *OpOpt) { opt.t = t } From 8f33ec22a1fd6d8a8333dec733242db1c6d7ef44 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 9 Jul 2021 12:11:10 +1000 Subject: [PATCH 086/154] More work to move Dtype related stuff into its own package Fixed all the relevant generators --- consopt.go | 2 +- genlib2/agg2_body.go | 6 ++--- genlib2/arith_tests.go | 14 +++++----- genlib2/declarations.go | 60 ++++++++++++++++++++--------------------- genlib2/dense_io.go | 4 +-- genlib2/engine.go | 24 ++++++++--------- genlib2/unary_tests.go | 4 +-- perf.go | 1 + tensor.go | 4 +-- testutils_test.go | 12 ++++++--- type_test.go | 60 ----------------------------------------- 11 files changed, 69 insertions(+), 122 deletions(-) diff --git a/consopt.go b/consopt.go index ab4135d..1d263b0 100644 --- a/consopt.go +++ b/consopt.go @@ -12,7 +12,7 @@ type ConsOpt func(Tensor) // Of is a construction option for a Tensor. func Of(a dtype.Dtype) ConsOpt { - Register(a) + dtype.Register(a) f := func(t Tensor) { switch tt := t.(type) { case *Dense: diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 1e16123..41220ef 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -19,7 +19,7 @@ const arithPrepRaw = `var safe, toReuse, incr bool } ` -const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); err != nil { +const prepVVRaw = `if err = binaryCheck(a, b, dtype.{{.TypeClassCheck}}); err != nil { return nil, errors.Wrapf(err, "{{.Name}} failed") } @@ -36,7 +36,7 @@ const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); } ` -const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); err != nil { +const prepMixedRaw = `if err = unaryCheck(t, dtype.{{.TypeClassCheck}}); err != nil { return nil, errors.Wrapf(err, "{{.Name}} failed") } @@ -67,7 +67,7 @@ const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); ` -const prepUnaryRaw = `if err = unaryCheck(a, {{.TypeClassCheck | lower}}Types); err != nil { +const prepUnaryRaw = `if err = unaryCheck(a, dtype.{{.TypeClassCheck}}); err != nil { err = errors.Wrapf(err, "{{.Name}} failed") return } diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index 369cc0f..6d0bd88 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -176,10 +176,10 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: API, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "", } if t.name == "Pow" { - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complex" } tests = append(tests, t) } @@ -223,9 +223,9 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { } switch t.name { case "Pow": - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complex" case "Sub": - t.EqFailTypeClassName = "unsignedTypes" + t.EqFailTypeClassName = "dtype.Unsigned" } tests = append(tests, t) } @@ -267,7 +267,7 @@ func generateDenseMethodArithTests(f io.Writer, ak Kinds) { EqFailTypeClassName: "nil", } if t.name == "Pow" { - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complex" } tests = append(tests, t) } @@ -311,9 +311,9 @@ func generateDenseMethodScalarTests(f io.Writer, ak Kinds) { } switch t.name { case "Pow": - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complex" case "Sub": - t.EqFailTypeClassName = "unsignedTypes" + t.EqFailTypeClassName = "dtype.Unsigned" } tests = append(tests, t) } diff --git a/genlib2/declarations.go b/genlib2/declarations.go index 7bcd6bc..130794a 100644 --- a/genlib2/declarations.go +++ b/genlib2/declarations.go @@ -25,7 +25,7 @@ var cmpSymbolTemplates = [...]string{ } var nonFloatConditionalUnarySymbolTemplates = [...]string{ - `{{if isFloat .Kind -}} + `{{if isFloat .Kind -}} {{.Range}}[{{.Index0}}] = {{mathPkg .Kind}}Abs({{.Range}}[{{.Index0}}]) {{else -}} if {{.Range}}[{{.Index0}}] < 0 { {{.Range}}[{{.Index0}}] = -{{.Range}}[{{.Index0}}] @@ -85,7 +85,7 @@ var funcOptDecl = map[string]string{ "reuse": "reuse := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", "incr": "incr := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", "unsafe": "", - "assame": `if err := typeclassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { + "assame": `if err := dtype.TypeClassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { return true // we exit early if the generated type is not something we can handle } `, @@ -427,51 +427,51 @@ func init() { // ops arithBinOps = []arithOp{ - {basicBinOp{"", "Add", false, isAddable}, "numberTypes", true, 0, false, "", true, false}, - {basicBinOp{"", "Sub", false, isNumber}, "numberTypes", false, 0, true, "Add", false, true}, - {basicBinOp{"", "Mul", false, isNumber}, "numberTypes", true, 1, false, "", true, false}, - {basicBinOp{"", "Div", false, isNumber}, "numberTypes", false, 1, true, "Mul", false, false}, - {basicBinOp{"", "Pow", true, isFloatCmplx}, "floatcmplxTypes", true, 1, false, "", false, false}, - {basicBinOp{"", "Mod", false, isNonComplexNumber}, "nonComplexNumberTypes", false, 0, false, "", false, false}, + {basicBinOp{"", "Add", false, isAddable}, "dtype.Number", true, 0, false, "", true, false}, + {basicBinOp{"", "Sub", false, isNumber}, "dtype.Number", false, 0, true, "Add", false, true}, + {basicBinOp{"", "Mul", false, isNumber}, "dtype.Number", true, 1, false, "", true, false}, + {basicBinOp{"", "Div", false, isNumber}, "dtype.Number", false, 1, true, "Mul", false, false}, + {basicBinOp{"", "Pow", true, isFloatCmplx}, "dtype.FloatComplex", true, 1, false, "", false, false}, + {basicBinOp{"", "Mod", false, isNonComplexNumber}, "dtype.NonComplexNumber", false, 0, false, "", false, false}, } for i := range arithBinOps { arithBinOps[i].symbol = arithSymbolTemplates[i] } cmpBinOps = []cmpOp{ - {basicBinOp{"", "Gt", false, isOrd}, "ordTypes", "Lt", true, false}, - {basicBinOp{"", "Gte", false, isOrd}, "ordTypes", "Lte", true, false}, - {basicBinOp{"", "Lt", false, isOrd}, "ordTypes", "Gt", true, false}, - {basicBinOp{"", "Lte", false, isOrd}, "ordTypes", "Gte", true, false}, - {basicBinOp{"", "Eq", false, isEq}, "eqTypes", "Eq", true, true}, - {basicBinOp{"", "Ne", false, isEq}, "eqTypes", "Ne", false, true}, + {basicBinOp{"", "Gt", false, isOrd}, "dtype.Ord", "Lt", true, false}, + {basicBinOp{"", "Gte", false, isOrd}, "dtype.Ord", "Lte", true, false}, + {basicBinOp{"", "Lt", false, isOrd}, "dtype.Ord", "Gt", true, false}, + {basicBinOp{"", "Lte", false, isOrd}, "dtype.Ord", "Gte", true, false}, + {basicBinOp{"", "Eq", false, isEq}, "dtype.Eq", "Eq", true, true}, + {basicBinOp{"", "Ne", false, isEq}, "dtype.Eq", "Ne", false, true}, } for i := range cmpBinOps { cmpBinOps[i].symbol = cmpSymbolTemplates[i] } conditionalUnaries = []unaryOp{ - {"", "Abs", false, isSignedNumber, "signedTypes", ""}, - {"", "Sign", false, isSignedNumber, "signedTypes", ""}, + {"", "Abs", false, isSignedNumber, "dtype.Signed", ""}, + {"", "Sign", false, isSignedNumber, "dtype.Signed", ""}, } for i := range conditionalUnaries { conditionalUnaries[i].symbol = nonFloatConditionalUnarySymbolTemplates[i] } unconditionalUnaries = []unaryOp{ - {"", "Neg", false, isNumber, "numberTypes", "Neg"}, - {"", "Inv", false, isNumber, "numberTypes", ""}, - {"", "Square", false, isNumber, "numberTypes", "Sqrt"}, - {"", "Cube", false, isNumber, "numberTypes", "Cbrt"}, - - {"", "Exp", true, isFloatCmplx, "floatcmplxTypes", "Log"}, - {"", "Tanh", true, isFloatCmplx, "floatcmplxTypes", ""}, - {"", "Log", true, isFloatCmplx, "floatcmplxTypes", "Exp"}, - {"", "Log2", true, isFloat, "floatTypes", ""}, - {"", "Log10", true, isFloatCmplx, "floatcmplxTypes", ""}, - {"", "Sqrt", true, isFloatCmplx, "floatcmplxTypes", "Square"}, - {"", "Cbrt", true, isFloat, "floatTypes", "Cube"}, - {"", "InvSqrt", true, isFloat, "floatTypes", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later + {"", "Neg", false, isNumber, "dtype.Number", "Neg"}, + {"", "Inv", false, isNumber, "dtype.Number", ""}, + {"", "Square", false, isNumber, "dtype.Number", "Sqrt"}, + {"", "Cube", false, isNumber, "dtype.Number", "Cbrt"}, + + {"", "Exp", true, isFloatCmplx, "dtype.FloatComplex", "Log"}, + {"", "Tanh", true, isFloatCmplx, "dtype.FloatComplex", ""}, + {"", "Log", true, isFloatCmplx, "dtype.FloatComplex", "Exp"}, + {"", "Log2", true, isFloat, "dtype.Floats", ""}, + {"", "Log10", true, isFloatCmplx, "dtype.FloatComplex", ""}, + {"", "Sqrt", true, isFloatCmplx, "dtype.FloatComplex", "Square"}, + {"", "Cbrt", true, isFloat, "dtype.Floats", "Cube"}, + {"", "InvSqrt", true, isFloat, "dtype.Floats", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later } nonF := len(unconditionalNumUnarySymbolTemplates) for i := range unconditionalNumUnarySymbolTemplates { @@ -482,7 +482,7 @@ func init() { } specialUnaries = []UnaryOp{ - specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "nonComplexNumberTypes", ""}, []string{"min", "max"}}, + specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "dtype.NonComplexNumber", ""}, []string{"min", "max"}}, } // typed operations diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 513d657..fd51c46 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -63,7 +63,7 @@ func (r *binaryReader) Err() error { // If tensor is masked, invalid values are replaced by the default fill value. func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string - if npdt, err = t.t.numpyDtype(); err != nil{ + if npdt, err = t.t.NumpyDtype(); err != nil{ return } @@ -290,7 +290,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ } // TODO: check for endianness. For now we assume everything is little endian - if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = dtype.FromNumpyDtype(string(match[1][1:])); err != nil { return } diff --git a/genlib2/engine.go b/genlib2/engine.go index 8792e2e..6323765 100644 --- a/genlib2/engine.go +++ b/genlib2/engine.go @@ -269,18 +269,18 @@ func (fn *EngineUnary) Write(w io.Writer) { func generateStdEngUncondUnary(f io.Writer, ak Kinds) { tcc := []string{ - "Number", // Neg - "Number", // Inv - "Number", // Square - "Number", // Cube - "FloatCmplx", // Exp - "FloatCmplx", // Tanhh - "FloatCmplx", // Log - "Float", // Log2 - "FloatCmplx", // Log10 - "FloatCmplx", // Sqrt - "Float", // Cbrt - "Float", // InvSqrt + "Number", // Neg + "Number", // Inv + "Number", // Square + "Number", // Cube + "FloatComplex", // Exp + "FloatComplex", // Tanhh + "FloatComplex", // Log + "Float", // Log2 + "FloatComplex", // Log10 + "FloatComplex", // Sqrt + "Float", // Cbrt + "Float", // InvSqrt } var gen []*EngineUnary for i, u := range unconditionalUnaries { diff --git a/genlib2/unary_tests.go b/genlib2/unary_tests.go index dedd02d..20fa176 100644 --- a/genlib2/unary_tests.go +++ b/genlib2/unary_tests.go @@ -15,7 +15,7 @@ const unaryTestBodyRaw = `invFn := func(q *Dense) bool { we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - + ret, err := {{.Name}}(a {{template "funcoptuse"}}) if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ if err != nil { @@ -24,7 +24,7 @@ const unaryTestBodyRaw = `invFn := func(q *Dense) bool { return true } {{if ne .InvTypeClass "" -}} - if err := typeclassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { return true // uninvertible due to type class implementation issues } {{end -}} diff --git a/perf.go b/perf.go index 200aeac..4d5ffd7 100644 --- a/perf.go +++ b/perf.go @@ -4,6 +4,7 @@ import ( "runtime" "sync" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) diff --git a/tensor.go b/tensor.go index 5b60671..8445a39 100644 --- a/tensor.go +++ b/tensor.go @@ -138,7 +138,7 @@ func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) { if t == nil { return } - if err = typeclassCheck(t.Dtype(), floatTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.Floats); err != nil { err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype()) return } @@ -159,7 +159,7 @@ func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { if t == nil { return } - if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype()) return } diff --git a/testutils_test.go b/testutils_test.go index 58814bf..4490102 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -507,11 +507,17 @@ func (e dummyEngine2) WorksWith(order DataOrder) bool { return e.e.Wor func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) } func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) } -func willerr(a *Dense, tc, eqtc *typeclass) (retVal, willFailEq bool) { - if err := typeclassCheck(a.Dtype(), eqtc); err == nil { +var nilTC dtype.TypeClass = -1 + +func willerr(a *Dense, tc, eqtc dtype.TypeClass) (retVal, willFailEq bool) { + if eqtc == nilTC { willFailEq = true + } else { + if err := dtype.TypeClassCheck(a.Dtype(), eqtc); err == nil { + willFailEq = true + } } - if err := typeclassCheck(a.Dtype(), tc); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), tc); err != nil { return true, willFailEq } diff --git a/type_test.go b/type_test.go index 12c5678..54b9acd 100644 --- a/type_test.go +++ b/type_test.go @@ -6,63 +6,3 @@ import ( "gorgonia.org/dtype" ) - -type Float16 uint16 - -func TestRegisterType(t *testing.T) { - dt := dtype.Dtype{reflect.TypeOf(Float16(0))} - RegisterFloat(dt) - - if err := typeclassCheck(dt, floatTypes); err != nil { - t.Errorf("Expected %v to be in floatTypes: %v", dt, err) - } - if err := typeclassCheck(dt, numberTypes); err != nil { - t.Errorf("Expected %v to be in numberTypes: %v", dt, err) - } - if err := typeclassCheck(dt, ordTypes); err != nil { - t.Errorf("Expected %v to be in ordTypes: %v", dt, err) - } - if err := typeclassCheck(dt, eqTypes); err != nil { - t.Errorf("Expected %v to be in eqTypes: %v", dt, err) - } - -} - -func TestDtypeConversions(t *testing.T) { - for k, v := range reverseNumpyDtypes { - if npdt, err := v.numpyDtype(); npdt != k { - t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt) - } else if err != nil { - t.Errorf("Error: %v", err) - } - } - dt := dtype.Dtype{reflect.TypeOf(Float16(0))} - if _, err := dt.numpyDtype(); err == nil { - t.Errorf("Expected an error when passing in type unknown to np") - } - - for k, v := range numpyDtypes { - if dt, err := fromNumpyDtype(v); dt != k { - // special cases - if Int.Size() == 4 && v == "i4" && dt == Int { - continue - } - if Int.Size() == 8 && v == "i8" && dt == Int { - continue - } - - if Uint.Size() == 4 && v == "u4" && dt == Uint { - continue - } - if Uint.Size() == 8 && v == "u8" && dt == Uint { - continue - } - t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt) - } else if err != nil { - t.Errorf("Error: %v", err) - } - } - if _, err := fromNumpyDtype("EDIUH"); err == nil { - t.Error("Expected error when nonsense is passed into fromNumpyDtype") - } -} From da916ae4dcd3a60721b297a0d2d62a8c8e151fe1 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 9 Jul 2021 13:33:03 +1000 Subject: [PATCH 087/154] More work on moving Dtype to its own package Generated the things correctly --- api_arith_generated_test.go | 104 ++-- api_arith_test.go | 5 +- api_cmp_generated_test.go | 86 ++-- api_unary_generated_test.go | 74 +-- api_unary_test.go | 85 ++-- defaultengine_argmethods.go | 9 +- defaultengine_arith.go | 25 +- defaultengine_cmp.go | 25 +- defaultengine_linalg.go | 5 +- defaultengine_mapreduce.go | 4 +- defaultengine_matop_misc.go | 3 +- defaultengine_misc.go | 3 +- defaultengine_prep.go | 14 +- defaultengine_unary.go | 29 +- dense_apply_test.go | 6 +- dense_arith_test.go | 104 ++-- dense_cmp_test.go | 86 ++-- dense_io.go | 4 +- dense_linalg.go | 5 +- genlib2/arith_tests.go | 16 +- genlib2/cmp_tests.go | 942 ++++++++++++++++++------------------ genlib2/engine.go | 6 +- genlib2/unary_tests.go | 302 ++++++------ testutils_test.go | 4 + 24 files changed, 986 insertions(+), 960 deletions(-) diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index 1120fba..ba019e8 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -5,13 +5,15 @@ package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) func TestAdd(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -37,7 +39,7 @@ func TestSub(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -64,7 +66,7 @@ func TestMul(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -91,7 +93,7 @@ func TestDiv(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -118,7 +120,7 @@ func TestPow(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -144,7 +146,7 @@ func TestAdd_unsafe(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -175,7 +177,7 @@ func TestSub_unsafe(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -207,7 +209,7 @@ func TestMul_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -239,7 +241,7 @@ func TestDiv_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -271,7 +273,7 @@ func TestPow_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -303,7 +305,7 @@ func TestAdd_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -335,7 +337,7 @@ func TestSub_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -368,7 +370,7 @@ func TestMul_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -401,7 +403,7 @@ func TestDiv_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -434,7 +436,7 @@ func TestPow_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -468,7 +470,7 @@ func TestAdd_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -497,7 +499,7 @@ func TestSub_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -527,7 +529,7 @@ func TestMul_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -557,7 +559,7 @@ func TestDiv_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -587,7 +589,7 @@ func TestPow_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -615,7 +617,7 @@ func TestAddScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -641,7 +643,7 @@ func TestAddScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -695,7 +697,7 @@ func TestSubScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -721,7 +723,7 @@ func TestSubScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -775,7 +777,7 @@ func TestMulScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -801,7 +803,7 @@ func TestMulScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -855,7 +857,7 @@ func TestDivScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -910,7 +912,7 @@ func TestPowScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -965,7 +967,7 @@ func TestAddScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -996,7 +998,7 @@ func TestAddScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1029,7 +1031,7 @@ func TestSubScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1060,7 +1062,7 @@ func TestSubScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1093,7 +1095,7 @@ func TestMulScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1124,7 +1126,7 @@ func TestMulScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1157,7 +1159,7 @@ func TestDivScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1191,7 +1193,7 @@ func TestPowScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1226,7 +1228,7 @@ func TestAddScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1258,7 +1260,7 @@ func TestAddScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1292,7 +1294,7 @@ func TestSubScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1324,7 +1326,7 @@ func TestSubScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1358,7 +1360,7 @@ func TestMulScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1390,7 +1392,7 @@ func TestMulScalar_reuse(t *testing.T) { b := identityVal(1, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1424,7 +1426,7 @@ func TestDivScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1459,7 +1461,7 @@ func TestPowScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1496,7 +1498,7 @@ func TestAddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1525,7 +1527,7 @@ func TestAddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1556,7 +1558,7 @@ func TestSubScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1588,7 +1590,7 @@ func TestMulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1617,7 +1619,7 @@ func TestMulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1648,7 +1650,7 @@ func TestDivScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1680,7 +1682,7 @@ func TestPowScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok diff --git a/api_arith_test.go b/api_arith_test.go index 75a4838..3a3cf67 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) // This file contains the tests for API functions that aren't generated by genlib @@ -40,7 +41,7 @@ func TestFMA(t *testing.T) { WithEngine(q.Engine())(y) y2 := y.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok1 := q.Engine().(FMAer) _, ok2 := q.Engine().(Muler) _, ok3 := q.Engine().(Adder) @@ -55,7 +56,7 @@ func TestFMA(t *testing.T) { return true } - we, _ = willerr(a, numberTypes, nil) + we, _ = willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok wi, err := Mul(a, x, WithIncr(y2)) diff --git a/api_cmp_generated_test.go b/api_cmp_generated_test.go index 002587b..e4ddd7b 100644 --- a/api_cmp_generated_test.go +++ b/api_cmp_generated_test.go @@ -6,11 +6,13 @@ import ( "reflect" "testing" "testing/quick" + + "gorgonia.org/dtype" ) func TestGt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -68,7 +70,7 @@ func TestGt(t *testing.T) { } func TestGte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -126,7 +128,7 @@ func TestGte(t *testing.T) { } func TestLt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -184,7 +186,7 @@ func TestLt(t *testing.T) { } func TestLte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -242,7 +244,7 @@ func TestLte(t *testing.T) { } func TestEq(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -298,7 +300,7 @@ func TestEq(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -333,7 +335,7 @@ func TestEq(t *testing.T) { } func TestNe(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -368,11 +370,11 @@ func TestNe(t *testing.T) { } func TestGt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -428,11 +430,11 @@ func TestGt_assame(t *testing.T) { } func TestGte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -488,11 +490,11 @@ func TestGte_assame(t *testing.T) { } func TestLt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -548,11 +550,11 @@ func TestLt_assame(t *testing.T) { } func TestLte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -608,11 +610,11 @@ func TestLte_assame(t *testing.T) { } func TestEq_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -666,11 +668,11 @@ func TestEq_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -704,11 +706,11 @@ func TestEq_assame(t *testing.T) { } func TestNe_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -742,7 +744,7 @@ func TestNe_assame(t *testing.T) { } func TestGtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -798,7 +800,7 @@ func TestGtScalar(t *testing.T) { } func TestGteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -854,7 +856,7 @@ func TestGteScalar(t *testing.T) { } func TestLtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -910,7 +912,7 @@ func TestLtScalar(t *testing.T) { } func TestLteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -966,7 +968,7 @@ func TestLteScalar(t *testing.T) { } func TestEqScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1020,7 +1022,7 @@ func TestEqScalar(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1053,7 +1055,7 @@ func TestEqScalar(t *testing.T) { } func TestNeScalar(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1086,11 +1088,11 @@ func TestNeScalar(t *testing.T) { } func TestGtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1144,11 +1146,11 @@ func TestGtScalar_assame(t *testing.T) { } func TestGteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1202,11 +1204,11 @@ func TestGteScalar_assame(t *testing.T) { } func TestLtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1260,11 +1262,11 @@ func TestLtScalar_assame(t *testing.T) { } func TestLteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1318,11 +1320,11 @@ func TestLteScalar_assame(t *testing.T) { } func TestEqScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1374,11 +1376,11 @@ func TestEqScalar_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1410,11 +1412,11 @@ func TestEqScalar_assame(t *testing.T) { } func TestNeScalar_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() diff --git a/api_unary_generated_test.go b/api_unary_generated_test.go index 31a23f2..e6fc8a9 100644 --- a/api_unary_generated_test.go +++ b/api_unary_generated_test.go @@ -5,13 +5,15 @@ package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) func TestNeg(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -37,7 +39,7 @@ func TestSquare(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -48,7 +50,7 @@ func TestSquare(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -66,7 +68,7 @@ func TestCube(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -77,7 +79,7 @@ func TestCube(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -95,7 +97,7 @@ func TestExp(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -121,7 +123,7 @@ func TestLog(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -147,7 +149,7 @@ func TestSqrt(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -173,7 +175,7 @@ func TestCbrt(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -199,7 +201,7 @@ func TestNeg_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -230,7 +232,7 @@ func TestSquare_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -241,7 +243,7 @@ func TestSquare_unsafe(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -264,7 +266,7 @@ func TestCube_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -275,7 +277,7 @@ func TestCube_unsafe(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -298,7 +300,7 @@ func TestExp_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -329,7 +331,7 @@ func TestLog_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -360,7 +362,7 @@ func TestSqrt_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -391,7 +393,7 @@ func TestCbrt_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -423,7 +425,7 @@ func TestNeg_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -455,7 +457,7 @@ func TestSquare_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -466,7 +468,7 @@ func TestSquare_reuse(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -490,7 +492,7 @@ func TestCube_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -501,7 +503,7 @@ func TestCube_reuse(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -525,7 +527,7 @@ func TestExp_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -557,7 +559,7 @@ func TestLog_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -589,7 +591,7 @@ func TestSqrt_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -621,7 +623,7 @@ func TestCbrt_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -655,7 +657,7 @@ func TestNeg_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -688,7 +690,7 @@ func TestSquare_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -699,7 +701,7 @@ func TestSquare_incr(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { @@ -724,7 +726,7 @@ func TestCube_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -735,7 +737,7 @@ func TestCube_incr(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { @@ -760,7 +762,7 @@ func TestExp_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -793,7 +795,7 @@ func TestLog_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -826,7 +828,7 @@ func TestSqrt_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -859,7 +861,7 @@ func TestCbrt_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok diff --git a/api_unary_test.go b/api_unary_test.go index 5f453a5..25b68f7 100644 --- a/api_unary_test.go +++ b/api_unary_test.go @@ -9,6 +9,7 @@ import ( "github.com/chewxy/math32" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) /* @@ -354,12 +355,12 @@ func TestInvSqrt(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a) @@ -387,12 +388,12 @@ func TestInvSqrt(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, UseUnsafe()) @@ -426,12 +427,12 @@ func TestInvSqrt(t *testing.T) { reuse := q.Clone().(*Dense) reuse.Zero() correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, WithReuse(reuse)) @@ -466,12 +467,12 @@ func TestInvSqrt(t *testing.T) { incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, WithIncr(incr)) @@ -509,12 +510,12 @@ func TestInv(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a) @@ -541,12 +542,12 @@ func TestInv(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, UseUnsafe()) @@ -577,12 +578,12 @@ func TestInv(t *testing.T) { correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, WithReuse(reuse)) @@ -613,12 +614,12 @@ func TestInv(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, WithIncr(incr)) @@ -654,12 +655,12 @@ func TestLog10(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a) @@ -688,12 +689,12 @@ func TestLog10(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, UseUnsafe()) @@ -725,12 +726,12 @@ func TestLog10(t *testing.T) { correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, WithReuse(reuse)) @@ -762,12 +763,12 @@ func TestLog10(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, WithIncr(incr)) @@ -806,10 +807,10 @@ func TestAbs(t *testing.T) { correct := New(Of(Bool), WithShape(q.Shape().Clone()...)) correct.Memset(true) // we'll exclude everything other than ordtypes because complex numbers cannot be abs'd - if err := typeclassCheck(a.Dtype(), ordTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Ord); err != nil { return true } - we, willFailEq := willerr(a, signedTypes, nil) + we, willFailEq := willerr(a, dtype.Signed, nilTC) _, ok := q.Engine().(Abser) we = we || !ok @@ -840,12 +841,12 @@ func TestTanh(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a) @@ -882,12 +883,12 @@ func TestTanh(t *testing.T) { invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, UseUnsafe()) @@ -929,12 +930,12 @@ func TestTanh(t *testing.T) { correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, WithReuse(reuse)) @@ -976,12 +977,12 @@ func TestTanh(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, WithIncr(incr)) @@ -1028,12 +1029,12 @@ func TestLog2(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a) @@ -1062,12 +1063,12 @@ func TestLog2(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, UseUnsafe()) @@ -1099,12 +1100,12 @@ func TestLog2(t *testing.T) { correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, WithReuse(reuse)) @@ -1136,12 +1137,12 @@ func TestLog2(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, WithIncr(incr)) diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index 5632fa6..21373fd 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -1,6 +1,9 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" + "gorgonia.org/dtype" +) func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) { @@ -13,7 +16,7 @@ func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) { } func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } @@ -100,7 +103,7 @@ func (e StdEng) Argmin(t Tensor, axis int) (retVal Tensor, err error) { } func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmin") } diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 918e1ca..2e4e55e 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -4,13 +4,14 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // Add performs a + b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Add failed") } @@ -75,7 +76,7 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Sub performs a - b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Sub failed") } @@ -140,7 +141,7 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Mul performs a × b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mul failed") } @@ -205,7 +206,7 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Div performs a ÷ b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Div failed") } @@ -270,7 +271,7 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Pow performs a ^ b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Pow failed") } @@ -335,7 +336,7 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Mod performs a % b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mod failed") } @@ -400,7 +401,7 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // AddScalar performs t + s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Add failed") } @@ -503,7 +504,7 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // SubScalar performs t - s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Sub failed") } @@ -606,7 +607,7 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // MulScalar performs t × s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mul failed") } @@ -709,7 +710,7 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // DivScalar performs t ÷ s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Div failed") } @@ -812,7 +813,7 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // PowScalar performs t ^ s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Pow failed") } @@ -915,7 +916,7 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // ModScalar performs t % s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mod failed") } diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 1d6ff48..3d6a7f0 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -4,6 +4,7 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -12,7 +13,7 @@ import ( //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gt failed") } @@ -90,7 +91,7 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gte failed") } @@ -168,7 +169,7 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lt failed") } @@ -246,7 +247,7 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lte failed") } @@ -324,7 +325,7 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, eqTypes); err != nil { + if err = binaryCheck(a, b, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Eq failed") } @@ -402,7 +403,7 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, eqTypes); err != nil { + if err = binaryCheck(a, b, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Ne failed") } @@ -480,7 +481,7 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gt failed") } @@ -602,7 +603,7 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gte failed") } @@ -724,7 +725,7 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lt failed") } @@ -846,7 +847,7 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lte failed") } @@ -964,7 +965,7 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, eqTypes); err != nil { + if err = unaryCheck(t, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Eq failed") } @@ -1082,7 +1083,7 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, eqTypes); err != nil { + if err = unaryCheck(t, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Ne failed") } diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index d9a16aa..4a6073b 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/errors" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) // Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error @@ -15,7 +16,7 @@ func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { return } - if err = typeclassCheck(t.Dtype(), numberTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.Number); err != nil { return nil, errors.Wrap(err, "Trace") } @@ -317,7 +318,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { if t, ok = a.(*Dense); !ok { return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a) } - if err = typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err = dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return nil, nil, nil, errors.Errorf("StdEng can only perform SVDs for float64 and float32 type. Got tensor of %v instead", t.Dtype()) } diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index a70af74..60dcd29 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -11,7 +11,7 @@ import ( ) func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nil); err != nil { + if err = unaryCheck(a, nilTC); err != nil { err = errors.Wrap(err, "Failed Map()") return } @@ -254,7 +254,7 @@ func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTe return } - if err = unaryCheck(a, nil); err != nil { + if err = unaryCheck(a, nilTC); err != nil { err = errors.Wrap(err, "prepReduce failed") return } diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 0ab392a..dde54a9 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -2,6 +2,7 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -370,7 +371,7 @@ func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { return } - if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { + if err = dtype.TypeClassCheck(a.Dtype(), dtype.Number); err != nil { return nil, errors.Wrap(err, "Diagonal") } diff --git a/defaultengine_misc.go b/defaultengine_misc.go index 8ce04db..f642f67 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -2,11 +2,12 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nonComplexNumberTypes); err != nil { + if err = unaryCheck(a, dtype.NonComplexNumber); err != nil { return nil, errors.Wrap(err, "Clamp failed") } diff --git a/defaultengine_prep.go b/defaultengine_prep.go index a3df181..1a13a6f 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -62,7 +62,7 @@ func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict boo return } -func binaryCheck(a, b Tensor, tc *typeclass) (err error) { +func binaryCheck(a, b Tensor, tc dtype.TypeClass) (err error) { // check if the tensors are accessible if !a.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, a) @@ -74,11 +74,11 @@ func binaryCheck(a, b Tensor, tc *typeclass) (err error) { at := a.Dtype() bt := b.Dtype() - if tc != nil { - if err = typeclassCheck(at, tc); err != nil { + if tc != nilTC { + if err = dtype.TypeClassCheck(at, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "a") } - if err = typeclassCheck(bt, tc); err != nil { + if err = dtype.TypeClassCheck(bt, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "b") } } @@ -92,13 +92,13 @@ func binaryCheck(a, b Tensor, tc *typeclass) (err error) { return nil } -func unaryCheck(a Tensor, tc *typeclass) error { +func unaryCheck(a Tensor, tc dtype.TypeClass) error { if !a.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, a) } at := a.Dtype() - if tc != nil { - if err := typeclassCheck(at, tc); err != nil { + if tc != nilTC { + if err := dtype.TypeClassCheck(at, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "a") } } diff --git a/defaultengine_unary.go b/defaultengine_unary.go index 986e246..d38cf57 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -4,11 +4,12 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Neg failed") return } @@ -76,7 +77,7 @@ func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Inv failed") return } @@ -144,7 +145,7 @@ func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Square failed") return } @@ -212,7 +213,7 @@ func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Cube failed") return } @@ -280,7 +281,7 @@ func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Exp failed") return } @@ -348,7 +349,7 @@ func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Tanh failed") return } @@ -416,7 +417,7 @@ func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Log failed") return } @@ -484,7 +485,7 @@ func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "Log2 failed") return } @@ -552,7 +553,7 @@ func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Log10 failed") return } @@ -620,7 +621,7 @@ func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Sqrt failed") return } @@ -688,7 +689,7 @@ func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "Cbrt failed") return } @@ -756,7 +757,7 @@ func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "InvSqrt failed") return } @@ -824,7 +825,7 @@ func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, signedTypes); err != nil { + if err = unaryCheck(a, dtype.Signed); err != nil { err = errors.Wrapf(err, "Abs failed") return } @@ -892,7 +893,7 @@ func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Sign(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, signedTypes); err != nil { + if err = unaryCheck(a, dtype.Signed); err != nil { err = errors.Wrapf(err, "Sign failed") return } diff --git a/dense_apply_test.go b/dense_apply_test.go index 8d73631..793f2c5 100644 --- a/dense_apply_test.go +++ b/dense_apply_test.go @@ -108,7 +108,7 @@ func TestDense_Apply(t *testing.T) { return true // we'll skip those that we cannot mutate } - we, eqFail := willerr(q, nil, nil) + we, eqFail := willerr(q, nilTC, nilTC) _, ok := q.Engine().(Mapper) we = we || !ok @@ -151,7 +151,7 @@ func TestDense_Apply_unsafe(t *testing.T) { return true // we'll skip those that we cannot mutate } - we, eqFail := willerr(q, nil, nil) + we, eqFail := willerr(q, nilTC, nilTC) _, ok := q.Engine().(Mapper) we = we || !ok @@ -192,7 +192,7 @@ func TestDense_Apply_reuse(t *testing.T) { return true // we'll skip those that we cannot mutate } - we, eqFail := willerr(q, nil, nil) + we, eqFail := willerr(q, nilTC, nilTC) _, ok := q.Engine().(Mapper) we = we || !ok diff --git a/dense_arith_test.go b/dense_arith_test.go index d414dd2..9b80873 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -5,13 +5,15 @@ package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) func TestDense_Add(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -37,7 +39,7 @@ func TestDense_Sub(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -64,7 +66,7 @@ func TestDense_Mul(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -91,7 +93,7 @@ func TestDense_Div(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -118,7 +120,7 @@ func TestDense_Pow(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -144,7 +146,7 @@ func TestDense_Add_unsafe(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -175,7 +177,7 @@ func TestDense_Sub_unsafe(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -207,7 +209,7 @@ func TestDense_Mul_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -239,7 +241,7 @@ func TestDense_Div_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -271,7 +273,7 @@ func TestDense_Pow_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -303,7 +305,7 @@ func TestDense_Add_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -335,7 +337,7 @@ func TestDense_Sub_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -368,7 +370,7 @@ func TestDense_Mul_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -401,7 +403,7 @@ func TestDense_Div_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -434,7 +436,7 @@ func TestDense_Pow_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -468,7 +470,7 @@ func TestDense_Add_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -497,7 +499,7 @@ func TestDense_Sub_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -527,7 +529,7 @@ func TestDense_Mul_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -557,7 +559,7 @@ func TestDense_Div_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -587,7 +589,7 @@ func TestDense_Pow_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -615,7 +617,7 @@ func TestDense_AddScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -641,7 +643,7 @@ func TestDense_AddScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -695,7 +697,7 @@ func TestDense_SubScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -721,7 +723,7 @@ func TestDense_SubScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -775,7 +777,7 @@ func TestDense_MulScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -801,7 +803,7 @@ func TestDense_MulScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -855,7 +857,7 @@ func TestDense_DivScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -910,7 +912,7 @@ func TestDense_PowScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -965,7 +967,7 @@ func TestDense_AddScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -996,7 +998,7 @@ func TestDense_AddScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1029,7 +1031,7 @@ func TestDense_SubScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1060,7 +1062,7 @@ func TestDense_SubScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1093,7 +1095,7 @@ func TestDense_MulScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1124,7 +1126,7 @@ func TestDense_MulScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1157,7 +1159,7 @@ func TestDense_DivScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1191,7 +1193,7 @@ func TestDense_PowScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1226,7 +1228,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1258,7 +1260,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1292,7 +1294,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1324,7 +1326,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1358,7 +1360,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1390,7 +1392,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { b := identityVal(1, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1424,7 +1426,7 @@ func TestDense_DivScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1459,7 +1461,7 @@ func TestDense_PowScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1496,7 +1498,7 @@ func TestDense_AddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1525,7 +1527,7 @@ func TestDense_AddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1556,7 +1558,7 @@ func TestDense_SubScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1588,7 +1590,7 @@ func TestDense_MulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1617,7 +1619,7 @@ func TestDense_MulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1648,7 +1650,7 @@ func TestDense_DivScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1680,7 +1682,7 @@ func TestDense_PowScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok diff --git a/dense_cmp_test.go b/dense_cmp_test.go index a0bc5b6..384e250 100644 --- a/dense_cmp_test.go +++ b/dense_cmp_test.go @@ -6,11 +6,13 @@ import ( "reflect" "testing" "testing/quick" + + "gorgonia.org/dtype" ) func TestDense_Gt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -68,7 +70,7 @@ func TestDense_Gt(t *testing.T) { } func TestDense_Gte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -126,7 +128,7 @@ func TestDense_Gte(t *testing.T) { } func TestDense_Lt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -184,7 +186,7 @@ func TestDense_Lt(t *testing.T) { } func TestDense_Lte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -242,7 +244,7 @@ func TestDense_Lte(t *testing.T) { } func TestDense_ElEq(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -298,7 +300,7 @@ func TestDense_ElEq(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -333,7 +335,7 @@ func TestDense_ElEq(t *testing.T) { } func TestDense_ElNe(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -368,11 +370,11 @@ func TestDense_ElNe(t *testing.T) { } func TestDense_Gt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -428,11 +430,11 @@ func TestDense_Gt_assame(t *testing.T) { } func TestDense_Gte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -488,11 +490,11 @@ func TestDense_Gte_assame(t *testing.T) { } func TestDense_Lt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -548,11 +550,11 @@ func TestDense_Lt_assame(t *testing.T) { } func TestDense_Lte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -608,11 +610,11 @@ func TestDense_Lte_assame(t *testing.T) { } func TestDense_ElEq_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -666,11 +668,11 @@ func TestDense_ElEq_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -704,11 +706,11 @@ func TestDense_ElEq_assame(t *testing.T) { } func TestDense_ElNe_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -742,7 +744,7 @@ func TestDense_ElNe_assame(t *testing.T) { } func TestDense_GtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -798,7 +800,7 @@ func TestDense_GtScalar(t *testing.T) { } func TestDense_GteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -854,7 +856,7 @@ func TestDense_GteScalar(t *testing.T) { } func TestDense_LtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -910,7 +912,7 @@ func TestDense_LtScalar(t *testing.T) { } func TestDense_LteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -966,7 +968,7 @@ func TestDense_LteScalar(t *testing.T) { } func TestDense_ElEqScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1020,7 +1022,7 @@ func TestDense_ElEqScalar(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1053,7 +1055,7 @@ func TestDense_ElEqScalar(t *testing.T) { } func TestDense_ElNeScalar(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1086,11 +1088,11 @@ func TestDense_ElNeScalar(t *testing.T) { } func TestDense_GtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1144,11 +1146,11 @@ func TestDense_GtScalar_assame(t *testing.T) { } func TestDense_GteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1202,11 +1204,11 @@ func TestDense_GteScalar_assame(t *testing.T) { } func TestDense_LtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1260,11 +1262,11 @@ func TestDense_LtScalar_assame(t *testing.T) { } func TestDense_LteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1318,11 +1320,11 @@ func TestDense_LteScalar_assame(t *testing.T) { } func TestDense_ElEqScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1374,11 +1376,11 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1410,11 +1412,11 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } func TestDense_ElNeScalar_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() diff --git a/dense_io.go b/dense_io.go index f5a1abd..b0d5395 100644 --- a/dense_io.go +++ b/dense_io.go @@ -164,7 +164,7 @@ func (r *binaryReader) Err() error { // If tensor is masked, invalid values are replaced by the default fill value. func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string - if npdt, err = t.t.numpyDtype(); err != nil { + if npdt, err = t.t.NumpyDtype(); err != nil { return } @@ -243,7 +243,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error) { } // TODO: check for endianness. For now we assume everything is little endian - if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = dtype.FromNumpyDtype(string(match[1][1:])); err != nil { return } diff --git a/dense_linalg.go b/dense_linalg.go index 4757d97..417358a 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -2,6 +2,7 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" ) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices @@ -17,7 +18,7 @@ func (t *Dense) Trace() (retVal interface{}, err error) { // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { // check that the data is a float - if err = typeclassCheck(t.t, floatcmplxTypes); err != nil { + if err = dtype.TypeClassCheck(t.t, dtype.FloatComplex); err != nil { return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") } @@ -413,7 +414,7 @@ func handleIncr(res *Dense, reuse, incr Tensor, expectedShape Shape) (retVal *De return } - if err = typeclassCheck(incrD.t, numberTypes); err != nil { + if err = dtype.TypeClassCheck(incrD.t, dtype.Number); err != nil { err = errors.Wrapf(err, "handleIncr only handles Number types. Got %v instead", incrD.t) return } diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index 6d0bd88..596cd8a 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -176,10 +176,10 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: API, - EqFailTypeClassName: "", + EqFailTypeClassName: "nilTC", } if t.name == "Pow" { - t.EqFailTypeClassName = "dtype.Complex" + t.EqFailTypeClassName = "dtype.Complexes" } tests = append(tests, t) } @@ -219,11 +219,11 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { arithOp: op, scalars: true, lvl: API, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } switch t.name { case "Pow": - t.EqFailTypeClassName = "dtype.Complex" + t.EqFailTypeClassName = "dtype.Complexes" case "Sub": t.EqFailTypeClassName = "dtype.Unsigned" } @@ -264,10 +264,10 @@ func generateDenseMethodArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: Dense, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } if t.name == "Pow" { - t.EqFailTypeClassName = "dtype.Complex" + t.EqFailTypeClassName = "dtype.Complexes" } tests = append(tests, t) } @@ -307,11 +307,11 @@ func generateDenseMethodScalarTests(f io.Writer, ak Kinds) { arithOp: op, scalars: true, lvl: Dense, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } switch t.name { case "Pow": - t.EqFailTypeClassName = "dtype.Complex" + t.EqFailTypeClassName = "dtype.Complexes" case "Sub": t.EqFailTypeClassName = "dtype.Unsigned" } diff --git a/genlib2/cmp_tests.go b/genlib2/cmp_tests.go index 8d3d8f6..1110e6e 100644 --- a/genlib2/cmp_tests.go +++ b/genlib2/cmp_tests.go @@ -1,471 +1,471 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const ( - APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` - APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` - APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` - APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` - APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` - APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` - APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` - APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` - - DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` - DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` - DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` - DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` - DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` -) - -const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} - if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ - t.Errorf("a: %-v", a) - t.Errorf("b: %-v", b) - t.Errorf("c: %-v", c) - t.Errorf("axb.Data() %v", axb.Data()) - t.Errorf("bxc.Data() %v", bxc.Data()) - t.Errorf("axc.Data() %v", axc.Data()) - return false - } -{{else -}} - {{if eq .Level "API" -}} - ab := axb.(*Dense).Bools() - bc := bxc.(*Dense).Bools() - ac := axc.(*Dense).Bools() - {{else -}} - ab := axb.Bools() - bc := bxc.Bools() - ac := axc.Bools() - {{end -}} - for i, vab := range ab { - if vab && bc[i] { - if !ac[i]{ - return false - } - } - } -{{end -}} -` - -const transitivityBodyRaw = `transFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - b := q.Clone().(*Dense) - c := q.Clone().(*Dense) - - bv, _ := quick.Value(b.Dtype().Type, r) - cv, _ := quick.Value(c.Dtype().Type, r) - b.Memset(bv.Interface()) - c.Memset(cv.Interface()) - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "axc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "transitivityCheck" .}} - return true -} -if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - bv, _ := quick.Value(a.Dtype().Type, r) - b := bv.Interface() - c := q.Clone().(*Dense) - cv, _ := quick.Value(c.Dtype().Type, r) - c.Memset(cv.Interface()) - - {{template "axb" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "axc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "transitivityCheck" .}} - return true -} -if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const symmetryBodyRaw = `symFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - b := q.Clone().(*Dense) - - bv, _ := quick.Value(b.Dtype().Type, r) - b.Memset(bv.Interface()) - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxa" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - return reflect.DeepEqual(axb.Data(), bxa.Data()) - -} -if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - bv, _ := quick.Value(a.Dtype().Type, r) - b := bv.Interface() - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxa" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - return reflect.DeepEqual(axb.Data(), bxa.Data()) - -} -if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Symmetry test for {{.Name}} failed: %v", err) -} -` - -type CmpTest struct { - cmpOp - scalars bool - lvl Level - FuncOpt string - EqFailTypeClassName string -} - -func (fn *CmpTest) Name() string { - if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { - return "El" + fn.cmpOp.Name() - } - return fn.cmpOp.Name() -} - -func (fn *CmpTest) Level() string { - switch fn.lvl { - case API: - return "API" - case Dense: - return "Dense" - } - return "" -} - -func (fn *CmpTest) Signature() *Signature { - var name string - switch fn.lvl { - case API: - name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) - case Dense: - name = fmt.Sprintf("TestDense_%s", fn.Name()) - } - if fn.scalars { - name += "Scalar" - } - if fn.FuncOpt != "" { - name += "_" + fn.FuncOpt - } - return &Signature{ - Name: name, - NameTemplate: plainName, - ParamNames: []string{"t"}, - ParamTemplates: []*template.Template{testingType}, - } -} - -func (fn *CmpTest) canWrite() bool { - return fn.IsTransitive || fn.IsSymmetric -} - -func (fn *CmpTest) WriteBody(w io.Writer) { - if fn.IsTransitive { - fn.writeTransitivity(w) - fmt.Fprintf(w, "\n") - } - if fn.IsSymmetric { - fn.writeSymmetry(w) - } -} - -func (fn *CmpTest) writeTransitivity(w io.Writer) { - var t *template.Template - if fn.scalars { - t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) - } else { - t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) - } - - switch fn.lvl { - case API: - if fn.scalars { - template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) - template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) - template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) - } else { - template.Must(t.New("axb").Parse(APICallVVaxbRaw)) - template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) - template.Must(t.New("axc").Parse(APICallVVaxcRaw)) - } - case Dense: - if fn.scalars { - template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) - template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) - template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) - } else { - template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) - template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) - template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) - } - } - template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - - t.Execute(w, fn) -} - -func (fn *CmpTest) writeSymmetry(w io.Writer) { - var t *template.Template - if fn.scalars { - t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) - } else { - t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) - } - - switch fn.lvl { - case API: - if fn.scalars { - template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) - template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) - } else { - template.Must(t.New("axb").Parse(APICallVVaxbRaw)) - template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) - } - case Dense: - if fn.scalars { - template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) - template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) - } else { - template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) - template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) - } - } - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - - t.Execute(w, fn) -} - -func (fn *CmpTest) Write(w io.Writer) { - sig := fn.Signature() - w.Write([]byte("func ")) - sig.Write(w) - w.Write([]byte("{\n")) - fn.WriteBody(w) - w.Write([]byte("}\n")) -} - -func generateAPICmpTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: API, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } - -} - -func generateAPICmpMixedTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: API, - scalars: true, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} - -func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: Dense, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} - -func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: Dense, - scalars: true, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const ( + APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` + APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` + APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` + APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` + APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` + APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` + APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` + APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` + + DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` + DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` + DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` + DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` + DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` +) + +const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} + if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ + t.Errorf("a: %-v", a) + t.Errorf("b: %-v", b) + t.Errorf("c: %-v", c) + t.Errorf("axb.Data() %v", axb.Data()) + t.Errorf("bxc.Data() %v", bxc.Data()) + t.Errorf("axc.Data() %v", axc.Data()) + return false + } +{{else -}} + {{if eq .Level "API" -}} + ab := axb.(*Dense).Bools() + bc := bxc.(*Dense).Bools() + ac := axc.(*Dense).Bools() + {{else -}} + ab := axb.Bools() + bc := bxc.Bools() + ac := axc.Bools() + {{end -}} + for i, vab := range ab { + if vab && bc[i] { + if !ac[i]{ + return false + } + } + } +{{end -}} +` + +const transitivityBodyRaw = `transFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + b := q.Clone().(*Dense) + c := q.Clone().(*Dense) + + bv, _ := quick.Value(b.Dtype().Type, r) + cv, _ := quick.Value(c.Dtype().Type, r) + b.Memset(bv.Interface()) + c.Memset(cv.Interface()) + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "axc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "transitivityCheck" .}} + return true +} +if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + bv, _ := quick.Value(a.Dtype().Type, r) + b := bv.Interface() + c := q.Clone().(*Dense) + cv, _ := quick.Value(c.Dtype().Type, r) + c.Memset(cv.Interface()) + + {{template "axb" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "axc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "transitivityCheck" .}} + return true +} +if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const symmetryBodyRaw = `symFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + b := q.Clone().(*Dense) + + bv, _ := quick.Value(b.Dtype().Type, r) + b.Memset(bv.Interface()) + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxa" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + return reflect.DeepEqual(axb.Data(), bxa.Data()) + +} +if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + bv, _ := quick.Value(a.Dtype().Type, r) + b := bv.Interface() + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxa" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + return reflect.DeepEqual(axb.Data(), bxa.Data()) + +} +if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Symmetry test for {{.Name}} failed: %v", err) +} +` + +type CmpTest struct { + cmpOp + scalars bool + lvl Level + FuncOpt string + EqFailTypeClassName string +} + +func (fn *CmpTest) Name() string { + if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { + return "El" + fn.cmpOp.Name() + } + return fn.cmpOp.Name() +} + +func (fn *CmpTest) Level() string { + switch fn.lvl { + case API: + return "API" + case Dense: + return "Dense" + } + return "" +} + +func (fn *CmpTest) Signature() *Signature { + var name string + switch fn.lvl { + case API: + name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) + case Dense: + name = fmt.Sprintf("TestDense_%s", fn.Name()) + } + if fn.scalars { + name += "Scalar" + } + if fn.FuncOpt != "" { + name += "_" + fn.FuncOpt + } + return &Signature{ + Name: name, + NameTemplate: plainName, + ParamNames: []string{"t"}, + ParamTemplates: []*template.Template{testingType}, + } +} + +func (fn *CmpTest) canWrite() bool { + return fn.IsTransitive || fn.IsSymmetric +} + +func (fn *CmpTest) WriteBody(w io.Writer) { + if fn.IsTransitive { + fn.writeTransitivity(w) + fmt.Fprintf(w, "\n") + } + if fn.IsSymmetric { + fn.writeSymmetry(w) + } +} + +func (fn *CmpTest) writeTransitivity(w io.Writer) { + var t *template.Template + if fn.scalars { + t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) + } else { + t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) + } + + switch fn.lvl { + case API: + if fn.scalars { + template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) + template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) + template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) + } else { + template.Must(t.New("axb").Parse(APICallVVaxbRaw)) + template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) + template.Must(t.New("axc").Parse(APICallVVaxcRaw)) + } + case Dense: + if fn.scalars { + template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) + template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) + template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) + } else { + template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) + template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) + template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) + } + } + template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + +func (fn *CmpTest) writeSymmetry(w io.Writer) { + var t *template.Template + if fn.scalars { + t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) + } else { + t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) + } + + switch fn.lvl { + case API: + if fn.scalars { + template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) + template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) + } else { + template.Must(t.New("axb").Parse(APICallVVaxbRaw)) + template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) + } + case Dense: + if fn.scalars { + template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) + template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) + } else { + template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) + template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) + } + } + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + +func (fn *CmpTest) Write(w io.Writer) { + sig := fn.Signature() + w.Write([]byte("func ")) + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n")) +} + +func generateAPICmpTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: API, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } + +} + +func generateAPICmpMixedTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: API, + scalars: true, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} + +func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: Dense, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} + +func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: Dense, + scalars: true, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} diff --git a/genlib2/engine.go b/genlib2/engine.go index 6323765..7b7a207 100644 --- a/genlib2/engine.go +++ b/genlib2/engine.go @@ -276,11 +276,11 @@ func generateStdEngUncondUnary(f io.Writer, ak Kinds) { "FloatComplex", // Exp "FloatComplex", // Tanhh "FloatComplex", // Log - "Float", // Log2 + "Floats", // Log2 "FloatComplex", // Log10 "FloatComplex", // Sqrt - "Float", // Cbrt - "Float", // InvSqrt + "Floats", // Cbrt + "Floats", // InvSqrt } var gen []*EngineUnary for i, u := range unconditionalUnaries { diff --git a/genlib2/unary_tests.go b/genlib2/unary_tests.go index 20fa176..5153f2b 100644 --- a/genlib2/unary_tests.go +++ b/genlib2/unary_tests.go @@ -1,151 +1,151 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const unaryTestBodyRaw = `invFn := func(q *Dense) bool { - a := q.Clone().(*Dense) - {{template "funcoptdecl" -}} - correct := a.Clone().(*Dense) - {{template "funcoptcorrect" -}} - - - we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - ret, err := {{.Name}}(a {{template "funcoptuse"}}) - if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ - if err != nil { - return false - } - return true - } - {{if ne .InvTypeClass "" -}} - if err := dtype.TypeClassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { - return true // uninvertible due to type class implementation issues - } - {{end -}} - {{if eq .FuncOpt "incr" -}} - if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { - t.Errorf("err while subtracting incr: %v", err) - return false - } - {{end -}} - {{.Inv}}(ret, UseUnsafe()) - if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { - return false - } - {{template "funcoptcheck" -}} - return true -} - -if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ - t.Errorf("Inv tests for {{.Name}} failed: %v", err) -} -` - -type unaryTest struct { - unaryOp - FuncOpt string - EqFailTypeClassName string - InvTypeClass string -} - -func (fn *unaryTest) Name() string { - if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { - return "El" + fn.unaryOp.Name() - } - return fn.unaryOp.Name() -} - -func (fn *unaryTest) Signature() *Signature { - name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) - if fn.FuncOpt != "" { - name += "_" + fn.FuncOpt - } - return &Signature{ - Name: name, - NameTemplate: plainName, - ParamNames: []string{"t"}, - ParamTemplates: []*template.Template{testingType}, - } -} - -func (fn *unaryTest) WriteBody(w io.Writer) { - t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - t.Execute(w, fn) -} - -func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } - -func (fn *unaryTest) Write(w io.Writer) { - sig := fn.Signature() - w.Write([]byte("func ")) - sig.Write(w) - w.Write([]byte("{\n")) - fn.WriteBody(w) - w.Write([]byte("}\n")) -} - -func generateAPIUnaryTests(f io.Writer, ak Kinds) { - var tests []*unaryTest - for _, op := range conditionalUnaries { - t := &unaryTest{ - unaryOp: op, - EqFailTypeClassName: "nil", - } - - tests = append(tests, t) - } - - for _, op := range unconditionalUnaries { - t := &unaryTest{ - unaryOp: op, - EqFailTypeClassName: "nil", - } - switch op.name { - case "Square": - t.InvTypeClass = "floatcmplxTypes" - case "Cube": - t.InvTypeClass = "floatTypes" - } - - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "unsafe" - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "reuse" - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "incr" - } - - // for now incr cannot be quickchecked - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const unaryTestBodyRaw = `invFn := func(q *Dense) bool { + a := q.Clone().(*Dense) + {{template "funcoptdecl" -}} + correct := a.Clone().(*Dense) + {{template "funcoptcorrect" -}} + + + we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + ret, err := {{.Name}}(a {{template "funcoptuse"}}) + if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ + if err != nil { + return false + } + return true + } + {{if ne .InvTypeClass "" -}} + if err := dtype.TypeClassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { + return true // uninvertible due to type class implementation issues + } + {{end -}} + {{if eq .FuncOpt "incr" -}} + if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { + t.Errorf("err while subtracting incr: %v", err) + return false + } + {{end -}} + {{.Inv}}(ret, UseUnsafe()) + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + {{template "funcoptcheck" -}} + return true +} + +if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ + t.Errorf("Inv tests for {{.Name}} failed: %v", err) +} +` + +type unaryTest struct { + unaryOp + FuncOpt string + EqFailTypeClassName string + InvTypeClass string +} + +func (fn *unaryTest) Name() string { + if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { + return "El" + fn.unaryOp.Name() + } + return fn.unaryOp.Name() +} + +func (fn *unaryTest) Signature() *Signature { + name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) + if fn.FuncOpt != "" { + name += "_" + fn.FuncOpt + } + return &Signature{ + Name: name, + NameTemplate: plainName, + ParamNames: []string{"t"}, + ParamTemplates: []*template.Template{testingType}, + } +} + +func (fn *unaryTest) WriteBody(w io.Writer) { + t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + t.Execute(w, fn) +} + +func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } + +func (fn *unaryTest) Write(w io.Writer) { + sig := fn.Signature() + w.Write([]byte("func ")) + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n")) +} + +func generateAPIUnaryTests(f io.Writer, ak Kinds) { + var tests []*unaryTest + for _, op := range conditionalUnaries { + t := &unaryTest{ + unaryOp: op, + EqFailTypeClassName: "nilTC", + } + + tests = append(tests, t) + } + + for _, op := range unconditionalUnaries { + t := &unaryTest{ + unaryOp: op, + EqFailTypeClassName: "nilTC", + } + switch op.name { + case "Square": + t.InvTypeClass = "dtype.FloatComplex" + case "Cube": + t.InvTypeClass = "dtype.Floats" + } + + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "unsafe" + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "reuse" + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "incr" + } + + // for now incr cannot be quickchecked + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} diff --git a/testutils_test.go b/testutils_test.go index 4490102..5748917 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -517,6 +517,10 @@ func willerr(a *Dense, tc, eqtc dtype.TypeClass) (retVal, willFailEq bool) { willFailEq = true } } + if tc == nilTC { + retVal = !a.IsNativelyAccessible() + return + } if err := dtype.TypeClassCheck(a.Dtype(), tc); err != nil { return true, willFailEq } From 617557adff98d31c8a9f67286a29854c83749cb2 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 9 Jul 2021 13:50:17 +1000 Subject: [PATCH 088/154] Finished move of Dtype to its own package --- dense_io.go | 18 ++- example_extension_test.go | 5 +- genlib2/dense_io.go | 18 ++- genlib2/dense_maskedmethods.go | 206 ++++++++++++++++----------------- junkyard_test.go | 2 +- known_issues_test.go | 3 +- testutils_test.go | 6 +- type_test.go | 11 +- types.go | 5 + 9 files changed, 141 insertions(+), 133 deletions(-) diff --git a/dense_io.go b/dense_io.go index b0d5395..c9e8f7c 100644 --- a/dense_io.go +++ b/dense_io.go @@ -794,12 +794,11 @@ func (t *Dense) FBDecode(buf []byte) error { t.strides[i] = int(serialized.Strides(i)) } typ := string(serialized.Type()) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode FlatBuffers") } + t.t = dt if t.e == nil { t.e = StdEng{} @@ -871,12 +870,11 @@ func (t *Dense) PBDecode(buf []byte) error { } t.Δ = Triangle(toSerialize.T) typ := string(toSerialize.Type) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode ProtoBuf") } + t.t = dt if t.e == nil { t.e = StdEng{} diff --git a/example_extension_test.go b/example_extension_test.go index e5c2b22..23be0f7 100644 --- a/example_extension_test.go +++ b/example_extension_test.go @@ -6,6 +6,7 @@ import ( "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor" ) @@ -21,7 +22,7 @@ type MyType struct { func (T MyType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "(%d, %d)", T.x, T.y) } // MyDtype this the dtype of MyType. This value is populated in the init() function below -var MyDtype tensor.Dtype +var MyDtype dtype.Dtype // MyEngine supports additions of MyType, as well as other Dtypes type MyEngine struct { @@ -73,7 +74,7 @@ func (e MyEngine) Add(a, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor } func init() { - MyDtype = tensor.Dtype{reflect.TypeOf(&MyType{})} + MyDtype = dtype.Dtype{reflect.TypeOf(&MyType{})} } func Example_extension() { diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index fd51c46..cc6855e 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -545,12 +545,11 @@ func (t *Dense) FBDecode(buf []byte) error { t.strides[i] = int(serialized.Strides(i)) } typ := string(serialized.Type()) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode FlatBuffers") } + t.t = dt if t.e == nil { t.e = StdEng{} @@ -621,12 +620,11 @@ func (t *Dense) PBDecode(buf []byte) error { } t.Δ = Triangle(toSerialize.T) typ := string(toSerialize.Type) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode ProtoBuf") } + t.t = dt if t.e == nil { t.e = StdEng{} diff --git a/genlib2/dense_maskedmethods.go b/genlib2/dense_maskedmethods.go index ce1133c..644e37a 100644 --- a/genlib2/dense_maskedmethods.go +++ b/genlib2/dense_maskedmethods.go @@ -1,103 +1,103 @@ -package main - -import ( - "fmt" - "io" - "reflect" - "text/template" -) - -var maskcmpMethods = []struct { - Name string - Desc string - NumArgs int - CmpFn string - ReqFloat bool - Kinds []reflect.Kind -}{ - {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, - {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, - {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, - {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, - {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, - {"MaskedLess", " less than ", 1, "a < x", false, nil}, - {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, - {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, - {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, -} - -const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val -// Any values must be the same type as the tensor -func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ - {{if .ReqFloat}} - if !isFloat(t.t) { - err = errors.Errorf("Can only do {{.Name}} with floating point types") - return - } - {{end}} - - if !t.IsMasked() { - t.makeMask() - } - - {{$numargs := .NumArgs}} - {{$name := .Name}} - {{$fn := .CmpFn}} - {{$reqFloat := .ReqFloat}} - switch t.t.Kind(){ - {{range .Kinds -}} - {{if isParameterized . -}} - {{else -}} - {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} - {{else -}} - case reflect.{{reflectKind .}}: - data := t.{{sliceOf .}} - mask := t.mask - {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} - {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} - {{if ge $numargs 3 -}} - {{if eq $name "MaskedValues"}} - delta := float64(1.0e-8) - if len(val3) > 0 { - delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) - } - {{else}} - z := val3.({{asType .}}) - {{end}} - {{end}} - if t.maskIsSoft{ - for i := range data { - a := data[i] - mask[i] = ({{$fn}}) - } - } else { - for i := range data { - a := data[i] - mask[i] = mask[i] || ({{$fn}}) - } - } - - {{end}} - {{end}} - {{end}} -} -return nil -} -` - -var ( - maskCmpMethod *template.Template -) - -func init() { - maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) -} - -func generateDenseMaskedMethods(f io.Writer, generic Kinds) { - for _, mm := range maskcmpMethods { - mm.Kinds = generic.Kinds - fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) - maskCmpMethod.Execute(f, mm) - - } -} +package main + +import ( + "fmt" + "io" + "reflect" + "text/template" +) + +var maskcmpMethods = []struct { + Name string + Desc string + NumArgs int + CmpFn string + ReqFloat bool + Kinds []reflect.Kind +}{ + {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, + {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, + {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, + {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, + {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, + {"MaskedLess", " less than ", 1, "a < x", false, nil}, + {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, + {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, + {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, +} + +const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val +// Any values must be the same type as the tensor +func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ + {{if .ReqFloat}} + if !isFloat(t.t) { + err = errors.Errorf("Can only do {{.Name}} with floating point types") + return + } + {{end}} + + if !t.IsMasked() { + t.makeMask() + } + + {{$numargs := .NumArgs}} + {{$name := .Name}} + {{$fn := .CmpFn}} + {{$reqFloat := .ReqFloat}} + switch t.t.Kind(){ + {{range .Kinds -}} + {{if isParameterized . -}} + {{else -}} + {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} + {{else -}} + case reflect.{{reflectKind .}}: + data := t.{{sliceOf .}} + mask := t.mask + {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} + {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} + {{if ge $numargs 3 -}} + {{if eq $name "MaskedValues"}} + delta := float64(1.0e-8) + if len(val3) > 0 { + delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) + } + {{else}} + z := val3.({{asType .}}) + {{end}} + {{end}} + if t.maskIsSoft{ + for i := range data { + a := data[i] + mask[i] = ({{$fn}}) + } + } else { + for i := range data { + a := data[i] + mask[i] = mask[i] || ({{$fn}}) + } + } + + {{end}} + {{end}} + {{end}} +} +return nil +} +` + +var ( + maskCmpMethod *template.Template +) + +func init() { + maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) +} + +func generateDenseMaskedMethods(f io.Writer, generic Kinds) { + for _, mm := range maskcmpMethods { + mm.Kinds = generic.Kinds + fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) + maskCmpMethod.Execute(f, mm) + + } +} diff --git a/junkyard_test.go b/junkyard_test.go index 428178a..6d4b43a 100644 --- a/junkyard_test.go +++ b/junkyard_test.go @@ -9,7 +9,7 @@ import ( func TestRandom(t *testing.T) { const size = 50 - for _, typ := range numberTypes.set { + for _, typ := range numberTypes { r := Random(typ, size) typR := reflect.TypeOf(r).Elem() diff --git a/known_issues_test.go b/known_issues_test.go index 36d4125..3175ce7 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -5,6 +5,7 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) func TestIssue70(t *testing.T) { @@ -43,7 +44,7 @@ func TestIssue72(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok //log.Printf("b-a(r) | b:%v, a %v, r %v", b, a.Shape(), reuse.Shape()) diff --git a/testutils_test.go b/testutils_test.go index 5748917..32bf70b 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -344,8 +344,8 @@ func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { func (t *Dense) Generate(r *rand.Rand, size int) reflect.Value { // generate type - ri := r.Intn(len(specializedTypes.set)) - of := specializedTypes.set[ri] + ri := r.Intn(len(specializedTypes)) + of := specializedTypes[ri] datatyp := reflect.SliceOf(of.Type) gendat, _ := quick.Value(datatyp, r) // generate dims @@ -552,7 +552,7 @@ func qcErrCheck(t *testing.T, name string, a Dtyper, b interface{}, we bool, err } func qcIsFloat(dt dtype.Dtype) bool { - if err := typeclassCheck(dt, floatcmplxTypes); err == nil { + if err := dtype.TypeClassCheck(dt, dtype.FloatComplex); err == nil { return true } return false diff --git a/type_test.go b/type_test.go index 54b9acd..7200f66 100644 --- a/type_test.go +++ b/type_test.go @@ -1,8 +1,13 @@ package tensor import ( - "reflect" - "testing" - "gorgonia.org/dtype" ) + +var numberTypes = []dtype.Dtype{ + Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, +} + +var specializedTypes = []dtype.Dtype{ + Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, +} diff --git a/types.go b/types.go index 2af1b72..64032f9 100644 --- a/types.go +++ b/types.go @@ -8,6 +8,9 @@ import ( "gorgonia.org/dtype" ) +// Dtype is an alias for dtype.Dtype. This alias is here for backward compatibility purposes, for when users are transitioning out of the older tensor libraries. +type Dtype = dtype.Dtype + var parameterizedKinds = [...]reflect.Kind{ reflect.Array, reflect.Chan, @@ -28,6 +31,8 @@ func isParameterizedKind(k reflect.Kind) bool { return false } +func isFloat(dt dtype.Dtype) bool { return dt == Float64 || dt == Float32 } + // type aliases var ( Bool = dtype.Bool From 077808a6731c79b6a9da3789838d9b8196f7626f Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 16 Jul 2021 10:19:54 +1000 Subject: [PATCH 089/154] Moved Shapes out and use the shapes package --- ap.go | 4 +- ap_test.go | 2 +- consopt.go | 2 +- defaultengine_matop_misc.go | 11 +- defaultengine_matop_stack.go | 4 +- dense_matop_memmove.go | 4 +- shape.go | 302 +---------------------------------- shape_test.go | 294 ++-------------------------------- testutils_test.go | 2 - types.go | 3 + 10 files changed, 34 insertions(+), 594 deletions(-) diff --git a/ap.go b/ap.go index 145af0a..31962aa 100644 --- a/ap.go +++ b/ap.go @@ -364,9 +364,9 @@ func (ap *AP) unlock() { ap.fin = false } func (ap *AP) calcStrides() []int { switch { case ap.o.IsRowMajor(): - return ap.shape.CalcStrides() + return CalcStrides(ap.shape) case ap.o.IsColMajor(): - return ap.shape.CalcStridesColMajor() + return CalcStridesColMajor(ap.shape) } panic("unreachable") } diff --git a/ap_test.go b/ap_test.go index b813d1f..791a8c5 100644 --- a/ap_test.go +++ b/ap_test.go @@ -228,7 +228,7 @@ func TestAccessPatternS(t *testing.T) { var err error for _, sts := range sliceTests { - ap = MakeAP(sts.shape, sts.shape.CalcStrides(), 0, 0) + ap = MakeAP(sts.shape, CalcStrides(sts.shape), 0, 0) if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil { t.Errorf("%v errored: %v", sts.name, err) continue diff --git a/consopt.go b/consopt.go index 1d263b0..4b2de84 100644 --- a/consopt.go +++ b/consopt.go @@ -235,7 +235,7 @@ func AsDenseDiag(backing interface{}) ConsOpt { sli := reflect.MakeSlice(xT, l*l, l*l) shape := Shape{l, l} - strides := shape.CalcStrides() + strides := CalcStrides(shape) for i := 0; i < l; i++ { idx, err := Ltoi(shape, strides, i, i) if err != nil { diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index dde54a9..85924b6 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -4,6 +4,8 @@ import ( "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" + + "gorgonia.org/shapes" ) var ( @@ -52,9 +54,11 @@ func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (T } func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { - if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { + var newShapelike shapes.Shapelike + if newShapelike, newRepeats, size, err = t.Shape().Repeat(shapes.Axis(axis), repeats...); err != nil { return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") } + newShape = newShapelike.(Shape) newAxis = axis if axis == AllAxes { newAxis = 0 @@ -253,10 +257,11 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen } } - var newShape Shape - if newShape, err = a.Shape().Concat(axis, ss...); err != nil { + var newShapelike shapes.Shapelike + if newShapelike, err = a.Shape().Concat(shapes.Axis(axis), shapes.ShapesToShapelikes(ss)...); err != nil { return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") } + newShape := newShapelike.(Shape) retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) if isMasked { diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index 879ca28..d5e661a 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -28,9 +28,9 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV info := t.Info() var newStrides []int if info.o.IsColMajor() { - newStrides = newShape.CalcStridesColMajor() + newStrides = CalcStridesColMajor(newShape) } else { - newStrides = newShape.CalcStrides() + newStrides = CalcStrides(newShape) } ap := MakeAP(newShape, newStrides, info.o, info.Δ) diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index fe05f2a..f2a54e2 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -27,9 +27,9 @@ func (t *Dense) Transpose() error { // important! because the strides would have changed once the underlying data changed var expStrides []int if t.AP.o.IsColMajor() { - expStrides = expShape.CalcStridesColMajor() + expStrides = CalcStridesColMajor(expShape) } else { - expStrides = expShape.CalcStrides() + expStrides = CalcStrides(expShape) } defer ReturnInts(expStrides) defer func() { diff --git a/shape.go b/shape.go index c1347b4..e5e7b0a 100644 --- a/shape.go +++ b/shape.go @@ -1,9 +1,7 @@ package tensor import ( - "fmt" - - "github.com/pkg/errors" + "gorgonia.org/shapes" ) var scalarShape = Shape{} @@ -11,21 +9,11 @@ var scalarShape = Shape{} // ScalarShape represents a scalar. It has no dimensions, no sizes func ScalarShape() Shape { return scalarShape } -// Shape represents the dimensions of a Tensor. A (2,3) matrix has a shape of (2,3) - 2 rows, 3 columns. -// Likewise, a shape of (2,3,4) means a Tensor has 3 dimensions: 2 layers, 3 rows, 4 columns. -// -// Vectors are of particular note. This package defines a shape of (x, 1) as a column vector and -// a (1, x) as a row vector. Row vectors and column vectors are matrices as well. It is important to note that -// row and column vectors and vanilla vectors are comparable under some circumstances -type Shape []int - -// TotalSize returns the number of elements expected in a Tensor of a certain shape -func (s Shape) TotalSize() int { - return ProdInts([]int(s)) -} +// Shape represents a Shape. See the package shapes +type Shape = shapes.Shape // CalcStrides calculates the default strides for a shape -func (s Shape) CalcStrides() []int { +func CalcStrides(s Shape) []int { if s.IsScalar() { return nil } @@ -51,7 +39,7 @@ func (s Shape) CalcStrides() []int { // CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions // during calculation of stride -func (s Shape) CalcStridesWithMask(mask []bool) []int { +func CalcStridesWithMask(s Shape, mask []bool) []int { if s.IsScalarEquiv() { return nil } @@ -86,7 +74,7 @@ func (s Shape) CalcStridesWithMask(mask []bool) []int { } // CalcStridesColMajor is like CalcStrides, but assumes a col major layout -func (s Shape) CalcStridesColMajor() []int { +func CalcStridesColMajor(s Shape) []int { if s.IsScalarEquiv() { return nil } @@ -109,281 +97,3 @@ func (s Shape) CalcStridesColMajor() []int { } return retVal } - -// Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors. -// -// If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size; -// if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size -func (s Shape) Eq(other Shape) bool { - if s.IsScalar() && other.IsScalar() { - return true - } - - if s.IsVector() && other.IsVector() { - switch { - case len(s) == 2 && len(other) == 1: - if (s.IsColVec() && s[0] == other[0]) || (s.IsRowVec() && s[1] == other[0]) { - return true - } - return false - case len(s) == 1 && len(other) == 2: - if (other.IsColVec() && other[0] == s[0]) || (other.IsRowVec() && other[1] == s[0]) { - return true - } - return false - } - } - - if len(s) != len(other) { - return false - } - - for i, v := range s { - if other[i] != v { - return false - } - } - return true -} - -// Clone clones a shape. -func (s Shape) Clone() Shape { - retVal := BorrowInts(len(s)) - copy(retVal, s) - return retVal -} - -// IsScalar returns true if the access pattern indicates it's a scalar value -func (s Shape) IsScalar() bool { - return len(s) == 0 -} - -// IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value -func (s Shape) IsScalarEquiv() bool { - if len(s) == 0 { - return true - } - isEquiv := true - for i := range s { - if s[i] != 1 { - return false - } - } - return isEquiv -} - -// IsVector returns whether the access pattern falls into one of three possible definitions of vectors: -// vanilla vector (not a row or a col) -// column vector -// row vector -func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1) } - -// IsColVec returns true when the access pattern has the shape (x, 1) -func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) } - -// IsRowVec returns true when the access pattern has the shape (1, x) -func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) } - -// IsVectorLike returns true when the shape looks like a vector -// e.g. a number that is surrounded by 1s: -// (1, 1, ... 1, 10, 1, 1... 1) -func (s Shape) IsVectorLike() bool { - var nonOnes int - for _, i := range s { - if i != 1 { - nonOnes++ - } - } - return nonOnes == 1 || nonOnes == 0 // if there is only one non-one then it's a vector or a scalarlike. -} - -// IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices -func (s Shape) IsMatrix() bool { return len(s) == 2 } - -// Dims returns the number of dimensions in the shape -func (s Shape) Dims() int { return len(s) } - -// DimSize returns the size of the dimension wanted. -// -// This method implemnents the DimSizer interface in Gorgonia. -func (s Shape) DimSize(d int) (size int, err error) { - if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) { - err = errors.Errorf(dimMismatch, len(s), d) - return - } - - switch { - case s.IsScalar(): - return 0, nil - default: - return s[d], nil - } -} - -// S gives the new shape after a shape has been sliced. It's repeated from the AP S() method mainly because there are other functions in Gorgonia that uses only shape -func (s Shape) S(slices ...Slice) (retVal Shape, err error) { - opDims := len(s) - if len(slices) > opDims { - err = errors.Errorf(dimMismatch, opDims, len(slices)) - return - } - - retVal = s.Clone() - - for d, size := range s { - var sl Slice // default is a nil Slice - if d <= len(slices)-1 { - sl = slices[d] - } - - var start, end, step int - if start, end, step, err = SliceDetails(sl, size); err != nil { - return - } - - if step > 0 { - retVal[d] = (end - start) / step - - //fix - if retVal[d] <= 0 { - retVal[d] = 1 - } - } else { - retVal[d] = (end - start) - } - - } - - // drop any dimension with size 1, except the last dimension - offset := 0 - dims := s.Dims() - for d := 0; d < dims; d++ { - if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { - retVal = append(retVal[:d], retVal[d+1:]...) - d-- - dims-- - offset++ - } - } - - if retVal.IsScalar() { - ReturnInts(retVal) - return ScalarShape(), nil - } - - return -} - -// Repeat returns the expected new shape given the repetition parameters. -func (s Shape) Repeat(axis int, repeats ...int) (newShape Shape, finalRepeats []int, size int, err error) { - switch { - case axis == AllAxes: - size = s.TotalSize() - newShape = Shape{size} - axis = 0 - case s.IsScalar(): - size = 1 - // special case for row vecs - if axis == 1 { - newShape = Shape{1, 0} - } else { - // otherwise it will be repeated into a vanilla vector - newShape = Shape{0} - } - case s.IsVector() && !s.IsRowVec() && !s.IsColVec() && axis == 1: - size = 1 - newShape = s.Clone() - newShape = append(newShape, 1) - default: - if axis >= len(s) { - // error - err = errors.Errorf(invalidAxis, axis, s.Dims()) - return - } - size = s[axis] - newShape = s.Clone() - } - - // special case to allow generic repeats - if len(repeats) == 1 { - rep := repeats[0] - repeats = make([]int, size) - for i := range repeats { - repeats[i] = rep - } - } - reps := len(repeats) - if reps != size { - err = errors.Errorf(broadcastError, size, reps) - return - } - - newSize := SumInts(repeats) - newShape[axis] = newSize - finalRepeats = repeats - return -} - -// Concat returns the expected new shape given the concatenation parameters -func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) { - dims := s.Dims() - - // check that all the concatenates have the same dimensions - for _, shp := range ss { - if shp.Dims() != dims { - err = errors.Errorf(dimMismatch, dims, shp.Dims()) - return - } - } - - // special case - if axis == AllAxes { - axis = 0 - } - - // nope... no negative indexing here. - if axis < 0 { - err = errors.Errorf(invalidAxis, axis, len(s)) - return - } - - if axis >= dims { - err = errors.Errorf(invalidAxis, axis, len(s)) - return - } - - newShape = Shape(BorrowInts(dims)) - copy(newShape, s) - - for _, shp := range ss { - for d := 0; d < dims; d++ { - if d == axis { - newShape[d] += shp[d] - } else { - // validate that the rest of the dimensions match up - if newShape[d] != shp[d] { - err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d) - return - } - } - } - } - return -} - -// Format implements fmt.Formatter, and formats a shape nicely -func (s Shape) Format(st fmt.State, r rune) { - switch r { - case 'v', 's': - st.Write([]byte("(")) - for i, v := range s { - fmt.Fprintf(st, "%d", v) - if i < len(s)-1 { - st.Write([]byte(", ")) - } - } - st.Write([]byte(")")) - default: - fmt.Fprintf(st, "%v", []int(s)) - } -} diff --git a/shape_test.go b/shape_test.go index 9cbc370..9433ba9 100644 --- a/shape_test.go +++ b/shape_test.go @@ -1,323 +1,47 @@ package tensor import ( - "fmt" "testing" "github.com/stretchr/testify/assert" ) -func TestShapeBasics(t *testing.T) { - var s Shape - var ds int - var err error - s = Shape{1, 2} - - if ds, err = s.DimSize(0); err != nil { - t.Error(err) - } - if ds != 1 { - t.Error("Expected DimSize(0) to be 1") - } - - if ds, err = s.DimSize(2); err == nil { - t.Error("Expected a DimensionMismatch error") - } - - s = ScalarShape() - if ds, err = s.DimSize(0); err != nil { - t.Error(err) - } - - if ds != 0 { - t.Error("Expected DimSize(0) of a scalar to be 0") - } - - // format for completeness sake - s = Shape{2, 1} - if fmt.Sprintf("%d", s) != "[2 1]" { - t.Error("Shape.Format() error") - } -} - -func TestShapeIsX(t *testing.T) { - assert := assert.New(t) - var s Shape - - // scalar shape - s = Shape{} - assert.True(s.IsScalar()) - assert.True(s.IsScalarEquiv()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // vectors - - // scalar-equiv vector - s = Shape{1} - assert.False(s.IsScalar()) - assert.True(s.IsScalarEquiv()) - assert.True(s.IsVector()) - assert.True(s.IsVectorLike()) - assert.True(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // vanila vector - s = Shape{2} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // col vec - s = Shape{2, 1} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.True(s.IsVectorLike()) - assert.True(s.IsColVec()) - assert.False(s.IsRowVec()) - - // row vec - s = Shape{1, 2} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.True(s.IsVectorLike()) - assert.False(s.IsColVec()) - assert.True(s.IsRowVec()) - - // matrix and up - s = Shape{2, 2} - assert.False(s.IsScalar()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // scalar equiv matrix - s = Shape{1, 1} - assert.False(s.IsScalar()) - assert.True(s.IsScalarEquiv()) - assert.True(s.IsVectorLike()) - assert.False(s.IsVector()) -} - func TestShapeCalcStride(t *testing.T) { assert := assert.New(t) var s Shape // scalar shape s = Shape{} - assert.Nil(s.CalcStrides()) + assert.Nil(CalcStrides(s)) // vector shape s = Shape{1} - assert.Equal([]int{1}, s.CalcStrides()) + assert.Equal([]int{1}, CalcStrides(s)) s = Shape{2, 1} - assert.Equal([]int{1, 1}, s.CalcStrides()) + assert.Equal([]int{1, 1}, CalcStrides(s)) s = Shape{1, 2} - assert.Equal([]int{2, 1}, s.CalcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) s = Shape{2} - assert.Equal([]int{1}, s.CalcStrides()) + assert.Equal([]int{1}, CalcStrides(s)) // matrix strides s = Shape{2, 2} - assert.Equal([]int{2, 1}, s.CalcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) s = Shape{5, 2} - assert.Equal([]int{2, 1}, s.CalcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) // 3D strides s = Shape{2, 3, 4} - assert.Equal([]int{12, 4, 1}, s.CalcStrides()) + assert.Equal([]int{12, 4, 1}, CalcStrides(s)) // stupid shape s = Shape{-2, 1, 2} fail := func() { - s.CalcStrides() + CalcStrides(s) } assert.Panics(fail) } - -func TestShapeEquality(t *testing.T) { - assert := assert.New(t) - var s1, s2 Shape - - // scalar - s1 = Shape{} - s2 = Shape{} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - // scalars and scalar equiv are not the same! - s1 = Shape{1} - s2 = Shape{} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // vector - s1 = Shape{3} - s2 = Shape{5} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - s1 = Shape{2, 1} - s2 = Shape{2, 1} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{2} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{1, 2} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - s1 = Shape{2} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{2, 3} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // matrix - s1 = Shape{2, 3} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{3, 2} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // just for that green coloured code - s1 = Shape{2} - s2 = Shape{1, 3} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) -} - -var shapeSliceTests = []struct { - name string - s Shape - sli []Slice - - expected Shape - err bool -}{ - {"slicing a scalar shape", ScalarShape(), nil, ScalarShape(), false}, - {"slicing a scalar shape", ScalarShape(), []Slice{rs{0, 0, 0}}, nil, true}, - {"vec[0]", Shape{2}, []Slice{rs{0, 1, 0}}, ScalarShape(), false}, - {"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true}, - {"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true}, - {"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false}, - {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, Shape{2, 2}, false}, - {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, Shape{1, 2}, false}, - {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, Shape{1, 2, 2}, false}, - {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, Shape{1, 2, 2}, false}, -} - -func TestShape_Slice(t *testing.T) { - for i, ssts := range shapeSliceTests { - newShape, err := ssts.s.S(ssts.sli...) - if checkErr(t, ssts.err, err, "Shape slice", i) { - continue - } - - if !ssts.expected.Eq(newShape) { - t.Errorf("Test %q: Expected shape %v. Got %v instead", ssts.name, ssts.expected, newShape) - } - } -} - -var shapeRepeatTests = []struct { - name string - s Shape - repeats []int - axis int - - expected Shape - expectedRepeats []int - expectedSize int - err bool -}{ - {"scalar repeat on axis 0", ScalarShape(), []int{3}, 0, Shape{3}, []int{3}, 1, false}, - {"scalar repeat on axis 1", ScalarShape(), []int{3}, 1, Shape{1, 3}, []int{3}, 1, false}, - {"vector repeat on axis 0", Shape{2}, []int{3}, 0, Shape{6}, []int{3, 3}, 2, false}, - {"vector repeat on axis 1", Shape{2}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, - {"colvec repeats on axis 0", Shape{2, 1}, []int{3}, 0, Shape{6, 1}, []int{3, 3}, 2, false}, - {"colvec repeats on axis 1", Shape{2, 1}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, - {"rowvec repeats on axis 0", Shape{1, 2}, []int{3}, 0, Shape{3, 2}, []int{3}, 1, false}, - {"rowvec repeats on axis 1", Shape{1, 2}, []int{3}, 1, Shape{1, 6}, []int{3, 3}, 2, false}, - {"3-Tensor repeats", Shape{2, 3, 2}, []int{1, 2, 1}, 1, Shape{2, 4, 2}, []int{1, 2, 1}, 3, false}, - {"3-Tensor generic repeats", Shape{2, 3, 2}, []int{2}, AllAxes, Shape{24}, []int{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 12, false}, - {"3-Tensor generic repeat, axis specified", Shape{2, 3, 2}, []int{2}, 2, Shape{2, 3, 4}, []int{2, 2}, 2, false}, - - // stupids - {"nonexisting axis 2", Shape{2, 1}, []int{3}, 2, nil, nil, 0, true}, - {"mismatching repeats", Shape{2, 3, 2}, []int{3, 1, 2}, 0, nil, nil, 0, true}, -} - -func TestShape_Repeat(t *testing.T) { - assert := assert.New(t) - for _, srts := range shapeRepeatTests { - newShape, reps, size, err := srts.s.Repeat(srts.axis, srts.repeats...) - - switch { - case srts.err: - if err == nil { - t.Error("Expected an error") - } - continue - case !srts.err && err != nil: - t.Error(err) - continue - } - - assert.True(srts.expected.Eq(newShape), "Test %q: Want: %v. Got %v", srts.name, srts.expected, newShape) - assert.Equal(srts.expectedRepeats, reps, "Test %q: ", srts.name) - assert.Equal(srts.expectedSize, size, "Test %q: ", srts.name) - } -} - -var shapeConcatTests = []struct { - name string - s Shape - axis int - ss []Shape - - expected Shape - err bool -}{ - {"standard, axis 0 ", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, - {"standard, axis 1 ", Shape{2, 2}, 1, []Shape{{2, 2}, {2, 2}}, Shape{2, 6}, false}, - {"standard, axis AllAxes ", Shape{2, 2}, -1, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, - {"concat to empty", Shape{2}, 0, nil, Shape{2}, false}, - - {"stupids: different dims", Shape{2, 2}, 0, []Shape{{2, 3, 2}}, nil, true}, - {"stupids: negative axes", Shape{2, 2}, -5, []Shape{{2, 2}}, nil, true}, - {"stupids: toobig axis", Shape{2, 2}, 5, []Shape{{2, 2}}, nil, true}, - {"subtle stupids: dim mismatch", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 3}}, nil, true}, -} - -func TestShape_Concat(t *testing.T) { - assert := assert.New(t) - for _, scts := range shapeConcatTests { - newShape, err := scts.s.Concat(scts.axis, scts.ss...) - switch { - case scts.err: - if err == nil { - t.Error("Expected an error") - } - continue - case !scts.err && err != nil: - t.Error(err) - continue - } - assert.Equal(scts.expected, newShape) - } -} diff --git a/testutils_test.go b/testutils_test.go index 32bf70b..39a7bf7 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -507,8 +507,6 @@ func (e dummyEngine2) WorksWith(order DataOrder) bool { return e.e.Wor func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) } func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) } -var nilTC dtype.TypeClass = -1 - func willerr(a *Dense, tc, eqtc dtype.TypeClass) (retVal, willFailEq bool) { if eqtc == nilTC { willFailEq = true diff --git a/types.go b/types.go index 64032f9..3579146 100644 --- a/types.go +++ b/types.go @@ -11,6 +11,9 @@ import ( // Dtype is an alias for dtype.Dtype. This alias is here for backward compatibility purposes, for when users are transitioning out of the older tensor libraries. type Dtype = dtype.Dtype +// nil type class for skipping type class checks +var nilTC dtype.TypeClass = -1 + var parameterizedKinds = [...]reflect.Kind{ reflect.Array, reflect.Chan, From 0c163843fd077e924c7c4816aa9fb7603b356db9 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 16 Jul 2021 11:47:44 +1000 Subject: [PATCH 090/154] Fixed iterator example --- example_iterator_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example_iterator_test.go b/example_iterator_test.go index aff34e3..e228063 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -58,8 +58,8 @@ func ExampleSliceIter() { fmt.Printf("Err %v\n", err) return } - fmt.Printf("S (requires iterator? %t)\n%v\n", S.(*Dense).RequiresIterator(), S) - it := IteratorFromDense(S.(*Dense)) + fmt.Printf("S (requires iterator? %t)\n%v\n", S.(DenseView).RequiresIterator(), S) + it := IteratorFromDense(S.(DenseView)) for i, err := it.Start(); err == nil; i, err = it.Next() { fmt.Printf("i %d, coord %v\n", i, it.Coord()) } From a20d13474e2f93f5b3a6233370ac041b1cc8877a Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 16 Jul 2021 13:44:59 +1000 Subject: [PATCH 091/154] More work to remove Shape stuff from this package --- slice.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/slice.go b/slice.go index ecba60d..f45e9fc 100644 --- a/slice.go +++ b/slice.go @@ -1,5 +1,12 @@ package tensor +import ( + "gorgonia.org/shapes" +) + +var xxx Slice = ss(1) +var _ shapes.Slice = xxx + // A Slice represents a slicing operation for a Tensor. type Slice interface { Start() int From eaefaa75cac0004f58ca03ffa1a291435109cb88 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 21 Jul 2021 10:47:48 +1000 Subject: [PATCH 092/154] Deifne equality for DenseTensor --- dense.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dense.go b/dense.go index f873f61..6d22c20 100644 --- a/dense.go +++ b/dense.go @@ -584,6 +584,14 @@ func (t *Dense) Eq(other interface{}) bool { return t.array.Eq(&ot.array) } + if ot, ok := other.(DenseTensor); ok { + if !t.Shape().Eq(ot.Shape()) { + return false + } + + return t.array.Eq(ot.arrPtr()) + } + return false } From 2a7c40d8344a5c42a67d9ddda91d89a3a47172eb Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 21 Jul 2021 11:35:26 +1000 Subject: [PATCH 093/154] Fixed MatMul API to handle different Tensor types --- api_arith.go | 44 +++++++++++++++++++------------------- defaultengine_mapreduce.go | 2 +- dense_linalg.go | 2 +- interfaces.go | 5 +++++ utils.go | 34 ++++++++++++++--------------- 5 files changed, 45 insertions(+), 42 deletions(-) diff --git a/api_arith.go b/api_arith.go index 9aa86a8..62a30a3 100644 --- a/api_arith.go +++ b/api_arith.go @@ -580,36 +580,36 @@ func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { // check whether retVal has the same size as the resulting matrix would be: mxn expectedShape := Shape{m, n} - // find an engine - aEng, aok := a.Engine().(MatMuler) - bEng, bok := b.Engine().(MatMuler) - mm := aEng - var eng Engine = a.Engine() - if !aok { - mm = bEng + eng := a.Engine() + mm, ok := eng.(MatMuler) + if !ok { eng = b.Engine() - if !bok { - return nil, errors.Errorf("Neither a or b have an engine that is a MatMuler. a: %T, b: %T", a.Engine(), b.Engine()) - } + mm, ok = eng.(MatMuler) + } + if !ok { + return nil, errors.Errorf("Neither a or b have an engine that is a MatMuler. a: %T, b: %T", a.Engine(), b.Engine()) } - // parse function options, and get a preallocated value - var reuse *Dense + var reuse Tensor fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) - if reuse, err = handleReuse(fo.Reuse(), expectedShape, true); err != nil { - err = errors.Wrapf(err, opFail, "MatMul") - return - } - + reuse = fo.Reuse() if reuse == nil { - reuse = recycledDense(a.Dtype(), expectedShape, WithEngine(eng)) + return nil, errors.Errorf("MatMul requires passing in of a reuse Tensor for now.") } - retVal = reuse - if err = mm.MatMul(a, b, retVal); err != nil { - return + if err := checkFixShape(reuse, expectedShape); err != nil { + return nil, errors.Wrapf(err, opFail, "MatMul") + } + if err = mm.MatMul(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "MatMul") } - return handleIncr(retVal.(*Dense), fo.Reuse(), fo.Incr(), expectedShape) + + incr := fo.Incr() + if incr != nil { + return Add(incr, reuse, UseUnsafe()) + } + return reuse, nil + } // MatVecMul performs matrix-vector multiplication between two Tensors. `a` is expected to be a matrix, and `b` is expected to be a vector diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 60dcd29..c484bac 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -74,7 +74,7 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e // SET RETVAL switch { case reuse != nil: - if err = reuseCheckShape(reuse, a.Shape()); err != nil { + if err = checkFixShape(reuse, a.Shape()); err != nil { err = errors.Wrapf(err, "Reuse shape check failed") return } diff --git a/dense_linalg.go b/dense_linalg.go index 417358a..756d9b7 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -390,7 +390,7 @@ func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, e if !safe { return } - if err = reuseCheckShape(retVal, expectedShape); err != nil { + if err = checkFixShape(retVal, expectedShape); err != nil { err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.") return } diff --git a/interfaces.go b/interfaces.go index 5327583..4b44154 100644 --- a/interfaces.go +++ b/interfaces.go @@ -141,6 +141,11 @@ type Kinder interface { Kind() reflect.Kind } +// MakeAliker is any Tensor that can make more like itself. +type MakeAliker interface { + MakeAike(opts ...ConsOpt) Tensor +} + type headerer interface { hdr() *storage.Header } diff --git a/utils.go b/utils.go index 064c812..22db57b 100644 --- a/utils.go +++ b/utils.go @@ -244,37 +244,35 @@ func SliceDetails(s Slice, size int) (start, end, step int, err error) { return } -// reuseDenseCheck checks a reuse tensor, and reshapes it to be the correct one -func reuseDenseCheck(reuse DenseTensor, as DenseTensor) (err error) { - if reuse.DataSize() != as.Size() { - err = errors.Errorf("Reused Tensor %p does not have expected shape %v. Got %v instead. Reuse Size: %v, as Size %v (real: %d)", reuse, as.Shape(), reuse.Shape(), reuse.DataSize(), as.Size(), as.DataSize()) - return - } - return reuseCheckShape(reuse, as.Shape()) - -} - -// reuseCheckShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. -func reuseCheckShape(reuse DenseTensor, s Shape) (err error) { +// checkFixShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. +func checkFixShape(reuse Tensor, s Shape) (err error) { throw := BorrowInts(len(s)) copy(throw, s) - if err = reuse.reshape(throw...); err != nil { - err = errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) + d, ok := reuse.(DenseTensor) + if !ok { + if err = reuse.Reshape(throw...); err != nil { + return errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) + } + return nil + } + + if err = d.reshape(throw...); err != nil { + err = errors.Wrapf(err, reuseReshapeErr, s, d.DataSize()) return } // clean up any funny things that may be in the reuse - if oldAP := reuse.oldAP(); !oldAP.IsZero() { + if oldAP := d.oldAP(); !oldAP.IsZero() { oldAP.zero() } - if axes := reuse.transposeAxes(); axes != nil { + if axes := d.transposeAxes(); axes != nil { ReturnInts(axes) } - if viewOf := reuse.parentTensor(); viewOf != nil { - reuse.setParentTensor(nil) + if viewOf := d.parentTensor(); viewOf != nil { + d.setParentTensor(nil) } return nil } From ed054ac63d8ecd3746d119ded7e1519b481c9a30 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 21 Jul 2021 12:09:14 +1000 Subject: [PATCH 094/154] Expanded OpOpt's methods to allow modification --- flags.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/flags.go b/flags.go index c8000d1..c514969 100644 --- a/flags.go +++ b/flags.go @@ -167,3 +167,30 @@ func (fo *OpOpt) Same() bool { return fo.same } // indicates that the result of `Add()` should be converted to a Tensor of Int. // Note that this function is not yet supported in most operations. func (fo *OpOpt) As() dtype.Dtype { return fo.t } + +// SetReuse allows the reuse parameter to be set. +func (fo *OpOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } + +// SetIncr allows the incr parameter to be set. +func (fo *OpOpt) SetIncr(incr Tensor) { fo.incr = incr } + +// FuncOpts is the inverse of ParseFuncOpts. +func (fo *OpOpt) FuncOpts() []FuncOpt { + retVal := make([]FuncOpt, 0, 4) + if fo.reuse != nil { + retVal = append(retVal, WithReuse(fo.reuse)) + } + if fo.incr != nil { + retVal = append(retVal, WithIncr(fo.incr)) + } + if fo.unsafe { + retVal = append(retVal, UseUnsafe()) + } + if fo.same { + retVal = append(retVal, AsSameType()) + } + if fo.t != (Dtype{}) { + retVal = append(retVal, As(fo.t)) + } + return retVal +} From a5bc18d8533b46d97272bb459169780bf01afdf7 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 26 Jul 2021 13:53:49 +1000 Subject: [PATCH 095/154] Added axial iterator --- example_iterator_test.go | 138 ++++++++++++++++++++++++++++++++- iterator_axial.go | 160 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 iterator_axial.go diff --git a/example_iterator_test.go b/example_iterator_test.go index e228063..0fa3025 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -1,6 +1,9 @@ package tensor -import "fmt" +import ( + "fmt" + "sync" +) // This is an example of how to use `IteratorFromDense` from a row-major Dense tensor func Example_iteratorRowmajor() { @@ -75,3 +78,136 @@ func ExampleSliceIter() { // i 4, coord [0 0] } + +func ExampleAxialIterator() { + T := New(WithShape(2, 3, 4), WithBacking([]float64{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + })) + fmt.Printf("T:\n%v", T) + it := AxialIteratorFromDense(T, 1, 0, false) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i %d coord %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // i 0 coord [0 0 1] + // i 1 coord [0 0 2] + // i 2 coord [0 0 3] + // i 3 coord [1 0 0] + // i 12 coord [1 0 1] + // i 13 coord [1 0 2] + // i 14 coord [1 0 3] + // i 15 coord [0 1 0] + // i 4 coord [0 1 1] + // i 5 coord [0 1 2] + // i 6 coord [0 1 3] + // i 7 coord [1 1 0] + // i 16 coord [1 1 1] + // i 17 coord [1 1 2] + // i 18 coord [1 1 3] + // i 19 coord [0 2 0] + // i 8 coord [0 2 1] + // i 9 coord [0 2 2] + // i 10 coord [0 2 3] + // i 11 coord [1 2 0] + // i 20 coord [1 2 1] + // i 21 coord [1 2 2] + // i 22 coord [1 2 3] + // i 23 coord [0 0 0] +} + +func ExampleAxialIterator_2() { + T := New(WithShape(2, 3, 4), WithBacking([]float64{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + })) + fmt.Printf("T:\n%v", T) + it := AxialIteratorFromDense(T, 1, 1, true) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i %d coord %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // i 4 coord [0 1 1] + // i 5 coord [0 1 2] + // i 6 coord [0 1 3] + // i 7 coord [1 1 0] + // i 16 coord [1 1 1] + // i 17 coord [1 1 2] + // i 18 coord [1 1 3] + // i 19 coord [0 0 0] +} + +func ExampleAxialIterator_concurrent() { + T := New(WithShape(2, 3, 4), WithBacking([]float64{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + })) + fmt.Printf("T:\n%v", T) + + axis := 1 + var its []Iterator + for i := 0; i < T.Shape()[axis]; i++ { + it := AxialIteratorFromDense(T, axis, i, true) + its = append(its, it) + } + + done := make(chan float64, T.Shape()[axis]) + var wg sync.WaitGroup + for _, it := range its { + wg.Add(1) + go func(it Iterator, t *Dense, done chan float64, wg *sync.WaitGroup) { + data := t.Data().([]float64) + var sum float64 + for i, err := it.Start(); err == nil; i, err = it.Next() { + sum += data[i] + } + done <- sum + wg.Done() + }(it, T, done, &wg) + } + + wg.Wait() + close(done) + + var total float64 + for v := range done { + total += v + } + + fmt.Printf("Total: %v", total) + + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // Total: 132 + +} diff --git a/iterator_axial.go b/iterator_axial.go new file mode 100644 index 0000000..0967197 --- /dev/null +++ b/iterator_axial.go @@ -0,0 +1,160 @@ +package tensor + +// AxialIterator iterates based on a given axis +type AxialIterator struct { + *AP + axis int // the axis to iterate along + + // state + axisSz int // if an axis is of size N, then axisSz indicates the current num (0 - N). + nextIndex int + lastIndex int + track []int + isReverse bool + done bool + fixed bool +} + +func AxialIteratorFromDense(t *Dense, axis, axisSz int, fixedAxis bool) *AxialIterator { + ap := t.Info() + return &AxialIterator{ + AP: ap, + track: make([]int, len(ap.shape)), + axis: axis, + axisSz: axisSz, + fixed: fixedAxis, + } +} + +// Start returns the first index +func (it *AxialIterator) Start() (retVal int, err error) { + it.Reset() + + // compute the nextIndex + if it.fixed { + it.track[it.axis] = it.axisSz + it.nextIndex, err = Ltoi(it.shape, it.strides, it.track...) + } + + return it.Next() +} + +// Next returns the next index. +// Example: let's say we're iterating on a tensor with the following +// shape: (2, 3, 4); axis: 1 +// At the start, the coordinates are: +// coordinates: (0, 0, 0) +// Next() will yield: +// coordinates: (0, 0, 1) +// But when the coordinates are: +// coordinates: (0, 0, 4) +// Next() will yield: +// coordinates: (1, 0, 0). +// Note that axis 1 is frozen at 0. +func (it *AxialIterator) Next() (int, error) { + if it.done { + return -1, noopError{} + } + + switch { + case it.isReverse: + return it.ndPrevious() + default: + return it.ndNext() + } + +} + +func (it *AxialIterator) ndNext() (int, error) { + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + track := it.track[:v+1] // force bounds check + coord := it.shape[:v+1] // force bounds check + strides := it.strides[:v+1] // fource bounds check + sz := it.axisSz + track[it.axis] = sz + + for i := v; i >= 0; i-- { + if i == it.axis { + continue // we're iterating along an axis. + } + track[i]++ + shapeI := coord[i] + strideI := strides[i] + if track[i] == shapeI { + track[i] = 0 + nextIndex -= (shapeI - 1) * strideI + if i == 0 { + it.axisSz++ + track[it.axis] = it.axisSz + + if it.fixed || track[it.axis] == coord[it.axis] || it.axisSz >= coord[it.axis] { + track[it.axis] = 0 + it.done = true + break + } + + nextIndex = track[it.axis] * strides[it.axis] + } + + continue + } + nextIndex += strideI + break + } + it.nextIndex = nextIndex + return it.lastIndex, nil +} + +func (it *AxialIterator) ndPrevious() (int, error) { + panic("Not yet implemented") +} + +// NextValidity is like Next, but returns the validity of the value at the index as well. +func (it *AxialIterator) NextValidity() (int, bool, error) { + i, err := it.Next() + return i, true, err +} + +// NextValid returns the next valid index, as well as a skip count. +func (it *AxialIterator) NextValid() (int, int, error) { + if it.done { + return -1, 1, noopError{} + } + + switch { + case it.isReverse: + a, err := it.ndPrevious() + return a, -1, err + default: + a, err := it.ndNext() + return a, 1, err + } +} + +// NextInvalid returns the next invalid index, as well as a skip count. +func (it *AxialIterator) NextInvalid() (int, int, error) { + panic("not implemented") // TODO: Implement +} + +// Reset resets the iterator +func (it *AxialIterator) Reset() { + it.nextIndex = 0 + for i := range it.track { + it.track[i] = 0 + } +} + +// SetReverse tells the iterator to iterate in reverse +func (it *AxialIterator) SetReverse() { it.isReverse = true } + +// SetForward tells the iterator to iterate forwards +func (it *AxialIterator) SetForward() { it.isReverse = false } + +// Coord returns the coordinates +func (it *AxialIterator) Coord() []int { return it.track } + +// Done returns true when the iterator is done iterating. +func (it *AxialIterator) Done() bool { return it.done } From 076a972566da13b83077070b30d8bf7b9f3dd049 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 26 Jul 2021 13:54:21 +1000 Subject: [PATCH 096/154] Added max example --- api_reduction.go | 10 +++++++++- example_mapreduce_test.go | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/api_reduction.go b/api_reduction.go index f146972..414abfd 100644 --- a/api_reduction.go +++ b/api_reduction.go @@ -2,7 +2,7 @@ package tensor import "github.com/pkg/errors" -// Sum sums a Tensor along the given axes +// Sum sums a Tensor along the given axes. func Sum(t Tensor, along ...int) (retVal Tensor, err error) { if sumer, ok := t.Engine().(Sumer); ok { return sumer.Sum(t, along...) @@ -10,6 +10,14 @@ func Sum(t Tensor, along ...int) (retVal Tensor, err error) { return nil, errors.New("Engine does not support Sum()") } +// Max finds the maximum value along the given axes. +func Max(t Tensor, along ...int) (retVal Tensor, err error) { + if maxer, ok := t.Engine().(Maxer); ok { + return maxer.Max(t, along...) + } + return nil, errors.New("Engine does not support Max()") +} + // Argmax finds the index of the max value along the axis provided func Argmax(t Tensor, axis int) (retVal Tensor, err error) { if argmaxer, ok := t.Engine().(Argmaxer); ok { diff --git a/example_mapreduce_test.go b/example_mapreduce_test.go index e08c6da..4aa85bf 100644 --- a/example_mapreduce_test.go +++ b/example_mapreduce_test.go @@ -89,7 +89,6 @@ func ExampleArgmax_sliced() { // // Argmax: 0 // Argmax is *tensor.Dense of int - } func ExampleArgmin() { @@ -109,3 +108,22 @@ func ExampleArgmin() { // Argmin: [0 1] // Argmin is *tensor.Dense of int } + +func ExampleMax() { + T := New(WithBacking([]int{1, 2, 5, 3, 4, 1}), WithShape(2, 3)) + fmt.Printf("T\n%v\n", T) + + // Max along all axes + m, _ := Max(T) + fmt.Printf("Max: %v\n", m) + fmt.Printf("Max is %T of %v", m, m.Dtype()) + + // Output: + // T + // ⎡1 2 5⎤ + // ⎣3 4 1⎦ + // + // Max: 5 + // Max is *tensor.Dense of int + +} From f83f4572b0085d8aa220e20b82b3120fd946f711 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 26 Jul 2021 22:28:36 +1000 Subject: [PATCH 097/154] Added prelim sketch of Scatter --- api_matop.go | 7 + defaultengine_matop_gatherscatter.go | 193 +++++++++++++++++++++++++++ defaultengine_matop_misc.go | 1 - engine.go | 4 + example_dense_scatter_test.go | 79 +++++++++++ iterator_axial.go | 12 +- shape.go | 35 +++++ 7 files changed, 329 insertions(+), 2 deletions(-) create mode 100644 defaultengine_matop_gatherscatter.go create mode 100644 example_dense_scatter_test.go diff --git a/api_matop.go b/api_matop.go index 3dc5df4..5df9a8f 100644 --- a/api_matop.go +++ b/api_matop.go @@ -162,3 +162,10 @@ func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, } return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) } + +func Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if sc, ok := a.Engine().(Scatterer); ok { + return sc.Scatter(a, indices, opts...) + } + return nil, errors.Errorf("Unable to scatter. Engine %T does not support Scattering.", a.Engine()) +} diff --git a/defaultengine_matop_gatherscatter.go b/defaultengine_matop_gatherscatter.go new file mode 100644 index 0000000..1119b30 --- /dev/null +++ b/defaultengine_matop_gatherscatter.go @@ -0,0 +1,193 @@ +package tensor + +import ( + "sync" + + "github.com/pkg/errors" +) + +func (e StdEng) Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + fo := ParseFuncOpts(opts...) + reuse := fo.Reuse() + + maxT, err := Max(indices) + if err != nil { + return nil, errors.Wrapf(err, "Cannot find the max of the indices") + } + max, ok := maxT.Data().(int) + if !ok { + return nil, errors.Errorf("Indices must be of ints. Got %v of %T instead", maxT.Data(), maxT.Data()) + } + + // expected shape + shp := indices.Shape().Clone() + shp[len(shp)-1] = max + 1 + + switch { + case reuse == nil && fo.Safe(): + // create reuse + reuse = New(WithShape(shp...), Of(a.Dtype())) + case reuse == nil && !fo.Safe(): + // check shape of `a` + case reuse != nil: + // check shape of `reuse` + } + + oldShape := a.Shape().Clone() + oldIndicesShape := a.Shape().Clone() + reuseOldShape := reuse.Shape().Clone() + defer func() { a.Reshape(oldShape...); indices.Reshape(oldIndicesShape...); reuse.Reshape(reuseOldShape...) }() + + switch { + case indices.Shape().IsVectorLike(): + idx := indices.Data().([]int) + _ = idx + // TODO + default: + // THIS IS ROW MAJOR ONLY + // THIS IS DENSE TENSOR ONLY + + a := a.(DenseTensor) + indices := indices.(DenseTensor) + reuse := reuse.(DenseTensor) + + // reshape everything into a matrix + a.Reshape(asMat(a.Shape(), a.Dims()-1, true)...) + indices.Reshape(asMat(indices.Shape(), indices.Dims()-1, true)...) + reuse.Reshape(asMat(reuse.Shape(), reuse.Dims()-1, true)...) + + // check that indices' shape[0] is <= a.Shape[0] + if indices.Shape()[0] > a.Shape()[0] { + // something is wrong + return nil, errors.Errorf("Cannot scatter") + } + + // now they are all matrices, we can iterate thru them + var ps []iteratorPair + for i := 0; i < indices.Shape()[0]; i++ { + ait := AxialIteratorFromDense(a, 0, i, true) + iit := AxialIteratorFromDense(indices, 0, i, true) + + ps = append(ps, iteratorPair{ait, iit, i}) + } + + errChan := make(chan error, len(ps)) + var wg sync.WaitGroup + for i := range ps { + wg.Add(1) + // note: be careful not to use `for i, p := range ps` + // and then use `go p.coiter`. + // This is because `p` is would not be captured by `go`, + // thus every `p` would be `ps[len(ps)-1]`. + go ps[i].coiter(a, indices, reuse, errChan, &wg) + } + wg.Wait() + close(errChan) + err = <-errChan // maybe get ALL the errors from errChan? + return reuse, err + + } + + panic("NYI") +} + +type iteratorPair struct { + a *AxialIterator + idx *AxialIterator + axis int +} + +func (it *iteratorPair) coiter(a, indices, reuse DenseTensor, errChan chan error, wg *sync.WaitGroup) { + defer wg.Done() + ii, err := it.idx.Start() + if err != nil { + if err = handleNoOp(err); err != nil { + errChan <- err + } + return + } + + iData := indices.Data().([]int) + retStride := reuse.Strides()[0] + switch { + case a.Dtype() == Float64 && reuse.Dtype() == Float64: + aData := a.Data().([]float64) + rData := reuse.Data().([]float64) + + var ai, ii int + if ai, err = it.a.Start(); err != nil { + goto reterr + } + if ii, err = it.idx.Start(); err != nil { + goto reterr + } + for { + + idx := iData[ii] + v := aData[ai] + + rData[it.axis*retStride+idx] = v + + if it.a.Done() || it.idx.Done() { + break + } + if ai, err = it.a.Next(); err != nil { + break + } + if ii, err = it.idx.Next(); err != nil { + break + } + } + case a.Dtype() == Float32 && reuse.Dtype() == Float32: + aData := a.Data().([]float32) + rData := reuse.Data().([]float32) + + var ai, ii int + if ai, err = it.a.Start(); err != nil { + goto reterr + } + if ii, err = it.idx.Start(); err != nil { + goto reterr + } + for { + + idx := iData[ii] + v := aData[ai] + + rData[it.axis*retStride+idx] = v + + if it.a.Done() || it.idx.Done() { + break + } + if ai, err = it.a.Next(); err != nil { + break + } + if ii, err = it.idx.Next(); err != nil { + break + } + } + + default: + + // generic + for ai, err := it.a.Start(); err == nil; ai, err = it.a.Next() { + if it.idx.Done() { + break + } + idx := iData[ii] + v := a.arrPtr().Get(ai) + reuse.Set(it.axis*retStride+idx, v) + + if ii, err = it.idx.Next(); err != nil { + break + } + } + } + +reterr: + if err = handleNoOp(err); err != nil { + errChan <- err + return + } + +} diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 85924b6..ffcbf4c 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -203,7 +203,6 @@ func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, } // we can straightaway broadcast - continue } diff --git a/engine.go b/engine.go index a1efa5b..4696d25 100644 --- a/engine.go +++ b/engine.go @@ -399,6 +399,10 @@ type ByIndiceser interface { SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) } +type Scatterer interface { + Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) +} + /* Internal interfaces for faster shit */ type denseArgmaxer interface { diff --git a/example_dense_scatter_test.go b/example_dense_scatter_test.go new file mode 100644 index 0000000..d71d4ff --- /dev/null +++ b/example_dense_scatter_test.go @@ -0,0 +1,79 @@ +package tensor + +import "fmt" + +func ExampleScatter() { + T := New(WithShape(2, 3, 4), WithBacking([]float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + })) + + indices := New(WithShape(2, 3, 4), WithBacking([]int{ + 3, 2, 1, 0, + 3, 2, 1, 0, + 4, 3, 2, 1, + + 0, 4, 1, 2, + 4, 4, 4, 4, + 3, 3, 3, 3, + })) + + s, err := Scatter(T, indices) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("%v\n", s) + + // Output: + // ⎡ 3 2 1 0 0⎤ + // ⎢ 7 6 5 4 0⎥ + // ⎣ 0 11 10 9 8⎦ + // + // ⎡ 0 2 3 0 1⎤ + // ⎢ 0 0 0 0 7⎥ + // ⎣ 0 0 0 11 0⎦ + +} + +func ExampleScatter_matrixIndices() { + T := New(WithShape(2, 3, 4), WithBacking([]float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + })) + + indices := New(WithShape(5, 4), WithBacking([]int{ + 3, 2, 1, 0, + 3, 2, 1, 0, + 4, 3, 2, 1, + 0, 4, 1, 2, + 4, 4, 4, 4, + })) + + s, err := Scatter(T, indices) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("%v\n", s) + + // Output: + // ⎡ 3 2 1 0 0⎤ + // ⎢ 7 6 5 4 0⎥ + // ⎢ 0 11 10 9 8⎥ + // ⎢ 0 2 3 0 1⎥ + // ⎣ 0 0 0 0 7⎦ + +} diff --git a/iterator_axial.go b/iterator_axial.go index 0967197..493ac21 100644 --- a/iterator_axial.go +++ b/iterator_axial.go @@ -15,7 +15,8 @@ type AxialIterator struct { fixed bool } -func AxialIteratorFromDense(t *Dense, axis, axisSz int, fixedAxis bool) *AxialIterator { +// AxialIteratorFromDense creates and axial iterator that will iterate along the given axis. `fixedAxis` defines if the axisSz is fixed. +func AxialIteratorFromDense(t DenseTensor, axis, axisSz int, fixedAxis bool) *AxialIterator { ap := t.Info() return &AxialIterator{ AP: ap, @@ -78,6 +79,15 @@ func (it *AxialIterator) ndNext() (int, error) { for i := v; i >= 0; i-- { if i == it.axis { + if i == 0 { + if it.fixed || track[it.axis] == coord[it.axis] || it.axisSz >= coord[it.axis] { + track[it.axis] = 0 + it.done = true + break + } + it.axisSz++ + track[it.axis] = it.axisSz + } continue // we're iterating along an axis. } track[i]++ diff --git a/shape.go b/shape.go index e5e7b0a..f8d5d0a 100644 --- a/shape.go +++ b/shape.go @@ -97,3 +97,38 @@ func CalcStridesColMajor(s Shape) []int { } return retVal } + +// asMat returns a matrix shape from the given shape and axis. The given axis is which dim it will stop in. +// +// asMat((5), 0, true) = (1, 5) +// asMat((5), 1, true) = (5, 1) +// asMat((3,4,5), 0, true) = (1, 60) +// asMat((3,4,5), 1, true) = (3, 20) +// asMat((3,4,5), 2, true) = (12, 5) +// asMat((3,4,5), 0, false) = (1, 20) +// asMat((3,4,5), 1, false) = (3, 5) +// asMat((3,4,5), 2, false) = (12, 1) +func asMat(a Shape, axis int, inclusive bool) (retVal Shape) { + // no need to do a check because asMat will only ever be used by internal functions. + + retVal = Shape(BorrowInts(2)) + switch { + case a.Dims() == 1 && axis == 0: + retVal[0] = 1 + retVal[1] = a[0] + return + case a.Dims() == 1 && axis == 1: + retVal[0] = a[0] + retVal[1] = 1 + return + } + // outer + retVal[0] = ProdInts(a[:axis]) + aplus := axis + if !inclusive { + aplus++ + } + // inner + retVal[1] = ProdInts(a[aplus:]) + return +} From 99f679cdfb10dfa6a1887619c7a783571cb430b8 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 26 Jul 2021 23:08:05 +1000 Subject: [PATCH 098/154] Added some checks for gather scatter --- defaultengine_matop_gatherscatter.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/defaultengine_matop_gatherscatter.go b/defaultengine_matop_gatherscatter.go index 1119b30..b28d4ea 100644 --- a/defaultengine_matop_gatherscatter.go +++ b/defaultengine_matop_gatherscatter.go @@ -28,9 +28,16 @@ func (e StdEng) Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err // create reuse reuse = New(WithShape(shp...), Of(a.Dtype())) case reuse == nil && !fo.Safe(): - // check shape of `a` + // check shape of `a` - the last dim of a must be at least max+1 + if a.Shape()[a.Dims()-1] < max+1 { + return nil, errors.Errorf("Cannot Scatter - the last dim of `a` %v must be at least %v, which is the maximum value of the indices + 1", a.Shape(), max+1) + } + reuse = a case reuse != nil: - // check shape of `reuse` + // check shape of `reuse` - last dim of `reuse` must at least be as large as max+1 + if reuse.Shape()[reuse.Dims()-1] < max+1 { + return nil, errors.Errorf("Cannot Scatter. The last dim of `reuse` %v must be at least %v, which is the maximum value off the indices + 1", reuse.Shape(), max+1) + } } oldShape := a.Shape().Clone() @@ -88,7 +95,7 @@ func (e StdEng) Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err } - panic("NYI") + panic("unreachable") } type iteratorPair struct { From d05b65f4457fdbadbc00bbd85bc4d3271b1a9b78 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 28 Jul 2021 13:55:02 +1000 Subject: [PATCH 099/154] Fixed the def'n of StandardEngine, added StandardEngine2 --- defaultengine.go | 1 - dense.go | 2 +- engine.go | 9 +++++++++ go.mod | 3 +++ go.sum | 2 ++ scalar.go | 4 +++- 6 files changed, 18 insertions(+), 3 deletions(-) diff --git a/defaultengine.go b/defaultengine.go index d1450c7..f9f5854 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -13,7 +13,6 @@ type StdEng struct { // makeArray allocates a slice for the array func (e StdEng) makeArray(arr *array, t dtype.Dtype, size int) { - arr.Raw = malloc(t, size) arr.t = t } diff --git a/dense.go b/dense.go index 6d22c20..20750a4 100644 --- a/dense.go +++ b/dense.go @@ -84,7 +84,7 @@ func (t *Dense) makeArray(size int) { case arrayMaker: te.makeArray(&t.array, t.t, size) return - case StandardEngine: + case StandardEngine2: default: } diff --git a/engine.go b/engine.go index 4696d25..f539ab1 100644 --- a/engine.go +++ b/engine.go @@ -26,7 +26,16 @@ type Engine interface { WorksWith(order DataOrder) bool // WorksWith returns true if the data order can be directly worked with } +// StandardEngine is any engine that wraps a StdEng{}. type StandardEngine interface { + StandardEngine2 + + // anything that wraps StdEng will contain the following interfaces: + arrayMaker +} + +// StandardEngine2 is any engine that implements the basic operations of a standard engine. +type StandardEngine2 interface { Engine Adder diff --git a/go.mod b/go.mod index 0aea516..c6383cf 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.13 replace gorgonia.org/dtype => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/dtype +replace gorgonia.org/shapes => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/shapes + require ( github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc github.com/chewxy/math32 v1.0.6 @@ -15,6 +17,7 @@ require ( go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 gonum.org/v1/gonum v0.8.2 gorgonia.org/dtype v0.0.0-00010101000000-000000000000 + gorgonia.org/shapes v0.0.0-00010101000000-000000000000 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) diff --git a/go.sum b/go.sum index 1fbe637..f57ec01 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= diff --git a/scalar.go b/scalar.go index 8c7c025..3a344ae 100644 --- a/scalar.go +++ b/scalar.go @@ -8,6 +8,8 @@ import ( "reflect" "unsafe" + "gorgonia.org/dtype" + "github.com/pkg/errors" "gorgonia.org/tensor/internal/storage" ) @@ -87,7 +89,7 @@ func (s Scalar) GobEncode() ([]byte, error) { } func (s Scalar) GobDecode([]byte) error { return errors.Errorf(methodNYI, "GobDecode", "Scalar") } // TODO -func (s Scalar) standardEngine() standardEngine { return StdEng{} } +func (s Scalar) standardEngine() StandardEngine { return StdEng{} } func (s Scalar) hdr() *storage.Header { return nil } func (s Scalar) arr() array { return array{} } func (s Scalar) arrPtr() *array { return nil } From 0e008e091ddb08348a3b5d4e979f12d51e7ac9f9 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 28 Jul 2021 15:58:00 +1000 Subject: [PATCH 100/154] Moved funcopts into its own file because it should be in its own file. Also, added a WithContext() FuncOpt. --- flags.go | 78 ---------------------------- funcopts.go | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++++ types.go | 51 ------------------ 3 files changed, 146 insertions(+), 129 deletions(-) create mode 100644 funcopts.go diff --git a/flags.go b/flags.go index c514969..5cc0bae 100644 --- a/flags.go +++ b/flags.go @@ -1,7 +1,5 @@ package tensor -import "gorgonia.org/dtype" - // DataOrder is a flag that indicates the order of data. The default DataOrder (0) // is what this package uses by default. type DataOrder byte @@ -118,79 +116,3 @@ func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { func (f MemoryFlag) nativelyAccessible() bool { return !((f & NativelyInaccessible) != 0) } func (f MemoryFlag) manuallyManaged() bool { return (f & ManuallyManaged) != 0 } func (f MemoryFlag) isOverallocated() bool { return (f & IsOverallocated) != 0 } - -// OpOpt are the options used to call ops -type OpOpt struct { - reuse Tensor - incr Tensor - unsafe bool - same bool - t dtype.Dtype -} - -// ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. -func ParseFuncOpts(opts ...FuncOpt) *OpOpt { - retVal := borrowOpOpt() - for _, opt := range opts { - opt(retVal) - } - return retVal -} - -// Incr returns the tensor to be incremented in the call. Can be nil. -func (fo *OpOpt) Incr() Tensor { return fo.incr } - -// Reuse returns the tensor to be reused in the call. Can be nil. -func (fo *OpOpt) Reuse() Tensor { return fo.reuse } - -// IncReuse returns whether a reuse tensor is to be used as the incr Tensor -func (fo *OpOpt) IncrReuse() (Tensor, bool) { - if fo.incr != nil { - return fo.incr, true - } - return fo.reuse, false -} - -// Safe signals if the op is to be done safely -func (fo *OpOpt) Safe() bool { return !fo.unsafe } - -// Same signals if the op is to return the same type as its inputs -func (fo *OpOpt) Same() bool { return fo.same } - -// As returns the dtype of the return value of the method call. -// For example: -// a.Lt(b, As(Bool)) -// indicates that the result of the `Lt()` should be a Tensor of Bool. -// -// Another example: -// a.Add(b, As(Int)) -// indicates that the result of `Add()` should be converted to a Tensor of Int. -// Note that this function is not yet supported in most operations. -func (fo *OpOpt) As() dtype.Dtype { return fo.t } - -// SetReuse allows the reuse parameter to be set. -func (fo *OpOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } - -// SetIncr allows the incr parameter to be set. -func (fo *OpOpt) SetIncr(incr Tensor) { fo.incr = incr } - -// FuncOpts is the inverse of ParseFuncOpts. -func (fo *OpOpt) FuncOpts() []FuncOpt { - retVal := make([]FuncOpt, 0, 4) - if fo.reuse != nil { - retVal = append(retVal, WithReuse(fo.reuse)) - } - if fo.incr != nil { - retVal = append(retVal, WithIncr(fo.incr)) - } - if fo.unsafe { - retVal = append(retVal, UseUnsafe()) - } - if fo.same { - retVal = append(retVal, AsSameType()) - } - if fo.t != (Dtype{}) { - retVal = append(retVal, As(fo.t)) - } - return retVal -} diff --git a/funcopts.go b/funcopts.go new file mode 100644 index 0000000..eb6e96c --- /dev/null +++ b/funcopts.go @@ -0,0 +1,146 @@ +package tensor + +import ( + "context" + + "gorgonia.org/dtype" +) + +// FuncOpt are optionals for calling Tensor functions. +type FuncOpt func(*OpOpt) + +// WithIncr passes in a Tensor to be incremented. +func WithIncr(incr Tensor) FuncOpt { + f := func(opt *OpOpt) { + opt.incr = incr + } + return f +} + +// WithReuse passes in a Tensor to be reused. +func WithReuse(reuse Tensor) FuncOpt { + f := func(opt *OpOpt) { + opt.reuse = reuse + } + return f +} + +// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions +func UseSafe() FuncOpt { + f := func(opt *OpOpt) { + opt.unsafe = false + } + return f +} + +// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace +func UseUnsafe() FuncOpt { + f := func(opt *OpOpt) { + opt.unsafe = true + } + return f +} + +// AsSameType makes sure that the return Tensor is the same type as input Tensors. +func AsSameType() FuncOpt { + f := func(opt *OpOpt) { + opt.same = true + } + return f +} + +// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 +func As(t dtype.Dtype) FuncOpt { + f := func(opt *OpOpt) { + opt.t = t + } + return f +} + +// WithContext allows a function to be called with a given context +func WithContext(ctx context.Context) FuncOpt { + f := func(opt *OpOpt) { + opt.ctx = ctx + } + return f +} + +// OpOpt are the options used to call ops +type OpOpt struct { + reuse Tensor + incr Tensor + unsafe bool + same bool + t dtype.Dtype + ctx context.Context +} + +// ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. +func ParseFuncOpts(opts ...FuncOpt) *OpOpt { + retVal := borrowOpOpt() + for _, opt := range opts { + opt(retVal) + } + return retVal +} + +// Incr returns the tensor to be incremented in the call. Can be nil. +func (fo *OpOpt) Incr() Tensor { return fo.incr } + +// Reuse returns the tensor to be reused in the call. Can be nil. +func (fo *OpOpt) Reuse() Tensor { return fo.reuse } + +// IncReuse returns whether a reuse tensor is to be used as the incr Tensor +func (fo *OpOpt) IncrReuse() (Tensor, bool) { + if fo.incr != nil { + return fo.incr, true + } + return fo.reuse, false +} + +// Safe signals if the op is to be done safely +func (fo *OpOpt) Safe() bool { return !fo.unsafe } + +// Same signals if the op is to return the same type as its inputs +func (fo *OpOpt) Same() bool { return fo.same } + +// As returns the dtype of the return value of the method call. +// For example: +// a.Lt(b, As(Bool)) +// indicates that the result of the `Lt()` should be a Tensor of Bool. +// +// Another example: +// a.Add(b, As(Int)) +// indicates that the result of `Add()` should be converted to a Tensor of Int. +// Note that this function is not yet supported in most operations. +func (fo *OpOpt) As() dtype.Dtype { return fo.t } + +// Context returns a context.Context that may have been passed in as a function option. +func (fo *OpOpt) Context() context.Context { return fo.ctx } + +// SetReuse allows the reuse parameter to be set. +func (fo *OpOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } + +// SetIncr allows the incr parameter to be set. +func (fo *OpOpt) SetIncr(incr Tensor) { fo.incr = incr } + +// FuncOpts is the inverse of ParseFuncOpts. +func (fo *OpOpt) FuncOpts() []FuncOpt { + retVal := make([]FuncOpt, 0, 4) + if fo.reuse != nil { + retVal = append(retVal, WithReuse(fo.reuse)) + } + if fo.incr != nil { + retVal = append(retVal, WithIncr(fo.incr)) + } + if fo.unsafe { + retVal = append(retVal, UseUnsafe()) + } + if fo.same { + retVal = append(retVal, AsSameType()) + } + if fo.t != (Dtype{}) { + retVal = append(retVal, As(fo.t)) + } + return retVal +} diff --git a/types.go b/types.go index 3579146..0cc7ef1 100644 --- a/types.go +++ b/types.go @@ -129,54 +129,3 @@ func (n NormOrder) String() string { } panic("unreachable") } - -// FuncOpt are optionals for calling Tensor function. -type FuncOpt func(*OpOpt) - -// WithIncr passes in a Tensor to be incremented. -func WithIncr(incr Tensor) FuncOpt { - f := func(opt *OpOpt) { - opt.incr = incr - } - return f -} - -// WithReuse passes in a Tensor to be reused. -func WithReuse(reuse Tensor) FuncOpt { - f := func(opt *OpOpt) { - opt.reuse = reuse - } - return f -} - -// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions -func UseSafe() FuncOpt { - f := func(opt *OpOpt) { - opt.unsafe = false - } - return f -} - -// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace -func UseUnsafe() FuncOpt { - f := func(opt *OpOpt) { - opt.unsafe = true - } - return f -} - -// AsSameType makes sure that the return Tensor is the same type as input Tensors. -func AsSameType() FuncOpt { - f := func(opt *OpOpt) { - opt.same = true - } - return f -} - -// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 -func As(t dtype.Dtype) FuncOpt { - f := func(opt *OpOpt) { - opt.t = t - } - return f -} From cc327c03a491b0033f32f3dfefd32fa38ddfcd00 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 28 Jul 2021 21:01:08 +1000 Subject: [PATCH 101/154] Added more support for context.Context. Now the API functions can be cancelled --- defaultengine_arith.go | 75 +++++++++++++++++++++++++++------ defaultengine_cmp.go | 74 ++++++++++++++++++++++++++------ defaultengine_mapreduce.go | 13 +++++- defaultengine_misc.go | 8 +++- defaultengine_prep.go | 14 ++++++- defaultengine_selbyidx.go | 14 ++++++- defaultengine_unary.go | 86 +++++++++++++++++++++++++++++++------- funcopts.go | 4 ++ genlib2/agg2_body.go | 18 ++++++-- 9 files changed, 258 insertions(+), 48 deletions(-) diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 2e4e55e..d05d05a 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -3,6 +3,9 @@ package tensor import ( + "context" + "log" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" @@ -17,9 +20,13 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -82,9 +89,13 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -147,9 +158,13 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -212,9 +227,13 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -277,9 +296,13 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -342,9 +365,13 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -411,9 +438,13 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -514,9 +545,13 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -617,9 +652,13 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -720,9 +759,13 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -823,9 +866,13 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -926,9 +973,13 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 3d6a7f0..2c2b697 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -3,6 +3,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" @@ -19,12 +21,16 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -97,12 +103,16 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -175,12 +185,16 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -253,12 +267,16 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -331,12 +349,16 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -409,12 +431,16 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -491,12 +517,16 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -613,12 +643,16 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -735,12 +769,16 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -857,12 +895,16 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -975,12 +1017,16 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -1093,12 +1139,16 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index c484bac..f553fb7 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -1,6 +1,7 @@ package tensor import ( + "context" "reflect" "sort" @@ -21,9 +22,13 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e var reuse DenseTensor var safe, _, incr bool - if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } switch { case safe && reuse == nil: // create reuse @@ -261,10 +266,14 @@ func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTe // FUNC PREP var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { err = errors.Wrap(err, "Unable to prep unary tensor") return } + if err = handleCtx(ctx); err != nil { + return + } var newShape Shape for i, s := range a.Shape() { diff --git a/defaultengine_misc.go b/defaultengine_misc.go index f642f67..c7fc933 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" @@ -13,9 +15,13 @@ func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal T var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap.s + } typ := a.Dtype().Type var ait, rit Iterator diff --git a/defaultengine_prep.go b/defaultengine_prep.go index 1a13a6f..6f6927a 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -1,16 +1,17 @@ package tensor import ( + "context" "reflect" "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" - // "log" ) -func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { +func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict bool, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr, same bool, err error) { fo := ParseFuncOpts(opts...) + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() @@ -62,6 +63,15 @@ func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict boo return } +func handleCtx(ctx context.Context) error { + select { + case <-ctx.Done(): + return noopError{} + default: + } + return nil +} + func binaryCheck(a, b Tensor, tc dtype.TypeClass) (err error) { // check if the tensors are accessible if !a.IsNativelyAccessible() { diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index cdcc318..de28b7b 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/tensor/internal/storage" @@ -27,9 +29,13 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(a.Dtype())) @@ -150,9 +156,13 @@ func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt var reuse DenseTensor var _, toReuse, _ bool - if reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } if !toReuse && reuse == nil { // create reuse reuse = New(WithShape(expectedShape...), Of(a.Dtype())) diff --git a/defaultengine_unary.go b/defaultengine_unary.go index d38cf57..1398a8a 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -3,6 +3,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" @@ -15,9 +17,13 @@ func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -83,9 +89,13 @@ func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -151,9 +161,13 @@ func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -219,9 +233,13 @@ func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -287,9 +305,13 @@ func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -355,9 +377,13 @@ func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -423,9 +449,13 @@ func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -491,9 +521,13 @@ func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -559,9 +593,13 @@ func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -627,9 +665,13 @@ func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -695,9 +737,13 @@ func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -763,9 +809,13 @@ func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -831,9 +881,13 @@ func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -899,9 +953,13 @@ func (e StdEng) Sign(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator diff --git a/funcopts.go b/funcopts.go index eb6e96c..6514a9c 100644 --- a/funcopts.go +++ b/funcopts.go @@ -78,9 +78,13 @@ type OpOpt struct { // ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. func ParseFuncOpts(opts ...FuncOpt) *OpOpt { retVal := borrowOpOpt() + for _, opt := range opts { opt(retVal) } + if retVal.ctx == nil { + retVal.ctx = context.Background() // default context - required for no panics. + } return retVal } diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 41220ef..f56f417 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -5,18 +5,26 @@ import "text/template" // level 2 aggregation (tensor.StdEng) templates const cmpPrepRaw = `var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(),false, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(),false, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } ` const arithPrepRaw = `var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } ` const prepVVRaw = `if err = binaryCheck(a, b, dtype.{{.TypeClassCheck}}); err != nil { @@ -73,9 +81,13 @@ const prepUnaryRaw = `if err = unaryCheck(a, dtype.{{.TypeClassCheck}}); err != } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil{ + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator From 848a2d9965cd7a376851758872185e10bbfe8f38 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 28 Jul 2021 21:57:00 +1000 Subject: [PATCH 102/154] Generated tests for context func opts --- api_arith_generated_test.go | 537 +++++++++++++++++++++++++++++++++++- api_unary_generated_test.go | 7 - defaultengine_arith.go | 1 - dense_arith_test.go | 13 - genlib2/agg3_body.go | 43 ++- genlib2/arith_tests.go | 15 + genlib2/declarations.go | 28 +- perf.go | 1 + 8 files changed, 599 insertions(+), 46 deletions(-) diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index ba019e8..fbf2f58 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -3,8 +3,10 @@ package tensor import ( + "context" "testing" "testing/quick" + "time" "gorgonia.org/dtype" ) @@ -165,7 +167,6 @@ func TestAdd_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -197,7 +198,6 @@ func TestSub_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -228,7 +228,6 @@ func TestMul_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -261,7 +260,6 @@ func TestDiv_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -292,7 +290,6 @@ func TestPow_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -610,6 +607,204 @@ func TestPow_incr(t *testing.T) { t.Errorf("Identity test for Pow failed: %v", err) } +} +func TestAdd_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Adder) + we = we || !ok + + ret, err := Add(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add failed: %v", err) + } + +} +func TestSub_context(t *testing.T) { + inv := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Suber) + we = we || !ok + + ret, err := Sub(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Add(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub failed: %v", err) + } +} +func TestMul_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Muler) + we = we || !ok + + ret, err := Mul(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul failed: %v", err) + } + +} +func TestDiv_context(t *testing.T) { + inv := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Diver) + we = we || !ok + + ret, err := Div(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Mul(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Div failed: %v", err) + } +} +func TestPow_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) + _, ok := a.Engine().(Power) + we = we || !ok + + ret, err := Pow(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Pow", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Pow failed: %v", err) + } + } func TestAddScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -986,7 +1181,6 @@ func TestAddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1017,7 +1211,6 @@ func TestAddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1051,7 +1244,6 @@ func TestSubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1082,7 +1274,6 @@ func TestSubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1114,7 +1305,6 @@ func TestMulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1145,7 +1335,6 @@ func TestMulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1179,7 +1368,6 @@ func TestDivScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1212,7 +1400,6 @@ func TestPowScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1705,3 +1892,327 @@ func TestPowScalar_incr(t *testing.T) { } } +func TestAddScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Adder) + we = we || !ok + + ret, err := Add(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add (tensor as left, scalar as right) failed: %v", err) + } + + iden2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Adder) + we = we || !ok + + ret, err := Add(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) + } + +} +func TestSubScalar_context(t *testing.T) { + inv1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + + ret, err := Sub(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "SubVS", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Add(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + inv2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + + ret, err := Sub(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "SubSV", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Sub(b, ret, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) + } +} +func TestMulScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Muler) + we = we || !ok + + ret, err := Mul(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + iden2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Muler) + we = we || !ok + + ret, err := Mul(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) + } + +} +func TestDivScalar_context(t *testing.T) { + inv1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Diver) + we = we || !ok + + ret, err := Div(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "DivVS", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Mul(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) + } + +} +func TestPowScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) + _, ok := q.Engine().(Power) + we = we || !ok + + ret, err := Pow(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Pow", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) + } + +} diff --git a/api_unary_generated_test.go b/api_unary_generated_test.go index e6fc8a9..3b52bc9 100644 --- a/api_unary_generated_test.go +++ b/api_unary_generated_test.go @@ -220,7 +220,6 @@ func TestNeg_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -254,7 +253,6 @@ func TestSquare_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -288,7 +286,6 @@ func TestCube_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -319,7 +316,6 @@ func TestExp_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -350,7 +346,6 @@ func TestLog_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -381,7 +376,6 @@ func TestSqrt_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -412,7 +406,6 @@ func TestCbrt_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } diff --git a/defaultengine_arith.go b/defaultengine_arith.go index d05d05a..e21fa4b 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -4,7 +4,6 @@ package tensor import ( "context" - "log" "github.com/pkg/errors" "gorgonia.org/dtype" diff --git a/dense_arith_test.go b/dense_arith_test.go index 9b80873..5909d11 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -165,7 +165,6 @@ func TestDense_Add_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -197,7 +196,6 @@ func TestDense_Sub_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -228,7 +226,6 @@ func TestDense_Mul_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -261,7 +258,6 @@ func TestDense_Div_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -292,7 +288,6 @@ func TestDense_Pow_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -986,7 +981,6 @@ func TestDense_AddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1017,7 +1011,6 @@ func TestDense_AddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1051,7 +1044,6 @@ func TestDense_SubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1082,7 +1074,6 @@ func TestDense_SubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1114,7 +1105,6 @@ func TestDense_MulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1145,7 +1135,6 @@ func TestDense_MulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1179,7 +1168,6 @@ func TestDense_DivScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1212,7 +1200,6 @@ func TestDense_PowScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } diff --git a/genlib2/agg3_body.go b/genlib2/agg3_body.go index e7b6592..da4a7ec 100644 --- a/genlib2/agg3_body.go +++ b/genlib2/agg3_body.go @@ -66,6 +66,9 @@ const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -76,7 +79,9 @@ const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -96,6 +101,9 @@ const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -106,7 +114,9 @@ const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -127,6 +137,9 @@ iden2 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call1" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -137,7 +150,9 @@ iden2 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -160,6 +175,9 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -171,7 +189,10 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} + return true } @@ -191,6 +212,9 @@ const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}VS", a, b, we, err); retEarly{ if err != nil { return false @@ -202,7 +226,9 @@ const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -224,6 +250,9 @@ inv2 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call1" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}SV", a, b, we, err); retEarly{ if err != nil { return false @@ -235,7 +264,9 @@ inv2 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index 596cd8a..d97709b 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -143,6 +143,7 @@ func (fn *ArithTest) writeInv(w io.Writer) { t.Execute(w, fn) } + func (fn *ArithTest) WriteScalarWrongType(w io.Writer) { if !fn.scalars { return @@ -205,6 +206,13 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { fn.FuncOpt = "incr" } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "context" + } + for _, fn := range tests { if fn.canWrite() { fn.Write(f) @@ -251,6 +259,13 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { fn.FuncOpt = "incr" } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "context" + } + for _, fn := range tests { if fn.canWrite() { fn.Write(f) diff --git a/genlib2/declarations.go b/genlib2/declarations.go index 130794a..9d70630 100644 --- a/genlib2/declarations.go +++ b/genlib2/declarations.go @@ -57,10 +57,11 @@ var unconditionalFloatUnarySymbolTemplates = [...]string{ } var funcOptUse = map[string]string{ - "reuse": ",WithReuse(reuse)", - "incr": ",WithIncr(incr)", - "unsafe": ",UseUnsafe()", - "assame": ", AsSameType()", + "reuse": ",WithReuse(reuse)", + "incr": ",WithIncr(incr)", + "unsafe": ",UseUnsafe()", + "assame": ", AsSameType()", + "context": ", WithContext(ctx)", } var funcOptCheck = map[string]string{ @@ -77,7 +78,10 @@ var funcOptCheck = map[string]string{ t.Errorf("Expected ret to be the same as a") return false } - + `, + "context": `if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } `, } @@ -89,6 +93,17 @@ var funcOptDecl = map[string]string{ return true // we exit early if the generated type is not something we can handle } `, + "context": `rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1 * time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r * 100)*time.Second) + } + defer cancel() +`, } var funcOptCorrect = map[string]string{ @@ -96,7 +111,8 @@ var funcOptCorrect = map[string]string{ "incr": `incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) `, - "unsafe": "", + "unsafe": "", + "context": "", } var stdTypes = [...]string{ diff --git a/perf.go b/perf.go index 4d5ffd7..d5a79df 100644 --- a/perf.go +++ b/perf.go @@ -264,6 +264,7 @@ func returnOpOpt(oo *OpOpt) { oo.unsafe = false oo.same = false oo.t = dtype.Dtype{} + oo.ctx = nil // if len(optPool) < cap(optPool) { // optPool <- oo // } From 1c8a1149e9cb6b510655cb8b14bdff327e005b83 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 28 Jul 2021 22:00:09 +1000 Subject: [PATCH 103/154] Unexported OpOpt. --- funcopts.go | 47 +++++++++++++++++++++++++---------------------- perf.go | 8 ++++---- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/funcopts.go b/funcopts.go index 6514a9c..65e3d4b 100644 --- a/funcopts.go +++ b/funcopts.go @@ -7,11 +7,14 @@ import ( ) // FuncOpt are optionals for calling Tensor functions. -type FuncOpt func(*OpOpt) +// The `*opOpt` type is unexported, but it's methods are exported. +// This is intentional as use of the `*opOpt` is very specialized. +// See funcopts.go for more information. +type FuncOpt func(*opOpt) // WithIncr passes in a Tensor to be incremented. func WithIncr(incr Tensor) FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.incr = incr } return f @@ -19,7 +22,7 @@ func WithIncr(incr Tensor) FuncOpt { // WithReuse passes in a Tensor to be reused. func WithReuse(reuse Tensor) FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.reuse = reuse } return f @@ -27,7 +30,7 @@ func WithReuse(reuse Tensor) FuncOpt { // UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions func UseSafe() FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.unsafe = false } return f @@ -35,7 +38,7 @@ func UseSafe() FuncOpt { // UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace func UseUnsafe() FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.unsafe = true } return f @@ -43,7 +46,7 @@ func UseUnsafe() FuncOpt { // AsSameType makes sure that the return Tensor is the same type as input Tensors. func AsSameType() FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.same = true } return f @@ -51,7 +54,7 @@ func AsSameType() FuncOpt { // As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 func As(t dtype.Dtype) FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.t = t } return f @@ -59,14 +62,14 @@ func As(t dtype.Dtype) FuncOpt { // WithContext allows a function to be called with a given context func WithContext(ctx context.Context) FuncOpt { - f := func(opt *OpOpt) { + f := func(opt *opOpt) { opt.ctx = ctx } return f } -// OpOpt are the options used to call ops -type OpOpt struct { +// opOpt are the options used to call ops +type opOpt struct { reuse Tensor incr Tensor unsafe bool @@ -76,7 +79,7 @@ type OpOpt struct { } // ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. -func ParseFuncOpts(opts ...FuncOpt) *OpOpt { +func ParseFuncOpts(opts ...FuncOpt) *opOpt { retVal := borrowOpOpt() for _, opt := range opts { @@ -89,13 +92,13 @@ func ParseFuncOpts(opts ...FuncOpt) *OpOpt { } // Incr returns the tensor to be incremented in the call. Can be nil. -func (fo *OpOpt) Incr() Tensor { return fo.incr } +func (fo *opOpt) Incr() Tensor { return fo.incr } // Reuse returns the tensor to be reused in the call. Can be nil. -func (fo *OpOpt) Reuse() Tensor { return fo.reuse } +func (fo *opOpt) Reuse() Tensor { return fo.reuse } -// IncReuse returns whether a reuse tensor is to be used as the incr Tensor -func (fo *OpOpt) IncrReuse() (Tensor, bool) { +// IncrReuse returns whether a reuse tensor is to be used as the incr Tensor +func (fo *opOpt) IncrReuse() (Tensor, bool) { if fo.incr != nil { return fo.incr, true } @@ -103,10 +106,10 @@ func (fo *OpOpt) IncrReuse() (Tensor, bool) { } // Safe signals if the op is to be done safely -func (fo *OpOpt) Safe() bool { return !fo.unsafe } +func (fo *opOpt) Safe() bool { return !fo.unsafe } // Same signals if the op is to return the same type as its inputs -func (fo *OpOpt) Same() bool { return fo.same } +func (fo *opOpt) Same() bool { return fo.same } // As returns the dtype of the return value of the method call. // For example: @@ -117,19 +120,19 @@ func (fo *OpOpt) Same() bool { return fo.same } // a.Add(b, As(Int)) // indicates that the result of `Add()` should be converted to a Tensor of Int. // Note that this function is not yet supported in most operations. -func (fo *OpOpt) As() dtype.Dtype { return fo.t } +func (fo *opOpt) As() dtype.Dtype { return fo.t } // Context returns a context.Context that may have been passed in as a function option. -func (fo *OpOpt) Context() context.Context { return fo.ctx } +func (fo *opOpt) Context() context.Context { return fo.ctx } // SetReuse allows the reuse parameter to be set. -func (fo *OpOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } +func (fo *opOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } // SetIncr allows the incr parameter to be set. -func (fo *OpOpt) SetIncr(incr Tensor) { fo.incr = incr } +func (fo *opOpt) SetIncr(incr Tensor) { fo.incr = incr } // FuncOpts is the inverse of ParseFuncOpts. -func (fo *OpOpt) FuncOpts() []FuncOpt { +func (fo *opOpt) FuncOpts() []FuncOpt { retVal := make([]FuncOpt, 0, 4) if fo.reuse != nil { retVal = append(retVal, WithReuse(fo.reuse)) diff --git a/perf.go b/perf.go index d5a79df..a37c610 100644 --- a/perf.go +++ b/perf.go @@ -239,10 +239,10 @@ func ReturnBools(is []bool) { // var optPool = make(chan *OpOpt, PoolSize) // var optPool = newRingbuffer(PoolSize) var optPool = &sync.Pool{ - New: func() interface{} { return new(OpOpt) }, + New: func() interface{} { return new(opOpt) }, } -func borrowOpOpt() *OpOpt { +func borrowOpOpt() *opOpt { // select { // case fo := <-optPool: // return fo @@ -250,7 +250,7 @@ func borrowOpOpt() *OpOpt { // return new(OpOpt) // } - return optPool.Get().(*OpOpt) + return optPool.Get().(*opOpt) // if fo, err := optPool.Get(); err == nil { // return (*OpOpt)(fo) @@ -258,7 +258,7 @@ func borrowOpOpt() *OpOpt { // return new(OpOpt) } -func returnOpOpt(oo *OpOpt) { +func returnOpOpt(oo *opOpt) { oo.reuse = nil oo.incr = nil oo.unsafe = false From 2432a67a3a8c5d02d7bbed54cc5835f6d478b9bf Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 11:54:33 +1000 Subject: [PATCH 104/154] Moved native iterators into package `tensor` proper. It is unexported in the `tensor` package as leaving it exported would lead to pollution of API. For backwards compatibility, the `native` package is still used. golinkname stubs are used in that case --- api_arith_generated_test.go | 4 +- api_cmp_generated_test.go | 4 +- api_unary.go | 4 +- api_unary_generated_test.go | 4 +- array_getset.go | 4 +- defaultengine_arith.go | 4 +- defaultengine_cmp.go | 4 +- defaultengine_unary.go | 4 +- dense_argmethods_test.go | 4 +- dense_arith.go | 4 +- dense_arith_test.go | 4 +- dense_cmp.go | 4 +- dense_cmp_test.go | 4 +- dense_compat.go | 4 +- dense_compat_test.go | 4 +- dense_generated.go | 4 +- dense_generated_test.go | 4 +- dense_getset_test.go | 4 +- dense_io.go | 4 +- dense_maskcmp_methods.go | 4 +- dense_maskcmp_methods_test.go | 4 +- dense_reduction_test.go | 4 +- generic_utils.go | 4 +- genlib2/main.go | 15 +- genlib2/native_iterator.go | 180 ++- genlib2/package.go | 20 +- internal/execution/eng_argmethods.go | 4 +- internal/execution/eng_arith.go | 4 +- internal/execution/eng_cmp.go | 4 +- internal/execution/eng_map.go | 4 +- internal/execution/eng_reduce.go | 4 +- internal/execution/eng_unary.go | 4 +- internal/execution/generic_argmethods.go | 4 +- internal/execution/generic_arith_mixed.go | 4 +- internal/execution/generic_arith_vv.go | 4 +- internal/execution/generic_cmp_mixed.go | 4 +- internal/execution/generic_cmp_vv.go | 4 +- internal/execution/generic_map.go | 4 +- internal/execution/generic_minmax.go | 4 +- internal/execution/generic_reduce.go | 4 +- internal/execution/generic_unary.go | 4 +- .../execution/reduction_specialization.go | 4 +- internal/storage/consts.go | 4 +- internal/storage/getset.go | 4 +- iterator_native.go | 1152 +++++++++++++++++ iterator_native_test.go | 633 +++++++++ native/iterator_native.go | 1118 +++------------- native/iterator_native2.go | 4 +- native/iterator_native2_test.go | 4 +- native/iterator_native_purego.go | 1132 ++++++++++++++++ native/iterator_native_test.go | 4 +- native/utils.go | 30 + test_test.go | 4 +- 53 files changed, 3346 insertions(+), 1114 deletions(-) create mode 100644 iterator_native.go create mode 100644 iterator_native_test.go create mode 100644 native/iterator_native_purego.go create mode 100644 native/utils.go diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index fbf2f58..3e0ac9a 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -11,6 +9,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestAdd(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) diff --git a/api_cmp_generated_test.go b/api_cmp_generated_test.go index e4ddd7b..4a612d8 100644 --- a/api_cmp_generated_test.go +++ b/api_cmp_generated_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestGt(t *testing.T) { transFn := func(q *Dense) bool { we, _ := willerr(q, dtype.Ord, nilTC) diff --git a/api_unary.go b/api_unary.go index b1afe71..4c81e33 100644 --- a/api_unary.go +++ b/api_unary.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + func Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { e := a.Engine() if neger, ok := e.(Neger); ok { diff --git a/api_unary_generated_test.go b/api_unary_generated_test.go index 3b52bc9..64813ae 100644 --- a/api_unary_generated_test.go +++ b/api_unary_generated_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestNeg(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) diff --git a/array_getset.go b/array_getset.go index c19fe68..69bcf95 100644 --- a/array_getset.go +++ b/array_getset.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Set sets the value of the underlying array at the index i. func (a *array) Set(i int, x interface{}) { switch a.t.Kind() { diff --git a/defaultengine_arith.go b/defaultengine_arith.go index e21fa4b..0a4d800 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Add performs a + b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 2c2b697..749cf11 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Gt performs a > b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), AsSameType(), WithReuse(). //UseUnsafe() will ensure that the same type is returned. diff --git a/defaultengine_unary.go b/defaultengine_unary.go index 1398a8a..8efe589 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Neg failed") diff --git a/dense_argmethods_test.go b/dense_argmethods_test.go index a4b03bd..a90b957 100644 --- a/dense_argmethods_test.go +++ b/dense_argmethods_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" ) +// Code generated by genlib2. DO NOT EDIT. + /* Test data */ var basicDenseI = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) diff --git a/dense_arith.go b/dense_arith.go index 7218d37..5c4eba9 100644 --- a/dense_arith.go +++ b/dense_arith.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + // Add performs t + other elementwise. Both t and other must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (t *Dense) Add(other *Dense, opts ...FuncOpt) (retVal *Dense, err error) { diff --git a/dense_arith_test.go b/dense_arith_test.go index 5909d11..db70124 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestDense_Add(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) diff --git a/dense_cmp.go b/dense_cmp.go index 4ffaadf..d7770ac 100644 --- a/dense_cmp.go +++ b/dense_cmp.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + // Gt performs t > other elementwise. Both t and other must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), AsSameType(), WithReuse(). //UseUnsafe() will ensure that the same type is returned. diff --git a/dense_cmp_test.go b/dense_cmp_test.go index 384e250..82e8518 100644 --- a/dense_cmp_test.go +++ b/dense_cmp_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestDense_Gt(t *testing.T) { transFn := func(q *Dense) bool { we, _ := willerr(q, dtype.Ord, nilTC) diff --git a/dense_compat.go b/dense_compat.go index 1161cf1..cf6764d 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor +// Code generated by genlib2. DO NOT EDIT. + import ( "fmt" "math" diff --git a/dense_compat_test.go b/dense_compat_test.go index 581563c..442b7e7 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" diff --git a/dense_generated.go b/dense_generated.go index 5a44a10..c9158fa 100644 --- a/dense_generated.go +++ b/dense_generated.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -8,6 +6,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + // Ones creates a *Dense with the provided shape and type func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) diff --git a/dense_generated_test.go b/dense_generated_test.go index 7332bf4..edd0850 100644 --- a/dense_generated_test.go +++ b/dense_generated_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + var onesTests = []struct { of dtype.Dtype shape Shape diff --git a/dense_getset_test.go b/dense_getset_test.go index cde0542..899e855 100644 --- a/dense_getset_test.go +++ b/dense_getset_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -11,6 +9,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + var denseSetGetTests = []struct { of dtype.Dtype data interface{} diff --git a/dense_io.go b/dense_io.go index c9e8f7c..374daf0 100644 --- a/dense_io.go +++ b/dense_io.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -21,6 +19,8 @@ import ( "gorgonia.org/tensor/internal/serialization/pb" ) +// Code generated by genlib2. DO NOT EDIT. + /* GOB SERIALIZATION */ // GobEncode implements gob.GobEncoder diff --git a/dense_maskcmp_methods.go b/dense_maskcmp_methods.go index 4cc3d95..d4b415a 100644 --- a/dense_maskcmp_methods.go +++ b/dense_maskcmp_methods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "github.com/pkg/errors" ) +// Code generated by genlib2. DO NOT EDIT. + /* MaskedEqual */ // MaskedEqual sets the mask to true where the corresponding data is equal to val diff --git a/dense_maskcmp_methods_test.go b/dense_maskcmp_methods_test.go index d16a78d..e48e89c 100644 --- a/dense_maskcmp_methods_test.go +++ b/dense_maskcmp_methods_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "github.com/stretchr/testify/assert" ) +// Code generated by genlib2. DO NOT EDIT. + /* MaskedEqual */ func TestDense_MaskedEqual_I(t *testing.T) { diff --git a/dense_reduction_test.go b/dense_reduction_test.go index 05de324..e4ef5ec 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/execution" ) +// Code generated by genlib2. DO NOT EDIT. + var denseReductionTests = []struct { of dtype.Dtype fn interface{} diff --git a/generic_utils.go b/generic_utils.go index ca00bd9..9a44263 100644 --- a/generic_utils.go +++ b/generic_utils.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -12,6 +10,8 @@ import ( "gorgonia.org/vecf64" ) +// Code generated by genlib2. DO NOT EDIT. + // Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i diff --git a/genlib2/main.go b/genlib2/main.go index 328cd19..63f6987 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -114,15 +114,22 @@ func main() { pipeline(tensorPkgLoc, "api_cmp_generated_test.go", Kinds{allKinds}, generateAPICmpTests, generateAPICmpMixedTests) pipeline(tensorPkgLoc, "dense_cmp_test.go", Kinds{allKinds}, generateDenseMethodCmpTests, generateDenseMethodCmpMixedTests) - // native iterators - pipeline(nativePkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators) - pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests) + // native iterators - the ones in the tensor package + pipeline(tensorPkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators(false)) + pipeline(tensorPkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(false)) pipeline(nativePkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect) pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests) + + // native iterators - exported into gorgonia.org/tensor/native + pipeline(nativePkgLoc+"_unsafe", "iterator_native.go", Kinds{allKinds}, generateNativeIteratorStubs) + pipeline(nativePkgLoc+"_purego", "iterator_native_purego.go", Kinds{allKinds}, generateNativeIterators(true)) + pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(true)) + pipeline(nativePkgLoc, "utils.go", Kinds{allKinds}, generateNativeChecks) } func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) { - fullpath := path.Join(pkg, filename) + pkgpath := strings.Replace(strings.Replace(pkg, "_unsafe", "", -1), "_purego", "", -1) + fullpath := path.Join(pkgpath, filename) f, err := os.Create(fullpath) if err != nil { log.Printf("fullpath %q", fullpath) diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go index a727253..f1e14be 100644 --- a/genlib2/native_iterator.go +++ b/genlib2/native_iterator.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io" + "reflect" "text/template" ) @@ -28,35 +29,45 @@ const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt dty } ` -const nativeIterRaw = `// Vector{{short .}} converts a *Dense into a []{{asType .}} +const nativeIterRaw = ` +{{- $vecName := ( printf "nativeDenseVector%s" (short .K) ) -}} +{{- $matName := ( printf "nativeDenseMatrix%s" (short .K) ) -}} +{{- $T3Name := ( printf "nativeDenseTensor3%s" (short .K) ) -}} +{{- if .N -}} + {{- $vecName = ( printf "Vector%s" (short .K) ) -}} + {{- $matName = ( printf "Matrix%s" (short .K) ) -}} + {{- $T3Name = ( printf "Tensor3%s" (short .K) ) -}} +{{- end -}} + +// {{$vecName}} converts a *Dense into a []{{asType .K}} // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) { - if err = checkNativeIterable(t, 1, {{reflectKind .}}); err != nil { +func {{$vecName}}(t *Dense) (retVal []{{asType .K}}, err error) { + if err = checkNativeIterable(t, 1, {{reflectKind .K}}); err != nil { return nil, err } - return t.{{sliceOf .}}, nil + return t.{{sliceOf .K}}, nil } -// Matrix{{short .}} converts a *Dense into a [][]{{asType .}} +// {{$matName}} converts a *Dense into a [][]{{asType .K}} // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { - if err = checkNativeIterable(t, 2, {{reflectKind .}}); err != nil { +func {{$matName}}(t *Dense) (retVal [][]{{asType .K}}, err error) { + if err = checkNativeIterable(t, 2, {{reflectKind .K}}); err != nil { return nil, err } - data := t.{{sliceOf .}} + data := t.{{sliceOf .K}} shape := t.Shape() strides := t.Strides() rows := shape[0] cols := shape[1] rowStride := strides[0] - retVal = make([][]{{asType .}}, rows) + retVal = make([][]{{asType .K}}, rows) for i := range retVal { start := i * rowStride - retVal[i] = make([]{{asType .}}, 0) + retVal[i] = make([]{{asType .K}}, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -65,14 +76,14 @@ func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { return } -// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// {{$T3Name}} converts a *Dense into a [][][]{{asType .K}}. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { - if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil { +func {{$T3Name}}(t *Dense) (retVal [][][]{{asType .K}}, err error) { + if err = checkNativeIterable(t, 3, {{reflectKind .K}}); err != nil { return nil, err } - data := t.{{sliceOf .}} + data := t.{{sliceOf .K}} shape := t.Shape() strides := t.Strides() @@ -81,11 +92,11 @@ func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { cols := shape[2] layerStride := strides[0] rowStride := strides[1] - retVal = make([][][]{{asType .}}, layers) + retVal = make([][][]{{asType .K}}, layers) for i := range retVal { - retVal[i] = make([][]{{asType .}}, rows) + retVal[i] = make([][]{{asType .K}}, rows) for j := range retVal[i] { - retVal[i][j] = make([]{{asType .}}, 0) + retVal[i][j] = make([]{{asType .K}}, 0) start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) @@ -97,15 +108,57 @@ func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { } ` -const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { +const nativeIterStubsRaw = `//go:linkname Vector{{short .}} gorgonia.org/tensor.nativeDenseVector{{short .}} + +// Vector{{short .}} converts a *Dense into a []{{asType .}} +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func Vector{{short .}}(t *tensor.Dense) (retVal []{{asType .}}, err error) + +//go:linkname Matrix{{short .}} gorgonia.org/tensor.nativeDenseMatrix{{short .}} + +// Matrix{{short .}} converts a *Dense into a [][]{{asType .}} +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func Matrix{{short .}}(t *tensor.Dense) (retVal [][]{{asType .}}, err error) + +//go:linkname Tensor3{{short .}} gorgonia.org/tensor.nativeDenseTensor3{{short .}} + +// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3{{short .}}(t *tensor.Dense) (retVal [][][]{{asType .}}, err error) +` + +const nativeIterTestRaw = ` +{{- $pkgTVecName := ( printf "nativeDenseVector%s" (short .K) ) -}} +{{- $pkgTMatName := ( printf "nativeDenseMatrix%s" (short .K) ) -}} +{{- $pkgTT3Name := ( printf "nativeDenseTensor3%s" (short .K) ) -}} +{{- $pkgNVecName := ( printf "Vector%s" (short .K) ) -}} +{{- $pkgNMatName := ( printf "Matrix%s" (short .K) ) -}} +{{- $pkgNT3Name := ( printf "Tensor3%s" (short .K) ) -}} +{{- $vecName := "" -}} +{{- $matName := "" -}} +{{- $T3Name := "" -}} +{{- if .N -}} + {{- $vecName = $pkgNVecName -}} + {{- $matName = $pkgNMatName -}} + {{- $T3Name = $pkgNT3Name -}} +{{- else -}} + {{- $vecName = $pkgTVecName -}} + {{- $matName = $pkgTMatName -}} + {{- $T3Name = $pkgTT3Name -}} +{{end -}} + + +func Test_{{$vecName}}(t *testing.T) { assert := assert.New(t) var T *Dense - {{if isRangeable . -}} - T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(6)) + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 6)), WithShape(6)) {{else -}} - T = New(Of({{reflectKind .}}), WithShape(6)) + T = New(Of({{reflectKind .K}}), WithShape(6)) {{end -}} - it, err := Vector{{short .}}(T) + it, err := {{$vecName}}(T) if err != nil { t.Fatal(err) } @@ -113,15 +166,15 @@ const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { assert.Equal(6, len(it)) } -func Test_Matrix{{short .}}(t *testing.T) { +func Test_{{$matName}}(t *testing.T) { assert := assert.New(t) var T *Dense - {{if isRangeable . -}} - T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(2, 3)) + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 6)), WithShape(2, 3)) {{else -}} - T = New(Of({{reflectKind .}}), WithShape(2, 3)) + T = New(Of({{reflectKind .K}}), WithShape(2, 3)) {{end -}} - it, err := Matrix{{short .}}(T) + it, err := {{$matName}}(T) if err != nil { t.Fatal(err) } @@ -130,15 +183,15 @@ func Test_Matrix{{short .}}(t *testing.T) { assert.Equal(3, len(it[0])) } -func Test_Tensor3{{short .}}(t *testing.T) { +func Test_{{$T3Name}}(t *testing.T) { assert := assert.New(t) var T *Dense - {{if isRangeable . -}} - T = New(WithBacking(Range({{reflectKind .}}, 0, 24)), WithShape(2, 3, 4)) + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 24)), WithShape(2, 3, 4)) {{else -}} - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4)) + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4)) {{end -}} - it, err := Tensor3{{short .}}(T) + it, err := {{$T3Name}}(T) if err != nil { t.Fatal(err) } @@ -150,31 +203,68 @@ func Test_Tensor3{{short .}}(t *testing.T) { ` var ( - NativeIter *template.Template - NativeIterTest *template.Template + NativeIter *template.Template + NativeIterTest *template.Template + NativeIterStubs *template.Template ) func init() { NativeIter = template.Must(template.New("NativeIter").Funcs(funcs).Parse(nativeIterRaw)) NativeIterTest = template.Must(template.New("NativeIterTest").Funcs(funcs).Parse(nativeIterTestRaw)) + NativeIterStubs = template.Must(template.New("NativeStubs").Funcs(funcs).Parse(nativeIterStubsRaw)) } -func generateNativeIterators(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) - fmt.Fprintf(f, "%v\n", checkNativeiterable) - ks := filter(ak.Kinds, isSpecialized) - for _, k := range ks { - fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) - NativeIter.Execute(f, k) - fmt.Fprint(f, "\n\n") +// generateNativeIterators generates the code for native iterators. `isNative` represents whether the code is generated for the `native` package or not. +// isNative will only be true for the `purego` build tag. +func generateNativeIterators(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + // checkNativeIteratble is separately generated and placed into util.go in the `native` package + // so there is no need to generate that here. + fmt.Fprintf(f, importUnqualifiedTensor) + } else { + fmt.Fprintf(f, "%v\n", checkNativeiterable) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) + NativeIter.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } } } -func generateNativeIteratorTests(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) +func generateNativeIteratorTests(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeIterTest.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeIteratorStubs(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnsafe) // this is required for go:linkname to work ks := filter(ak.Kinds, isSpecialized) for _, k := range ks { - NativeIterTest.Execute(f, k) + NativeIterStubs.Execute(f, k) fmt.Fprint(f, "\n\n") } } + +func generateNativeChecks(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + fmt.Fprintf(f, "%v\n", checkNativeiterable) +} diff --git a/genlib2/package.go b/genlib2/package.go index 4380b6b..8ffcf79 100644 --- a/genlib2/package.go +++ b/genlib2/package.go @@ -8,17 +8,27 @@ import ( func writePkgName(f io.Writer, pkg string) { switch pkg { case tensorPkgLoc: - fmt.Fprintf(f, "// %s\n\npackage tensor\n\n", genmsg) + fmt.Fprintf(f, "package tensor\n\n // %s\n\n", genmsg) case nativePkgLoc: - fmt.Fprintf(f, "// %s\n\npackage native\n\n", genmsg) + fmt.Fprintf(f, "package native\n\n // %s\n\n", genmsg) + case nativePkgLoc + "_unsafe": + fmt.Fprintf(f, "// +build !purego \n\npackage native\n\n // %s\n\n", genmsg) + case nativePkgLoc + "_purego": + fmt.Fprintf(f, "// +build purego \n\npackage native\n\n // %s\n\n", genmsg) case execLoc: - fmt.Fprintf(f, "// %s\n\npackage execution\n\n", genmsg) + fmt.Fprintf(f, "package execution\n\n // %s\n\n", genmsg) case storageLoc: - fmt.Fprintf(f, "// %s\n\npackage storage\n\n", genmsg) + fmt.Fprintf(f, "package storage\n\n // %s\n\n", genmsg) default: - fmt.Fprintf(f, "// %s\n\npackage unknown\n\n", genmsg) + fmt.Fprintf(f, "package unknown\n\n %s\n\n", genmsg) } } const importUnqualifiedTensor = `import . "gorgonia.org/tensor" ` + +const importInternalNative = `import inative "gorgonia.org/tensor/internal/native" +` + +const importUnsafe = `import _ "unsafe" +` diff --git a/internal/execution/eng_argmethods.go b/internal/execution/eng_argmethods.go index 05ed725..9adc173 100644 --- a/internal/execution/eng_argmethods.go +++ b/internal/execution/eng_argmethods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) ArgmaxIter(t reflect.Type, a *storage.Header, it Iterator, lastSize int) (indices []int, err error) { var next int switch t { diff --git a/internal/execution/eng_arith.go b/internal/execution/eng_arith.go index f3de110..a1681a6 100644 --- a/internal/execution/eng_arith.go +++ b/internal/execution/eng_arith.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Add(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { as := isScalar(a, t) bs := isScalar(b, t) diff --git a/internal/execution/eng_cmp.go b/internal/execution/eng_cmp.go index b2c4ece..e5d3dd5 100644 --- a/internal/execution/eng_cmp.go +++ b/internal/execution/eng_cmp.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Gt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { as := isScalar(a, t) bs := isScalar(b, t) diff --git a/internal/execution/eng_map.go b/internal/execution/eng_map.go index 81cb2c4..ecd2b64 100644 --- a/internal/execution/eng_map.go +++ b/internal/execution/eng_map.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Map(t reflect.Type, fn interface{}, a *storage.Header, incr bool) (err error) { as := isScalar(a, t) switch t { diff --git a/internal/execution/eng_reduce.go b/internal/execution/eng_reduce.go index 88c7ae5..bebe52f 100644 --- a/internal/execution/eng_reduce.go +++ b/internal/execution/eng_reduce.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) ReduceFirst(t reflect.Type, data *storage.Header, retVal *storage.Header, split int, size int, fn interface{}) (err error) { switch t { case Bool: diff --git a/internal/execution/eng_unary.go b/internal/execution/eng_unary.go index bd9bd81..4038190 100644 --- a/internal/execution/eng_unary.go +++ b/internal/execution/eng_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Neg(t reflect.Type, a *storage.Header) (err error) { switch t { case Int: diff --git a/internal/execution/generic_argmethods.go b/internal/execution/generic_argmethods.go index 3edb606..cdf4b7d 100644 --- a/internal/execution/generic_argmethods.go +++ b/internal/execution/generic_argmethods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -8,6 +6,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func ArgmaxI(a []int) int { var set bool var f int diff --git a/internal/execution/generic_arith_mixed.go b/internal/execution/generic_arith_mixed.go index 6e8aa72..94f5e8b 100644 --- a/internal/execution/generic_arith_mixed.go +++ b/internal/execution/generic_arith_mixed.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func AddSVI(a int, b []int) { for i := range b { b[i] = a + b[i] diff --git a/internal/execution/generic_arith_vv.go b/internal/execution/generic_arith_vv.go index 26f3772..ea43563 100644 --- a/internal/execution/generic_arith_vv.go +++ b/internal/execution/generic_arith_vv.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -11,6 +9,8 @@ import ( "gorgonia.org/vecf64" ) +// Code generated by genlib2. DO NOT EDIT. + func VecAddI(a []int, b []int) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_cmp_mixed.go b/internal/execution/generic_cmp_mixed.go index b9a1154..1c53747 100644 --- a/internal/execution/generic_cmp_mixed.go +++ b/internal/execution/generic_cmp_mixed.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func GtSVI(a int, b []int, retVal []bool) { for i := range retVal { retVal[i] = a > b[i] diff --git a/internal/execution/generic_cmp_vv.go b/internal/execution/generic_cmp_vv.go index 7d528c4..a501f93 100644 --- a/internal/execution/generic_cmp_vv.go +++ b/internal/execution/generic_cmp_vv.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func GtI(a []int, b []int, retVal []bool) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_map.go b/internal/execution/generic_map.go index 41c7de8..f054239 100644 --- a/internal/execution/generic_map.go +++ b/internal/execution/generic_map.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func MapB(fn func(bool) bool, a []bool) { for i := range a { a[i] = fn(a[i]) diff --git a/internal/execution/generic_minmax.go b/internal/execution/generic_minmax.go index 170f01b..2cc94e4 100644 --- a/internal/execution/generic_minmax.go +++ b/internal/execution/generic_minmax.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution +// Code generated by genlib2. DO NOT EDIT. + func VecMinI(a, b []int) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_reduce.go b/internal/execution/generic_reduce.go index a489f1c..ef94057 100644 --- a/internal/execution/generic_reduce.go +++ b/internal/execution/generic_reduce.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func ReduceB(f func(a, b bool) bool, def bool, l ...bool) (retVal bool) { retVal = def if len(l) == 0 { diff --git a/internal/execution/generic_unary.go b/internal/execution/generic_unary.go index cb3f87f..7c05acd 100644 --- a/internal/execution/generic_unary.go +++ b/internal/execution/generic_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func NegI(a []int) { for i := range a { a[i] = -a[i] diff --git a/internal/execution/reduction_specialization.go b/internal/execution/reduction_specialization.go index e83e67e..90cfe69 100644 --- a/internal/execution/reduction_specialization.go +++ b/internal/execution/reduction_specialization.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func MonotonicSum(t reflect.Type, a *storage.Header) (retVal interface{}, err error) { switch t { case Int: diff --git a/internal/storage/consts.go b/internal/storage/consts.go index 7304ac5..b6e03cc 100644 --- a/internal/storage/consts.go +++ b/internal/storage/consts.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package storage import ( @@ -7,6 +5,8 @@ import ( "unsafe" ) +// Code generated by genlib2. DO NOT EDIT. + var ( bType = reflect.TypeOf(bool(false)) iType = reflect.TypeOf(int(0)) diff --git a/internal/storage/getset.go b/internal/storage/getset.go index c60d61c..89421f0 100644 --- a/internal/storage/getset.go +++ b/internal/storage/getset.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package storage import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + /* bool */ func (h *Header) Bools() []bool { diff --git a/iterator_native.go b/iterator_native.go new file mode 100644 index 0000000..470891d --- /dev/null +++ b/iterator_native.go @@ -0,0 +1,1152 @@ +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. + +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} + +/* Native Iterables for bool */ + +// nativeDenseVectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// nativeDenseMatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3B(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int */ + +// nativeDenseVectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// nativeDenseMatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int8 */ + +// nativeDenseVectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// nativeDenseMatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int16 */ + +// nativeDenseVectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// nativeDenseMatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int32 */ + +// nativeDenseVectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// nativeDenseMatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int64 */ + +// nativeDenseVectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// nativeDenseMatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint */ + +// nativeDenseVectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// nativeDenseMatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint8 */ + +// nativeDenseVectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// nativeDenseMatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint16 */ + +// nativeDenseVectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// nativeDenseMatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint32 */ + +// nativeDenseVectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// nativeDenseMatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint64 */ + +// nativeDenseVectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// nativeDenseMatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float32 */ + +// nativeDenseVectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// nativeDenseMatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3F32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float64 */ + +// nativeDenseVectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// nativeDenseMatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3F64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex64 */ + +// nativeDenseVectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// nativeDenseMatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3C64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex128 */ + +// nativeDenseVectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// nativeDenseMatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3C128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for string */ + +// nativeDenseVectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// nativeDenseMatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3Str(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} diff --git a/iterator_native_test.go b/iterator_native_test.go new file mode 100644 index 0000000..afcd14d --- /dev/null +++ b/iterator_native_test.go @@ -0,0 +1,633 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Code generated by genlib2. DO NOT EDIT. + +func Test_nativeDenseVectorB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(6)) + it, err := nativeDenseVectorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3)) + it, err := nativeDenseMatrixB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3B(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3B(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3F32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3F32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3F64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3F64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3C64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3C64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3C128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3C128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(6)) + it, err := nativeDenseVectorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3)) + it, err := nativeDenseMatrixStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3Str(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3Str(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} diff --git a/native/iterator_native.go b/native/iterator_native.go index b820159..8ebc0e5 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -1,1153 +1,331 @@ -// Code generated by genlib2. DO NOT EDIT. +// +build !purego package native +// Code generated by genlib2. DO NOT EDIT. + import ( - "reflect" - "unsafe" + _ "unsafe" - "github.com/pkg/errors" - "gorgonia.org/dtype" - . "gorgonia.org/tensor" + "gorgonia.org/tensor" ) -func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { - // checks: - if !t.IsNativelyAccessible() { - return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - } - - if t.Shape().Dims() != dims { - return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) - } - - if t.F() || t.RequiresIterator() { - return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") - } - - if t.Dtype() != dt { - return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) - } - - return nil -} - -/* Native Iterables for bool */ +//go:linkname VectorB gorgonia.org/tensor.nativeDenseVectorB // VectorB converts a *Dense into a []bool // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorB(t *Dense) (retVal []bool, err error) { - if err = checkNativeIterable(t, 1, Bool); err != nil { - return nil, err - } - return t.Bools(), nil -} +func VectorB(t *tensor.Dense) (retVal []bool, err error) + +//go:linkname MatrixB gorgonia.org/tensor.nativeDenseMatrixB // MatrixB converts a *Dense into a [][]bool // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixB(t *Dense) (retVal [][]bool, err error) { - if err = checkNativeIterable(t, 2, Bool); err != nil { - return nil, err - } - - data := t.Bools() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]bool, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]bool, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixB(t *tensor.Dense) (retVal [][]bool, err error) + +//go:linkname Tensor3B gorgonia.org/tensor.nativeDenseTensor3B // Tensor3B converts a *Dense into a [][][]bool. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3B(t *Dense) (retVal [][][]bool, err error) { - if err = checkNativeIterable(t, 3, Bool); err != nil { - return nil, err - } - - data := t.Bools() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]bool, layers) - for i := range retVal { - retVal[i] = make([][]bool, rows) - for j := range retVal[i] { - retVal[i][j] = make([]bool, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int */ +func Tensor3B(t *tensor.Dense) (retVal [][][]bool, err error) + +//go:linkname VectorI gorgonia.org/tensor.nativeDenseVectorI // VectorI converts a *Dense into a []int // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI(t *Dense) (retVal []int, err error) { - if err = checkNativeIterable(t, 1, Int); err != nil { - return nil, err - } - return t.Ints(), nil -} +func VectorI(t *tensor.Dense) (retVal []int, err error) + +//go:linkname MatrixI gorgonia.org/tensor.nativeDenseMatrixI // MatrixI converts a *Dense into a [][]int // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI(t *Dense) (retVal [][]int, err error) { - if err = checkNativeIterable(t, 2, Int); err != nil { - return nil, err - } - - data := t.Ints() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI(t *tensor.Dense) (retVal [][]int, err error) + +//go:linkname Tensor3I gorgonia.org/tensor.nativeDenseTensor3I // Tensor3I converts a *Dense into a [][][]int. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I(t *Dense) (retVal [][][]int, err error) { - if err = checkNativeIterable(t, 3, Int); err != nil { - return nil, err - } - - data := t.Ints() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int, layers) - for i := range retVal { - retVal[i] = make([][]int, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int8 */ +func Tensor3I(t *tensor.Dense) (retVal [][][]int, err error) + +//go:linkname VectorI8 gorgonia.org/tensor.nativeDenseVectorI8 // VectorI8 converts a *Dense into a []int8 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI8(t *Dense) (retVal []int8, err error) { - if err = checkNativeIterable(t, 1, Int8); err != nil { - return nil, err - } - return t.Int8s(), nil -} +func VectorI8(t *tensor.Dense) (retVal []int8, err error) + +//go:linkname MatrixI8 gorgonia.org/tensor.nativeDenseMatrixI8 // MatrixI8 converts a *Dense into a [][]int8 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI8(t *Dense) (retVal [][]int8, err error) { - if err = checkNativeIterable(t, 2, Int8); err != nil { - return nil, err - } - - data := t.Int8s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int8, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int8, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI8(t *tensor.Dense) (retVal [][]int8, err error) + +//go:linkname Tensor3I8 gorgonia.org/tensor.nativeDenseTensor3I8 // Tensor3I8 converts a *Dense into a [][][]int8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { - if err = checkNativeIterable(t, 3, Int8); err != nil { - return nil, err - } - - data := t.Int8s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int8, layers) - for i := range retVal { - retVal[i] = make([][]int8, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int8, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int16 */ +func Tensor3I8(t *tensor.Dense) (retVal [][][]int8, err error) + +//go:linkname VectorI16 gorgonia.org/tensor.nativeDenseVectorI16 // VectorI16 converts a *Dense into a []int16 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI16(t *Dense) (retVal []int16, err error) { - if err = checkNativeIterable(t, 1, Int16); err != nil { - return nil, err - } - return t.Int16s(), nil -} +func VectorI16(t *tensor.Dense) (retVal []int16, err error) + +//go:linkname MatrixI16 gorgonia.org/tensor.nativeDenseMatrixI16 // MatrixI16 converts a *Dense into a [][]int16 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI16(t *Dense) (retVal [][]int16, err error) { - if err = checkNativeIterable(t, 2, Int16); err != nil { - return nil, err - } - - data := t.Int16s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int16, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int16, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI16(t *tensor.Dense) (retVal [][]int16, err error) + +//go:linkname Tensor3I16 gorgonia.org/tensor.nativeDenseTensor3I16 // Tensor3I16 converts a *Dense into a [][][]int16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { - if err = checkNativeIterable(t, 3, Int16); err != nil { - return nil, err - } - - data := t.Int16s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int16, layers) - for i := range retVal { - retVal[i] = make([][]int16, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int16, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int32 */ +func Tensor3I16(t *tensor.Dense) (retVal [][][]int16, err error) + +//go:linkname VectorI32 gorgonia.org/tensor.nativeDenseVectorI32 // VectorI32 converts a *Dense into a []int32 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI32(t *Dense) (retVal []int32, err error) { - if err = checkNativeIterable(t, 1, Int32); err != nil { - return nil, err - } - return t.Int32s(), nil -} +func VectorI32(t *tensor.Dense) (retVal []int32, err error) + +//go:linkname MatrixI32 gorgonia.org/tensor.nativeDenseMatrixI32 // MatrixI32 converts a *Dense into a [][]int32 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI32(t *Dense) (retVal [][]int32, err error) { - if err = checkNativeIterable(t, 2, Int32); err != nil { - return nil, err - } - - data := t.Int32s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int32, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI32(t *tensor.Dense) (retVal [][]int32, err error) + +//go:linkname Tensor3I32 gorgonia.org/tensor.nativeDenseTensor3I32 // Tensor3I32 converts a *Dense into a [][][]int32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { - if err = checkNativeIterable(t, 3, Int32); err != nil { - return nil, err - } - - data := t.Int32s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int32, layers) - for i := range retVal { - retVal[i] = make([][]int32, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int32, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int64 */ +func Tensor3I32(t *tensor.Dense) (retVal [][][]int32, err error) + +//go:linkname VectorI64 gorgonia.org/tensor.nativeDenseVectorI64 // VectorI64 converts a *Dense into a []int64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI64(t *Dense) (retVal []int64, err error) { - if err = checkNativeIterable(t, 1, Int64); err != nil { - return nil, err - } - return t.Int64s(), nil -} +func VectorI64(t *tensor.Dense) (retVal []int64, err error) + +//go:linkname MatrixI64 gorgonia.org/tensor.nativeDenseMatrixI64 // MatrixI64 converts a *Dense into a [][]int64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI64(t *Dense) (retVal [][]int64, err error) { - if err = checkNativeIterable(t, 2, Int64); err != nil { - return nil, err - } - - data := t.Int64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI64(t *tensor.Dense) (retVal [][]int64, err error) + +//go:linkname Tensor3I64 gorgonia.org/tensor.nativeDenseTensor3I64 // Tensor3I64 converts a *Dense into a [][][]int64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { - if err = checkNativeIterable(t, 3, Int64); err != nil { - return nil, err - } - - data := t.Int64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int64, layers) - for i := range retVal { - retVal[i] = make([][]int64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint */ +func Tensor3I64(t *tensor.Dense) (retVal [][][]int64, err error) + +//go:linkname VectorU gorgonia.org/tensor.nativeDenseVectorU // VectorU converts a *Dense into a []uint // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU(t *Dense) (retVal []uint, err error) { - if err = checkNativeIterable(t, 1, Uint); err != nil { - return nil, err - } - return t.Uints(), nil -} +func VectorU(t *tensor.Dense) (retVal []uint, err error) + +//go:linkname MatrixU gorgonia.org/tensor.nativeDenseMatrixU // MatrixU converts a *Dense into a [][]uint // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU(t *Dense) (retVal [][]uint, err error) { - if err = checkNativeIterable(t, 2, Uint); err != nil { - return nil, err - } - - data := t.Uints() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU(t *tensor.Dense) (retVal [][]uint, err error) + +//go:linkname Tensor3U gorgonia.org/tensor.nativeDenseTensor3U // Tensor3U converts a *Dense into a [][][]uint. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U(t *Dense) (retVal [][][]uint, err error) { - if err = checkNativeIterable(t, 3, Uint); err != nil { - return nil, err - } - - data := t.Uints() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint, layers) - for i := range retVal { - retVal[i] = make([][]uint, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint8 */ +func Tensor3U(t *tensor.Dense) (retVal [][][]uint, err error) + +//go:linkname VectorU8 gorgonia.org/tensor.nativeDenseVectorU8 // VectorU8 converts a *Dense into a []uint8 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU8(t *Dense) (retVal []uint8, err error) { - if err = checkNativeIterable(t, 1, Uint8); err != nil { - return nil, err - } - return t.Uint8s(), nil -} +func VectorU8(t *tensor.Dense) (retVal []uint8, err error) + +//go:linkname MatrixU8 gorgonia.org/tensor.nativeDenseMatrixU8 // MatrixU8 converts a *Dense into a [][]uint8 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU8(t *Dense) (retVal [][]uint8, err error) { - if err = checkNativeIterable(t, 2, Uint8); err != nil { - return nil, err - } - - data := t.Uint8s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint8, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint8, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU8(t *tensor.Dense) (retVal [][]uint8, err error) + +//go:linkname Tensor3U8 gorgonia.org/tensor.nativeDenseTensor3U8 // Tensor3U8 converts a *Dense into a [][][]uint8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { - if err = checkNativeIterable(t, 3, Uint8); err != nil { - return nil, err - } - - data := t.Uint8s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint8, layers) - for i := range retVal { - retVal[i] = make([][]uint8, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint8, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint16 */ +func Tensor3U8(t *tensor.Dense) (retVal [][][]uint8, err error) + +//go:linkname VectorU16 gorgonia.org/tensor.nativeDenseVectorU16 // VectorU16 converts a *Dense into a []uint16 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU16(t *Dense) (retVal []uint16, err error) { - if err = checkNativeIterable(t, 1, Uint16); err != nil { - return nil, err - } - return t.Uint16s(), nil -} +func VectorU16(t *tensor.Dense) (retVal []uint16, err error) + +//go:linkname MatrixU16 gorgonia.org/tensor.nativeDenseMatrixU16 // MatrixU16 converts a *Dense into a [][]uint16 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU16(t *Dense) (retVal [][]uint16, err error) { - if err = checkNativeIterable(t, 2, Uint16); err != nil { - return nil, err - } - - data := t.Uint16s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint16, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint16, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU16(t *tensor.Dense) (retVal [][]uint16, err error) + +//go:linkname Tensor3U16 gorgonia.org/tensor.nativeDenseTensor3U16 // Tensor3U16 converts a *Dense into a [][][]uint16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { - if err = checkNativeIterable(t, 3, Uint16); err != nil { - return nil, err - } - - data := t.Uint16s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint16, layers) - for i := range retVal { - retVal[i] = make([][]uint16, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint16, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint32 */ +func Tensor3U16(t *tensor.Dense) (retVal [][][]uint16, err error) + +//go:linkname VectorU32 gorgonia.org/tensor.nativeDenseVectorU32 // VectorU32 converts a *Dense into a []uint32 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU32(t *Dense) (retVal []uint32, err error) { - if err = checkNativeIterable(t, 1, Uint32); err != nil { - return nil, err - } - return t.Uint32s(), nil -} +func VectorU32(t *tensor.Dense) (retVal []uint32, err error) + +//go:linkname MatrixU32 gorgonia.org/tensor.nativeDenseMatrixU32 // MatrixU32 converts a *Dense into a [][]uint32 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU32(t *Dense) (retVal [][]uint32, err error) { - if err = checkNativeIterable(t, 2, Uint32); err != nil { - return nil, err - } - - data := t.Uint32s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint32, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU32(t *tensor.Dense) (retVal [][]uint32, err error) + +//go:linkname Tensor3U32 gorgonia.org/tensor.nativeDenseTensor3U32 // Tensor3U32 converts a *Dense into a [][][]uint32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { - if err = checkNativeIterable(t, 3, Uint32); err != nil { - return nil, err - } - - data := t.Uint32s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint32, layers) - for i := range retVal { - retVal[i] = make([][]uint32, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint32, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint64 */ +func Tensor3U32(t *tensor.Dense) (retVal [][][]uint32, err error) + +//go:linkname VectorU64 gorgonia.org/tensor.nativeDenseVectorU64 // VectorU64 converts a *Dense into a []uint64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU64(t *Dense) (retVal []uint64, err error) { - if err = checkNativeIterable(t, 1, Uint64); err != nil { - return nil, err - } - return t.Uint64s(), nil -} +func VectorU64(t *tensor.Dense) (retVal []uint64, err error) + +//go:linkname MatrixU64 gorgonia.org/tensor.nativeDenseMatrixU64 // MatrixU64 converts a *Dense into a [][]uint64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU64(t *Dense) (retVal [][]uint64, err error) { - if err = checkNativeIterable(t, 2, Uint64); err != nil { - return nil, err - } - - data := t.Uint64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU64(t *tensor.Dense) (retVal [][]uint64, err error) + +//go:linkname Tensor3U64 gorgonia.org/tensor.nativeDenseTensor3U64 // Tensor3U64 converts a *Dense into a [][][]uint64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { - if err = checkNativeIterable(t, 3, Uint64); err != nil { - return nil, err - } - - data := t.Uint64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint64, layers) - for i := range retVal { - retVal[i] = make([][]uint64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for float32 */ +func Tensor3U64(t *tensor.Dense) (retVal [][][]uint64, err error) + +//go:linkname VectorF32 gorgonia.org/tensor.nativeDenseVectorF32 // VectorF32 converts a *Dense into a []float32 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorF32(t *Dense) (retVal []float32, err error) { - if err = checkNativeIterable(t, 1, Float32); err != nil { - return nil, err - } - return t.Float32s(), nil -} +func VectorF32(t *tensor.Dense) (retVal []float32, err error) + +//go:linkname MatrixF32 gorgonia.org/tensor.nativeDenseMatrixF32 // MatrixF32 converts a *Dense into a [][]float32 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixF32(t *Dense) (retVal [][]float32, err error) { - if err = checkNativeIterable(t, 2, Float32); err != nil { - return nil, err - } - - data := t.Float32s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]float32, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]float32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixF32(t *tensor.Dense) (retVal [][]float32, err error) + +//go:linkname Tensor3F32 gorgonia.org/tensor.nativeDenseTensor3F32 // Tensor3F32 converts a *Dense into a [][][]float32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { - if err = checkNativeIterable(t, 3, Float32); err != nil { - return nil, err - } - - data := t.Float32s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]float32, layers) - for i := range retVal { - retVal[i] = make([][]float32, rows) - for j := range retVal[i] { - retVal[i][j] = make([]float32, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for float64 */ +func Tensor3F32(t *tensor.Dense) (retVal [][][]float32, err error) + +//go:linkname VectorF64 gorgonia.org/tensor.nativeDenseVectorF64 // VectorF64 converts a *Dense into a []float64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorF64(t *Dense) (retVal []float64, err error) { - if err = checkNativeIterable(t, 1, Float64); err != nil { - return nil, err - } - return t.Float64s(), nil -} +func VectorF64(t *tensor.Dense) (retVal []float64, err error) + +//go:linkname MatrixF64 gorgonia.org/tensor.nativeDenseMatrixF64 // MatrixF64 converts a *Dense into a [][]float64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixF64(t *Dense) (retVal [][]float64, err error) { - if err = checkNativeIterable(t, 2, Float64); err != nil { - return nil, err - } - - data := t.Float64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]float64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]float64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixF64(t *tensor.Dense) (retVal [][]float64, err error) + +//go:linkname Tensor3F64 gorgonia.org/tensor.nativeDenseTensor3F64 // Tensor3F64 converts a *Dense into a [][][]float64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { - if err = checkNativeIterable(t, 3, Float64); err != nil { - return nil, err - } - - data := t.Float64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]float64, layers) - for i := range retVal { - retVal[i] = make([][]float64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]float64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for complex64 */ +func Tensor3F64(t *tensor.Dense) (retVal [][][]float64, err error) + +//go:linkname VectorC64 gorgonia.org/tensor.nativeDenseVectorC64 // VectorC64 converts a *Dense into a []complex64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorC64(t *Dense) (retVal []complex64, err error) { - if err = checkNativeIterable(t, 1, Complex64); err != nil { - return nil, err - } - return t.Complex64s(), nil -} +func VectorC64(t *tensor.Dense) (retVal []complex64, err error) + +//go:linkname MatrixC64 gorgonia.org/tensor.nativeDenseMatrixC64 // MatrixC64 converts a *Dense into a [][]complex64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixC64(t *Dense) (retVal [][]complex64, err error) { - if err = checkNativeIterable(t, 2, Complex64); err != nil { - return nil, err - } - - data := t.Complex64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]complex64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]complex64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixC64(t *tensor.Dense) (retVal [][]complex64, err error) + +//go:linkname Tensor3C64 gorgonia.org/tensor.nativeDenseTensor3C64 // Tensor3C64 converts a *Dense into a [][][]complex64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { - if err = checkNativeIterable(t, 3, Complex64); err != nil { - return nil, err - } - - data := t.Complex64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]complex64, layers) - for i := range retVal { - retVal[i] = make([][]complex64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]complex64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for complex128 */ +func Tensor3C64(t *tensor.Dense) (retVal [][][]complex64, err error) + +//go:linkname VectorC128 gorgonia.org/tensor.nativeDenseVectorC128 // VectorC128 converts a *Dense into a []complex128 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorC128(t *Dense) (retVal []complex128, err error) { - if err = checkNativeIterable(t, 1, Complex128); err != nil { - return nil, err - } - return t.Complex128s(), nil -} +func VectorC128(t *tensor.Dense) (retVal []complex128, err error) + +//go:linkname MatrixC128 gorgonia.org/tensor.nativeDenseMatrixC128 // MatrixC128 converts a *Dense into a [][]complex128 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixC128(t *Dense) (retVal [][]complex128, err error) { - if err = checkNativeIterable(t, 2, Complex128); err != nil { - return nil, err - } - - data := t.Complex128s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]complex128, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]complex128, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixC128(t *tensor.Dense) (retVal [][]complex128, err error) + +//go:linkname Tensor3C128 gorgonia.org/tensor.nativeDenseTensor3C128 // Tensor3C128 converts a *Dense into a [][][]complex128. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { - if err = checkNativeIterable(t, 3, Complex128); err != nil { - return nil, err - } - - data := t.Complex128s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]complex128, layers) - for i := range retVal { - retVal[i] = make([][]complex128, rows) - for j := range retVal[i] { - retVal[i][j] = make([]complex128, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for string */ +func Tensor3C128(t *tensor.Dense) (retVal [][][]complex128, err error) + +//go:linkname VectorStr gorgonia.org/tensor.nativeDenseVectorStr // VectorStr converts a *Dense into a []string // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorStr(t *Dense) (retVal []string, err error) { - if err = checkNativeIterable(t, 1, String); err != nil { - return nil, err - } - return t.Strings(), nil -} +func VectorStr(t *tensor.Dense) (retVal []string, err error) + +//go:linkname MatrixStr gorgonia.org/tensor.nativeDenseMatrixStr // MatrixStr converts a *Dense into a [][]string // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixStr(t *Dense) (retVal [][]string, err error) { - if err = checkNativeIterable(t, 2, String); err != nil { - return nil, err - } - - data := t.Strings() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]string, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]string, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixStr(t *tensor.Dense) (retVal [][]string, err error) + +//go:linkname Tensor3Str gorgonia.org/tensor.nativeDenseTensor3Str // Tensor3Str converts a *Dense into a [][][]string. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3Str(t *Dense) (retVal [][][]string, err error) { - if err = checkNativeIterable(t, 3, String); err != nil { - return nil, err - } - - data := t.Strings() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]string, layers) - for i := range retVal { - retVal[i] = make([][]string, rows) - for j := range retVal[i] { - retVal[i][j] = make([]string, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} +func Tensor3Str(t *tensor.Dense) (retVal [][][]string, err error) diff --git a/native/iterator_native2.go b/native/iterator_native2.go index d47bfb3..2d0cdf5 100644 --- a/native/iterator_native2.go +++ b/native/iterator_native2.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package native +// Code generated by genlib2. DO NOT EDIT. + import ( "reflect" "unsafe" diff --git a/native/iterator_native2_test.go b/native/iterator_native2_test.go index df56b5e..a6f247f 100644 --- a/native/iterator_native2_test.go +++ b/native/iterator_native2_test.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package native +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" diff --git a/native/iterator_native_purego.go b/native/iterator_native_purego.go new file mode 100644 index 0000000..57e03c1 --- /dev/null +++ b/native/iterator_native_purego.go @@ -0,0 +1,1132 @@ +// +build purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +/* Native Iterables for bool */ + +// VectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// MatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3B(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int */ + +// VectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// MatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int8 */ + +// VectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// MatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int16 */ + +// VectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// MatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int32 */ + +// VectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// MatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int64 */ + +// VectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// MatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint */ + +// VectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// MatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint8 */ + +// VectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// MatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint16 */ + +// VectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// MatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint32 */ + +// VectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// MatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint64 */ + +// VectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// MatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float32 */ + +// VectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// MatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float64 */ + +// VectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// MatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex64 */ + +// VectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// MatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex128 */ + +// VectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// MatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for string */ + +// VectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// MatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3Str(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} diff --git a/native/iterator_native_test.go b/native/iterator_native_test.go index 09236a0..2e99966 100644 --- a/native/iterator_native_test.go +++ b/native/iterator_native_test.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package native +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" diff --git a/native/utils.go b/native/utils.go new file mode 100644 index 0000000..78c561e --- /dev/null +++ b/native/utils.go @@ -0,0 +1,30 @@ +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "github.com/pkg/errors" + "gorgonia.org/dtype" + . "gorgonia.org/tensor" +) + +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} diff --git a/test_test.go b/test_test.go index 772a71f..f5a7e0c 100644 --- a/test_test.go +++ b/test_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -12,6 +10,8 @@ import ( "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func anyToFloat64s(x interface{}) (retVal []float64) { switch xt := x.(type) { case []int: From d1906c96fb69a50f6a41ebd5ef57758bfdb40d0e Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 13:32:50 +1000 Subject: [PATCH 105/154] Moved native select into package `tensor` as well. --- genlib2/main.go | 9 +- genlib2/native_iterator.go | 2 +- genlib2/native_select.go | 130 +++-- iterator_native2.go | 635 ++++++++++++++++++++++ iterator_native2_test.go | 841 ++++++++++++++++++++++++++++++ native/iterator_native2.go | 711 +++---------------------- native/iterator_native2_purego.go | 620 ++++++++++++++++++++++ native/utils.go | 16 + 8 files changed, 2294 insertions(+), 670 deletions(-) create mode 100644 iterator_native2.go create mode 100644 iterator_native2_test.go create mode 100644 native/iterator_native2_purego.go diff --git a/genlib2/main.go b/genlib2/main.go index 63f6987..87f2788 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -117,14 +117,17 @@ func main() { // native iterators - the ones in the tensor package pipeline(tensorPkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators(false)) pipeline(tensorPkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(false)) - pipeline(nativePkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect) - pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests) + pipeline(tensorPkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect(false)) + pipeline(tensorPkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests(false)) // native iterators - exported into gorgonia.org/tensor/native pipeline(nativePkgLoc+"_unsafe", "iterator_native.go", Kinds{allKinds}, generateNativeIteratorStubs) pipeline(nativePkgLoc+"_purego", "iterator_native_purego.go", Kinds{allKinds}, generateNativeIterators(true)) pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(true)) - pipeline(nativePkgLoc, "utils.go", Kinds{allKinds}, generateNativeChecks) + pipeline(nativePkgLoc+"_unsafe", "iterator_native2.go", Kinds{allKinds}, generateNativeSelectStubs) + pipeline(nativePkgLoc+"_purego", "iterator_native2_purego.go", Kinds{allKinds}, generateNativeSelect(true)) + pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests(true)) + pipeline(nativePkgLoc, "utils.go", Kinds{allKinds}, generateNativeIterChecks, generateNativeSelChecks) } func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) { diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go index f1e14be..1d7a85c 100644 --- a/genlib2/native_iterator.go +++ b/genlib2/native_iterator.go @@ -264,7 +264,7 @@ func generateNativeIteratorStubs(f io.Writer, ak Kinds) { } } -func generateNativeChecks(f io.Writer, ak Kinds) { +func generateNativeIterChecks(f io.Writer, ak Kinds) { fmt.Fprintf(f, importUnqualifiedTensor) fmt.Fprintf(f, "%v\n", checkNativeiterable) } diff --git a/genlib2/native_select.go b/genlib2/native_select.go index a05ce18..1095668 100644 --- a/genlib2/native_select.go +++ b/genlib2/native_select.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io" + "reflect" "text/template" ) @@ -22,29 +23,35 @@ const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt return nil } ` -const nativeSelectRaw = `// Select{{short .}} creates a slice of flat data types. See Example of NativeSelectF64. -func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) { - if err := checkNativeSelectable(t, axis, {{reflectKind .}}); err != nil { +const nativeSelectRaw = ` +{{- $selName := ( printf "nativeSelect%s" (short .K) ) -}} +{{- if .N -}} + {{- $selName = ( printf "Select%s" (short .K) ) -}} +{{- end -}} + +// {{$selName}} creates a slice of flat data types. See Example of NativeSelectF64. +func {{$selName}}(t *Dense, axis int) (retVal [][]{{asType .K}}, err error) { + if err := checkNativeSelectable(t, axis, {{reflectKind .K}}); err != nil { return nil, err } switch t.Shape().Dims() { case 0, 1: - retVal = make([][]{{asType .}}, 1) - retVal[0] = t.{{sliceOf .}} + retVal = make([][]{{asType .K}}, 1) + retVal[0] = t.{{sliceOf .K}} case 2: if axis == 0 { - return Matrix{{short .}}(t) + return {{if .N}}Matrix{{short .K}}{{else}}nativeDenseMatrix{{short .K}}{{end}}(t) } fallthrough default: // size := t.Shape()[axis] - data := t.{{sliceOf .}} + data := t.{{sliceOf .K}} stride := t.Strides()[axis] upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]{{asType .}}, 0, upper) + retVal = make([][]{{asType .K}}, 0, upper) for i, r := 0, 0; r < upper; i += stride { - s := make([]{{asType .}}, 0) + s := make([]{{asType .K}}, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) hdr.Len = stride @@ -58,85 +65,132 @@ func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) return } ` -const nativeSelectTestRaw = `func TestSelect{{short .}}(t *testing.T) { +const nativeSelectTestRaw = ` +{{- $selName := ( printf "nativeSelect%s" (short .K) ) -}} +{{- if .N -}} + {{- $selName = ( printf "Select%s" (short .K) ) -}} +{{- end -}} +func Test{{$selName}}(t *testing.T) { assert := assert.New(t) var T *Dense var err error - var x [][]{{asType .}} - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) - if x, err = Select{{short .}}(T, 1); err != nil { + var x [][]{{asType .K}} + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 1); err != nil { t.Fatal(err) } assert.Equal(6, len(x)) assert.Equal(20, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) - if x, err = Select{{short .}}(T, 0); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 0); err != nil { t.Fatal(err) } assert.Equal(2, len(x)) assert.Equal(60, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) - if x, err = Select{{short .}}(T, 3); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 3); err != nil { t.Fatal(err) } assert.Equal(120, len(x)) assert.Equal(1, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3), ) - if x, err = Select{{short .}}(T, 0); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3), ) + if x, err = {{$selName}}(T, 0); err != nil { t.Fatal(err) } assert.Equal(2, len(x)) assert.Equal(3, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3), ) - if x, err = Select{{short .}}(T, 1); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3), ) + if x, err = {{$selName}}(T, 1); err != nil { t.Fatal(err) } assert.Equal(6, len(x)) assert.Equal(1, len(x[0])) - T = New(FromScalar({{if eq .String "bool" -}}false{{else if eq .String "string" -}}""{{else -}}{{asType .}}(0) {{end -}} )) - if x, err = Select{{short .}}(T, 0); err != nil { + T = New(FromScalar({{if eq .K.String "bool" -}}false{{else if eq .K.String "string" -}}""{{else -}}{{asType .K}}(0) {{end -}} )) + if x, err = {{$selName}}(T, 0); err != nil { t.Fatal(err) } assert.Equal(1, len(x)) assert.Equal(1, len(x[0])) - if _, err = Select{{short .}}(T, 10); err == nil{ + if _, err = {{$selName}}(T, 10); err == nil{ t.Fatal("Expected errors") } } ` +const nativeSelectStubsRaw = `//go:linkname Select{{short .}} gorgonia.org/tensor.nativeSelect{{short .}} + +// Select{{short .}} creates a slice of {{asType .}}s. See Example of NativeSelectF64. +func Select{{short .}}(t *tensor.Dense, axis int) (retVal [][]{{asType .}}, err error) +` + var ( - NativeSelect *template.Template - NativeSelectTest *template.Template + NativeSelect *template.Template + NativeSelectTest *template.Template + NativeSelectStubs *template.Template ) func init() { NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw)) NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw)) + NativeSelectStubs = template.Must(template.New("NativeSelectStub").Funcs(funcs).Parse(nativeSelectStubsRaw)) } -func generateNativeSelect(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) - fmt.Fprintf(f, "%v\n", checkNativeSelectable) - ks := filter(ak.Kinds, isSpecialized) - for _, k := range ks { - fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) - NativeSelect.Execute(f, k) - fmt.Fprint(f, "\n\n") +// generateNativeSelect generates code for the native selection. `isNative` indicates if the +// code is meant to be generated for the native package. The code is generated for the native package +// only for the purposes of the `purego` build tag. +func generateNativeSelect(isNative bool) func(io.Writer, Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } else { + fmt.Fprintf(f, "%v\n", checkNativeSelectable) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) + NativeSelect.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } } } -func generateNativeSelectTests(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) +func generateNativeSelectTests(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeSelectTest.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeSelectStubs(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnsafe) // this is required for go:linkname to work ks := filter(ak.Kinds, isSpecialized) for _, k := range ks { - NativeSelectTest.Execute(f, k) - fmt.Fprint(f, "\n\n") + NativeSelectStubs.Execute(f, k) + fmt.Fprintf(f, "\n\n") } } + +func generateNativeSelChecks(f io.Writer, ak Kinds) { + // fmt.Fprintf(f, importUnqualifiedTensor) // already generated by generateNativeIterChecks + fmt.Fprintf(f, "%v\n", checkNativeSelectable) +} diff --git a/iterator_native2.go b/iterator_native2.go new file mode 100644 index 0000000..d3cf1f2 --- /dev/null +++ b/iterator_native2.go @@ -0,0 +1,635 @@ +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. + +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} + +/* Native Select for bool */ + +// nativeSelectB creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectB(t *Dense, axis int) (retVal [][]bool, err error) { + if err := checkNativeSelectable(t, axis, Bool); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]bool, 1) + retVal[0] = t.Bools() + case 2: + if axis == 0 { + return nativeDenseMatrixB(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Bools() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]bool, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int */ + +// nativeSelectI creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI(t *Dense, axis int) (retVal [][]int, err error) { + if err := checkNativeSelectable(t, axis, Int); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int, 1) + retVal[0] = t.Ints() + case 2: + if axis == 0 { + return nativeDenseMatrixI(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Ints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int8 */ + +// nativeSelectI8 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI8(t *Dense, axis int) (retVal [][]int8, err error) { + if err := checkNativeSelectable(t, axis, Int8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int8, 1) + retVal[0] = t.Int8s() + case 2: + if axis == 0 { + return nativeDenseMatrixI8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int16 */ + +// nativeSelectI16 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI16(t *Dense, axis int) (retVal [][]int16, err error) { + if err := checkNativeSelectable(t, axis, Int16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int16, 1) + retVal[0] = t.Int16s() + case 2: + if axis == 0 { + return nativeDenseMatrixI16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int32 */ + +// nativeSelectI32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI32(t *Dense, axis int) (retVal [][]int32, err error) { + if err := checkNativeSelectable(t, axis, Int32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int32, 1) + retVal[0] = t.Int32s() + case 2: + if axis == 0 { + return nativeDenseMatrixI32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int64 */ + +// nativeSelectI64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI64(t *Dense, axis int) (retVal [][]int64, err error) { + if err := checkNativeSelectable(t, axis, Int64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int64, 1) + retVal[0] = t.Int64s() + case 2: + if axis == 0 { + return nativeDenseMatrixI64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint */ + +// nativeSelectU creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU(t *Dense, axis int) (retVal [][]uint, err error) { + if err := checkNativeSelectable(t, axis, Uint); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint, 1) + retVal[0] = t.Uints() + case 2: + if axis == 0 { + return nativeDenseMatrixU(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint8 */ + +// nativeSelectU8 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { + if err := checkNativeSelectable(t, axis, Uint8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint8, 1) + retVal[0] = t.Uint8s() + case 2: + if axis == 0 { + return nativeDenseMatrixU8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint16 */ + +// nativeSelectU16 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { + if err := checkNativeSelectable(t, axis, Uint16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint16, 1) + retVal[0] = t.Uint16s() + case 2: + if axis == 0 { + return nativeDenseMatrixU16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint32 */ + +// nativeSelectU32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { + if err := checkNativeSelectable(t, axis, Uint32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint32, 1) + retVal[0] = t.Uint32s() + case 2: + if axis == 0 { + return nativeDenseMatrixU32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint64 */ + +// nativeSelectU64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { + if err := checkNativeSelectable(t, axis, Uint64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint64, 1) + retVal[0] = t.Uint64s() + case 2: + if axis == 0 { + return nativeDenseMatrixU64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float32 */ + +// nativeSelectF32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectF32(t *Dense, axis int) (retVal [][]float32, err error) { + if err := checkNativeSelectable(t, axis, Float32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float32, 1) + retVal[0] = t.Float32s() + case 2: + if axis == 0 { + return nativeDenseMatrixF32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float64 */ + +// nativeSelectF64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectF64(t *Dense, axis int) (retVal [][]float64, err error) { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float64, 1) + retVal[0] = t.Float64s() + case 2: + if axis == 0 { + return nativeDenseMatrixF64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex64 */ + +// nativeSelectC64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { + if err := checkNativeSelectable(t, axis, Complex64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex64, 1) + retVal[0] = t.Complex64s() + case 2: + if axis == 0 { + return nativeDenseMatrixC64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex128 */ + +// nativeSelectC128 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { + if err := checkNativeSelectable(t, axis, Complex128); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex128, 1) + retVal[0] = t.Complex128s() + case 2: + if axis == 0 { + return nativeDenseMatrixC128(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex128s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex128, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for string */ + +// nativeSelectStr creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectStr(t *Dense, axis int) (retVal [][]string, err error) { + if err := checkNativeSelectable(t, axis, String); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]string, 1) + retVal[0] = t.Strings() + case 2: + if axis == 0 { + return nativeDenseMatrixStr(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Strings() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]string, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} diff --git a/iterator_native2_test.go b/iterator_native2_test.go new file mode 100644 index 0000000..02291b5 --- /dev/null +++ b/iterator_native2_test.go @@ -0,0 +1,841 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Code generated by genlib2. DO NOT EDIT. + +func TestnativeSelectB(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]bool + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = nativeSelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(false)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectB(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = nativeSelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int(0))) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int8 + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = nativeSelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int8(0))) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int16 + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = nativeSelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int16(0))) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int32 + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = nativeSelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int32(0))) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int64 + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = nativeSelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int64(0))) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = nativeSelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint(0))) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint8 + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = nativeSelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint8(0))) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint16 + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = nativeSelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint16(0))) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint32 + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = nativeSelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint32(0))) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint64 + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = nativeSelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint64(0))) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float32 + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = nativeSelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float32(0))) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectF32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float64 + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = nativeSelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float64(0))) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectF64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex64 + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = nativeSelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex64(0))) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectC64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex128 + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = nativeSelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex128(0))) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectC128(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]string + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = nativeSelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar("")) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectStr(T, 10); err == nil { + t.Fatal("Expected errors") + } +} diff --git a/native/iterator_native2.go b/native/iterator_native2.go index 2d0cdf5..f6b2e0e 100644 --- a/native/iterator_native2.go +++ b/native/iterator_native2.go @@ -1,636 +1,91 @@ +// +build !purego + package native // Code generated by genlib2. DO NOT EDIT. import ( - "reflect" - "unsafe" + _ "unsafe" - "github.com/pkg/errors" - "gorgonia.org/dtype" - . "gorgonia.org/tensor" + "gorgonia.org/tensor" ) -func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { - if !t.IsNativelyAccessible() { - return errors.New("Cannot select on non-natively accessible data") - } - if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { - return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) - } - if t.F() || t.RequiresIterator() { - return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") - } - if t.Dtype() != dt { - return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) - } - return nil -} - -/* Native Select for bool */ - -// SelectB creates a slice of flat data types. See Example of NativeSelectF64. -func SelectB(t *Dense, axis int) (retVal [][]bool, err error) { - if err := checkNativeSelectable(t, axis, Bool); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]bool, 1) - retVal[0] = t.Bools() - case 2: - if axis == 0 { - return MatrixB(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Bools() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]bool, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]bool, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for int */ - -// SelectI creates a slice of flat data types. See Example of NativeSelectF64. -func SelectI(t *Dense, axis int) (retVal [][]int, err error) { - if err := checkNativeSelectable(t, axis, Int); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]int, 1) - retVal[0] = t.Ints() - case 2: - if axis == 0 { - return MatrixI(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Ints() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]int, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]int, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for int8 */ - -// SelectI8 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectI8(t *Dense, axis int) (retVal [][]int8, err error) { - if err := checkNativeSelectable(t, axis, Int8); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]int8, 1) - retVal[0] = t.Int8s() - case 2: - if axis == 0 { - return MatrixI8(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Int8s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]int8, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]int8, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for int16 */ - -// SelectI16 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectI16(t *Dense, axis int) (retVal [][]int16, err error) { - if err := checkNativeSelectable(t, axis, Int16); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]int16, 1) - retVal[0] = t.Int16s() - case 2: - if axis == 0 { - return MatrixI16(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Int16s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]int16, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]int16, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for int32 */ - -// SelectI32 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectI32(t *Dense, axis int) (retVal [][]int32, err error) { - if err := checkNativeSelectable(t, axis, Int32); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]int32, 1) - retVal[0] = t.Int32s() - case 2: - if axis == 0 { - return MatrixI32(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Int32s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]int32, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]int32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for int64 */ - -// SelectI64 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectI64(t *Dense, axis int) (retVal [][]int64, err error) { - if err := checkNativeSelectable(t, axis, Int64); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]int64, 1) - retVal[0] = t.Int64s() - case 2: - if axis == 0 { - return MatrixI64(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Int64s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]int64, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]int64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for uint */ - -// SelectU creates a slice of flat data types. See Example of NativeSelectF64. -func SelectU(t *Dense, axis int) (retVal [][]uint, err error) { - if err := checkNativeSelectable(t, axis, Uint); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]uint, 1) - retVal[0] = t.Uints() - case 2: - if axis == 0 { - return MatrixU(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Uints() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]uint, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]uint, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for uint8 */ - -// SelectU8 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { - if err := checkNativeSelectable(t, axis, Uint8); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]uint8, 1) - retVal[0] = t.Uint8s() - case 2: - if axis == 0 { - return MatrixU8(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Uint8s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]uint8, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]uint8, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for uint16 */ - -// SelectU16 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { - if err := checkNativeSelectable(t, axis, Uint16); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]uint16, 1) - retVal[0] = t.Uint16s() - case 2: - if axis == 0 { - return MatrixU16(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Uint16s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]uint16, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]uint16, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for uint32 */ - -// SelectU32 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { - if err := checkNativeSelectable(t, axis, Uint32); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]uint32, 1) - retVal[0] = t.Uint32s() - case 2: - if axis == 0 { - return MatrixU32(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Uint32s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]uint32, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]uint32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for uint64 */ - -// SelectU64 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { - if err := checkNativeSelectable(t, axis, Uint64); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]uint64, 1) - retVal[0] = t.Uint64s() - case 2: - if axis == 0 { - return MatrixU64(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Uint64s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]uint64, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]uint64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for float32 */ - -// SelectF32 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectF32(t *Dense, axis int) (retVal [][]float32, err error) { - if err := checkNativeSelectable(t, axis, Float32); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]float32, 1) - retVal[0] = t.Float32s() - case 2: - if axis == 0 { - return MatrixF32(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Float32s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]float32, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]float32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for float64 */ - -// SelectF64 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectF64(t *Dense, axis int) (retVal [][]float64, err error) { - if err := checkNativeSelectable(t, axis, Float64); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]float64, 1) - retVal[0] = t.Float64s() - case 2: - if axis == 0 { - return MatrixF64(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Float64s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]float64, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]float64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for complex64 */ - -// SelectC64 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { - if err := checkNativeSelectable(t, axis, Complex64); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]complex64, 1) - retVal[0] = t.Complex64s() - case 2: - if axis == 0 { - return MatrixC64(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Complex64s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]complex64, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]complex64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for complex128 */ - -// SelectC128 creates a slice of flat data types. See Example of NativeSelectF64. -func SelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { - if err := checkNativeSelectable(t, axis, Complex128); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]complex128, 1) - retVal[0] = t.Complex128s() - case 2: - if axis == 0 { - return MatrixC128(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Complex128s() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]complex128, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]complex128, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} - -/* Native Select for string */ - -// SelectStr creates a slice of flat data types. See Example of NativeSelectF64. -func SelectStr(t *Dense, axis int) (retVal [][]string, err error) { - if err := checkNativeSelectable(t, axis, String); err != nil { - return nil, err - } - - switch t.Shape().Dims() { - case 0, 1: - retVal = make([][]string, 1) - retVal[0] = t.Strings() - case 2: - if axis == 0 { - return MatrixStr(t) - } - fallthrough - default: - // size := t.Shape()[axis] - data := t.Strings() - stride := t.Strides()[axis] - upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]string, 0, upper) - for i, r := 0, 0; r < upper; i += stride { - s := make([]string, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = stride - hdr.Cap = stride - retVal = append(retVal, s) - r++ - } - return retVal, nil - - } - return -} +//go:linkname SelectB gorgonia.org/tensor.nativeSelectB + +// SelectB creates a slice of bools. See Example of NativeSelectF64. +func SelectB(t *tensor.Dense, axis int) (retVal [][]bool, err error) + +//go:linkname SelectI gorgonia.org/tensor.nativeSelectI + +// SelectI creates a slice of ints. See Example of NativeSelectF64. +func SelectI(t *tensor.Dense, axis int) (retVal [][]int, err error) + +//go:linkname SelectI8 gorgonia.org/tensor.nativeSelectI8 + +// SelectI8 creates a slice of int8s. See Example of NativeSelectF64. +func SelectI8(t *tensor.Dense, axis int) (retVal [][]int8, err error) + +//go:linkname SelectI16 gorgonia.org/tensor.nativeSelectI16 + +// SelectI16 creates a slice of int16s. See Example of NativeSelectF64. +func SelectI16(t *tensor.Dense, axis int) (retVal [][]int16, err error) + +//go:linkname SelectI32 gorgonia.org/tensor.nativeSelectI32 + +// SelectI32 creates a slice of int32s. See Example of NativeSelectF64. +func SelectI32(t *tensor.Dense, axis int) (retVal [][]int32, err error) + +//go:linkname SelectI64 gorgonia.org/tensor.nativeSelectI64 + +// SelectI64 creates a slice of int64s. See Example of NativeSelectF64. +func SelectI64(t *tensor.Dense, axis int) (retVal [][]int64, err error) + +//go:linkname SelectU gorgonia.org/tensor.nativeSelectU + +// SelectU creates a slice of uints. See Example of NativeSelectF64. +func SelectU(t *tensor.Dense, axis int) (retVal [][]uint, err error) + +//go:linkname SelectU8 gorgonia.org/tensor.nativeSelectU8 + +// SelectU8 creates a slice of uint8s. See Example of NativeSelectF64. +func SelectU8(t *tensor.Dense, axis int) (retVal [][]uint8, err error) + +//go:linkname SelectU16 gorgonia.org/tensor.nativeSelectU16 + +// SelectU16 creates a slice of uint16s. See Example of NativeSelectF64. +func SelectU16(t *tensor.Dense, axis int) (retVal [][]uint16, err error) + +//go:linkname SelectU32 gorgonia.org/tensor.nativeSelectU32 + +// SelectU32 creates a slice of uint32s. See Example of NativeSelectF64. +func SelectU32(t *tensor.Dense, axis int) (retVal [][]uint32, err error) + +//go:linkname SelectU64 gorgonia.org/tensor.nativeSelectU64 + +// SelectU64 creates a slice of uint64s. See Example of NativeSelectF64. +func SelectU64(t *tensor.Dense, axis int) (retVal [][]uint64, err error) + +//go:linkname SelectF32 gorgonia.org/tensor.nativeSelectF32 + +// SelectF32 creates a slice of float32s. See Example of NativeSelectF64. +func SelectF32(t *tensor.Dense, axis int) (retVal [][]float32, err error) + +//go:linkname SelectF64 gorgonia.org/tensor.nativeSelectF64 + +// SelectF64 creates a slice of float64s. See Example of NativeSelectF64. +func SelectF64(t *tensor.Dense, axis int) (retVal [][]float64, err error) + +//go:linkname SelectC64 gorgonia.org/tensor.nativeSelectC64 + +// SelectC64 creates a slice of complex64s. See Example of NativeSelectF64. +func SelectC64(t *tensor.Dense, axis int) (retVal [][]complex64, err error) + +//go:linkname SelectC128 gorgonia.org/tensor.nativeSelectC128 + +// SelectC128 creates a slice of complex128s. See Example of NativeSelectF64. +func SelectC128(t *tensor.Dense, axis int) (retVal [][]complex128, err error) + +//go:linkname SelectStr gorgonia.org/tensor.nativeSelectStr + +// SelectStr creates a slice of strings. See Example of NativeSelectF64. +func SelectStr(t *tensor.Dense, axis int) (retVal [][]string, err error) diff --git a/native/iterator_native2_purego.go b/native/iterator_native2_purego.go new file mode 100644 index 0000000..e2f1e2c --- /dev/null +++ b/native/iterator_native2_purego.go @@ -0,0 +1,620 @@ +// +build purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +/* Native Select for bool */ + +// SelectB creates a slice of flat data types. See Example of NativeSelectF64. +func SelectB(t *Dense, axis int) (retVal [][]bool, err error) { + if err := checkNativeSelectable(t, axis, Bool); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]bool, 1) + retVal[0] = t.Bools() + case 2: + if axis == 0 { + return MatrixB(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Bools() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]bool, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int */ + +// SelectI creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI(t *Dense, axis int) (retVal [][]int, err error) { + if err := checkNativeSelectable(t, axis, Int); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int, 1) + retVal[0] = t.Ints() + case 2: + if axis == 0 { + return MatrixI(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Ints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int8 */ + +// SelectI8 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI8(t *Dense, axis int) (retVal [][]int8, err error) { + if err := checkNativeSelectable(t, axis, Int8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int8, 1) + retVal[0] = t.Int8s() + case 2: + if axis == 0 { + return MatrixI8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int16 */ + +// SelectI16 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI16(t *Dense, axis int) (retVal [][]int16, err error) { + if err := checkNativeSelectable(t, axis, Int16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int16, 1) + retVal[0] = t.Int16s() + case 2: + if axis == 0 { + return MatrixI16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int32 */ + +// SelectI32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI32(t *Dense, axis int) (retVal [][]int32, err error) { + if err := checkNativeSelectable(t, axis, Int32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int32, 1) + retVal[0] = t.Int32s() + case 2: + if axis == 0 { + return MatrixI32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int64 */ + +// SelectI64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectI64(t *Dense, axis int) (retVal [][]int64, err error) { + if err := checkNativeSelectable(t, axis, Int64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int64, 1) + retVal[0] = t.Int64s() + case 2: + if axis == 0 { + return MatrixI64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint */ + +// SelectU creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU(t *Dense, axis int) (retVal [][]uint, err error) { + if err := checkNativeSelectable(t, axis, Uint); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint, 1) + retVal[0] = t.Uints() + case 2: + if axis == 0 { + return MatrixU(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint8 */ + +// SelectU8 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { + if err := checkNativeSelectable(t, axis, Uint8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint8, 1) + retVal[0] = t.Uint8s() + case 2: + if axis == 0 { + return MatrixU8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint16 */ + +// SelectU16 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { + if err := checkNativeSelectable(t, axis, Uint16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint16, 1) + retVal[0] = t.Uint16s() + case 2: + if axis == 0 { + return MatrixU16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint32 */ + +// SelectU32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { + if err := checkNativeSelectable(t, axis, Uint32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint32, 1) + retVal[0] = t.Uint32s() + case 2: + if axis == 0 { + return MatrixU32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint64 */ + +// SelectU64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { + if err := checkNativeSelectable(t, axis, Uint64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint64, 1) + retVal[0] = t.Uint64s() + case 2: + if axis == 0 { + return MatrixU64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float32 */ + +// SelectF32 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectF32(t *Dense, axis int) (retVal [][]float32, err error) { + if err := checkNativeSelectable(t, axis, Float32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float32, 1) + retVal[0] = t.Float32s() + case 2: + if axis == 0 { + return MatrixF32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float64 */ + +// SelectF64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectF64(t *Dense, axis int) (retVal [][]float64, err error) { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float64, 1) + retVal[0] = t.Float64s() + case 2: + if axis == 0 { + return MatrixF64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex64 */ + +// SelectC64 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { + if err := checkNativeSelectable(t, axis, Complex64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex64, 1) + retVal[0] = t.Complex64s() + case 2: + if axis == 0 { + return MatrixC64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex128 */ + +// SelectC128 creates a slice of flat data types. See Example of NativeSelectF64. +func SelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { + if err := checkNativeSelectable(t, axis, Complex128); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex128, 1) + retVal[0] = t.Complex128s() + case 2: + if axis == 0 { + return MatrixC128(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex128s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex128, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for string */ + +// SelectStr creates a slice of flat data types. See Example of NativeSelectF64. +func SelectStr(t *Dense, axis int) (retVal [][]string, err error) { + if err := checkNativeSelectable(t, axis, String); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]string, 1) + retVal[0] = t.Strings() + case 2: + if axis == 0 { + return MatrixStr(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Strings() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]string, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} diff --git a/native/utils.go b/native/utils.go index 78c561e..341388e 100644 --- a/native/utils.go +++ b/native/utils.go @@ -28,3 +28,19 @@ func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { return nil } + +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} From e5c930c0fd58ab03d0811b16858f6ce08c1201db Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 14:34:37 +1000 Subject: [PATCH 106/154] Added lazy native select --- example_lazy_native_select_test.go | 38 ++++++++++ iterator_native2_lazy.go | 108 +++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 example_lazy_native_select_test.go create mode 100644 iterator_native2_lazy.go diff --git a/example_lazy_native_select_test.go b/example_lazy_native_select_test.go new file mode 100644 index 0000000..8a5249e --- /dev/null +++ b/example_lazy_native_select_test.go @@ -0,0 +1,38 @@ +package tensor + +import ( + "fmt" +) + +func ExampleLazySelectF64() { + T := New(WithShape(50, 5), WithBacking(Range(Float64, 1, 251))) + + // now let's iterate this using a lazy native select, selecting 10 rows at time + + it := NewLazySelectF64(T, 0, 10) + + var i int + cur := it.Native() + fmt.Printf("%d: %v\n", i, cur) + hasRem, trunc := it.Next() + i++ + for ; hasRem; hasRem, trunc = it.Next() { + cur = it.Native() + fmt.Printf("%d: %v\n", i, cur) + + i++ + } + cur = it.Native() + fmt.Printf("%d: %v\n", i, cur) + if trunc { + // do something + } + + // Output: + // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] + // 1: [[51 52 53 54 55] [56 57 58 59 60] [61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75] [76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90] [91 92 93 94 95] [96 97 98 99 100]] + // 2: [[101 102 103 104 105] [106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120] [121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135] [136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] + // 3: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165] [166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180] [181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195] [196 197 198 199 200]] + // 4: [[201 202 203 204 205] [206 207 208 209 210] [211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225] [226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240] [241 242 243 244 245] [246 247 248 249 250]] + +} diff --git a/iterator_native2_lazy.go b/iterator_native2_lazy.go new file mode 100644 index 0000000..fb51d7a --- /dev/null +++ b/iterator_native2_lazy.go @@ -0,0 +1,108 @@ +package tensor + +import ( + "log" + "reflect" + "runtime" + "unsafe" +) + +type LazySelectF64 struct { + t *Dense + it [][]float64 // FUTURE: this can be made into generic in the future + + // state + + upper int // the outer dimension after being "reshaped" + limit int // limit as to how many rows the `it` can store + stride int // stride + r int // current row +} + +func NewLazySelectF64(t *Dense, axis int, limit int) *LazySelectF64 { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + panic(err) + } + + if limit <= 0 { + limit = runtime.NumCPU() // default + } + upper := ProdInts(t.Shape()[:axis+1]) + if limit > upper { + limit = upper + // `it` should come from nativeSelectF64 + } + stride := t.Strides()[axis] + data := t.Float64s() + + it := make([][]float64, 0, limit) + var i, r int + for i, r = 0, 0; r < limit; i += stride { + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + it = append(it, s) + r++ + } + + return &LazySelectF64{ + t: t, + it: it, + upper: upper, + limit: limit, + stride: stride, + r: r, + } +} + +// Next moves the next batch into the native iterator. +func (it *LazySelectF64) Next() (hasRemaingRows, truncated bool) { + var ( + i int // data ptr + r int // relative row + s int // absolute row + ) + data := it.t.Float64s() + for i, r, s = it.r*it.stride, 0, it.r; r < it.limit && s < it.upper; i, r, s = i+it.stride, r+1, s+1 { + sl := it.it[r] + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&sl)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = it.stride + hdr.Cap = it.stride + it.it[r] = sl + } + it.r = s + + log.Printf("r %v limit %v, s %v upper %v", r, it.limit, s, it.upper) + + if r < it.limit { + // truncate it.it + it.it = it.it[:r] + return false, true + } + if s == it.upper { + return false, false + } + + return true, false +} + +func (it *LazySelectF64) Native() [][]float64 { return it.it } + +func (it *LazySelectF64) Reset() { + it.it = it.it[:it.limit] + + data := it.t.Float64s() + var i, r int + for i, r = 0, 0; r < it.limit; i += it.stride { + sl := it.it[r] + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&sl)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = it.stride + hdr.Cap = it.stride + it.it[r] = sl + r++ + } +} From d113c2b7f6c16eeb42b9ae2da3fbacf4ede51abf Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 14:40:38 +1000 Subject: [PATCH 107/154] renamed the files --- ...ect_test.go => example_batched_nativeselect_test.go | 0 genlib2/main.go | 10 +++++----- native/{iterator_native2.go => select_native.go} | 0 ...rator_native2_purego.go => select_native_purego.go} | 0 ...{iterator_native2_test.go => select_native_test.go} | 0 iterator_native2.go => select_native.go | 0 iterator_native2_lazy.go => select_native_batched.go | 0 iterator_native2_test.go => select_native_test.go | 0 8 files changed, 5 insertions(+), 5 deletions(-) rename example_lazy_native_select_test.go => example_batched_nativeselect_test.go (100%) rename native/{iterator_native2.go => select_native.go} (100%) rename native/{iterator_native2_purego.go => select_native_purego.go} (100%) rename native/{iterator_native2_test.go => select_native_test.go} (100%) rename iterator_native2.go => select_native.go (100%) rename iterator_native2_lazy.go => select_native_batched.go (100%) rename iterator_native2_test.go => select_native_test.go (100%) diff --git a/example_lazy_native_select_test.go b/example_batched_nativeselect_test.go similarity index 100% rename from example_lazy_native_select_test.go rename to example_batched_nativeselect_test.go diff --git a/genlib2/main.go b/genlib2/main.go index 87f2788..6745508 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -117,16 +117,16 @@ func main() { // native iterators - the ones in the tensor package pipeline(tensorPkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators(false)) pipeline(tensorPkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(false)) - pipeline(tensorPkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect(false)) - pipeline(tensorPkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests(false)) + pipeline(tensorPkgLoc, "select_native.go", Kinds{allKinds}, generateNativeSelect(false)) + pipeline(tensorPkgLoc, "select_native_test.go", Kinds{allKinds}, generateNativeSelectTests(false)) // native iterators - exported into gorgonia.org/tensor/native pipeline(nativePkgLoc+"_unsafe", "iterator_native.go", Kinds{allKinds}, generateNativeIteratorStubs) pipeline(nativePkgLoc+"_purego", "iterator_native_purego.go", Kinds{allKinds}, generateNativeIterators(true)) pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(true)) - pipeline(nativePkgLoc+"_unsafe", "iterator_native2.go", Kinds{allKinds}, generateNativeSelectStubs) - pipeline(nativePkgLoc+"_purego", "iterator_native2_purego.go", Kinds{allKinds}, generateNativeSelect(true)) - pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests(true)) + pipeline(nativePkgLoc+"_unsafe", "select_native.go", Kinds{allKinds}, generateNativeSelectStubs) + pipeline(nativePkgLoc+"_purego", "select_native_purego.go", Kinds{allKinds}, generateNativeSelect(true)) + pipeline(nativePkgLoc, "select_native_test.go", Kinds{allKinds}, generateNativeSelectTests(true)) pipeline(nativePkgLoc, "utils.go", Kinds{allKinds}, generateNativeIterChecks, generateNativeSelChecks) } diff --git a/native/iterator_native2.go b/native/select_native.go similarity index 100% rename from native/iterator_native2.go rename to native/select_native.go diff --git a/native/iterator_native2_purego.go b/native/select_native_purego.go similarity index 100% rename from native/iterator_native2_purego.go rename to native/select_native_purego.go diff --git a/native/iterator_native2_test.go b/native/select_native_test.go similarity index 100% rename from native/iterator_native2_test.go rename to native/select_native_test.go diff --git a/iterator_native2.go b/select_native.go similarity index 100% rename from iterator_native2.go rename to select_native.go diff --git a/iterator_native2_lazy.go b/select_native_batched.go similarity index 100% rename from iterator_native2_lazy.go rename to select_native_batched.go diff --git a/iterator_native2_test.go b/select_native_test.go similarity index 100% rename from iterator_native2_test.go rename to select_native_test.go From b3d23369d0f868fb391ee923c2e200a68d83aa4e Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 14:46:11 +1000 Subject: [PATCH 108/154] Renamed LazyNativeSelectF64 to BatchedNativeSelectF64 --- select_native_batched.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/select_native_batched.go b/select_native_batched.go index fb51d7a..35bc72d 100644 --- a/select_native_batched.go +++ b/select_native_batched.go @@ -7,7 +7,7 @@ import ( "unsafe" ) -type LazySelectF64 struct { +type BatchedNativeSelectF64 struct { t *Dense it [][]float64 // FUTURE: this can be made into generic in the future @@ -19,7 +19,7 @@ type LazySelectF64 struct { r int // current row } -func NewLazySelectF64(t *Dense, axis int, limit int) *LazySelectF64 { +func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { if err := checkNativeSelectable(t, axis, Float64); err != nil { panic(err) } @@ -47,7 +47,7 @@ func NewLazySelectF64(t *Dense, axis int, limit int) *LazySelectF64 { r++ } - return &LazySelectF64{ + return &BatchedNativeSelectF64{ t: t, it: it, upper: upper, @@ -58,7 +58,7 @@ func NewLazySelectF64(t *Dense, axis int, limit int) *LazySelectF64 { } // Next moves the next batch into the native iterator. -func (it *LazySelectF64) Next() (hasRemaingRows, truncated bool) { +func (it *BatchedNativeSelectF64) Next() (hasRemaingRows, truncated bool) { var ( i int // data ptr r int // relative row @@ -89,9 +89,9 @@ func (it *LazySelectF64) Next() (hasRemaingRows, truncated bool) { return true, false } -func (it *LazySelectF64) Native() [][]float64 { return it.it } +func (it *BatchedNativeSelectF64) Native() [][]float64 { return it.it } -func (it *LazySelectF64) Reset() { +func (it *BatchedNativeSelectF64) Reset() { it.it = it.it[:it.limit] data := it.t.Float64s() From c588e451a385d000cb144d2aef630dc9ce431068 Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 14:59:44 +1000 Subject: [PATCH 109/154] Playing with the API a bit to see if I can make it nicer --- example_batched_nativeselect_test.go | 16 ++++++---------- select_native_batched.go | 12 +++++++++++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/example_batched_nativeselect_test.go b/example_batched_nativeselect_test.go index 8a5249e..7428faf 100644 --- a/example_batched_nativeselect_test.go +++ b/example_batched_nativeselect_test.go @@ -4,29 +4,25 @@ import ( "fmt" ) -func ExampleLazySelectF64() { +func ExampleBatchedNativeSelectF64() { T := New(WithShape(50, 5), WithBacking(Range(Float64, 1, 251))) // now let's iterate this using a lazy native select, selecting 10 rows at time - it := NewLazySelectF64(T, 0, 10) + it := BatchSelectF64(T, 0, 10) var i int - cur := it.Native() - fmt.Printf("%d: %v\n", i, cur) - hasRem, trunc := it.Next() - i++ - for ; hasRem; hasRem, trunc = it.Next() { + var cur [][]float64 + for hasRem, trunc := it.Start(); hasRem; hasRem, trunc = it.Next() { cur = it.Native() fmt.Printf("%d: %v\n", i, cur) + if trunc { + } i++ } cur = it.Native() fmt.Printf("%d: %v\n", i, cur) - if trunc { - // do something - } // Output: // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] diff --git a/select_native_batched.go b/select_native_batched.go index 35bc72d..b35d343 100644 --- a/select_native_batched.go +++ b/select_native_batched.go @@ -57,6 +57,16 @@ func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { } } +func (it *BatchedNativeSelectF64) Start() (hasRemainingRows, truncated bool) { + if it.r != it.limit || len(it.it) != it.limit { + // then it's been moved, so we reset + it.Reset() + } + hasRemainingRows = it.upper > it.r + truncated = false + return +} + // Next moves the next batch into the native iterator. func (it *BatchedNativeSelectF64) Next() (hasRemaingRows, truncated bool) { var ( @@ -82,7 +92,7 @@ func (it *BatchedNativeSelectF64) Next() (hasRemaingRows, truncated bool) { it.it = it.it[:r] return false, true } - if s == it.upper { + if it.r == it.upper { return false, false } From 2194ae82a2d95a9fdeb62e6cab686f85b058db1d Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 20:13:55 +1000 Subject: [PATCH 110/154] Playing with API a bit more. --- example_batched_nativeselect_test.go | 16 +++++----------- select_native_batched.go | 16 +++++++++------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/example_batched_nativeselect_test.go b/example_batched_nativeselect_test.go index 7428faf..006c751 100644 --- a/example_batched_nativeselect_test.go +++ b/example_batched_nativeselect_test.go @@ -11,18 +11,12 @@ func ExampleBatchedNativeSelectF64() { it := BatchSelectF64(T, 0, 10) - var i int - var cur [][]float64 - for hasRem, trunc := it.Start(); hasRem; hasRem, trunc = it.Next() { - cur = it.Native() - fmt.Printf("%d: %v\n", i, cur) - if trunc { - } - - i++ + var batchNo int + for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { + fmt.Printf("%d: %v\n", batchNo, cur) + batchNo++ } - cur = it.Native() - fmt.Printf("%d: %v\n", i, cur) + fmt.Printf("%d: %v\n", batchNo, it.Native()) // Output: // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] diff --git a/select_native_batched.go b/select_native_batched.go index b35d343..c630b55 100644 --- a/select_native_batched.go +++ b/select_native_batched.go @@ -57,18 +57,18 @@ func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { } } -func (it *BatchedNativeSelectF64) Start() (hasRemainingRows, truncated bool) { +func (it *BatchedNativeSelectF64) Start() (curBatch [][]float64, hasRemainingRows bool) { if it.r != it.limit || len(it.it) != it.limit { // then it's been moved, so we reset it.Reset() } + curBatch = it.it hasRemainingRows = it.upper > it.r - truncated = false return } // Next moves the next batch into the native iterator. -func (it *BatchedNativeSelectF64) Next() (hasRemaingRows, truncated bool) { +func (it *BatchedNativeSelectF64) Next() (curBatch [][]float64, hasRemaingRows bool) { var ( i int // data ptr r int // relative row @@ -90,19 +90,19 @@ func (it *BatchedNativeSelectF64) Next() (hasRemaingRows, truncated bool) { if r < it.limit { // truncate it.it it.it = it.it[:r] - return false, true + return it.it, false } if it.r == it.upper { - return false, false + return it.it, false } - return true, false + return it.it, true } func (it *BatchedNativeSelectF64) Native() [][]float64 { return it.it } func (it *BatchedNativeSelectF64) Reset() { - it.it = it.it[:it.limit] + it.it = it.it[:it.limit:it.limit] data := it.t.Float64s() var i, r int @@ -116,3 +116,5 @@ func (it *BatchedNativeSelectF64) Reset() { r++ } } + +func (it *BatchedNativeSelectF64) IsTruncated() bool { return len(it.it) != it.limit } From 63c57ae647fe8387594eb4c479a52c8a40fab064 Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 29 Jul 2021 20:45:58 +1000 Subject: [PATCH 111/154] I am finally satisfied with the API. Added a lot more documentation --- example_batched_nativeselect_test.go | 49 +++++++++++++++++++++++-- select_native_batched.go | 53 ++++++++++++++++++---------- 2 files changed, 81 insertions(+), 21 deletions(-) diff --git a/example_batched_nativeselect_test.go b/example_batched_nativeselect_test.go index 006c751..696f92d 100644 --- a/example_batched_nativeselect_test.go +++ b/example_batched_nativeselect_test.go @@ -2,6 +2,7 @@ package tensor import ( "fmt" + "log" ) func ExampleBatchedNativeSelectF64() { @@ -9,20 +10,64 @@ func ExampleBatchedNativeSelectF64() { // now let's iterate this using a lazy native select, selecting 10 rows at time + fmt.Println("Batchsize of 10") it := BatchSelectF64(T, 0, 10) - var batchNo int for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { fmt.Printf("%d: %v\n", batchNo, cur) batchNo++ } - fmt.Printf("%d: %v\n", batchNo, it.Native()) + fmt.Printf("Is Truncated? %t\n", it.IsTruncated()) + + log.Printf("XXX") + fmt.Println("Reusing the same iterator for another loop") + batchNo = 0 + for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { + fmt.Printf("%d: %v\n", batchNo, cur) + batchNo++ + } + + fmt.Println("Batchsize of 3") + it = BatchSelectF64(T, 0, 3) + batchNo = 0 + for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { + fmt.Printf("%d: %v\n", batchNo, cur) + batchNo++ + } + fmt.Printf("Is Truncated? %t\n", it.IsTruncated()) // Output: + // Batchsize of 10 + // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] + // 1: [[51 52 53 54 55] [56 57 58 59 60] [61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75] [76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90] [91 92 93 94 95] [96 97 98 99 100]] + // 2: [[101 102 103 104 105] [106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120] [121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135] [136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] + // 3: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165] [166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180] [181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195] [196 197 198 199 200]] + // 4: [[201 202 203 204 205] [206 207 208 209 210] [211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225] [226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240] [241 242 243 244 245] [246 247 248 249 250]] + // Is Truncated? false + // Reusing the same iterator for another loop // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] // 1: [[51 52 53 54 55] [56 57 58 59 60] [61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75] [76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90] [91 92 93 94 95] [96 97 98 99 100]] // 2: [[101 102 103 104 105] [106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120] [121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135] [136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] // 3: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165] [166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180] [181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195] [196 197 198 199 200]] // 4: [[201 202 203 204 205] [206 207 208 209 210] [211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225] [226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240] [241 242 243 244 245] [246 247 248 249 250]] + // Batchsize of 3 + // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15]] + // 1: [[16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30]] + // 2: [[31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45]] + // 3: [[46 47 48 49 50] [51 52 53 54 55] [56 57 58 59 60]] + // 4: [[61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75]] + // 5: [[76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90]] + // 6: [[91 92 93 94 95] [96 97 98 99 100] [101 102 103 104 105]] + // 7: [[106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120]] + // 8: [[121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135]] + // 9: [[136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] + // 10: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165]] + // 11: [[166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180]] + // 12: [[181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195]] + // 13: [[196 197 198 199 200] [201 202 203 204 205] [206 207 208 209 210]] + // 14: [[211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225]] + // 15: [[226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240]] + // 16: [[241 242 243 244 245] [246 247 248 249 250]] + // Is Truncated? true } diff --git a/select_native_batched.go b/select_native_batched.go index c630b55..38a56ba 100644 --- a/select_native_batched.go +++ b/select_native_batched.go @@ -1,7 +1,6 @@ package tensor import ( - "log" "reflect" "runtime" "unsafe" @@ -38,12 +37,14 @@ func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { it := make([][]float64, 0, limit) var i, r int for i, r = 0, 0; r < limit; i += stride { - s := make([]float64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + // this block of code is basically + // it = append(it, data[i:i+stride]) + // TODO: benchmark + it = append(it, make([]float64, 0)) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&it[len(it)-1])) hdr.Data = uintptr(unsafe.Pointer(&data[i])) hdr.Len = stride hdr.Cap = stride - it = append(it, s) r++ } @@ -58,7 +59,7 @@ func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { } func (it *BatchedNativeSelectF64) Start() (curBatch [][]float64, hasRemainingRows bool) { - if it.r != it.limit || len(it.it) != it.limit { + if it.r != it.limit || it.IsTruncated() { // then it's been moved, so we reset it.Reset() } @@ -71,29 +72,42 @@ func (it *BatchedNativeSelectF64) Start() (curBatch [][]float64, hasRemainingRow func (it *BatchedNativeSelectF64) Next() (curBatch [][]float64, hasRemaingRows bool) { var ( i int // data ptr - r int // relative row + r int // relative row / row counter for this batch s int // absolute row ) + if it.r == it.upper { + return it.it, false + } data := it.t.Float64s() + + // this loop statement looks scary. But it isn't. Let me break it down: + // Initialization: + // i := it.r*it.stride // the data pointer is the row number * the stride of the matrix. + // r := 0 // loop counter. We're gonna iterate `it.limit` times. + // s := it.r // the current row number of the matrix. + // Condition (continue if the following are true): + // r < it.limit // we only want to iterate at most `it.limit` times. + // s < it.upper // we want to make sure we don't iterate more rows than there are rows in the matrix. + // Next: + // i = i + it.stride // we're ready to go to the next row. + // r = r+1 // we increment the row counter. + // s = s+1 // we increment the absolute row number. + // + // Could this be written in a less concise way? Sure. But then there'd be a lot more places to keep track of things. for i, r, s = it.r*it.stride, 0, it.r; r < it.limit && s < it.upper; i, r, s = i+it.stride, r+1, s+1 { - sl := it.it[r] - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&sl)) + // the block of code below is basically: + // it.it[r] = data[i:i+stride] + // r++ + // For some reason when this is done, Go actually does a lot more allocations. + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&it.it[r])) hdr.Data = uintptr(unsafe.Pointer(&data[i])) - hdr.Len = it.stride - hdr.Cap = it.stride - it.it[r] = sl } it.r = s - log.Printf("r %v limit %v, s %v upper %v", r, it.limit, s, it.upper) - - if r < it.limit { - // truncate it.it + if it.r == it.upper && r < it.limit { + // truncate it.it because iterated rows is less than the limit. + // This implies that there are some extra rows. it.it = it.it[:r] - return it.it, false - } - if it.r == it.upper { - return it.it, false } return it.it, true @@ -115,6 +129,7 @@ func (it *BatchedNativeSelectF64) Reset() { it.it[r] = sl r++ } + it.r = r } func (it *BatchedNativeSelectF64) IsTruncated() bool { return len(it.it) != it.limit } From 7a4060aea4d3253f04f88bbaa83c26310af52466 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 4 Aug 2021 23:07:29 +1000 Subject: [PATCH 112/154] fixed the WithBacking funcopt to actually play nicer with CUDA. --- dense.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/dense.go b/dense.go index 20750a4..39b0f90 100644 --- a/dense.go +++ b/dense.go @@ -278,7 +278,28 @@ func (t *Dense) fix() { t.oe = oe } + _, isNonStdEng := t.e.(NonStdEngine) + switch { + case isNonStdEng && t.Shape() != nil: + // if there is already data in the array, we should back it up now + raw := t.array.Header.Raw + + // make the array + size := t.Shape().TotalSize() + if t.Shape().IsScalar() { + size = 1 + } + t.makeArray(size) + + if len(raw) != 0 { + // copy over if natively accessible + if t.IsNativelyAccessible() { + bs := t.byteSlice() + copy(bs, raw) + } + } + case t.IsScalar() && t.array.Header.Raw == nil: t.makeArray(1) case t.Shape() == nil && t.array.Header.Raw != nil: From 082061fe13c51d2e4ef46daabb67dedfc27e3710 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 4 Aug 2021 23:08:36 +1000 Subject: [PATCH 113/154] Added a temporary range based iter... --- example_batched_nativeselect_test.go | 35 ++++++++++++++++++++++++++++ select_native_batched.go | 28 ++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/example_batched_nativeselect_test.go b/example_batched_nativeselect_test.go index 696f92d..7350aa9 100644 --- a/example_batched_nativeselect_test.go +++ b/example_batched_nativeselect_test.go @@ -71,3 +71,38 @@ func ExampleBatchedNativeSelectF64() { // Is Truncated? true } + +func ExampleIterSelect() { + T := New(WithShape(20, 5), WithBacking(Range(Float64, 1, 101))) + it := NewIterSelect(T, 0) + data := T.Float64s() + var rowNo int + for start, end, hasRem := it.Start(); hasRem; start, end, hasRem = it.Next() { + sl := data[start:end] + fmt.Printf("%d: %v\n", rowNo, sl) + rowNo++ + } + + // Output: + // 0: [1 2 3 4 5] + // 1: [6 7 8 9 10] + // 2: [11 12 13 14 15] + // 3: [16 17 18 19 20] + // 4: [21 22 23 24 25] + // 5: [26 27 28 29 30] + // 6: [31 32 33 34 35] + // 7: [36 37 38 39 40] + // 8: [41 42 43 44 45] + // 9: [46 47 48 49 50] + // 10: [51 52 53 54 55] + // 11: [56 57 58 59 60] + // 12: [61 62 63 64 65] + // 13: [66 67 68 69 70] + // 14: [71 72 73 74 75] + // 15: [76 77 78 79 80] + // 16: [81 82 83 84 85] + // 17: [86 87 88 89 90] + // 18: [91 92 93 94 95] + // 19: [96 97 98 99 100] + +} diff --git a/select_native_batched.go b/select_native_batched.go index 38a56ba..c05bd76 100644 --- a/select_native_batched.go +++ b/select_native_batched.go @@ -133,3 +133,31 @@ func (it *BatchedNativeSelectF64) Reset() { } func (it *BatchedNativeSelectF64) IsTruncated() bool { return len(it.it) != it.limit } + +type IterSelect struct { + r int + upper int + stride int + total int +} + +func NewIterSelect(t *Dense, axis int) *IterSelect { + upper := ProdInts(t.Shape()[:axis+1]) + stride := t.Strides()[axis] + total := t.DataSize() + return &IterSelect{upper: upper, stride: stride, total: total} +} + +func (it *IterSelect) Start() (start, end int, hasRem bool) { + if it.r > it.stride { + it.Reset() + } + return it.r, it.stride, it.r*it.stride+it.stride < it.total +} + +func (it *IterSelect) Next() (start, end int, hasRem bool) { + it.r += it.stride + return it.r, it.r + it.stride, it.r+it.stride <= it.total +} + +func (it *IterSelect) Reset() { it.r = 0 } From 94622fd3013654a931b828ffbb27ee0f0daa4dcb Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 5 Aug 2021 15:07:34 +1000 Subject: [PATCH 114/154] Fixed a bug in the ElNe API --- api_cmp.go | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/api_cmp.go b/api_cmp.go index ffb602d..b2ac050 100644 --- a/api_cmp.go +++ b/api_cmp.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // public API for comparison ops @@ -295,12 +297,26 @@ func ElNe(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { eleqer, ok = at.Engine().(ElEqer) switch bt := b.(type) { case Tensor: - if !ok { - if eleqer, ok = bt.Engine().(ElEqer); !ok { - return nil, errors.Errorf("Neither operands have engines that support ElEq") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if eleqer, ok = bt.Engine().(ElEqer); !ok { + return nil, errors.Errorf("Neither operands have engines that support ElEq") + } + } + return eleqer.ElNe(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false + at, bt = bt, at + } else { + leftTensor = true + } + if !ok { + return nil, errors.Errorf("Engine does not support ElNE") } + return eleqer.NeScalar(at, bt, leftTensor, opts...) } - return eleqer.ElNe(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support ElEq") From a7c5545ef499599ca0cabab3969a1c2e9c9800b5 Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 19 Aug 2021 09:09:22 +1000 Subject: [PATCH 115/154] Added some funcopts to Inner to allow for context handling --- api_arith.go | 4 ++-- dense_linalg.go | 8 +++++++- interfaces.go | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/api_arith.go b/api_arith.go index 62a30a3..6369875 100644 --- a/api_arith.go +++ b/api_arith.go @@ -628,7 +628,7 @@ func MatVecMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } // Inner finds the inner products of two vector Tensors. Both arguments to the functions are eexpected to be vectors. -func Inner(a, b Tensor) (retVal interface{}, err error) { +func Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err error) { if a.Dtype() != b.Dtype() { err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) return @@ -637,7 +637,7 @@ func Inner(a, b Tensor) (retVal interface{}, err error) { switch at := a.(type) { case *Dense: bt := b.(*Dense) - return at.Inner(bt) + return at.Inner(bt, opts...) } panic("Unreachable") } diff --git a/dense_linalg.go b/dense_linalg.go index 756d9b7..38efca0 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -16,7 +16,13 @@ func (t *Dense) Trace() (retVal interface{}, err error) { } // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. -func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { +func (t *Dense) Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + // check that the data is a float if err = dtype.TypeClassCheck(t.t, dtype.FloatComplex); err != nil { return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") diff --git a/interfaces.go b/interfaces.go index 4b44154..aabde3d 100644 --- a/interfaces.go +++ b/interfaces.go @@ -112,7 +112,7 @@ type DenseTensor interface { cap() int // operations - Inner(other Tensor) (retVal interface{}, err error) + Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err error) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err error) From 0f1ea5e7c3ba871c07bd3e9cc0c076650aff53bc Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 19 Aug 2021 10:52:30 +1000 Subject: [PATCH 116/154] Added some more support for context.Context --- defaultengine_linalg.go | 10 ++++++++-- defaultenginefloat32.go | 20 +++++++++++++++++--- defaultenginefloat64.go | 21 +++++++++++++++++---- dense_linalg.go | 15 +++++---------- engine.go | 8 ++++---- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 4a6073b..90cff91 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -10,7 +10,7 @@ import ( ) // Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error -func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { +func (e StdEng) Trace(t Tensor, opts ...FuncOpt) (retVal interface{}, err error) { if t.Dims() != 2 { err = errors.Errorf(dimMismatch, 2, t.Dims()) return @@ -372,7 +372,13 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // Inner is a thin layer over BLAS's D/Sdot. // It returns a scalar value, wrapped in an interface{}, which is not quite nice. -func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { +func (e StdEng) Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + var ad, bd DenseTensor if ad, bd, err = e.checkTwoFloatComplexTensors(a, b); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Inner") diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index 9f1ebf7..44a126b 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" @@ -9,9 +11,10 @@ import ( "gorgonia.org/vecf32" ) -func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() toReuse = reuseT != nil @@ -179,9 +182,14 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if err = e.checkThree(a, b, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") } @@ -220,7 +228,13 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, return } -func (e Float32Engine) Inner(a, b Tensor) (retVal float32, err error) { +func (e Float32Engine) Inner(a, b Tensor, opts ...FuncOpt) (retVal float32, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return 0, err // this err will be noopError{}, no need to wrap. + } + var A, B []float32 var AD, BD *Dense var ok bool diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 4e2167a..6bbf95c 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" @@ -9,9 +11,9 @@ import ( "gorgonia.org/vecf64" ) -func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) - + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() toReuse = reuseT != nil @@ -176,9 +178,14 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if err = e.checkThree(a, b, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") } @@ -217,7 +224,13 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, return } -func (e Float64Engine) Inner(a, b Tensor) (retVal float64, err error) { +func (e Float64Engine) Inner(a, b Tensor, opts ...FuncOpt) (retVal float64, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return 0, err // this err will be noopError{}, no need to wrap. + } + var A, B []float64 var AD, BD *Dense var ok bool diff --git a/dense_linalg.go b/dense_linalg.go index 38efca0..002f0c7 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -6,22 +6,17 @@ import ( ) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices -func (t *Dense) Trace() (retVal interface{}, err error) { +func (t *Dense) Trace(opts ...FuncOpt) (retVal interface{}, err error) { e := t.e if tracer, ok := e.(Tracer); ok { - return tracer.Trace(t) + return tracer.Trace(t, opts...) } return nil, errors.Errorf("Engine %T does not support Trace", e) } // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. func (t *Dense) Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err error) { - fo := ParseFuncOpts(opts...) - ctx := fo.Context() - if err = handleCtx(ctx); err != nil { - return nil, err // this err will be noopError{}, no need to wrap. - } // check that the data is a float if err = dtype.TypeClassCheck(t.t, dtype.FloatComplex); err != nil { @@ -42,11 +37,11 @@ func (t *Dense) Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err er e := t.e switch ip := e.(type) { case InnerProderF32: - return ip.Inner(t, other) + return ip.Inner(t, other, opts...) case InnerProderF64: - return ip.Inner(t, other) + return ip.Inner(t, other, opts...) case InnerProder: - return ip.Inner(t, other) + return ip.Inner(t, other, opts...) } return nil, errors.Errorf("Engine does not support Inner()") diff --git a/engine.go b/engine.go index f539ab1..542ee7b 100644 --- a/engine.go +++ b/engine.go @@ -172,7 +172,7 @@ type Moder interface { // Tracer is any engine that can return the trace (aka the sum of the diagonal elements). type Tracer interface { - Trace(a Tensor) (interface{}, error) + Trace(a Tensor, opts ...FuncOpt) (interface{}, error) } // FMAer is any engine that can perform fused multiply add functions: A * X + Y. Also known as Axpy. @@ -193,17 +193,17 @@ type MatVecMuler interface { // InnerProder is any engine that can perform inner product multiplication type InnerProder interface { - Inner(a, b Tensor) (interface{}, error) // Inner always returns a scalar value + Inner(a, b Tensor, opts ...FuncOpt) (interface{}, error) // Inner always returns a scalar value } // InnerProderF32 is an optimization for float32 - results are returned as float32. type InnerProderF32 interface { - Inner(a, b Tensor) (float32, error) + Inner(a, b Tensor, opts ...FuncOpt) (float32, error) } // InnerProderF64 is an optimization for float64 - results are returned as float64 type InnerProderF64 interface { - Inner(a, b Tensor) (float64, error) + Inner(a, b Tensor, opts ...FuncOpt) (float64, error) } // OuterProder is any engine that can perform outer product (kronecker) multiplication From 519fa045de6f661c4e86fed4e3b34ddd3939f145 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 09:17:32 +1000 Subject: [PATCH 117/154] Changed almost all the engine interface definitions to also include context.Context. This breaks the package in very many ways. The rest of the changes are forthcoming. --- engine.go | 64 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/engine.go b/engine.go index 542ee7b..93b19d5 100644 --- a/engine.go +++ b/engine.go @@ -1,6 +1,10 @@ package tensor -import "gorgonia.org/dtype" +import ( + "context" + + "gorgonia.org/dtype" +) // Memory is a representation of memory of the value. // @@ -74,33 +78,33 @@ type NonStdEngine interface { // Transposer is any engine that can perform an unsafe transpose of a tensor. type Transposer interface { - Transpose(t Tensor, expStrides []int) error + Transpose(ctx context.Context, t Tensor, expStrides []int) error } // Concater is any enegine that can concatenate multiple Tensors together type Concater interface { - Concat(t Tensor, axis int, others ...Tensor) (Tensor, error) + Concat(ctx context.Context, t Tensor, axis int, others ...Tensor) (Tensor, error) } // Stacker is any engine that can stack multiple Tenosrs along an axis type Stacker interface { - Stack(t Tensor, axis int, others ...Tensor) (Tensor, error) + Stack(ctx context.Context, t Tensor, axis int, others ...Tensor) (Tensor, error) } // DenseStacker is any engine that can stack DenseTensors along an axis. This is a specialization of Stacker. type DenseStacker interface { - StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) + StackDense(ctx context.Context, t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) } // Repeater is any engine that can repeat values along the given axis. type Repeater interface { - Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) - RepeatReuse(t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error) + Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) (Tensor, error) + RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error) } // Diager is any engine that can return a tensor that only contains the diagonal values of the input type Diager interface { - Diag(a Tensor) (Tensor, error) + Diag(ctx context.Context, a Tensor) (Tensor, error) } /* NUMBER INTERFACES @@ -172,53 +176,53 @@ type Moder interface { // Tracer is any engine that can return the trace (aka the sum of the diagonal elements). type Tracer interface { - Trace(a Tensor, opts ...FuncOpt) (interface{}, error) + Trace(ctx context.Context, a Tensor) (interface{}, error) } // FMAer is any engine that can perform fused multiply add functions: A * X + Y. Also known as Axpy. type FMAer interface { - FMA(a, x, y Tensor) (Tensor, error) - FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) + FMA(ctx context.Context, a, x, y Tensor) (Tensor, error) + FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (Tensor, error) } // MatMuler is any engine that can perform matrix multiplication type MatMuler interface { - MatMul(a, b, preallocated Tensor) error + MatMul(ctx context.Context, a, b, preallocated Tensor) error } // MatVecMuler is any engine that can perform matrix vector multiplication type MatVecMuler interface { - MatVecMul(a, b, preallocated Tensor) error + MatVecMul(ctx context.Context, a, b, preallocated Tensor) error } // InnerProder is any engine that can perform inner product multiplication type InnerProder interface { - Inner(a, b Tensor, opts ...FuncOpt) (interface{}, error) // Inner always returns a scalar value + Inner(ctx context.Context, a, b Tensor) (interface{}, error) // Inner always returns a scalar value } // InnerProderF32 is an optimization for float32 - results are returned as float32. type InnerProderF32 interface { - Inner(a, b Tensor, opts ...FuncOpt) (float32, error) + Inner(ctx context.Context, a, b Tensor) (float32, error) } // InnerProderF64 is an optimization for float64 - results are returned as float64 type InnerProderF64 interface { - Inner(a, b Tensor, opts ...FuncOpt) (float64, error) + Inner(ctx context.Context, a, b Tensor) (float64, error) } // OuterProder is any engine that can perform outer product (kronecker) multiplication type OuterProder interface { - Outer(a, b, preallocated Tensor) error + Outer(ctx context.Context, a, b, preallocated Tensor) error } // Dotter is used to implement sparse matrices type Dotter interface { - Dot(a, b Tensor, opts ...FuncOpt) (Tensor, error) + Dot(ctx context.Context, a, b Tensor) (Tensor, error) } // SVDer is any engine that can perform SVD type SVDer interface { - SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) + SVD(ctx context.Context, a Tensor, uv, full bool) (s, u, v Tensor, err error) } /* ORD INTERFACES */ @@ -354,22 +358,22 @@ type OptimizedReducer interface { // Sumer is any engine that can perform summation along an axis of a Tensor. type Sumer interface { - Sum(a Tensor, along ...int) (Tensor, error) + Sum(ctx context.Context, a Tensor, along ...int) (Tensor, error) } // Proder is any engine that can perform product along an axis of a Tensor. type Proder interface { - Prod(a Tensor, along ...int) (Tensor, error) + Prod(ctx context.Context, a Tensor, along ...int) (Tensor, error) } // Miner is any engine that can find the minimum value along an axis of a Tensor. type Miner interface { - Min(a Tensor, along ...int) (Tensor, error) + Min(ctx context.Context, a Tensor, along ...int) (Tensor, error) } // Maxer is any engine that can find the maximum value along an axis of a Tensor. type Maxer interface { - Max(a Tensor, along ...int) (Tensor, error) + Max(ctx context.Context, a Tensor, along ...int) (Tensor, error) } /* Arg methods */ @@ -377,39 +381,39 @@ type Maxer interface { // Argmaxer is any engine that can find the indices of the maximum values along an axis. // By convention the returned Tensor has Dtype of Int. type Argmaxer interface { - Argmax(t Tensor, axis int) (Tensor, error) + Argmax(ctx context.Context, t Tensor, axis int) (Tensor, error) } // Argmaxer is any engine that can find the indices of the minimum values along an axis. // By convention the returned Tensor has Dtype of Int. type Argminer interface { - Argmin(t Tensor, axis int) (Tensor, error) + Argmin(ctx context.Context, t Tensor, axis int) (Tensor, error) } // NaNChecker checks that the tensor contains a NaN // Errors are to be returned if the concept of NaN does not apply to the data type. // Other errors may also occur. See specific implementations for details type NaNChecker interface { - HasNaN(t Tensor) (bool, error) + HasNaN(ctx context.Context, t Tensor) (bool, error) } // InfChecker checks that the tensor contains a Inf. // Errors are to be returned if the concept of Inf does not apply to the data type. // Other errors may also occur. See specific implementations for details type InfChecker interface { - HasInf(t Tensor) (bool, error) + HasInf(ctx context.Context, t Tensor) (bool, error) } /* Advanced Indexing */ // ByIndiceser allows for values in tensor `a` to be selected by the indices listed in the `indices` tensor. type ByIndiceser interface { - SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) - SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndices(ctx context.Context, a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndicesB(ctx context.Context, a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) } type Scatterer interface { - Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) + Scatter(ctx context.Context, a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) } /* Internal interfaces for faster shit */ From b0b47474bb5d089443da840d39aac11fc94bbd29 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 09:39:48 +1000 Subject: [PATCH 118/154] Converted transpose and concat to use context. --- defaultengine_matop_misc.go | 6 +++++- defaultengine_matop_transpose.go | 8 +++++++- defaultengine_matop_transpose_inplace.go | 8 +++++++- dense_matop_memmove.go | 8 +++++--- engine.go | 5 +++++ utils.go | 11 +++++++++++ 6 files changed, 40 insertions(+), 6 deletions(-) diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index ffcbf4c..7a13195 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -232,7 +232,11 @@ func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, } // Concat tensors -func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { +func (e StdEng) Concat(ctx context.Context, t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + switch tt := t.(type) { case DenseTensor: var denses []DenseTensor diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index cef220e..7ca63cb 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -3,10 +3,16 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) -func (e StdEng) Transpose(a Tensor, expStrides []int) error { +func (e StdEng) Transpose(ctx context.Context, a Tensor, expStrides []int) error { + if err := handleCtx(ctx); err != nil { + return err + } + if !a.IsNativelyAccessible() { return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") } diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index 8627927..8d1d5f3 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -3,10 +3,16 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) -func (e StdEng) Transpose(a Tensor, expStrides []int) error { +func (e StdEng) Transpose(ctx context.Context, a Tensor, expStrides []int) error { + if err := handleCtx(ctx); err != nil { + return err + } + if !a.IsNativelyAccessible() { return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") } diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index f2a54e2..75bdc8c 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -43,13 +43,14 @@ func (t *Dense) Transpose() error { } // actually move data - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) transposer, ok := e.(Transposer) if !ok { return errors.Errorf("Engine does not support Transpose()") } - return transposer.Transpose(t, expStrides) + return transposer.Transpose(ctx, t, expStrides) } // Repeat is like Numpy's repeat. It repeats the elements of an array. @@ -67,11 +68,12 @@ func (t *Dense) Repeat(axis int, repeats ...int) (retVal Tensor, err error) { // Concat concatenates the other tensors along the given axis. It is like Numpy's concatenate() function. func (t *Dense) Concat(axis int, Ts ...*Dense) (retVal *Dense, err error) { e := t.Engine() + ctx := ctxFromEngine(e) if c, ok := e.(Concater); ok { var ret Tensor others := densesToTensors(Ts) - if ret, err = c.Concat(t, axis, others...); err != nil { + if ret, err = c.Concat(ctx, t, axis, others...); err != nil { return nil, errors.Wrapf(err, opFail, "Concat") } return ret.(*Dense), nil diff --git a/engine.go b/engine.go index 93b19d5..025a757 100644 --- a/engine.go +++ b/engine.go @@ -69,6 +69,11 @@ type arrayMaker interface { makeArray(arr *array, t dtype.Dtype, size int) } +// contexter is any engine (or type) that returns the current context. +type contexter interface { + Context() context.Context +} + // NonStdEngine are any engines that do not allocate using the default built in allocator type NonStdEngine interface { NonStdAlloc() // noop diff --git a/utils.go b/utils.go index 22db57b..eba80de 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) @@ -289,6 +291,7 @@ func memsetBools(a []bool, v bool) { } } +// allones checks that a slice of ints are all 1. func allones(a []int) bool { for i := range a { if a[i] != 1 { @@ -298,6 +301,14 @@ func allones(a []int) bool { return true } +// ctxFromEngine gets a context from an engine if it's a contexter. Otherwise it returns a context.Background() +func ctxFromEngine(e Engine) context.Context { + if c, ok := e.(contexter); ok { + return c.Context() + } + return context.Background() +} + /* FOR ILLUSTRATIVE PURPOSES */ // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) From f6a3a04b61f700e730affcaa5eec0c15739db1de Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 09:52:36 +1000 Subject: [PATCH 119/154] added some notes clarifying the API design of the tensor package. Looks like it's poorly designed as there are repeat things --- README.md | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 66d6372..98108c8 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,16 @@ Package `tensor` is a package that provides efficient, generic (by some definiti The main purpose of this package is to support the operations required by [Gorgonia](https://gorgonia.org/gorgonia). ## Introduction ## -In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. +In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. While slices are cool, a large majority of scientific and numeric computing work relies heavily on matrices (two-dimensional arrays), three dimensional arrays and so on. In Go, the typical way of getting multidimensional arrays is to use something like `[][]T`. Applications that are more math heavy may opt to use the very excellent Gonum [`matrix` package](https://github.com/gonum/matrix). What then if we want to go beyond having a `float64` matrix? What if we wanted a 3-dimensional `float32` array? -It comes to reason then there should be a data structure that handles these things. The `tensor` package fits in that niche. +It comes to reason then there should be a data structure that handles these things. The `tensor` package fits in that niche. ### Basic Idea: Tensor ### A tensor is a multidimensional array. It's like a slice, but works in multiple dimensions. -With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstractions used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). +With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstractions used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). Tensors come with their own set of usage patterns and abstractions. Most of these have analogues in slices, enumerated below (do note that certain slice operation will have more than one tensor analogue - this is due to the number of options available): @@ -26,7 +26,7 @@ Tensors come with their own set of usage patterns and abstractions. Most of thes | `a[0]` | `T.At(x,y)` | | `append(a, ...)`| `T.Stack(...)`, `T.Concat(...)` | | `copy(dest, src)`| `T.CopyTo(dest)`, `tensor.Copy(dest, src)` | -| `for _, v := range a` | `for i, err := iterator.Next(); err == nil; i, err = iterator.Next()` | +| `for _, v := range a` | `for i, err := iterator.Next(); err == nil; i, err = iterator.Next()` | Some operations for a tensor does not have direct analogues to slice operations. However, they stem from the same idea, and can be considered a superset of all operations common to slices. They're enumerated below: @@ -77,7 +77,7 @@ fmt.Printf("a:\n%v\n", a) To create a 3-Tensor is just as easy - just put the correct shape and you're good to go: -```go +```go // Creating a (2,3,4) 3-Tensor of float32 b := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) fmt.Printf("b:\n%1.1f\n", b) @@ -133,6 +133,12 @@ fmt.Printf("b:\n%v", b) There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/gorgonia.org/tensor) page +## API Notes ## + +This package has a notion of "layers" in its API. This section clarifies the different patterns seen in the API. + + + ## Design of `*Dense` ## @@ -142,7 +148,7 @@ The design of the `*Dense` tensor is quite simple in concept. However, let's sta The data structure for `*Dense` is similar, but a lot more complex. Much of the complexity comes from the need to do accounting work on the data structure as well as preserving references to memory locations. This is how the `*Dense` is defined: -```go +```go type Dense struct { *AP array @@ -168,7 +174,7 @@ type array struct { } ``` -`*storage.Header` is the same structure as `reflect.SliceHeader`, except it stores a `unsafe.Pointer` instead of a `uintptr`. This is done so that eventually when more tests are done to determine how the garbage collector marks data, the `v` field may be removed. +`*storage.Header` is the same structure as `reflect.SliceHeader`, except it stores a `unsafe.Pointer` instead of a `uintptr`. This is done so that eventually when more tests are done to determine how the garbage collector marks data, the `v` field may be removed. The `storage.Header` field of the `array` (and hence `*Dense`) is there to provide a quick and easy way to translate back into a slice for operations that use familiar slice semantics, of which much of the operations are dependent upon. @@ -205,17 +211,17 @@ The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https: Example: -```go +```go x := New(WithBacking([]string{"hello", "world", "hello", "world"}), WithShape(2,2)) x = New(WithBacking([]int{1,2,3,4}), WithShape(2,2)) ``` -The above code will not cause a compile error, because the structure holding the underlying array (of `string`s and then of `int`s) is a `*Dense`. +The above code will not cause a compile error, because the structure holding the underlying array (of `string`s and then of `int`s) is a `*Dense`. One could argue that this sidesteps the compiler's type checking system, deferring it to runtime (which a number of people consider dangerous). However, tools are being developed to type check these things, and until Go does support typechecked generics, unfortunately this will be the way it has to be. -Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. +Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. # How This Package is Developed # Much of the code in this package is generated. The code to generate them is in the directory `genlib2`. `genlib2` requires [`goimports`](https://godoc.org/golang.org/x/tools/cmd/goimports) binary to be available in the $PATH. @@ -246,7 +252,7 @@ See also: CONTRIBUTING.md ## Contributors and Significant Contributors ## -All contributions are welcome. However, there is a new class of contributor, called Significant Contributors. +All contributions are welcome. However, there is a new class of contributor, called Significant Contributors. A Significant Contributor is one who has shown *deep understanding* of how the library works and/or its environs. Here are examples of what constitutes a Significant Contribution: From 47dcee9bf3e8025e7202bdf012db5ba36a7a6675 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:01:45 +1000 Subject: [PATCH 120/154] Fixed up Stack, Repeat and Diag --- defaultengine.go | 7 ++++++- defaultengine_matop_misc.go | 20 +++++++++++++++++--- defaultengine_matop_stack.go | 13 ++++++++++++- dense_matop_memmove.go | 7 +++++-- genlib2/engine.go | 3 +++ 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/defaultengine.go b/defaultengine.go index f9f5854..5338391 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -6,9 +6,14 @@ import ( "gorgonia.org/tensor/internal/execution" ) +// stdDenseEng is the default execution engine for dense tensor operations. +type stdDenseEng struct { + execution.E +} + // StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. type StdEng struct { - execution.E + stdDenseEng } // makeArray allocates a slice for the array diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 7a13195..303c5e2 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" @@ -17,7 +19,11 @@ type fastcopier interface { } // Repeat ... -func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { +func (e StdEng) Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + switch tt := t.(type) { case DenseTensor: newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) @@ -32,7 +38,11 @@ func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { } // RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. -func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { +func (e StdEng) RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + switch tt := t.(type) { case DenseTensor: newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) @@ -368,7 +378,11 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen } // Diag ... -func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { +func (e StdEng) Diag(ctx context.Context, t Tensor) (retVal Tensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + a, ok := t.(DenseTensor) if !ok { return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index d5e661a..33c148d 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -1,12 +1,23 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) // This file contains code for the execution engine to stack tensors -func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { +var ( + // _ Stacker = StdEng{} + _ DenseStacker = StdEng{} +) + +func (e StdEng) StackDense(ctx context.Context, t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + opdims := t.Dims() if axis >= opdims+1 { err = errors.Errorf(dimMismatch, opdims+1, axis) diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index 75bdc8c..9d63082 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -58,9 +58,10 @@ func (t *Dense) Transpose() error { // Just like NumPy, the repeats param is broadcasted to fit the size of the given axis. func (t *Dense) Repeat(axis int, repeats ...int) (retVal Tensor, err error) { e := t.Engine() + ctx := ctxFromEngine(e) if rp, ok := e.(Repeater); ok { - return rp.Repeat(t, axis, repeats...) + return rp.Repeat(ctx, t, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } @@ -129,8 +130,10 @@ func (t *Dense) Stack(axis int, others ...*Dense) (retVal *Dense, err error) { } func (t *Dense) stackDense(axis int, others ...DenseTensor) (retVal DenseTensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) if ds, ok := t.Engine().(DenseStacker); ok { - return ds.StackDense(t, axis, others...) + return ds.StackDense(ctx, t, axis, others...) } return nil, errors.Errorf("Engine does not support DenseStacker") } diff --git a/genlib2/engine.go b/genlib2/engine.go index 7b7a207..d9561e4 100644 --- a/genlib2/engine.go +++ b/genlib2/engine.go @@ -7,6 +7,7 @@ import ( ) type EngineArith struct { + isStdDenseEng bool Name string VecVar string PrepData string @@ -33,9 +34,11 @@ func (fn *EngineArith) Signature() *Signature { case fn.VV: paramNames = []string{"a", "b", "opts"} paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} + default: paramNames = []string{"t", "s", "leftTensor", "opts"} paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} + } return &Signature{ Name: fn.methName(), From 084762d9e8630f912474fa00559def64f2edf978 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:17:16 +1000 Subject: [PATCH 121/154] Fixed the linear algebra methods to include context.Context --- defaultengine_linalg.go | 45 +++++++++++++++++++++++++++++++---------- defaultenginefloat32.go | 16 ++++++++++----- defaultenginefloat64.go | 15 +++++++++----- dense_linalg.go | 30 ++++++++++++++------------- engine.go | 2 +- 5 files changed, 72 insertions(+), 36 deletions(-) diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 90cff91..287fb18 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -1,6 +1,7 @@ package tensor import ( + "context" "reflect" "github.com/pkg/errors" @@ -10,7 +11,11 @@ import ( ) // Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error -func (e StdEng) Trace(t Tensor, opts ...FuncOpt) (retVal interface{}, err error) { +func (e StdEng) Trace(ctx context.Context, t Tensor, opts ...FuncOpt) (retVal interface{}, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + if t.Dims() != 2 { err = errors.Errorf(dimMismatch, 2, t.Dims()) return @@ -119,6 +124,12 @@ func (e StdEng) Trace(t Tensor, opts ...FuncOpt) (retVal interface{}, err error) } func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return nil, err + } + if _, ok := x.(DenseTensor); !ok { err = errors.Errorf("Engine only supports working on x that is a DenseTensor. Got %T instead", x) return @@ -139,8 +150,6 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { return } - fo := ParseFuncOpts(opts...) - var reuse, incr DenseTensor if reuse, err = getFloatDenseTensor(fo.reuse); err != nil { err = errors.Wrapf(err, opFail, "Dot - reuse") @@ -212,7 +221,7 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { return } var ret interface{} - if ret, err = e.Inner(a, b); err != nil { + if ret, err = e.Inner(ctx, a, b); err != nil { return nil, errors.Wrapf(err, opFail, "Dot") } return New(FromScalar(ret)), nil @@ -309,7 +318,11 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } // TODO: make it take DenseTensor -func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { +func (e StdEng) SVD(ctx context.Context, a Tensor, uv, full bool) (s, u, v Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, nil, nil, err + } + var t *Dense var ok bool if err = e.checkAccessible(a); err != nil { @@ -372,9 +385,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // Inner is a thin layer over BLAS's D/Sdot. // It returns a scalar value, wrapped in an interface{}, which is not quite nice. -func (e StdEng) Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err error) { - fo := ParseFuncOpts(opts...) - ctx := fo.Context() +func (e StdEng) Inner(ctx context.Context, a, b Tensor) (retVal interface{}, err error) { if err = handleCtx(ctx); err != nil { return nil, err // this err will be noopError{}, no need to wrap. } @@ -405,7 +416,11 @@ func (e StdEng) Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err err // Because DGEMV computes: // y = αA * x + βy // we set beta to 0, so we don't have to manually zero out the reused/retval tensor data -func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { +func (e StdEng) MatVecMul(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err := handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { @@ -477,7 +492,11 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { // DGEMM computes: // C = αA * B + βC // To prevent needless zeroing out of the slice, we just set β to 0 -func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { +func (e StdEng) MatMul(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err := handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { @@ -585,7 +604,11 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { } // Outer is a thin wrapper over S/Dger -func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { +func (e StdEng) Outer(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err = handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index 44a126b..2b78aad 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -127,7 +127,11 @@ func (e Float32Engine) makeArray(arr *array, t dtype.Dtype, size int) { arr.t = t } -func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { +func (e Float32Engine) FMA(ctx context.Context, a, x, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkThree(a, x, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -150,7 +154,11 @@ func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { return } -func (e Float32Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { +func (e Float32Engine) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkTwo(a, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -228,9 +236,7 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, return } -func (e Float32Engine) Inner(a, b Tensor, opts ...FuncOpt) (retVal float32, err error) { - fo := ParseFuncOpts(opts...) - ctx := fo.Context() +func (e Float32Engine) Inner(ctx context.Context, a, b Tensor) (retVal float32, err error) { if err = handleCtx(ctx); err != nil { return 0, err // this err will be noopError{}, no need to wrap. } diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 6bbf95c..85c59b2 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -123,7 +123,11 @@ func (e Float64Engine) makeArray(arr *array, t dtype.Dtype, size int) { arr.t = t } -func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { +func (e Float64Engine) FMA(ctx context.Context, a, x, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkThree(a, x, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -146,7 +150,10 @@ func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { return } -func (e Float64Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { +func (e Float64Engine) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } reuse := y if err = e.checkTwo(a, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -224,9 +231,7 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, return } -func (e Float64Engine) Inner(a, b Tensor, opts ...FuncOpt) (retVal float64, err error) { - fo := ParseFuncOpts(opts...) - ctx := fo.Context() +func (e Float64Engine) Inner(ctx context.Context, a, b Tensor, opts ...FuncOpt) (retVal float64, err error) { if err = handleCtx(ctx); err != nil { return 0, err // this err will be noopError{}, no need to wrap. } diff --git a/dense_linalg.go b/dense_linalg.go index 002f0c7..5b103b2 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -6,18 +6,17 @@ import ( ) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices -func (t *Dense) Trace(opts ...FuncOpt) (retVal interface{}, err error) { +func (t *Dense) Trace() (retVal interface{}, err error) { e := t.e - + ctx := ctxFromEngine(e) if tracer, ok := e.(Tracer); ok { - return tracer.Trace(t, opts...) + return tracer.Trace(ctx, t) } return nil, errors.Errorf("Engine %T does not support Trace", e) } // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. -func (t *Dense) Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err error) { - +func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { // check that the data is a float if err = dtype.TypeClassCheck(t.t, dtype.FloatComplex); err != nil { return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") @@ -35,13 +34,14 @@ func (t *Dense) Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err er } e := t.e + ctx := ctxFromEngine(e) switch ip := e.(type) { case InnerProderF32: - return ip.Inner(t, other, opts...) + return ip.Inner(ctx, t, other) case InnerProderF64: - return ip.Inner(t, other, opts...) + return ip.Inner(ctx, t, other) case InnerProder: - return ip.Inner(t, other, opts...) + return ip.Inner(ctx, t, other) } return nil, errors.Errorf("Engine does not support Inner()") @@ -95,11 +95,11 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err AsFortran(nil)(retVal) } } + ctx := fo.Context() e := t.e - if mvm, ok := e.(MatVecMuler); ok { - if err = mvm.MatVecMul(t, other, retVal); err != nil { + if err = mvm.MatVecMul(ctx, t, other, retVal); err != nil { return nil, errors.Wrapf(err, opFail, "MatVecMul") } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -144,10 +144,11 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) AsFortran(nil)(retVal) } } + ctx := fo.Context() e := t.e if mm, ok := e.(MatMuler); ok { - if err = mm.MatMul(t, other, retVal); err != nil { + if err = mm.MatMul(ctx, t, other, retVal); err != nil { return } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -183,13 +184,14 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) AsFortran(nil)(retVal) } } + ctx := fo.Context() e := t.e // DGER does not have any beta. So the values have to be zeroed first if the tensor is to be reused retVal.Zero() if op, ok := e.(OuterProder); ok { - if err = op.Outer(t, other, retVal); err != nil { + if err = op.Outer(ctx, t, other, retVal); err != nil { return nil, errors.Wrapf(err, opFail, "engine.uter") } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -357,10 +359,10 @@ func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err // In the future, when gonum/lapack fully supports float32, we'll look into rewriting this func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) { e := t.Engine() - + ctx := ctxFromEngine(e) if svder, ok := e.(SVDer); ok { var sT, uT, vT Tensor - if sT, uT, vT, err = svder.SVD(t, uv, full); err != nil { + if sT, uT, vT, err = svder.SVD(ctx, t, uv, full); err != nil { return nil, nil, nil, errors.Wrap(err, "Error while performing *Dense.SVD") } if s, err = assertDense(sT); err != nil { diff --git a/engine.go b/engine.go index 025a757..7f99a91 100644 --- a/engine.go +++ b/engine.go @@ -222,7 +222,7 @@ type OuterProder interface { // Dotter is used to implement sparse matrices type Dotter interface { - Dot(ctx context.Context, a, b Tensor) (Tensor, error) + Dot(a, b Tensor, opts ...FuncOpt) (Tensor, error) } // SVDer is any engine that can perform SVD From f39319c09eceaa222fba5a0d6992590515d84c95 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:24:08 +1000 Subject: [PATCH 122/154] Fixed reduction methods to use context --- defaultengine_mapreduce.go | 18 ++++++++++++--- dense_reduction_methods.go | 46 ++++++++++++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index f553fb7..2c7ccd4 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -181,7 +181,11 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, return } -func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) Sum(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + a2 := a if v, ok := a.(View); ok && v.IsMaterializable() { a2 = v.Materialize() @@ -189,7 +193,11 @@ func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...) } -func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) Min(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + a2 := a if v, ok := a.(View); ok && v.IsMaterializable() { a2 = v.Materialize() @@ -197,7 +205,11 @@ func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...) } -func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) Max(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + a2 := a if v, ok := a.(View); ok && v.IsMaterializable() { a2 = v.Materialize() diff --git a/dense_reduction_methods.go b/dense_reduction_methods.go index cb744b5..28058a2 100644 --- a/dense_reduction_methods.go +++ b/dense_reduction_methods.go @@ -3,37 +3,65 @@ package tensor import "github.com/pkg/errors" func (t *Dense) Sum(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if sumer, ok := e.(Sumer); ok { var ret Tensor - if ret, err = sumer.Sum(t, along...); err != nil { + if ret, err = sumer.Sum(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Sum") + } + return } return nil, errors.Errorf("Engine does not support Sum") } +func (t *Dense) Prod(along ...int) (retVal *Dense, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Proder); ok { + var ret Tensor + if ret, err = sumer.Prod(ctx, t, along...); err != nil { + return + } + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Prod") + } + return + } + return nil, errors.Errorf("Engine does not support Prod") +} + func (t *Dense) Max(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if maxer, ok := e.(Maxer); ok { var ret Tensor - if ret, err = maxer.Max(t, along...); err != nil { + if ret, err = maxer.Max(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Max") + } + return } return nil, errors.Errorf("Engine does not support Max") } func (t *Dense) Min(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if miner, ok := e.(Miner); ok { var ret Tensor - if ret, err = miner.Min(t, along...); err != nil { + if ret, err = miner.Min(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Min") + } + return } return nil, errors.Errorf("Engine does not support Min") } From 2ce32f759a2460f6f00544f6f58065ce64674ec4 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:26:19 +1000 Subject: [PATCH 123/154] Fixed argmethods to use context --- defaultengine_argmethods.go | 12 ++++++++++-- dense_argmethods.go | 6 ++++-- interfaces.go | 2 +- testutils_test.go | 9 +++++++-- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index 21373fd..a25d5b6 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -1,11 +1,16 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/dtype" ) -func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) { +func (e StdEng) Argmax(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } switch tt := t.(type) { case DenseTensor: @@ -92,7 +97,10 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e return New(WithShape(newShape...), WithBacking(indices)), nil } -func (e StdEng) Argmin(t Tensor, axis int) (retVal Tensor, err error) { +func (e StdEng) Argmin(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } switch tt := t.(type) { case DenseTensor: diff --git a/dense_argmethods.go b/dense_argmethods.go index bfdc0d7..ca64784 100644 --- a/dense_argmethods.go +++ b/dense_argmethods.go @@ -7,13 +7,14 @@ import "github.com/pkg/errors" // Argmax finds the index of the max value along the axis provided func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { e := t.e + ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgmaxer: return am.argmaxDenseTensor(t, axis) case Argmaxer: var ret Tensor var ok bool - if ret, err = am.Argmax(t, axis); err != nil { + if ret, err = am.Argmax(ctx, t, axis); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } if retVal, ok = ret.(*Dense); !ok { @@ -29,13 +30,14 @@ func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { // Argmin finds the index of the min value along the axis provided func (t *Dense) Argmin(axis int) (retVal *Dense, err error) { e := t.e + ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgminer: return am.argminDenseTensor(t, axis) case Argminer: var ret Tensor var ok bool - if ret, err = am.Argmin(t, axis); err != nil { + if ret, err = am.Argmin(ctx, t, axis); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } if retVal, ok = ret.(*Dense); !ok { diff --git a/interfaces.go b/interfaces.go index aabde3d..4b44154 100644 --- a/interfaces.go +++ b/interfaces.go @@ -112,7 +112,7 @@ type DenseTensor interface { cap() int // operations - Inner(other Tensor, opts ...FuncOpt) (retVal interface{}, err error) + Inner(other Tensor) (retVal interface{}, err error) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err error) diff --git a/testutils_test.go b/testutils_test.go index 39a7bf7..77312fb 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -2,6 +2,7 @@ package tensor import ( "bytes" + "context" "errors" "math" "math/cmplx" @@ -504,8 +505,12 @@ func (e dummyEngine2) Memcpy(dst, src Memory) error { return e.e.Mem func (e dummyEngine2) Accessible(mem Memory) (Memory, error) { return e.e.Accessible(mem) } func (e dummyEngine2) WorksWith(order DataOrder) bool { return e.e.WorksWith(order) } -func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) } -func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) } +func (e dummyEngine2) Argmax(ctx context.Context, t Tensor, axis int) (Tensor, error) { + return e.e.Argmax(ctx, t, axis) +} +func (e dummyEngine2) Argmin(ctx context.Context, t Tensor, axis int) (Tensor, error) { + return e.e.Argmin(ctx, t, axis) +} func willerr(a *Dense, tc, eqtc dtype.TypeClass) (retVal, willFailEq bool) { if eqtc == nilTC { From a5388ef6716c9f1ab512c601929db6de4a3b8b96 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:36:06 +1000 Subject: [PATCH 124/154] Fixed up a missing FMA fix for handling contexts --- defaultengine_misc.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/defaultengine_misc.go b/defaultengine_misc.go index c7fc933..6a0c570 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -80,9 +80,15 @@ func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal T return } -func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) { +func (e StdEng) FMA(ctx context.Context, a, x, y Tensor) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } return e.Mul(a, x, WithIncr(y)) } -func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) { +func (e StdEng) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } return e.MulScalar(a, x, true, WithIncr(y)) } From 695f4ec693db41affdfc463f8cf0d669fb87cf0c Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:46:37 +1000 Subject: [PATCH 125/154] Fixed most of the APIs. Some of them like the reduction-related package level function needs to be redefined --- api_arith.go | 19 +++++++++++++------ api_matop.go | 16 +++++++++++----- api_reduction.go | 34 ++++++++++++++++++++++++++-------- defaultengine_linalg.go | 2 +- engine.go | 6 +++--- 5 files changed, 54 insertions(+), 23 deletions(-) diff --git a/api_arith.go b/api_arith.go index 6369875..5638549 100644 --- a/api_arith.go +++ b/api_arith.go @@ -526,19 +526,24 @@ func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { // FMA performs Y = A * X + Y. func FMA(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { - var e FMAer + var fm FMAer + if xTensor, ok := x.(Tensor); ok { for _, T := range [3]Tensor{a, xTensor, y} { - e, ok = T.Engine().(FMAer) + e := T.Engine() + ctx := ctxFromEngine(e) + fm, ok = e.(FMAer) if ok { - return e.FMA(a, xTensor, y) + return fm.FMA(ctx, a, xTensor, y) } } } else { for _, T := range [2]Tensor{a, y} { - e, ok = T.Engine().(FMAer) + e := T.Engine() + ctx := ctxFromEngine(e) + fm, ok = e.(FMAer) if ok { - return e.FMAScalar(a, x, y) + return fm.FMAScalar(ctx, a, x, y) } } } @@ -593,14 +598,16 @@ func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { var reuse Tensor fo := ParseFuncOpts(opts...) defer returnOpOpt(fo) + ctx := fo.Context() reuse = fo.Reuse() if reuse == nil { return nil, errors.Errorf("MatMul requires passing in of a reuse Tensor for now.") } + if err := checkFixShape(reuse, expectedShape); err != nil { return nil, errors.Wrapf(err, opFail, "MatMul") } - if err = mm.MatMul(a, b, reuse); err != nil { + if err = mm.MatMul(ctx, a, b, reuse); err != nil { return nil, errors.Wrapf(err, opFail, "MatMul") } diff --git a/api_matop.go b/api_matop.go index 5df9a8f..71051f5 100644 --- a/api_matop.go +++ b/api_matop.go @@ -9,16 +9,20 @@ import ( // Repeat repeats a Tensor along the axis and given the number of repeats. func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { - if r, ok := t.Engine().(Repeater); ok { - return r.Repeat(t, axis, repeats...) + e := t.Engine() + ctx := ctxFromEngine(e) + if r, ok := e.(Repeater); ok { + return r.Repeat(ctx, t, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } // RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid. func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) { - if r, ok := t.Engine().(Repeater); ok { - return r.RepeatReuse(t, reuse, axis, repeats...) + e := t.Engine() + ctx := ctxFromEngine(e) + if r, ok := e.(Repeater); ok { + return r.RepeatReuse(ctx, t, reuse, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } @@ -134,8 +138,10 @@ func Materialize(t Tensor) Tensor { } func Diag(t Tensor) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) if d, ok := t.Engine().(Diager); ok { - return d.Diag(t) + return d.Diag(ctx, t) } return nil, errors.Errorf("Unable to perform diagonalization of tensor ") } diff --git a/api_reduction.go b/api_reduction.go index 414abfd..63c2257 100644 --- a/api_reduction.go +++ b/api_reduction.go @@ -4,32 +4,50 @@ import "github.com/pkg/errors" // Sum sums a Tensor along the given axes. func Sum(t Tensor, along ...int) (retVal Tensor, err error) { - if sumer, ok := t.Engine().(Sumer); ok { - return sumer.Sum(t, along...) + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Sumer); ok { + return sumer.Sum(ctx, t, along...) } return nil, errors.New("Engine does not support Sum()") } +// Prod sums a Tensor along the given axes. +func Prod(t Tensor, along ...int) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Proder); ok { + return sumer.Prod(ctx, t, along...) + } + return nil, errors.New("Engine does not support Prod()") +} + // Max finds the maximum value along the given axes. func Max(t Tensor, along ...int) (retVal Tensor, err error) { - if maxer, ok := t.Engine().(Maxer); ok { - return maxer.Max(t, along...) + e := t.Engine() + ctx := ctxFromEngine(e) + if maxer, ok := e.(Maxer); ok { + return maxer.Max(ctx, t, along...) } return nil, errors.New("Engine does not support Max()") } // Argmax finds the index of the max value along the axis provided func Argmax(t Tensor, axis int) (retVal Tensor, err error) { - if argmaxer, ok := t.Engine().(Argmaxer); ok { - return argmaxer.Argmax(t, axis) + e := t.Engine() + ctx := ctxFromEngine(e) + if argmaxer, ok := e.(Argmaxer); ok { + return argmaxer.Argmax(ctx, t, axis) } return nil, errors.New("Engine does not support Argmax()") } // Argmin finds the index of the min value along the axis provided func Argmin(t Tensor, axis int) (retVal Tensor, err error) { - if argminer, ok := t.Engine().(Argminer); ok { - return argminer.Argmin(t, axis) + e := t.Engine() + ctx := ctxFromEngine(e) + if argminer, ok := e.(Argminer); ok { + return argminer.Argmin(ctx, t, axis) } return nil, errors.New("Engine does not support Argmax()") } diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 287fb18..6274504 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -636,7 +636,7 @@ func (e StdEng) Outer(ctx context.Context, a, b, prealloc Tensor) (err error) { return err } - if err = e.MatMul(a, b, prealloc); err != nil { + if err = e.MatMul(ctx, a, b, prealloc); err != nil { return err } diff --git a/engine.go b/engine.go index 7f99a91..cec9c3b 100644 --- a/engine.go +++ b/engine.go @@ -413,12 +413,12 @@ type InfChecker interface { // ByIndiceser allows for values in tensor `a` to be selected by the indices listed in the `indices` tensor. type ByIndiceser interface { - SelectByIndices(ctx context.Context, a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) - SelectByIndicesB(ctx context.Context, a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) } type Scatterer interface { - Scatter(ctx context.Context, a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) + Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) } /* Internal interfaces for faster shit */ From c745b2dfcada6be898e832c54670978e550ac27c Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 20 Aug 2021 11:56:16 +1000 Subject: [PATCH 126/154] Fixed up argmethods. Now to work on actually fixing the API --- defaultengine_argmethods.go | 20 ++++++++++---------- dense_argmethods.go | 4 ++-- engine.go | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index a25d5b6..f2bbf60 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -8,19 +8,19 @@ import ( ) func (e StdEng) Argmax(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { - if err = handleCtx(ctx); err != nil { - return nil, err - } switch tt := t.(type) { case DenseTensor: - return e.argmaxDenseTensor(tt, axis) + return e.argmaxDenseTensor(ctx, tt, axis) default: return nil, errors.Errorf(typeNYI, "StdEng.Argmax", t) } } -func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { +func (e StdEng) argmaxDenseTensor(ctx context.Context, t DenseTensor, axis int) (retVal *Dense, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } @@ -98,19 +98,19 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e } func (e StdEng) Argmin(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { - if err = handleCtx(ctx); err != nil { - return nil, err - } switch tt := t.(type) { case DenseTensor: - return e.argminDenseTensor(tt, axis) + return e.argminDenseTensor(ctx, tt, axis) default: return nil, errors.Errorf(typeNYI, "StdEng.Argmin", t) } } -func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { +func (e StdEng) argminDenseTensor(ctx context.Context, t DenseTensor, axis int) (retVal *Dense, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmin") } diff --git a/dense_argmethods.go b/dense_argmethods.go index ca64784..fdace5f 100644 --- a/dense_argmethods.go +++ b/dense_argmethods.go @@ -10,7 +10,7 @@ func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgmaxer: - return am.argmaxDenseTensor(t, axis) + return am.argmaxDenseTensor(ctx, t, axis) case Argmaxer: var ret Tensor var ok bool @@ -33,7 +33,7 @@ func (t *Dense) Argmin(axis int) (retVal *Dense, err error) { ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgminer: - return am.argminDenseTensor(t, axis) + return am.argminDenseTensor(ctx, t, axis) case Argminer: var ret Tensor var ok bool diff --git a/engine.go b/engine.go index cec9c3b..fb31170 100644 --- a/engine.go +++ b/engine.go @@ -424,9 +424,9 @@ type Scatterer interface { /* Internal interfaces for faster shit */ type denseArgmaxer interface { - argmaxDenseTensor(t DenseTensor, axis int) (*Dense, error) + argmaxDenseTensor(ctx context.Context, t DenseTensor, axis int) (*Dense, error) } type denseArgminer interface { - argminDenseTensor(t DenseTensor, axis int) (*Dense, error) + argminDenseTensor(ctx context.Context, t DenseTensor, axis int) (*Dense, error) } From b6b33e3ff9e300b9fde428cacf534557a5da05a4 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 3 Sep 2021 09:56:57 +1000 Subject: [PATCH 127/154] Fixed the API for Inner --- api_arith.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_arith.go b/api_arith.go index 5638549..13ccd05 100644 --- a/api_arith.go +++ b/api_arith.go @@ -644,7 +644,7 @@ func Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err error) { switch at := a.(type) { case *Dense: bt := b.(*Dense) - return at.Inner(bt, opts...) + return at.Inner(bt) } panic("Unreachable") } From 139501815a90301567faba98aad7c7d3a05b5d01 Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 3 Sep 2021 10:04:00 +1000 Subject: [PATCH 128/154] Fixed the StdEng's definition of Trace to actually implement Tracer --- defaultengine_linalg.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 6274504..1f7eaef 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -11,7 +11,7 @@ import ( ) // Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error -func (e StdEng) Trace(ctx context.Context, t Tensor, opts ...FuncOpt) (retVal interface{}, err error) { +func (e StdEng) Trace(ctx context.Context, t Tensor) (retVal interface{}, err error) { if err := handleCtx(ctx); err != nil { return nil, err } From 5efcf9a70c0edda2a7b9625cbea06b3d5b465e52 Mon Sep 17 00:00:00 2001 From: chewxy Date: Thu, 9 Sep 2021 13:34:20 +1000 Subject: [PATCH 129/154] Changed the shape definition of Outer such that it only checks for the product of shapes Added a few more engine interfaces (expm1 and log1p) --- dense_linalg.go | 6 ------ dense_matop.go | 4 +++- engine.go | 10 ++++++++++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/dense_linalg.go b/dense_linalg.go index 5b103b2..10eb936 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -159,12 +159,6 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) // Outer finds the outer product of two vectors func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) { - // check both are vectors - if !t.Shape().IsVector() || !other.Shape().IsVector() { - err = errors.Errorf("Outer only works when there are two vectors. t's shape: %v. other's shape: %v", t.Shape(), other.Shape()) - return - } - m := t.Size() n := other.Size() diff --git a/dense_matop.go b/dense_matop.go index cb976cc..e48b5cd 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -131,7 +131,9 @@ func (t *Dense) SetAt(v interface{}, coords ...int) error { return errors.Errorf(inaccessibleData, t) } - if len(coords) != t.Dims() { + switch { + case t.IsScalar() && len(coords) == 1: + case len(coords) != t.Dims(): return errors.Errorf(dimMismatch, t.Dims(), len(coords)) } diff --git a/engine.go b/engine.go index fb31170..9902809 100644 --- a/engine.go +++ b/engine.go @@ -334,6 +334,16 @@ type InvSqrter interface { InvSqrt(a Tensor, opts ...FuncOpt) (Tensor, error) } +// Expm1er is any engine that can perform expm1 on the values of a Tensor. +type Expm1er interface { + Expm1(a Tensor, opts ...FuncOpt) (Tensor, error) +} + +// Log1per is any engine that can perform log1p on the values of a Tensor. +type Log1per interface { + Log1p(a Tensor, opts ...FuncOpt) (Tensor, error) +} + // Signer is any engine that can perform a sign function on the values of a Tensor. type Signer interface { Sign(a Tensor, opts ...FuncOpt) (Tensor, error) From c08a781544d86349e38dfebcaccb77f9fdf82b8a Mon Sep 17 00:00:00 2001 From: Chewxy Date: Thu, 16 Sep 2021 12:46:15 +1000 Subject: [PATCH 130/154] Added Min/Max Between (#117) * Fixed #90 * Added MinBetween and MaxBetween engine def * Added code to generate Min/MaxBetween * Moved example out from the generated file * Generated MinBetween and MaxBetween methods for StdEng * Added some compile time assertions * Added API for Min/Max between * Added better prep for min/max between of engine --- api_minmax.go | 155 ++ defaultengine_minmax.go | 349 +++++ dense_reduction_test.go | 14 - engine.go | 16 + example_dense_reduction_test.go | 30 + genlib2/agg1_body.go | 59 + genlib2/agg2_body.go | 126 +- genlib2/engine.go | 117 ++ genlib2/generic_cmp.go | 160 +- genlib2/internaleng.go | 96 ++ genlib2/main.go | 2 + internal/execution/eng_minmaxbetween.go | 778 ++++++++++ internal/execution/generic_minmax.go | 1858 +++++++++++++++++++++++ 13 files changed, 3738 insertions(+), 22 deletions(-) create mode 100644 api_minmax.go create mode 100644 defaultengine_minmax.go create mode 100644 example_dense_reduction_test.go create mode 100644 internal/execution/eng_minmaxbetween.go diff --git a/api_minmax.go b/api_minmax.go new file mode 100644 index 0000000..964df7d --- /dev/null +++ b/api_minmax.go @@ -0,0 +1,155 @@ +package tensor + +import "github.com/pkg/errors" + +func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + var minbetweener MinBetweener + var oe standardEngine + var ok bool + switch at := a.(type) { + case Tensor: + oe = at.standardEngine() + switch bt := b.(type) { + case Tensor: + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition + if oe != nil { + return oe.MinBetween(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.MinBetween(at, bt, opts...) + } + if minbetweener, ok = at.Engine().(MinBetweener); ok { + return minbetweener.MinBetween(at, bt, opts...) + } + if minbetweener, ok = bt.Engine().(MinBetweener); ok { + return minbetweener.MinBetween(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support MinBetween") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.MinBetweenScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.MinBetweenScalar(at, bt, leftTensor, opts...) + } + if minbetweener, ok = at.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...) + } + if minbetweener, ok = bt.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support MinBetween") + } + + default: + if oe != nil { + return oe.MinBetweenScalar(at, bt, true, opts...) + } + if minbetweener, ok = at.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(at, bt, true, opts...) + } + return nil, errors.New("Operand A's engine does not support MinBetween") + } + default: + switch bt := b.(type) { + case Tensor: + if oe = bt.standardEngine(); oe != nil { + return oe.MinBetweenScalar(bt, at, false, opts...) + } + if minbetweener, ok = bt.Engine().(MinBetweener); ok { + return minbetweener.MinBetweenScalar(bt, at, false, opts...) + } + return nil, errors.New("Operand B's engine does not support MinBetween") + default: + return nil, errors.Errorf("Cannot perform MinBetween of %T and %T", a, b) + } + } + panic("Unreachable") +} + +func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + var maxbetweener MaxBetweener + var oe standardEngine + var ok bool + switch at := a.(type) { + case Tensor: + oe = at.standardEngine() + switch bt := b.(type) { + case Tensor: + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition + if oe != nil { + return oe.MaxBetween(at, bt, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.MaxBetween(at, bt, opts...) + } + if maxbetweener, ok = at.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetween(at, bt, opts...) + } + if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetween(at, bt, opts...) + } + return nil, errors.New("Neither engines of either operand support MaxBetween") + + } else { // at least one of the operands is a scalar + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false // a Scalar-Tensor * b Tensor + tmp := at + at = bt + bt = tmp + } else { + leftTensor = true // a Tensor * b Scalar-Tensor + } + + if oe != nil { + return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + if oe = bt.standardEngine(); oe != nil { + return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + if maxbetweener, ok = at.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...) + } + return nil, errors.New("Neither engines of either operand support MaxBetween") + } + + default: + if oe != nil { + return oe.MaxBetweenScalar(at, bt, true, opts...) + } + if maxbetweener, ok = at.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(at, bt, true, opts...) + } + return nil, errors.New("Operand A's engine does not support MaxBetween") + } + default: + switch bt := b.(type) { + case Tensor: + if oe = bt.standardEngine(); oe != nil { + return oe.MaxBetweenScalar(bt, at, false, opts...) + } + if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { + return maxbetweener.MaxBetweenScalar(bt, at, false, opts...) + } + return nil, errors.New("Operand B's engine does not support MaxBetween") + default: + return nil, errors.Errorf("Cannot perform MaxBetween of %T and %T", a, b) + } + } + panic("Unreachable") +} diff --git a/defaultengine_minmax.go b/defaultengine_minmax.go new file mode 100644 index 0000000..56ac432 --- /dev/null +++ b/defaultengine_minmax.go @@ -0,0 +1,349 @@ +// Code generated by genlib2. DO NOT EDIT. + +package tensor + +import ( + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +var ( + _ MinBetweener = StdEng{} + _ MaxBetweener = StdEng{} +) + +func (e StdEng) MinBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if err = binaryCheck(a, b, ordTypes); err != nil { + return nil, errors.Wrapf(err, "MinBetween failed") + } + + var reuse DenseTensor + var safe bool + if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + typ := a.Dtype().Type + var dataA, dataB, dataReuse *storage.Header + var ait, bit, iit Iterator + var useIter, swap bool + if dataA, dataB, dataReuse, ait, bit, iit, useIter, swap, err = prepDataVV(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.MinBetween") + } + // check to see if anything needs to be created + if reuse == nil { + if swap { + reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e)) + } else { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + } + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MinBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MinBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + return + } + + // standard + switch { + case !safe && reuse == nil: + err = e.E.MinBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MinBetween(typ, dataReuse, dataB) + retVal = reuse + default: + panic("Unreachable") + } + return +} + +func (e StdEng) MaxBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if err = binaryCheck(a, b, ordTypes); err != nil { + return nil, errors.Wrapf(err, "MaxBetween failed") + } + + var reuse DenseTensor + var safe bool + if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + typ := a.Dtype().Type + var dataA, dataB, dataReuse *storage.Header + var ait, bit, iit Iterator + var useIter, swap bool + if dataA, dataB, dataReuse, ait, bit, iit, useIter, swap, err = prepDataVV(a, b, reuse); err != nil { + return nil, errors.Wrapf(err, "StdEng.MaxBetween") + } + // check to see if anything needs to be created + if reuse == nil { + if swap { + reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e)) + } else { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + } + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MaxBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MaxBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + return + } + + // standard + switch { + case !safe && reuse == nil: + err = e.E.MaxBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MaxBetween(typ, dataReuse, dataB) + retVal = reuse + default: + panic("Unreachable") + } + return +} + +func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(t, ordTypes); err != nil { + return nil, errors.Wrapf(err, "MinBetween failed") + } + + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "MinBetween failed") + } + + var reuse DenseTensor + var safe bool + if reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + a := t + typ := t.Dtype().Type + var ait, bit, iit Iterator + var dataA, dataB, dataReuse, scalarHeader *storage.Header + var useIter, newAlloc bool + + if leftTensor { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MinBetween") + } + scalarHeader = dataB + } else { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MinBetween") + } + scalarHeader = dataA + } + + // check to see if anything needs to be created + if reuse == nil { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MinBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil && !leftTensor: + storage.CopyIter(typ, dataReuse, dataB, iit, bit) + bit.Reset() + iit.Reset() + err = e.E.MinBetweenIter(typ, dataA, dataReuse, ait, bit) + retVal = reuse + case safe && reuse != nil && leftTensor: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MinBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return + } + + // handle special case where A and B have both len 1 + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { + switch { + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MinBetween(typ, dataReuse, dataB) + retVal = reuse + return + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MinBetween(typ, dataReuse, dataA) + retVal = reuse + return + } + } + // standard + switch { + case !safe && reuse == nil: + err = e.E.MinBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MinBetween(typ, dataReuse, dataB) + retVal = reuse + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MinBetween(typ, dataA, dataReuse) + retVal = reuse + default: + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return +} + +func (e StdEng) MaxBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(t, ordTypes); err != nil { + return nil, errors.Wrapf(err, "MaxBetween failed") + } + + if err = scalarDtypeCheck(t, s); err != nil { + return nil, errors.Wrap(err, "MaxBetween failed") + } + + var reuse DenseTensor + var safe bool + if reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + a := t + typ := t.Dtype().Type + var ait, bit, iit Iterator + var dataA, dataB, dataReuse, scalarHeader *storage.Header + var useIter, newAlloc bool + + if leftTensor { + if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MaxBetween") + } + scalarHeader = dataB + } else { + if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.MaxBetween") + } + scalarHeader = dataA + } + + // check to see if anything needs to be created + if reuse == nil { + reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e)) + dataReuse = reuse.hdr() + if useIter { + iit = IteratorFromDense(reuse) + } + } + + if useIter { + switch { + case !safe && reuse == nil: + err = e.E.MaxBetweenIter(typ, dataA, dataB, ait, bit) + retVal = a + case safe && reuse != nil && !leftTensor: + storage.CopyIter(typ, dataReuse, dataB, iit, bit) + bit.Reset() + iit.Reset() + err = e.E.MaxBetweenIter(typ, dataA, dataReuse, ait, bit) + retVal = reuse + case safe && reuse != nil && leftTensor: + storage.CopyIter(typ, dataReuse, dataA, iit, ait) + ait.Reset() + iit.Reset() + err = e.E.MaxBetweenIter(typ, dataReuse, dataB, iit, bit) + retVal = reuse + default: // safe && bool + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return + } + + // handle special case where A and B have both len 1 + if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) { + switch { + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MaxBetween(typ, dataReuse, dataB) + retVal = reuse + return + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MaxBetween(typ, dataReuse, dataA) + retVal = reuse + return + } + } + // standard + switch { + case !safe && reuse == nil: + err = e.E.MaxBetween(typ, dataA, dataB) + retVal = a + case safe && reuse != nil && leftTensor: + storage.Copy(typ, dataReuse, dataA) + err = e.E.MaxBetween(typ, dataReuse, dataB) + retVal = reuse + case safe && reuse != nil && !leftTensor: + storage.Copy(typ, dataReuse, dataB) + err = e.E.MaxBetween(typ, dataA, dataReuse) + retVal = reuse + default: + panic("Unreachable") + } + if newAlloc { + freeScalar(scalarHeader.Raw) + } + returnHeader(scalarHeader) + return +} diff --git a/dense_reduction_test.go b/dense_reduction_test.go index f83d0c6..b10e3ac 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -547,17 +547,3 @@ func TestDense_Min(t *testing.T) { _, err = T.Min(1000) assert.NotNil(err) } - -func TestSlicedSum(t *testing.T) { - T := New(WithShape(4, 4), WithBacking([]int{ - 1, 2, 3, 4, - 5, 6, 7, 8, - 1, 2, 3, 4, - 5, 6, 7, 8, - })) - s, _ := T.Slice(sli(1, 3), sli(1, 3)) - sum, _ := Sum(s) - if sum.Data().(int) != 18 { - t.Errorf("Expected the sum of %v to be 18. Got %v instead", s, sum) - } -} diff --git a/engine.go b/engine.go index 1ac8400..f4d7bd2 100644 --- a/engine.go +++ b/engine.go @@ -44,6 +44,8 @@ type standardEngine interface { Gter Gteer ElEqer + MinBetweener + MaxBetweener // Anything that returns interface{} cannot be added here because they will likely have additional // optimized versions of the functions for types. @@ -157,6 +159,20 @@ type Moder interface { ModScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error) } +// MinBetweener is any engine that can perform an elementwise min=between. +type MinBetweener interface { + MinBetween(a, b Tensor, opts ...FuncOpt) (Tensor, error) + + MinBetweenScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error) +} + +// MaxBetweener is any engine that can perform an elementwise ma b[i]{ + b[i] = a + } + } +} + +func MaxVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){ + for i := range a { + if b > a[i]{ + a[i] = b + } + } +} +` + +// Iter Min/Max +const genericIterMinMaxRaw = `func MinIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){ + var i,j int + var validi ,validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + + +func MaxIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){ + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){ + var i,j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil{ + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + ` // scalar Min/Max @@ -413,6 +562,7 @@ const genericScalarMinMaxRaw = `func Min{{short .}}(a, b {{asType .}}) (c {{asTy return b } + func Max{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a > b { return a } @@ -421,13 +571,15 @@ func Max{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a > b { ` var ( - genericElMinMax *template.Template - genericMinMax *template.Template + genericElMinMax *template.Template + genericMinMax *template.Template + genericElMinMaxIter *template.Template ) func init() { genericElMinMax = template.Must(template.New("genericVecVecMinMax").Funcs(funcs).Parse(genericElMinMaxRaw)) genericMinMax = template.Must(template.New("genericMinMax").Funcs(funcs).Parse(genericScalarMinMaxRaw)) + genericElMinMaxIter = template.Must(template.New("genericIterMinMax").Funcs(funcs).Parse(genericIterMinMaxRaw)) } func generateMinMax(f io.Writer, ak Kinds) { @@ -438,4 +590,8 @@ func generateMinMax(f io.Writer, ak Kinds) { for _, k := range filter(ak.Kinds, isOrd) { genericMinMax.Execute(f, k) } + + for _, k := range filter(ak.Kinds, isOrd) { + genericElMinMaxIter.Execute(f, k) + } } diff --git a/genlib2/internaleng.go b/genlib2/internaleng.go index ee5f15c..6d07d32 100644 --- a/genlib2/internaleng.go +++ b/genlib2/internaleng.go @@ -308,6 +308,102 @@ func generateECmp(f io.Writer, kinds Kinds) { } } +/* MIN/MAX BETWEEN */ + +type InternalEngMinMaxBetween struct { + BinOp + Kinds []reflect.Kind + Iter bool +} + +func (fn *InternalEngMinMaxBetween) Name() string { + name := fn.BinOp.Name() + + switch { + case fn.Iter: + return fmt.Sprintf("%sBetweenIter", name) + default: + return name + "Between" + } +} + +func (fn *InternalEngMinMaxBetween) Signature() *Signature { + var paramNames []string + var paramTemplates []*template.Template + + switch { + case fn.Iter: + paramNames = []string{"t", "a", "b", "ait", "bit"} + paramTemplates = []*template.Template{reflectType, arrayType, arrayType, iteratorType, iteratorType} + default: + paramNames = []string{"t", "a", "b"} + paramTemplates = []*template.Template{reflectType, arrayType, arrayType} + } + return &Signature{ + Name: fn.Name(), + NameTemplate: plainName, + ParamNames: paramNames, + ParamTemplates: paramTemplates, + + Err: true, + } +} + +func (fn *InternalEngMinMaxBetween) WriteBody(w io.Writer) { + var T *template.Template + + switch { + case fn.Iter: + T = eMinMaxIter + default: + T = eMinMaxSame + } + + lb := eLoopBody{ + BinOp: fn.BinOp, + Kinds: fn.Kinds, + } + T.Execute(w, lb) +} + +func (fn *InternalEngMinMaxBetween) Write(w io.Writer) { + w.Write([]byte("func (e E) ")) + sig := fn.Signature() + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n\n")) +} + +func generateEMinMaxBetween(f io.Writer, kinds Kinds) { + minmaxOps := []cmpOp{cmpBinOps[0], cmpBinOps[2]} // Gt and Lt + minmaxOps[0].name = "Max" + minmaxOps[1].name = "Min" + var methods []*InternalEngMinMaxBetween + for _, bo := range minmaxOps { + var ks []reflect.Kind + for _, k := range kinds.Kinds { + if tc := bo.TypeClass(); tc != nil && tc(k) { + ks = append(ks, k) + } + } + meth := &InternalEngMinMaxBetween{ + BinOp: bo, + Kinds: ks, + } + methods = append(methods, meth) + } + + for _, meth := range methods { + meth.Write(f) + meth.Iter = true + } + for _, meth := range methods { + meth.Write(f) + } + +} + /* REDUCTION */ type InternalEngReduce struct { diff --git a/genlib2/main.go b/genlib2/main.go index 328cd19..46327c4 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -73,6 +73,7 @@ func main() { pipeline(execLoc, "eng_arith.go", Kinds{allKinds}, generateEArith) pipeline(execLoc, "eng_map.go", Kinds{allKinds}, generateEMap) pipeline(execLoc, "eng_cmp.go", Kinds{allKinds}, generateECmp) + pipeline(execLoc, "eng_minmaxbetween.go", Kinds{allKinds}, generateEMinMaxBetween) pipeline(execLoc, "eng_reduce.go", Kinds{allKinds}, generateEReduce) pipeline(execLoc, "eng_unary.go", Kinds{allKinds}, generateUncondEUnary, generateCondEUnary, generateSpecialEUnaries) pipeline(execLoc, "reduction_specialization.go", Kinds{allKinds}, generateReductionSpecialization) @@ -82,6 +83,7 @@ func main() { pipeline(tensorPkgLoc, "defaultengine_arith.go", Kinds{allKinds}, generateStdEngArith) pipeline(tensorPkgLoc, "defaultengine_cmp.go", Kinds{allKinds}, generateStdEngCmp) pipeline(tensorPkgLoc, "defaultengine_unary.go", Kinds{allKinds}, generateStdEngUncondUnary, generateStdEngCondUnary) + pipeline(tensorPkgLoc, "defaultengine_minmax.go", Kinds{allKinds}, generateStdEngMinMax) // level 3 aggregation pipeline(tensorPkgLoc, "dense_arith.go", Kinds{allKinds}, generateDenseArith) diff --git a/internal/execution/eng_minmaxbetween.go b/internal/execution/eng_minmaxbetween.go new file mode 100644 index 0000000..5d31706 --- /dev/null +++ b/internal/execution/eng_minmaxbetween.go @@ -0,0 +1,778 @@ +// Code generated by genlib2. DO NOT EDIT. + +package execution + +import ( + "reflect" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +func (e E) MaxBetween(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMaxI(at, bt) + case as && !bs: + MaxSVI(at[0], bt) + case !as && bs: + MaxVSI(at, bt[0]) + default: + VecMaxI(at, bt) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMaxI8(at, bt) + case as && !bs: + MaxSVI8(at[0], bt) + case !as && bs: + MaxVSI8(at, bt[0]) + default: + VecMaxI8(at, bt) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMaxI16(at, bt) + case as && !bs: + MaxSVI16(at[0], bt) + case !as && bs: + MaxVSI16(at, bt[0]) + default: + VecMaxI16(at, bt) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMaxI32(at, bt) + case as && !bs: + MaxSVI32(at[0], bt) + case !as && bs: + MaxVSI32(at, bt[0]) + default: + VecMaxI32(at, bt) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMaxI64(at, bt) + case as && !bs: + MaxSVI64(at[0], bt) + case !as && bs: + MaxVSI64(at, bt[0]) + default: + VecMaxI64(at, bt) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMaxU(at, bt) + case as && !bs: + MaxSVU(at[0], bt) + case !as && bs: + MaxVSU(at, bt[0]) + default: + VecMaxU(at, bt) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMaxU8(at, bt) + case as && !bs: + MaxSVU8(at[0], bt) + case !as && bs: + MaxVSU8(at, bt[0]) + default: + VecMaxU8(at, bt) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMaxU16(at, bt) + case as && !bs: + MaxSVU16(at[0], bt) + case !as && bs: + MaxVSU16(at, bt[0]) + default: + VecMaxU16(at, bt) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMaxU32(at, bt) + case as && !bs: + MaxSVU32(at[0], bt) + case !as && bs: + MaxVSU32(at, bt[0]) + default: + VecMaxU32(at, bt) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMaxU64(at, bt) + case as && !bs: + MaxSVU64(at[0], bt) + case !as && bs: + MaxVSU64(at, bt[0]) + default: + VecMaxU64(at, bt) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMaxF32(at, bt) + case as && !bs: + MaxSVF32(at[0], bt) + case !as && bs: + MaxVSF32(at, bt[0]) + default: + VecMaxF32(at, bt) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMaxF64(at, bt) + case as && !bs: + MaxSVF64(at[0], bt) + case !as && bs: + MaxVSF64(at, bt[0]) + default: + VecMaxF64(at, bt) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMaxStr(at, bt) + case as && !bs: + MaxSVStr(at[0], bt) + case !as && bs: + MaxVSStr(at, bt[0]) + default: + VecMaxStr(at, bt) + } + return + default: + return errors.Errorf("Unsupported type %v for Max", t) + } +} + +func (e E) MinBetween(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMinI(at, bt) + case as && !bs: + MinSVI(at[0], bt) + case !as && bs: + MinVSI(at, bt[0]) + default: + VecMinI(at, bt) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMinI8(at, bt) + case as && !bs: + MinSVI8(at[0], bt) + case !as && bs: + MinVSI8(at, bt[0]) + default: + VecMinI8(at, bt) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMinI16(at, bt) + case as && !bs: + MinSVI16(at[0], bt) + case !as && bs: + MinVSI16(at, bt[0]) + default: + VecMinI16(at, bt) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMinI32(at, bt) + case as && !bs: + MinSVI32(at[0], bt) + case !as && bs: + MinVSI32(at, bt[0]) + default: + VecMinI32(at, bt) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMinI64(at, bt) + case as && !bs: + MinSVI64(at[0], bt) + case !as && bs: + MinVSI64(at, bt[0]) + default: + VecMinI64(at, bt) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMinU(at, bt) + case as && !bs: + MinSVU(at[0], bt) + case !as && bs: + MinVSU(at, bt[0]) + default: + VecMinU(at, bt) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMinU8(at, bt) + case as && !bs: + MinSVU8(at[0], bt) + case !as && bs: + MinVSU8(at, bt[0]) + default: + VecMinU8(at, bt) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMinU16(at, bt) + case as && !bs: + MinSVU16(at[0], bt) + case !as && bs: + MinVSU16(at, bt[0]) + default: + VecMinU16(at, bt) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMinU32(at, bt) + case as && !bs: + MinSVU32(at[0], bt) + case !as && bs: + MinVSU32(at, bt[0]) + default: + VecMinU32(at, bt) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMinU64(at, bt) + case as && !bs: + MinSVU64(at[0], bt) + case !as && bs: + MinVSU64(at, bt[0]) + default: + VecMinU64(at, bt) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMinF32(at, bt) + case as && !bs: + MinSVF32(at[0], bt) + case !as && bs: + MinVSF32(at, bt[0]) + default: + VecMinF32(at, bt) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMinF64(at, bt) + case as && !bs: + MinSVF64(at[0], bt) + case !as && bs: + MinVSF64(at, bt[0]) + default: + VecMinF64(at, bt) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMinStr(at, bt) + case as && !bs: + MinSVStr(at[0], bt) + case !as && bs: + MinVSStr(at, bt[0]) + default: + VecMinStr(at, bt) + } + return + default: + return errors.Errorf("Unsupported type %v for Min", t) + } +} + +func (e E) MaxBetweenIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMaxI(at, bt) + case as && !bs: + MaxIterSVI(at[0], bt, bit) + case !as && bs: + MaxIterVSI(at, bt[0], ait) + default: + VecMaxIterI(at, bt, ait, bit) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMaxI8(at, bt) + case as && !bs: + MaxIterSVI8(at[0], bt, bit) + case !as && bs: + MaxIterVSI8(at, bt[0], ait) + default: + VecMaxIterI8(at, bt, ait, bit) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMaxI16(at, bt) + case as && !bs: + MaxIterSVI16(at[0], bt, bit) + case !as && bs: + MaxIterVSI16(at, bt[0], ait) + default: + VecMaxIterI16(at, bt, ait, bit) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMaxI32(at, bt) + case as && !bs: + MaxIterSVI32(at[0], bt, bit) + case !as && bs: + MaxIterVSI32(at, bt[0], ait) + default: + VecMaxIterI32(at, bt, ait, bit) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMaxI64(at, bt) + case as && !bs: + MaxIterSVI64(at[0], bt, bit) + case !as && bs: + MaxIterVSI64(at, bt[0], ait) + default: + VecMaxIterI64(at, bt, ait, bit) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMaxU(at, bt) + case as && !bs: + MaxIterSVU(at[0], bt, bit) + case !as && bs: + MaxIterVSU(at, bt[0], ait) + default: + VecMaxIterU(at, bt, ait, bit) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMaxU8(at, bt) + case as && !bs: + MaxIterSVU8(at[0], bt, bit) + case !as && bs: + MaxIterVSU8(at, bt[0], ait) + default: + VecMaxIterU8(at, bt, ait, bit) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMaxU16(at, bt) + case as && !bs: + MaxIterSVU16(at[0], bt, bit) + case !as && bs: + MaxIterVSU16(at, bt[0], ait) + default: + VecMaxIterU16(at, bt, ait, bit) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMaxU32(at, bt) + case as && !bs: + MaxIterSVU32(at[0], bt, bit) + case !as && bs: + MaxIterVSU32(at, bt[0], ait) + default: + VecMaxIterU32(at, bt, ait, bit) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMaxU64(at, bt) + case as && !bs: + MaxIterSVU64(at[0], bt, bit) + case !as && bs: + MaxIterVSU64(at, bt[0], ait) + default: + VecMaxIterU64(at, bt, ait, bit) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMaxF32(at, bt) + case as && !bs: + MaxIterSVF32(at[0], bt, bit) + case !as && bs: + MaxIterVSF32(at, bt[0], ait) + default: + VecMaxIterF32(at, bt, ait, bit) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMaxF64(at, bt) + case as && !bs: + MaxIterSVF64(at[0], bt, bit) + case !as && bs: + MaxIterVSF64(at, bt[0], ait) + default: + VecMaxIterF64(at, bt, ait, bit) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMaxStr(at, bt) + case as && !bs: + MaxIterSVStr(at[0], bt, bit) + case !as && bs: + MaxIterVSStr(at, bt[0], ait) + default: + VecMaxIterStr(at, bt, ait, bit) + } + return + default: + return errors.Errorf("Unsupported type %v for Max", t) + } +} + +func (e E) MinBetweenIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Iterator, bit Iterator) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + switch { + case as && bs: + VecMinI(at, bt) + case as && !bs: + MinIterSVI(at[0], bt, bit) + case !as && bs: + MinIterVSI(at, bt[0], ait) + default: + VecMinIterI(at, bt, ait, bit) + } + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + switch { + case as && bs: + VecMinI8(at, bt) + case as && !bs: + MinIterSVI8(at[0], bt, bit) + case !as && bs: + MinIterVSI8(at, bt[0], ait) + default: + VecMinIterI8(at, bt, ait, bit) + } + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + switch { + case as && bs: + VecMinI16(at, bt) + case as && !bs: + MinIterSVI16(at[0], bt, bit) + case !as && bs: + MinIterVSI16(at, bt[0], ait) + default: + VecMinIterI16(at, bt, ait, bit) + } + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + switch { + case as && bs: + VecMinI32(at, bt) + case as && !bs: + MinIterSVI32(at[0], bt, bit) + case !as && bs: + MinIterVSI32(at, bt[0], ait) + default: + VecMinIterI32(at, bt, ait, bit) + } + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + switch { + case as && bs: + VecMinI64(at, bt) + case as && !bs: + MinIterSVI64(at[0], bt, bit) + case !as && bs: + MinIterVSI64(at, bt[0], ait) + default: + VecMinIterI64(at, bt, ait, bit) + } + return + case Uint: + at := a.Uints() + bt := b.Uints() + switch { + case as && bs: + VecMinU(at, bt) + case as && !bs: + MinIterSVU(at[0], bt, bit) + case !as && bs: + MinIterVSU(at, bt[0], ait) + default: + VecMinIterU(at, bt, ait, bit) + } + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + switch { + case as && bs: + VecMinU8(at, bt) + case as && !bs: + MinIterSVU8(at[0], bt, bit) + case !as && bs: + MinIterVSU8(at, bt[0], ait) + default: + VecMinIterU8(at, bt, ait, bit) + } + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + switch { + case as && bs: + VecMinU16(at, bt) + case as && !bs: + MinIterSVU16(at[0], bt, bit) + case !as && bs: + MinIterVSU16(at, bt[0], ait) + default: + VecMinIterU16(at, bt, ait, bit) + } + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + switch { + case as && bs: + VecMinU32(at, bt) + case as && !bs: + MinIterSVU32(at[0], bt, bit) + case !as && bs: + MinIterVSU32(at, bt[0], ait) + default: + VecMinIterU32(at, bt, ait, bit) + } + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + switch { + case as && bs: + VecMinU64(at, bt) + case as && !bs: + MinIterSVU64(at[0], bt, bit) + case !as && bs: + MinIterVSU64(at, bt[0], ait) + default: + VecMinIterU64(at, bt, ait, bit) + } + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + switch { + case as && bs: + VecMinF32(at, bt) + case as && !bs: + MinIterSVF32(at[0], bt, bit) + case !as && bs: + MinIterVSF32(at, bt[0], ait) + default: + VecMinIterF32(at, bt, ait, bit) + } + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + switch { + case as && bs: + VecMinF64(at, bt) + case as && !bs: + MinIterSVF64(at[0], bt, bit) + case !as && bs: + MinIterVSF64(at, bt[0], ait) + default: + VecMinIterF64(at, bt, ait, bit) + } + return + case String: + at := a.Strings() + bt := b.Strings() + switch { + case as && bs: + VecMinStr(at, bt) + case as && !bs: + MinIterSVStr(at[0], bt, bit) + case !as && bs: + MinIterVSStr(at, bt[0], ait) + default: + VecMinIterStr(at, bt, ait, bit) + } + return + default: + return errors.Errorf("Unsupported type %v for Min", t) + } +} diff --git a/internal/execution/generic_minmax.go b/internal/execution/generic_minmax.go index 170f01b..8398d5f 100644 --- a/internal/execution/generic_minmax.go +++ b/internal/execution/generic_minmax.go @@ -12,6 +12,23 @@ func VecMinI(a, b []int) { } } } + +func MinSVI(a int, b []int) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI(a []int, b int) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI(a, b []int) { a = a[:] b = b[:len(a)] @@ -22,6 +39,22 @@ func VecMaxI(a, b []int) { } } } + +func MaxSVI(a int, b []int) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI(a []int, b int) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI8(a, b []int8) { a = a[:] b = b[:len(a)] @@ -32,6 +65,23 @@ func VecMinI8(a, b []int8) { } } } + +func MinSVI8(a int8, b []int8) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI8(a []int8, b int8) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI8(a, b []int8) { a = a[:] b = b[:len(a)] @@ -42,6 +92,22 @@ func VecMaxI8(a, b []int8) { } } } + +func MaxSVI8(a int8, b []int8) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI8(a []int8, b int8) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI16(a, b []int16) { a = a[:] b = b[:len(a)] @@ -52,6 +118,23 @@ func VecMinI16(a, b []int16) { } } } + +func MinSVI16(a int16, b []int16) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI16(a []int16, b int16) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI16(a, b []int16) { a = a[:] b = b[:len(a)] @@ -62,6 +145,22 @@ func VecMaxI16(a, b []int16) { } } } + +func MaxSVI16(a int16, b []int16) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI16(a []int16, b int16) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI32(a, b []int32) { a = a[:] b = b[:len(a)] @@ -72,6 +171,23 @@ func VecMinI32(a, b []int32) { } } } + +func MinSVI32(a int32, b []int32) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI32(a []int32, b int32) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI32(a, b []int32) { a = a[:] b = b[:len(a)] @@ -82,6 +198,22 @@ func VecMaxI32(a, b []int32) { } } } + +func MaxSVI32(a int32, b []int32) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI32(a []int32, b int32) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinI64(a, b []int64) { a = a[:] b = b[:len(a)] @@ -92,6 +224,23 @@ func VecMinI64(a, b []int64) { } } } + +func MinSVI64(a int64, b []int64) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSI64(a []int64, b int64) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxI64(a, b []int64) { a = a[:] b = b[:len(a)] @@ -102,6 +251,22 @@ func VecMaxI64(a, b []int64) { } } } + +func MaxSVI64(a int64, b []int64) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSI64(a []int64, b int64) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU(a, b []uint) { a = a[:] b = b[:len(a)] @@ -112,6 +277,23 @@ func VecMinU(a, b []uint) { } } } + +func MinSVU(a uint, b []uint) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU(a []uint, b uint) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU(a, b []uint) { a = a[:] b = b[:len(a)] @@ -122,6 +304,22 @@ func VecMaxU(a, b []uint) { } } } + +func MaxSVU(a uint, b []uint) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU(a []uint, b uint) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU8(a, b []uint8) { a = a[:] b = b[:len(a)] @@ -132,6 +330,23 @@ func VecMinU8(a, b []uint8) { } } } + +func MinSVU8(a uint8, b []uint8) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU8(a []uint8, b uint8) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU8(a, b []uint8) { a = a[:] b = b[:len(a)] @@ -142,6 +357,22 @@ func VecMaxU8(a, b []uint8) { } } } + +func MaxSVU8(a uint8, b []uint8) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU8(a []uint8, b uint8) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU16(a, b []uint16) { a = a[:] b = b[:len(a)] @@ -152,6 +383,23 @@ func VecMinU16(a, b []uint16) { } } } + +func MinSVU16(a uint16, b []uint16) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU16(a []uint16, b uint16) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU16(a, b []uint16) { a = a[:] b = b[:len(a)] @@ -162,6 +410,22 @@ func VecMaxU16(a, b []uint16) { } } } + +func MaxSVU16(a uint16, b []uint16) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU16(a []uint16, b uint16) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU32(a, b []uint32) { a = a[:] b = b[:len(a)] @@ -172,6 +436,23 @@ func VecMinU32(a, b []uint32) { } } } + +func MinSVU32(a uint32, b []uint32) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU32(a []uint32, b uint32) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU32(a, b []uint32) { a = a[:] b = b[:len(a)] @@ -182,6 +463,22 @@ func VecMaxU32(a, b []uint32) { } } } + +func MaxSVU32(a uint32, b []uint32) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU32(a []uint32, b uint32) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinU64(a, b []uint64) { a = a[:] b = b[:len(a)] @@ -192,6 +489,23 @@ func VecMinU64(a, b []uint64) { } } } + +func MinSVU64(a uint64, b []uint64) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSU64(a []uint64, b uint64) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxU64(a, b []uint64) { a = a[:] b = b[:len(a)] @@ -202,6 +516,22 @@ func VecMaxU64(a, b []uint64) { } } } + +func MaxSVU64(a uint64, b []uint64) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSU64(a []uint64, b uint64) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinF32(a, b []float32) { a = a[:] b = b[:len(a)] @@ -212,6 +542,23 @@ func VecMinF32(a, b []float32) { } } } + +func MinSVF32(a float32, b []float32) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSF32(a []float32, b float32) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxF32(a, b []float32) { a = a[:] b = b[:len(a)] @@ -222,6 +569,22 @@ func VecMaxF32(a, b []float32) { } } } + +func MaxSVF32(a float32, b []float32) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSF32(a []float32, b float32) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinF64(a, b []float64) { a = a[:] b = b[:len(a)] @@ -232,6 +595,23 @@ func VecMinF64(a, b []float64) { } } } + +func MinSVF64(a float64, b []float64) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSF64(a []float64, b float64) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxF64(a, b []float64) { a = a[:] b = b[:len(a)] @@ -242,6 +622,22 @@ func VecMaxF64(a, b []float64) { } } } + +func MaxSVF64(a float64, b []float64) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSF64(a []float64, b float64) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func VecMinStr(a, b []string) { a = a[:] b = b[:len(a)] @@ -252,6 +648,23 @@ func VecMinStr(a, b []string) { } } } + +func MinSVStr(a string, b []string) { + for i := range b { + if a < b[i] { + b[i] = a + } + } +} + +func MinVSStr(a []string, b string) { + for i := range a { + if b < a[i] { + a[i] = b + } + } +} + func VecMaxStr(a, b []string) { a = a[:] b = b[:len(a)] @@ -262,6 +675,22 @@ func VecMaxStr(a, b []string) { } } } + +func MaxSVStr(a string, b []string) { + for i := range b { + if a > b[i] { + b[i] = a + } + } +} + +func MaxVSStr(a []string, b string) { + for i := range a { + if b > a[i] { + a[i] = b + } + } +} func MinI(a, b int) (c int) { if a < b { return a @@ -431,3 +860,1432 @@ func MaxStr(a, b string) (c string) { } return b } +func MinIterSVI(a int, b []int, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI(a []int, b int, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI(a, b []int, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI(a int, b []int, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI(a []int, b int, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI(a, b []int, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI8(a int8, b []int8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI8(a []int8, b int8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI8(a, b []int8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI8(a int8, b []int8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI8(a []int8, b int8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI8(a, b []int8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI16(a int16, b []int16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI16(a []int16, b int16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI16(a, b []int16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI16(a int16, b []int16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI16(a []int16, b int16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI16(a, b []int16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI32(a int32, b []int32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI32(a []int32, b int32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI32(a, b []int32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI32(a int32, b []int32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI32(a []int32, b int32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI32(a, b []int32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVI64(a int64, b []int64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSI64(a []int64, b int64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterI64(a, b []int64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVI64(a int64, b []int64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSI64(a []int64, b int64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterI64(a, b []int64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU(a uint, b []uint, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU(a []uint, b uint, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU(a, b []uint, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU(a uint, b []uint, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU(a []uint, b uint, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU(a, b []uint, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU8(a uint8, b []uint8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU8(a []uint8, b uint8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU8(a, b []uint8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU8(a uint8, b []uint8, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU8(a []uint8, b uint8, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU8(a, b []uint8, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU16(a uint16, b []uint16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU16(a []uint16, b uint16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU16(a, b []uint16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU16(a uint16, b []uint16, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU16(a []uint16, b uint16, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU16(a, b []uint16, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU32(a uint32, b []uint32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU32(a []uint32, b uint32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU32(a, b []uint32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU32(a uint32, b []uint32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU32(a []uint32, b uint32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU32(a, b []uint32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVU64(a uint64, b []uint64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSU64(a []uint64, b uint64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterU64(a, b []uint64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVU64(a uint64, b []uint64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSU64(a []uint64, b uint64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterU64(a, b []uint64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVF32(a float32, b []float32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSF32(a []float32, b float32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterF32(a, b []float32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVF32(a float32, b []float32, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSF32(a []float32, b float32, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterF32(a, b []float32, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVF64(a float64, b []float64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSF64(a []float64, b float64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterF64(a, b []float64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVF64(a float64, b []float64, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSF64(a []float64, b float64, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterF64(a, b []float64, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} + +func MinIterSVStr(a string, b []string, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a < b[i] { + b[i] = a + } + } + } + return +} + +func MinIterVSStr(a []string, b string, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b < a[i] { + a[i] = b + } + } + } + return +} + +func VecMinIterStr(a, b []string, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] < a[i] { + a[i] = b[j] + } + } + } + return +} + +func MaxIterSVStr(a string, b []string, bit Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if a > b[i] { + b[i] = a + } + } + } + return +} + +func MaxIterVSStr(a []string, b string, ait Iterator) (err error) { + var i int + var validi bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi { + if b > a[i] { + a[i] = b + } + } + } + return +} + +func VecMaxIterStr(a, b []string, ait, bit Iterator) (err error) { + var i, j int + var validi, validj bool + for { + if i, validi, err = ait.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if j, validj, err = bit.NextValidity(); err != nil { + err = handleNoOp(err) + break + } + if validi && validj { + if b[j] > a[i] { + a[i] = b[j] + } + } + } + return +} From 8bf05a882f6182e95a352588f1e4c612d4ad55d8 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 5 Oct 2021 05:02:59 +1100 Subject: [PATCH 131/154] Fixed minmaxbetween for 0.10.0 --- api_minmax.go | 20 ++++++------ defaultengine_arith.go | 6 ++++ defaultengine_cmp.go | 6 ++++ defaultengine_minmax.go | 41 +++++++++++++++++++------ genlib2/agg2_body.go | 6 +++- internal/execution/eng_minmaxbetween.go | 4 +-- 6 files changed, 60 insertions(+), 23 deletions(-) diff --git a/api_minmax.go b/api_minmax.go index 964df7d..e8a7de1 100644 --- a/api_minmax.go +++ b/api_minmax.go @@ -4,18 +4,18 @@ import "github.com/pkg/errors" func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var minbetweener MinBetweener - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition if oe != nil { return oe.MinBetween(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MinBetween(at, bt, opts...) } if minbetweener, ok = at.Engine().(MinBetweener); ok { @@ -40,7 +40,7 @@ func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MinBetweenScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MinBetweenScalar(at, bt, leftTensor, opts...) } if minbetweener, ok = at.Engine().(MinBetweener); ok { @@ -64,7 +64,7 @@ func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MinBetweenScalar(bt, at, false, opts...) } if minbetweener, ok = bt.Engine().(MinBetweener); ok { @@ -80,18 +80,18 @@ func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var maxbetweener MaxBetweener - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition if oe != nil { return oe.MaxBetween(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MaxBetween(at, bt, opts...) } if maxbetweener, ok = at.Engine().(MaxBetweener); ok { @@ -116,7 +116,7 @@ func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) } if maxbetweener, ok = at.Engine().(MaxBetweener); ok { @@ -140,7 +140,7 @@ func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MaxBetweenScalar(bt, at, false, opts...) } if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 0a4d800..65a7bf0 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -14,6 +14,7 @@ import ( // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Add failed") } @@ -83,6 +84,7 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Sub failed") } @@ -152,6 +154,7 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Mul failed") } @@ -221,6 +224,7 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Div failed") } @@ -290,6 +294,7 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Pow failed") } @@ -359,6 +364,7 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Mod failed") } diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 749cf11..6a986d3 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -16,6 +16,7 @@ import ( // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Gt failed") } @@ -98,6 +99,7 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Gte failed") } @@ -180,6 +182,7 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Lt failed") } @@ -262,6 +265,7 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Lte failed") } @@ -344,6 +348,7 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Eq); err != nil { + return nil, errors.Wrapf(err, "Eq failed") } @@ -426,6 +431,7 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { if err = binaryCheck(a, b, dtype.Eq); err != nil { + return nil, errors.Wrapf(err, "Ne failed") } diff --git a/defaultengine_minmax.go b/defaultengine_minmax.go index 56ac432..a16cbf0 100644 --- a/defaultengine_minmax.go +++ b/defaultengine_minmax.go @@ -1,27 +1,35 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + var ( _ MinBetweener = StdEng{} _ MaxBetweener = StdEng{} ) func (e StdEng) MinBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "MinBetween failed") } var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -75,15 +83,20 @@ func (e StdEng) MinBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, } func (e StdEng) MaxBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "MaxBetween failed") } var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -137,7 +150,7 @@ func (e StdEng) MaxBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, } func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "MinBetween failed") } @@ -147,9 +160,13 @@ func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -243,7 +260,7 @@ func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts } func (e StdEng) MaxBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "MaxBetween failed") } @@ -253,9 +270,13 @@ func (e StdEng) MaxBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 53464f9..e1da1aa 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -28,9 +28,13 @@ const arithPrepRaw = `var safe, toReuse, incr bool ` const minmaxPrepRaw = `var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err !=nil{ + return nil, err // this err will be noopError{}, no need to wrap. + } ` const prepVVRaw = `if err = binaryCheck(a, b, dtype.{{.TypeClassCheck}}); err != nil { diff --git a/internal/execution/eng_minmaxbetween.go b/internal/execution/eng_minmaxbetween.go index 5d31706..8c41606 100644 --- a/internal/execution/eng_minmaxbetween.go +++ b/internal/execution/eng_minmaxbetween.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) MaxBetween(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { as := isScalar(a, t) bs := isScalar(b, t) From 47f11bda23ac109d2f900f66bbfafc1e1b42feef Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 13 Oct 2021 06:29:31 +1100 Subject: [PATCH 132/154] Working on a bug on selbyidxB function --- defaultengine_selbyidx.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index cdcc318..88fa5c3 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -228,8 +228,8 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data for i, idx := range indices { dstCoord[axis] = idx srcCoord[axis] = i - dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...) - start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...) + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + start, _ := Ltoi(apB.shape, apB.strides, srcCoord...) for o := 0; o < outer; o++ { dstEnd := dstStart + axStride From 11a940ebec5834c4cfad085474a1a00ab732723b Mon Sep 17 00:00:00 2001 From: David Cuadrado Date: Fri, 15 Oct 2021 19:18:31 -0500 Subject: [PATCH 133/154] Softmax backwards (#119) * Add support for LogSoftMax * Add suport for SoftMax operation * Update min go version * Adding backwards operations for SoftMax and LogSoftMax This is still missing the case for inner dimensions. * Complete backwards operations for SoftMax and LogSoftMax * Add more tests * Optimizations for SoftMax and LogSoftMax --- api_matop.go | 40 ++- defaultengine_softmax.go | 605 +++++++++++++++++++++++++++++++++++++++ dense_softmax_test.go | 287 +++++++++++++++++++ engine.go | 8 + go.mod | 4 +- go.sum | 7 +- 6 files changed, 942 insertions(+), 9 deletions(-) create mode 100644 defaultengine_softmax.go create mode 100644 dense_softmax_test.go diff --git a/api_matop.go b/api_matop.go index e0f479d..2cb41a8 100644 --- a/api_matop.go +++ b/api_matop.go @@ -135,7 +135,7 @@ func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err if sbi, ok := a.Engine().(ByIndiceser); ok { return sbi.SelectByIndices(a, indices, axis, opts...) } - return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) + return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine()) } // ByIndicesB is the backpropagation of ByIndices. @@ -146,5 +146,41 @@ func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, if sbi, ok := a.Engine().(ByIndiceser); ok { return sbi.SelectByIndicesB(a, b, indices, axis, opts...) } - return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) + return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine()) +} + +// LogSoftMax applies log softmax to the given tensor. +func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := x.Engine().(SoftMaxer); ok { + return sm.LogSoftMax(x, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine()) +} + +// SoftMax applies softmax to the given tensor. +func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := x.Engine().(SoftMaxer); ok { + return sm.SoftMax(x, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine()) +} + +// SoftMaxB applies softmax backwards operation +func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := output.Engine().(SoftMaxer); ok { + return sm.SoftMaxB(output, grad, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) +} + +// LogSoftMaxB applies softmax backwards operation +func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if sm, ok := output.Engine().(SoftMaxer); ok { + return sm.LogSoftMaxB(output, grad, axis, opts...) + } + + return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) } diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go new file mode 100644 index 0000000..1a22675 --- /dev/null +++ b/defaultengine_softmax.go @@ -0,0 +1,605 @@ +package tensor + +import ( + "fmt" + "math" + + "github.com/chewxy/math32" + "github.com/pkg/errors" +) + +// if dims = 2 and axis -1 it returns the last dimension. In this case 1 +func resolveAxis(axis int, dims int) int { + res := axis % dims + if (res < 0 && dims > 0) || (res > 0 && dims < 0) { + return res + dims + } + + return res +} + +func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + axis = resolveAxis(axis, x.Dims()) + expectedShape := x.Shape().Clone() + + var reuse DenseTensor + var safe, toReuse, _ bool + if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(x.Dtype())) + } + + switch x.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF32(reuse, x, axis, false) + } else { + e.softMaxInnerDimF32(reuse, x, axis, false) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF64(reuse, x, axis, false) + } else { + e.softMaxInnerDimF64(reuse, x, axis, false) + } + default: + return nil, fmt.Errorf("type %v not supported", x.Dtype()) + } + + return reuse, nil +} + +func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !output.Shape().Eq(grad.Shape()) { + return nil, fmt.Errorf("output and grad shapes don't match") + } + + if !output.Dtype().Eq(grad.Dtype()) { + return nil, fmt.Errorf("output and grad types don't match") + } + + axis = resolveAxis(axis, output.Dims()) + expectedShape := output.Shape().Clone() + + var reuse DenseTensor + var safe, toReuse, _ bool + if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(output.Dtype())) + } + + switch output.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF32(reuse, output, grad, axis, false) + } else { + e.softMaxBInnerDimF32(reuse, output, grad, axis, false) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF64(reuse, output, grad, axis, false) + } else { + e.softMaxBInnerDimF64(reuse, output, grad, axis, false) + } + default: + return nil, fmt.Errorf("type %v not supported", output.Dtype()) + } + + return reuse, nil +} + +func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + axis = resolveAxis(axis, x.Dims()) + expectedShape := x.Shape().Clone() + + var reuse DenseTensor + var safe, toReuse, _ bool + if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(x.Dtype())) + } + + switch x.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF32(reuse, x, axis, true) + } else { + e.softMaxInnerDimF32(reuse, x, axis, true) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxLastDimF64(reuse, x, axis, true) + } else { + e.softMaxInnerDimF64(reuse, x, axis, true) + } + default: + return nil, fmt.Errorf("type %v not supported", x.Dtype()) + } + + return reuse, nil +} + +func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !output.Shape().Eq(grad.Shape()) { + return nil, fmt.Errorf("output and grad shapes don't match") + } + + if !output.Dtype().Eq(grad.Dtype()) { + return nil, fmt.Errorf("output and grad types don't match") + } + + axis = resolveAxis(axis, output.Dims()) + expectedShape := output.Shape().Clone() + + var reuse DenseTensor + var safe, toReuse, _ bool + if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if safe || !toReuse && reuse == nil && safe { + // create reuse + reuse = New(WithShape(expectedShape...), Of(output.Dtype())) + } + + switch output.Dtype() { + case Float32: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF32(reuse, output, grad, axis, true) + } else { + e.softMaxBInnerDimF32(reuse, output, grad, axis, true) + } + case Float64: + if expectedShape.Dims()-1 == axis { + e.softMaxBLastDimF64(reuse, output, grad, axis, true) + } else { + e.softMaxBInnerDimF64(reuse, output, grad, axis, true) + } + default: + return nil, fmt.Errorf("type %v not supported", output.Dtype()) + } + + return reuse, nil +} + +func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax bool) { + outputArr := output.Data().([]float64) + xArr := x.Data().([]float64) + xShape := x.Shape() + + outerSize := 1 + dimSize := xShape[axis] + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + for ii := 0; ii < outerSize; ii++ { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + if xArr[i] > maxInput { + maxInput = xArr[i] + } + } + + sumExp := float64(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math.Exp(z) + + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } + + sumExp += exp + } + + if !logSoftMax { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + if logSoftMax { + outputArr[i] -= math.Log(sumExp) + } else { + outputArr[i] *= sumExp + } + } + } +} + +func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { + dx := inputGrad.Data().([]float64) + outputArr := output.Data().([]float64) + gradArr := grad.Data().([]float64) + + outputShape := output.Shape() + + outerSize := 1 + dimSize := outputShape[axis] + for i := 0; i < axis; i++ { + outerSize *= outputShape[i] + } + + for ii := 0; ii < outerSize; ii++ { + if logSoftMax { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + sum += gradArr[i] + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) + } + } else { + mul := make([]float64, dimSize) + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + mul[j] = outputArr[i] * gradArr[i] + } + + sum := mul[0] + for j := 1; j < dimSize; j++ { + sum += mul[j] + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + } + } +} + +func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax bool) { + xShape := x.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + for i := axis + 1; i < xShape.Dims(); i++ { + innerSize *= xShape[i] + } + + dimSize := xShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + outputArr := output.Data().([]float64) + xArr := x.Data().([]float64) + + for ii := 0; ii < innerSize*outerSize; ii++ { + outerIndex, innerIndex := divmod(ii, innerSize) + + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride + + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } + } + + sumExp := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride + + exp := math.Exp(inputPart[i] - maxInput) + + if !logSoftmax { + outputPart[i] = exp + } + + sumExp += exp + } + + if logSoftmax { + sumExp = math.Log(sumExp) + } else { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + } +} + +func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { + dxShape := inputGrad.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= dxShape[i] + } + + for i := axis + 1; i < dxShape.Dims(); i++ { + innerSize *= dxShape[i] + } + + dimSize := dxShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + dxArr := inputGrad.Data().([]float64) + outputArr := output.Data().([]float64) + gradArr := grad.Data().([]float64) + + for ii := 0; ii < innerSize*outerSize; ii++ { + outerIndex, innerIndex := divmod(ii, innerSize) + + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + sum := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } + } + } +} + +func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) { + outputArr := output.Data().([]float32) + xArr := x.Data().([]float32) + xShape := x.Shape() + + outerSize := 1 + dimSize := xShape[axis] + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + for ii := 0; ii < outerSize; ii++ { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + if xArr[i] > maxInput { + maxInput = xArr[i] + } + } + + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math32.Exp(z) + + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } + + sumExp += exp + } + + if !logSoftMax { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + if logSoftMax { + outputArr[i] -= math32.Log(sumExp) + } else { + outputArr[i] *= sumExp + } + } + } +} + +func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { + dx := inputGrad.Data().([]float32) + outputArr := output.Data().([]float32) + gradArr := grad.Data().([]float32) + + outputShape := output.Shape() + + outerSize := 1 + dimSize := outputShape[axis] + for i := 0; i < axis; i++ { + outerSize *= outputShape[i] + } + + for ii := 0; ii < outerSize; ii++ { + if logSoftMax { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j + + sum += gradArr[i] + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) + } + } else { + mul := make([]float32, dimSize) + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + mul[j] = outputArr[i] * gradArr[i] + } + + sum := mul[0] + for j := 1; j < dimSize; j++ { + sum += mul[j] + } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + } + } +} + +func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax bool) { + xShape := x.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= xShape[i] + } + + for i := axis + 1; i < xShape.Dims(); i++ { + innerSize *= xShape[i] + } + + dimSize := xShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + outputArr := output.Data().([]float32) + xArr := x.Data().([]float32) + + for ii := 0; ii < innerSize*outerSize; ii++ { + outerIndex, innerIndex := divmod(ii, innerSize) + + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride + + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } + } + + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride + + exp := math32.Exp(inputPart[i] - maxInput) + + if !logSoftmax { + outputPart[i] = exp + } + + sumExp += exp + } + + if logSoftmax { + sumExp = math32.Log(sumExp) + } else { + sumExp = 1 / sumExp + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + } +} + +func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { + dxShape := inputGrad.Shape() + + innerSize, outerSize := 1, 1 + for i := 0; i < axis; i++ { + outerSize *= dxShape[i] + } + + for i := axis + 1; i < dxShape.Dims(); i++ { + innerSize *= dxShape[i] + } + + dimSize := dxShape[axis] + dimStride := innerSize + outerStride := dimSize * dimStride + + dxArr := inputGrad.Data().([]float32) + outputArr := output.Data().([]float32) + gradArr := grad.Data().([]float32) + + for ii := 0; ii < innerSize*outerSize; ii++ { + outerIndex, innerIndex := divmod(ii, innerSize) + + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] + + sum := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } + } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } + } + } +} diff --git a/dense_softmax_test.go b/dense_softmax_test.go new file mode 100644 index 0000000..eaa68df --- /dev/null +++ b/dense_softmax_test.go @@ -0,0 +1,287 @@ +package tensor + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSoftMax(t *testing.T) { + testCases := []struct { + fn func(x Tensor, axis int, opts ...FuncOpt) (Tensor, error) + x Tensor + axis int + expectedOutput interface{} + }{ + { + fn: LogSoftMax, + x: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}, + }, + { + fn: LogSoftMax, + x: New( + Of(Float32), + WithShape(3, 4), + WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}, + }, + { + fn: LogSoftMax, + x: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float32{-0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443966, -0.64439666, -0.7443966, -0.64439666, -0.7443967, -0.64439666}, + }, + { + fn: LogSoftMax, + x: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: 1, + expectedOutput: []float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}, + }, + { + fn: SoftMax, + x: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: 1, + expectedOutput: []float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}, + }, + { + fn: SoftMax, + x: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}, + }, + { + fn: SoftMax, + x: New( + Of(Float32), + WithShape(3, 4), + WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}, + }, + { + fn: SoftMax, + x: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), + ), + axis: -1, + expectedOutput: []float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}, + }, + } + for i, tC := range testCases { + t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.x.Shape(), tC.x.Dtype()), func(t *testing.T) { + c := assert.New(t) + + output, err := tC.fn(tC.x, tC.axis) + t.Logf("output: %#v", output.Data()) + + c.NoError(err) + c.NotNil(output) + + c.Equal(tC.x.Shape(), output.Shape()) + c.InDeltaSlice(tC.expectedOutput, output.Data(), 1e-6) + }) + } +} + +func TestSoftMaxB(t *testing.T) { + testCases := []struct { + fn func(output, grad Tensor, axis int, opts ...FuncOpt) (Tensor, error) + output Tensor + grad Tensor + axis int + expectedOutput interface{} + }{ + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float64{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float64{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float64{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float64{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}), + ), + grad: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}), + ), + grad: New( + Of(Float32), + WithShape(3, 2, 2), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: 1, + expectedOutput: []float32{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float32), + WithShape(3, 4), + WithBacking([]float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float32{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957}, + }, + { + fn: LogSoftMaxB, + output: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}), + ), + grad: New( + Of(Float64), + WithShape(3, 4), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float32{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598}, + }, + { + fn: SoftMaxB, + output: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float32{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}), + ), + grad: New( + Of(Float64), + WithShape(3, 2, 2), + WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), + ), + axis: -1, + expectedOutput: []float32{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183}, + }, + } + for i, tC := range testCases { + t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.output.Shape(), tC.output.Dtype()), func(t *testing.T) { + c := assert.New(t) + + dx, err := tC.fn(tC.output, tC.grad, tC.axis) + t.Logf("output: %#v", tC.output.Data()) + + c.NoError(err) + c.NotNil(dx) + + c.Equal(tC.output.Shape(), dx.Shape()) + c.InDeltaSlice(tC.expectedOutput, dx.Data(), 1e-6) + }) + } +} diff --git a/engine.go b/engine.go index f4d7bd2..5730c60 100644 --- a/engine.go +++ b/engine.go @@ -422,3 +422,11 @@ type denseArgmaxer interface { type denseArgminer interface { argminDenseTensor(t DenseTensor, axis int) (*Dense, error) } + +type SoftMaxer interface { + LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + + SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) +} diff --git a/go.mod b/go.mod index 7106ca9..e488a8d 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,11 @@ module gorgonia.org/tensor -go 1.13 +go 1.15 require ( github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc github.com/chewxy/hm v1.0.0 - github.com/chewxy/math32 v1.0.6 + github.com/chewxy/math32 v1.0.8 github.com/gogo/protobuf v1.3.1 github.com/golang/protobuf v1.4.3 github.com/google/flatbuffers v1.12.0 diff --git a/go.sum b/go.sum index 21b3359..8d91866 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= -github.com/chewxy/math32 v1.0.6 h1:JWZYUNl2rtgVVui6z8JBsDgkOG2DYmfSODyo95yKfx4= -github.com/chewxy/math32 v1.0.6/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/chewxy/math32 v1.0.8 h1:fU5E4Ec4Z+5RtRAi3TovSxUjQPkgRh+HbP7tKB2OFbM= +github.com/chewxy/math32 v1.0.8/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= @@ -24,7 +24,6 @@ github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGw github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -33,11 +32,9 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/flatbuffers v1.12.0 h1:/PtAHvnBY4Kqnx/xCQ3OIV9uYcSFGScBsWI3Oogeh6w= github.com/google/flatbuffers v1.12.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= From f2726a559493918283a9db3e2481223c8a6a33e2 Mon Sep 17 00:00:00 2001 From: David Cuadrado Date: Sun, 17 Oct 2021 23:11:17 -0500 Subject: [PATCH 134/154] Implement the Narrow operation for Dense tensors. (#120) * Implement the Narrow operation for Dense tensors. Also, implemented a default slicer * Added a new failing test. This seems like a bug with the non contiguous View implementation * Fixed the test. The test was correct Co-authored-by: chewxy --- ap.go | 1 + ap_test.go | 41 ++++-------------- dense_matop.go | 10 +++++ dense_matop_test.go | 75 +++++++++++++++++++++++++++++++++ example_dense_reduction_test.go | 2 +- example_mapreduce_test.go | 4 +- slice.go | 36 ++++++++++++++++ 7 files changed, 133 insertions(+), 36 deletions(-) diff --git a/ap.go b/ap.go index 145af0a..410ec40 100644 --- a/ap.go +++ b/ap.go @@ -296,6 +296,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err er offset++ } } + newAP = MakeAP(newShape, newStrides, order, ap.Δ) } return diff --git a/ap_test.go b/ap_test.go index b813d1f..8314546 100644 --- a/ap_test.go +++ b/ap_test.go @@ -7,31 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -type dummySlice struct { - start, end, step int -} - -func (s dummySlice) Start() int { return s.start } -func (s dummySlice) End() int { return s.end } -func (s dummySlice) Step() int { return s.step } - -func sli(start int, opt ...int) dummySlice { - var end, step int - switch len(opt) { - case 0: - end = start + 1 - step = 0 - case 1: - end = opt[0] - step = 1 - default: - end = opt[0] - step = opt[1] - - } - return dummySlice{start: start, end: end, step: step} -} - func dummyScalar1() AP { return AP{} } func dummyScalar2() AP { return AP{shape: Shape{1}} } @@ -203,16 +178,16 @@ var sliceTests = []struct { contiguous bool }{ // vectors - {"a[0]", Shape{5}, []Slice{sli(0)}, 0, 1, ScalarShape(), nil, true}, - {"a[0:2]", Shape{5}, []Slice{sli(0, 2)}, 0, 2, Shape{2}, []int{1}, true}, - {"a[1:3]", Shape{5}, []Slice{sli(1, 3)}, 1, 3, Shape{2}, []int{1}, true}, - {"a[1:5:2]", Shape{5}, []Slice{sli(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false}, + {"a[0]", Shape{5}, []Slice{S(0)}, 0, 1, ScalarShape(), nil, true}, + {"a[0:2]", Shape{5}, []Slice{S(0, 2)}, 0, 2, Shape{2}, []int{1}, true}, + {"a[1:3]", Shape{5}, []Slice{S(1, 3)}, 1, 3, Shape{2}, []int{1}, true}, + {"a[1:5:2]", Shape{5}, []Slice{S(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false}, // matrix - {"A[0]", Shape{2, 3}, []Slice{sli(0)}, 0, 3, Shape{1, 3}, []int{1}, true}, - {"A[1:3]", Shape{4, 5}, []Slice{sli(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, - {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{sli(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened - {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, sli(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, + {"A[0]", Shape{2, 3}, []Slice{S(0)}, 0, 3, Shape{1, 3}, []int{1}, true}, + {"A[1:3]", Shape{4, 5}, []Slice{S(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, + {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{S(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened + {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, S(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, // tensor {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true}, diff --git a/dense_matop.go b/dense_matop.go index 7e81419..ea56b8a 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -198,6 +198,16 @@ func (t *Dense) CopyTo(other *Dense) error { return errors.Errorf(methodNYI, "CopyTo", "views") } +// Narrow narrows the tensor +func (t *Dense) Narrow(dim, start, lenght int) (View, error) { + dim = resolveAxis(dim, t.Dims()) + + slices := make([]Slice, MinInt(dim+1, t.Dims())) + slices[dim] = S(start, start+lenght, 1) + + return t.Slice(slices...) +} + // Slice performs slicing on the *Dense Tensor. It returns a view which shares the same underlying memory as the original *Dense. // // Given: diff --git a/dense_matop_test.go b/dense_matop_test.go index d9de697..2b6a8b8 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -1,6 +1,7 @@ package tensor import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -639,6 +640,80 @@ func TestDense_Slice(t *testing.T) { } } +func TestDense_Narrow(t *testing.T) { + testCases := []struct { + x *Dense + dim, start, length int + expected *Dense + }{ + { + x: New( + WithShape(3), + WithBacking([]int{1, 2, 3}), + ), + dim: 0, + start: 1, + length: 1, + expected: New( + WithShape(), + WithBacking([]int{2}), + ), + }, + { + x: New( + WithShape(3, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), + ), + dim: 0, + start: 0, + length: 2, + expected: New( + WithShape(2, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6}), + ), + }, + { + x: New( + WithShape(3, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), + ), + dim: 1, + start: 1, + length: 2, + expected: New( + WithShape(3, 2), + WithBacking([]int{2, 3, 5, 6, 8, 9}), + ), + }, + { + x: New( + WithShape(3, 3), + WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), + ), + dim: 1, + start: 0, + length: 1, + expected: New( + WithShape(3), + WithBacking([]int{1, 4, 7}), + ), + }, + } + + for i, tC := range testCases { + t.Run(fmt.Sprintf("Example #%d narrow(%v,%d,%d,%v)", i+1, tC.x.Shape(), tC.dim, tC.start, tC.length), func(t *testing.T) { + c := assert.New(t) + + y, err := tC.x.Narrow(tC.dim, tC.start, tC.length) + c.NoError(err) + + yMat := y.Materialize() + c.Equal(tC.expected.Shape(), yMat.Shape()) + c.Equal(tC.expected.Data(), yMat.Data()) + }) + } +} + func TestDense_SliceInto(t *testing.T) { V := New(WithShape(100), Of(Byte)) T := New(WithBacking([]float64{1, 2, 3, 4, 5, 6}), WithShape(2, 3)) diff --git a/example_dense_reduction_test.go b/example_dense_reduction_test.go index 1a536a9..fa97e98 100644 --- a/example_dense_reduction_test.go +++ b/example_dense_reduction_test.go @@ -9,7 +9,7 @@ func Example_sum_Sliced() { 1, 2, 3, 4, 5, 6, 7, 8, })) - s, _ := T.Slice(sli(1, 3), sli(1, 3)) + s, _ := T.Slice(S(1, 3), S(1, 3)) sum, _ := Sum(s) fmt.Printf("T:\n%v\nsliced:\n%v\nSum: %v", T, s, sum) diff --git a/example_mapreduce_test.go b/example_mapreduce_test.go index e08c6da..47bd2ce 100644 --- a/example_mapreduce_test.go +++ b/example_mapreduce_test.go @@ -35,7 +35,7 @@ func ExampleSum_sliced() { T := New(WithBacking([]float64{0, 1, 2, 3}), WithShape(2, 2)) fmt.Printf("T:\n%v\n", T) - V, _ := T.Slice(nil, sli(1)) + V, _ := T.Slice(nil, S(1)) fmt.Printf("V:\n%v\n", V) Σ, _ := Sum(V) @@ -75,7 +75,7 @@ func ExampleArgmax_sliced() { fmt.Printf("T:\n%v\n", T) // slice creates a view - V, _ := T.Slice(nil, sli(1)) + V, _ := T.Slice(nil, S(1)) // argmax along the x-axis am, _ := Argmax(V, 0) diff --git a/slice.go b/slice.go index ecba60d..41e1419 100644 --- a/slice.go +++ b/slice.go @@ -34,3 +34,39 @@ type ss int func (s ss) Start() int { return int(s) } func (s ss) End() int { return int(s) + 1 } func (s ss) Step() int { return 0 } + +// sli is slice. It's named sli to prevent confusion over naming +type sli struct { + start, end, step int +} + +// S creates a Slice. +// end is optional. It should be passed in as the first param of the optionals. +// step is optional. It should be passed in as the second param of the optionals. +// +// Default end is start+1. Default step is 1, unless end == step+1, then it defaults to 0 +func S(start int, opt ...int) Slice { + var end, step int + if len(opt) > 0 { + end = opt[0] + } else { + end = start + 1 + } + + step = 1 + if len(opt) > 1 { + step = opt[1] + } else if end == start+1 { + step = 0 + } + + return &sli{ + start: start, + end: end, + step: step, + } +} + +func (s *sli) Start() int { return s.start } +func (s *sli) End() int { return s.end } +func (s *sli) Step() int { return s.step } From 60b483543c1097f0a5a298836909e178b80d0359 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Tue, 19 Oct 2021 04:59:07 +1100 Subject: [PATCH 135/154] Narrow op (#122) * Implement the Narrow operation for Dense tensors. Also, implemented a default slicer * Added a new failing test. This seems like a bug with the non contiguous View implementation * Fixed the test. The test was correct * Added a package level API for Narrow Co-authored-by: David Cuadrado --- api_matop.go | 10 ++++++++++ dense_matop.go | 6 +++--- dense_matop_test.go | 4 ++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/api_matop.go b/api_matop.go index 2cb41a8..a9797c3 100644 --- a/api_matop.go +++ b/api_matop.go @@ -7,6 +7,16 @@ import ( // this file handles matops. While by default most of these matops should already have been defined as part of the // Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions +// Narrow narrows the tensor. +func Narrow(t Tensor, dim, start, length int) (View, error) { + dim = resolveAxis(dim, t.Dims()) + + slices := make([]Slice, MinInt(dim+1, t.Dims())) + slices[dim] = S(start, start+length, 1) + + return t.Slice(slices...) +} + // Repeat repeats a Tensor along the axis and given the number of repeats. func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { if r, ok := t.Engine().(Repeater); ok { diff --git a/dense_matop.go b/dense_matop.go index ea56b8a..b059f35 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -198,12 +198,12 @@ func (t *Dense) CopyTo(other *Dense) error { return errors.Errorf(methodNYI, "CopyTo", "views") } -// Narrow narrows the tensor -func (t *Dense) Narrow(dim, start, lenght int) (View, error) { +// Narrow narrows the tensor. +func (t *Dense) Narrow(dim, start, length int) (View, error) { dim = resolveAxis(dim, t.Dims()) slices := make([]Slice, MinInt(dim+1, t.Dims())) - slices[dim] = S(start, start+lenght, 1) + slices[dim] = S(start, start+length, 1) return t.Slice(slices...) } diff --git a/dense_matop_test.go b/dense_matop_test.go index 2b6a8b8..48c854f 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -710,6 +710,10 @@ func TestDense_Narrow(t *testing.T) { yMat := y.Materialize() c.Equal(tC.expected.Shape(), yMat.Shape()) c.Equal(tC.expected.Data(), yMat.Data()) + + // err = y.Memset(1024) + // c.Nil(err) + // t.Logf("example %d y \n%v\n%v", i+1, y, y.Data()) }) } } From 2c194a81953044a98900ea941fa45e8ba0601fd5 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 19 Oct 2021 07:09:42 +1100 Subject: [PATCH 136/154] fix a bug in ByIdx --- defaultengine_selbyidx.go | 26 ++++++++-------- dense_selbyidx_test.go | 64 +++++++++++++++++++++++++-------------- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index 88fa5c3..c7f9fa8 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -131,38 +131,38 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da } // SelectByIndicesB is the backwards function of SelectByIndices. -func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { +func (e StdEng) SelectByIndicesB(output, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !indices.Shape().IsVectorLike() { - return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape()) + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape()) } if indices.Dtype() != Int { - return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype()) + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype()) } // if b is a scalar, then use Slice - if a.Shape().IsScalarEquiv() { - slices := make([]Slice, a.Shape().Dims()) - slices[axis] = ss(b.Data().([]int)[0]) - return a.Slice(slices...) + if output.Shape().IsScalarEquiv() { + slices := make([]Slice, output.Shape().Dims()) + slices[axis] = ss(outGrad.Data().([]int)[0]) + return output.Slice(slices...) } - expectedShape := a.Shape().Clone() + expectedShape := output.Shape().Clone() var reuse DenseTensor var _, toReuse, _ bool - if reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + if reuse, _, toReuse, _, _, err = handleFuncOpts(output.Shape(), output.Dtype(), output.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !toReuse && reuse == nil { // create reuse - reuse = New(WithShape(expectedShape...), Of(a.Dtype())) + reuse = New(WithShape(expectedShape...), Of(output.Dtype())) } - typ := a.Dtype().Type + typ := output.Dtype().Type var _, dataB, dataReuse *storage.Header var _, bit, iit Iterator var useIter bool - if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil { + if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(output, outGrad, reuse); err != nil { return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB") } @@ -172,7 +172,7 @@ func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt return } - e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, b.(*Dense).AP, reuse.(*Dense).AP) + e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP) return reuse, nil } diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go index 86369be..e542133 100644 --- a/dense_selbyidx_test.go +++ b/dense_selbyidx_test.go @@ -19,6 +19,9 @@ type selByIndicesTest struct { } var selByIndicesTests = []selByIndicesTest{ + {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, + Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, + }, {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, @@ -35,7 +38,7 @@ var selByIndicesTests = []selByIndicesTest{ Correct: []int{1, 1}, CorrectShape: Shape{2}}, {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, - {Name: "(2,1) Matrx (colvec)m with (10) indies", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, }, } @@ -60,38 +63,47 @@ var selByIndicesBTests = []struct { CorrectGrad interface{} CorrectGradShape Shape }{ + // Basic + { + CorrectGrad: []float64{1, 1, 1, 1}, + }, + // 3-tensor, axis 0 + { + CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + // 3-tensor, axis 1 { - selByIndicesTest: selByIndicesTests[0], - CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 16, 18, 20, 22, 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0}, - CorrectGradShape: Shape{3, 2, 4}, + CorrectGrad: []float64{0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2}, }, + // 3-tensor, axis 2 { - selByIndicesTest: selByIndicesTests[1], - CorrectGrad: []float64{0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 0, 24, 26, 28, 30, 0, 0, 0, 0, 40, 42, 44, 46}, - CorrectGradShape: Shape{3, 2, 4}, + CorrectGrad: []float64{0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0}, }, + // vector, axis 0 { - selByIndicesTest: selByIndicesTests[2], - CorrectGrad: []float64{0, 2, 0, 0, 0, 10, 0, 0, 0, 18, 0, 0, 0, 26, 0, 0, 0, 34, 0, 0, 0, 42, 0, 0}, - CorrectGradShape: Shape{3, 2, 4}, + CorrectGrad: []int{0, 2, 0, 0, 0}, }, + // vector, axis 1 { - selByIndicesTest: selByIndicesTests[3], - CorrectGrad: []int{0, 2, 0, 0, 0}, - CorrectGradShape: Shape{5}, + CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, }, + // (4,2) Matrix with (10) indices { - selByIndicesTest: selByIndicesTests[5], - CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, - CorrectGradShape: Shape{4, 2}, + CorrectGrad: []float32{2, 2, 4, 4, 4, 4, 0, 0}, }, + // (2, 1) Matrix (colvec) with (10) indices { - selByIndicesTest: selByIndicesTests[6], - CorrectGrad: []float64{0, 10}, - CorrectGradShape: Shape{2, 1}, + CorrectGrad: []float64{0, 10}, }, } +func init() { + for i := range selByIndicesBTests { + selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] + selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape + } +} + func TestDense_SelectByIndicesB(t *testing.T) { assert := assert.New(t) @@ -102,12 +114,20 @@ func TestDense_SelectByIndicesB(t *testing.T) { if checkErr(t, tc.WillErr, err, tc.Name, i) { continue } - grad, err := ByIndicesB(T, ret, indices, tc.Axis) + outGrad := ret.Clone().(*Dense) + switch outGrad.Dtype() { + case Float64: + outGrad.Memset(1.0) + case Float32: + outGrad.Memset(float32(1.0)) + } + + grad, err := ByIndicesB(T, outGrad, indices, tc.Axis) if checkErr(t, tc.WillErr, err, tc.Name, i) { continue } - assert.Equal(tc.CorrectGrad, grad.Data(), "%v", tc.Name) - assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead", tc.Name, tc.CorrectGradShape, grad.Shape()) + assert.Equal(tc.CorrectGrad, grad.Data(), "%v - x:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, T, indices, ret, grad) + assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead.\n\nx:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, tc.CorrectGradShape, grad.Shape(), T, indices, ret, grad) } } From 37fa72d41c54f2abda98480c521badb5ca123d97 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 19 Oct 2021 07:10:07 +1100 Subject: [PATCH 137/154] Added some logging to the Narrow test to ensure sanity --- dense_matop_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dense_matop_test.go b/dense_matop_test.go index 48c854f..652c71d 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -703,17 +703,19 @@ func TestDense_Narrow(t *testing.T) { for i, tC := range testCases { t.Run(fmt.Sprintf("Example #%d narrow(%v,%d,%d,%v)", i+1, tC.x.Shape(), tC.dim, tC.start, tC.length), func(t *testing.T) { c := assert.New(t) + // t.Logf("X:\n%v", tC.x) y, err := tC.x.Narrow(tC.dim, tC.start, tC.length) c.NoError(err) + // t.Logf("y:\n%v", y) yMat := y.Materialize() c.Equal(tC.expected.Shape(), yMat.Shape()) c.Equal(tC.expected.Data(), yMat.Data()) - - // err = y.Memset(1024) - // c.Nil(err) - // t.Logf("example %d y \n%v\n%v", i+1, y, y.Data()) + + // err = y.Memset(1024) + // c.NoError(err) + // t.Logf("After Memset\nY: %v\nX:\n%v", y, tC.x) }) } } From 0dd0c0ad3fe7576f791524dd4837016f59583532 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 19 Oct 2021 07:10:30 +1100 Subject: [PATCH 138/154] Added comments and documentation to softmax --- defaultengine_softmax.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index 1a22675..c3b0e99 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -18,6 +18,11 @@ func resolveAxis(axis int, dims int) int { return res } +// SoftMax performs the softmax operation on the given tensor. Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. +// +// The softmax function is defined as : +// σ(x) = e^x_i / Σ(e^x_i) func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { axis = resolveAxis(axis, x.Dims()) expectedShape := x.Shape().Clone() @@ -52,6 +57,8 @@ func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err return reuse, nil } +// SoftMaxB computes gradient of the input `x`, given the `output = SoftMax(x)` and its associated gradient. Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !output.Shape().Eq(grad.Shape()) { return nil, fmt.Errorf("output and grad shapes don't match") @@ -94,6 +101,10 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal return reuse, nil } +// LogSoftMax performs softmax but in log space. This provides some amount of numerical stabilization. +// Conceptually it is the same as performing a logarithm after applying the softmax function. +// Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { axis = resolveAxis(axis, x.Dims()) expectedShape := x.Shape().Clone() @@ -128,6 +139,9 @@ func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, return reuse, nil } +// LogSoftMaxB computes the gradient of the input `x`, given the `output = LogSoftmax(x)` and its associated gradient. +// Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !output.Shape().Eq(grad.Shape()) { return nil, fmt.Errorf("output and grad shapes don't match") From fde293bf678733f935e91c8b7a07e19cfd44a76c Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 19 Oct 2021 07:11:48 +1100 Subject: [PATCH 139/154] Renamed the variables in SelectByIndicesB and SelectByIndices for better clarity. --- defaultengine_selbyidx.go | 41 ++++++++++++++++++++------------------- engine.go | 2 +- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index c7f9fa8..8007dcc 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -7,23 +7,24 @@ import ( "reflect" ) -func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { - if !b.Shape().IsVectorLike() { - return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape()) +// SelectByIndices selects the values given the in `indices` tensor. +func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !indices.Shape().IsVectorLike() { + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape()) } - if b.Dtype() != Int { - return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype()) + if indices.Dtype() != Int { + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype()) } // if b is a scalar, then use Slice if a.Shape().IsScalarEquiv() { slices := make([]Slice, a.Shape().Dims()) - slices[axis] = ss(b.Data().([]int)[0]) + slices[axis] = ss(indices.Data().([]int)[0]) return a.Slice(slices...) } expectedShape := a.Shape().Clone() - expectedShape[axis] = b.Shape().TotalSize() + expectedShape[axis] = indices.Shape().TotalSize() var reuse DenseTensor var safe, toReuse, _ bool @@ -36,9 +37,9 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal } if !safe { - if a.Shape()[axis] != b.Shape().TotalSize() { + if a.Shape()[axis] != indices.Shape().TotalSize() { expected := a.Shape().Clone() - expected[axis] = b.Shape().TotalSize() + expected[axis] = indices.Shape().TotalSize() return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape()) } @@ -49,7 +50,7 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator var useIter bool - if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil { + if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil { return nil, errors.Wrapf(err, "StdEng.Add") } @@ -130,8 +131,8 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da } } -// SelectByIndicesB is the backwards function of SelectByIndices. -func (e StdEng) SelectByIndicesB(output, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { +// SelectByIndicesB computes the gradient of the result of `SelectByIndices`. +func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !indices.Shape().IsVectorLike() { return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape()) } @@ -140,29 +141,29 @@ func (e StdEng) SelectByIndicesB(output, outGrad, indices Tensor, axis int, opts } // if b is a scalar, then use Slice - if output.Shape().IsScalarEquiv() { - slices := make([]Slice, output.Shape().Dims()) + if input.Shape().IsScalarEquiv() { + slices := make([]Slice, input.Shape().Dims()) slices[axis] = ss(outGrad.Data().([]int)[0]) - return output.Slice(slices...) + return input.Slice(slices...) } - expectedShape := output.Shape().Clone() + expectedShape := input.Shape().Clone() var reuse DenseTensor var _, toReuse, _ bool - if reuse, _, toReuse, _, _, err = handleFuncOpts(output.Shape(), output.Dtype(), output.DataOrder(), true, opts...); err != nil { + if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !toReuse && reuse == nil { // create reuse - reuse = New(WithShape(expectedShape...), Of(output.Dtype())) + reuse = New(WithShape(expectedShape...), Of(input.Dtype())) } - typ := output.Dtype().Type + typ := input.Dtype().Type var _, dataB, dataReuse *storage.Header var _, bit, iit Iterator var useIter bool - if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(output, outGrad, reuse); err != nil { + if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil { return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB") } diff --git a/engine.go b/engine.go index 5730c60..39e3f04 100644 --- a/engine.go +++ b/engine.go @@ -410,7 +410,7 @@ type InfChecker interface { // ByIndiceser allows for values in tensor `a` to be selected by the indices listed in the `indices` tensor. type ByIndiceser interface { SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) - SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) } /* Internal interfaces for faster shit */ From 32c82b5ec85bce51ba4fb017e1359b50e89b7a56 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 19 Oct 2021 15:31:57 +1100 Subject: [PATCH 140/154] Added example of ByIndices --- defaultengine_selbyidx.go | 6 ++++ example_byindices_test.go | 74 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 example_byindices_test.go diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index 8007dcc..b99c69e 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -8,6 +8,9 @@ import ( ) // SelectByIndices selects the values given the in `indices` tensor. +// +// Currently SelectByIndices only supports Dense tensors that do not require the use of iterators. +// Please make a pull request to support tensors that require the use of an iterator to traverse data. func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !indices.Shape().IsVectorLike() { return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape()) @@ -132,6 +135,9 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da } // SelectByIndicesB computes the gradient of the result of `SelectByIndices`. +// +// Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators. +// Please make a pull request to support tensors that require the use of an iterator to traverse data. func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !indices.Shape().IsVectorLike() { return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape()) diff --git a/example_byindices_test.go b/example_byindices_test.go new file mode 100644 index 0000000..7f94781 --- /dev/null +++ b/example_byindices_test.go @@ -0,0 +1,74 @@ +package tensor + +import "fmt" + +func ExampleByIndices() { + a := New(WithShape(2, 2), WithBacking([]float64{ + 100, 200, + 300, 400, + })) + indices := New(WithBacking([]int{1, 1, 1, 0, 1})) + b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1 + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\n", a, indices, b) + + // Output: + // a: + // ⎡100 200⎤ + // ⎣300 400⎦ + // + // indices: [1 1 1 0 1] + // b: + // ⎡300 400⎤ + // ⎢300 400⎥ + // ⎢300 400⎥ + // ⎢100 200⎥ + // ⎣300 400⎦ + +} + +func ExampleByIndicesB() { + a := New(WithShape(2, 2), WithBacking([]float64{ + 100, 200, + 300, 400, + })) + indices := New(WithBacking([]int{1, 1, 1, 0, 1})) + b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1 + if err != nil { + fmt.Println(err) + return + } + + outGrad := b.Clone().(*Dense) + outGrad.Memset(1.0) + + grad, err := ByIndicesB(a, outGrad, indices, 0) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\ngrad:\n%v", a, indices, b, grad) + + // Output: + // a: + // ⎡100 200⎤ + // ⎣300 400⎦ + // + // indices: [1 1 1 0 1] + // b: + // ⎡300 400⎤ + // ⎢300 400⎥ + // ⎢300 400⎥ + // ⎢100 200⎥ + // ⎣300 400⎦ + // + // grad: + // ⎡1 1⎤ + // ⎣4 4⎦ + +} From e546815db9116f5d376287eefa42fbf5588a537d Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 19 Oct 2021 15:43:22 +1100 Subject: [PATCH 141/154] Removed the allocation as suggested by @dcu --- defaultengine_softmax.go | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index c3b0e99..368cfa7 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -264,22 +264,22 @@ func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, log dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) } } else { - mul := make([]float64, dimSize) - + //mul := make([]float64, dimSize) + var sum float64 for j := 0; j < dimSize; j++ { i := ii*dimSize + j - mul[j] = outputArr[i] * gradArr[i] + // mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] } - sum := mul[0] - for j := 1; j < dimSize; j++ { - sum += mul[j] - } + //sum := mul[0] + //for j := 1; j < dimSize; j++ { + // sum += mul[j] + //} for j := 0; j < dimSize; j++ { i := ii*dimSize + j - dx[i] = (gradArr[i] - sum) * outputArr[i] } } @@ -481,18 +481,19 @@ func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, log dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) } } else { - mul := make([]float32, dimSize) - + // mul := make([]float32, dimSize) + var sum float32 for j := 0; j < dimSize; j++ { i := ii*dimSize + j - mul[j] = outputArr[i] * gradArr[i] + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] } - sum := mul[0] - for j := 1; j < dimSize; j++ { - sum += mul[j] - } + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + //} for j := 0; j < dimSize; j++ { i := ii*dimSize + j From 32e4381e29e64100245bee3b15b0f68ac2288b8c Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 20 Oct 2021 06:15:57 +1100 Subject: [PATCH 142/154] Switch to using `getFloat64s` and `getFloat32s` (new utility func) to reduce allocations Results vs prev: benchmark old ns/op new ns/op delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 2237 2057 -8.05% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 2138 1920 -10.20% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 2112 1798 -14.87% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 2123 1844 -13.14% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 2236 1937 -13.37% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 2305 2040 -11.50% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 2167 1931 -10.89% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 2261 1884 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 2119 2035 -3.96% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 2143 1846 -13.86% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 2212 1821 -17.68% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 2164 1930 -10.81% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 36898948 36137745 -2.06% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 35541861 35019509 -1.47% benchmark old allocs new allocs delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 16 12 -25.00% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 17 13 -23.53% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 17 13 -23.53% benchmark old bytes new bytes delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 664 568 -14.46% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 616 520 -15.58% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 664 568 -14.46% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 616 520 -15.58% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 648 552 -14.81% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 19392926 19392912 -0.00% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 9701448 9701351 -0.00% --- defaultengine_selbyidx.go | 4 +-- defaultengine_softmax.go | 57 ++++++++++++++++++++------------------- interfaces.go | 1 + utils.go | 21 +++++++++++++++ 4 files changed, 53 insertions(+), 30 deletions(-) diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index b99c69e..e0564e6 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -22,7 +22,7 @@ func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (r // if b is a scalar, then use Slice if a.Shape().IsScalarEquiv() { slices := make([]Slice, a.Shape().Dims()) - slices[axis] = ss(indices.Data().([]int)[0]) + slices[axis] = ss(getInts(indices)[0]) return a.Slice(slices...) } @@ -179,7 +179,7 @@ func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts return } - e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP) + e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP) return reuse, nil } diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index 368cfa7..b30cbd2 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -185,8 +185,9 @@ func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (ret } func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax bool) { - outputArr := output.Data().([]float64) - xArr := x.Data().([]float64) + outputArr := getFloat64s(output) + xArr := getFloat64s(x) + xShape := x.Shape() outerSize := 1 @@ -237,9 +238,9 @@ func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax } func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { - dx := inputGrad.Data().([]float64) - outputArr := output.Data().([]float64) - gradArr := grad.Data().([]float64) + dx := getFloat64s(inputGrad) + outputArr := getFloat64s(output) + gradArr := getFloat64s(grad) outputShape := output.Shape() @@ -269,14 +270,14 @@ func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, log for j := 0; j < dimSize; j++ { i := ii*dimSize + j - // mul[j] = outputArr[i] * gradArr[i] + //mul[j] = outputArr[i] * gradArr[i] sum += outputArr[i] * gradArr[i] } - //sum := mul[0] - //for j := 1; j < dimSize; j++ { - // sum += mul[j] - //} + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } for j := 0; j < dimSize; j++ { i := ii*dimSize + j @@ -302,8 +303,8 @@ func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax dimStride := innerSize outerStride := dimSize * dimStride - outputArr := output.Data().([]float64) - xArr := x.Data().([]float64) + outputArr := getFloat64s(output) + xArr := getFloat64s(x) for ii := 0; ii < innerSize*outerSize; ii++ { outerIndex, innerIndex := divmod(ii, innerSize) @@ -367,9 +368,9 @@ func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, lo dimStride := innerSize outerStride := dimSize * dimStride - dxArr := inputGrad.Data().([]float64) - outputArr := output.Data().([]float64) - gradArr := grad.Data().([]float64) + dxArr := getFloat64s(inputGrad) + outputArr := getFloat64s(output) + gradArr := getFloat64s(grad) for ii := 0; ii < innerSize*outerSize; ii++ { outerIndex, innerIndex := divmod(ii, innerSize) @@ -402,8 +403,8 @@ func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, lo } func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) { - outputArr := output.Data().([]float32) - xArr := x.Data().([]float32) + outputArr := getFloat32s(output) + xArr := getFloat32s(x) xShape := x.Shape() outerSize := 1 @@ -454,9 +455,9 @@ func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax } func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { - dx := inputGrad.Data().([]float32) - outputArr := output.Data().([]float32) - gradArr := grad.Data().([]float32) + dx := getFloat32s(inputGrad) + outputArr := getFloat32s(output) + gradArr := getFloat32s(grad) outputShape := output.Shape() @@ -481,7 +482,7 @@ func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, log dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) } } else { - // mul := make([]float32, dimSize) + //mul := make([]float32, dimSize) var sum float32 for j := 0; j < dimSize; j++ { i := ii*dimSize + j @@ -492,8 +493,8 @@ func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, log // sum := mul[0] // for j := 1; j < dimSize; j++ { - // sum += mul[j] - //} + // sum += mul[j] + // } for j := 0; j < dimSize; j++ { i := ii*dimSize + j @@ -520,8 +521,8 @@ func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax dimStride := innerSize outerStride := dimSize * dimStride - outputArr := output.Data().([]float32) - xArr := x.Data().([]float32) + outputArr := getFloat32s(output) + xArr := getFloat32s(x) for ii := 0; ii < innerSize*outerSize; ii++ { outerIndex, innerIndex := divmod(ii, innerSize) @@ -585,9 +586,9 @@ func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, lo dimStride := innerSize outerStride := dimSize * dimStride - dxArr := inputGrad.Data().([]float32) - outputArr := output.Data().([]float32) - gradArr := grad.Data().([]float32) + dxArr := getFloat32s(inputGrad) + outputArr := getFloat32s(output) + gradArr := getFloat32s(grad) for ii := 0; ii < innerSize*outerSize; ii++ { outerIndex, innerIndex := divmod(ii, innerSize) diff --git a/interfaces.go b/interfaces.go index c0fd7e3..e33502f 100644 --- a/interfaces.go +++ b/interfaces.go @@ -144,6 +144,7 @@ type unsafeMem interface { Set(i int, x interface{}) GetF64(i int) float64 GetF32(i int) float32 + Ints() []int Float64s() []float64 Float32s() []float32 Complex64s() []complex64 diff --git a/utils.go b/utils.go index 064c812..2b3aa65 100644 --- a/utils.go +++ b/utils.go @@ -300,6 +300,27 @@ func allones(a []int) bool { return true } +func getFloat64s(a Tensor) []float64 { + if um, ok := a.(unsafeMem); ok { + return um.Float64s() + } + return a.Data().([]float64) +} + +func getFloat32s(a Tensor) []float32 { + if um, ok := a.(unsafeMem); ok { + return um.Float32s() + } + return a.Data().([]float32) +} + +func getInts(a Tensor) []int { + if um, ok := a.(unsafeMem); ok { + return um.Ints() + } + return a.Data().([]int) +} + /* FOR ILLUSTRATIVE PURPOSES */ // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) From 923510bda9d049d25f4c77e3638b7f96b73a34c4 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 20 Oct 2021 06:47:10 +1100 Subject: [PATCH 143/154] Removed unnecessary calls to .Clone() which reduces the number of allocs. Results: ``` benchmark old ns/op new ns/op delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 2057 1619 -21.29% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 1920 1563 -18.59% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 1798 1508 -16.13% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 1844 1575 -14.59% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 1937 1836 -5.21% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 2040 1672 -18.04% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 1931 1704 -11.76% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 1884 1542 -18.15% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 2035 1558 -23.44% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 1846 1626 -11.92% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 1821 1552 -14.77% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 1930 1499 -22.33% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 36137745 36795574 +1.82% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 35019509 34759423 -0.74% benchmark old allocs new allocs delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 12 10 -16.67% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 13 11 -15.38% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 13 11 -15.38% benchmark old bytes new bytes delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 568 528 -7.04% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 520 480 -7.69% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 568 528 -7.04% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 520 480 -7.69% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 552 504 -8.70% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 19392912 19392892 -0.00% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 9701351 9701312 -0.00% ``` --- defaultengine_softmax.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index b30cbd2..b70da25 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -25,7 +25,7 @@ func resolveAxis(axis int, dims int) int { // σ(x) = e^x_i / Σ(e^x_i) func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { axis = resolveAxis(axis, x.Dims()) - expectedShape := x.Shape().Clone() + expectedShape := x.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -69,7 +69,7 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal } axis = resolveAxis(axis, output.Dims()) - expectedShape := output.Shape().Clone() + expectedShape := output.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -107,7 +107,7 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal // Please make a pull request to support sparse tensors. func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { axis = resolveAxis(axis, x.Dims()) - expectedShape := x.Shape().Clone() + expectedShape := x.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -152,7 +152,7 @@ func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (ret } axis = resolveAxis(axis, output.Dims()) - expectedShape := output.Shape().Clone() + expectedShape := output.Shape() var reuse DenseTensor var safe, toReuse, _ bool From 8e712327720be39776d43290f5bf2fd6ec4e65db Mon Sep 17 00:00:00 2001 From: chewxy Date: Fri, 22 Oct 2021 09:27:27 +1100 Subject: [PATCH 144/154] Parallelized the softmax code --- defaultengine_softmax.go | 479 ++++++++++++++++++++++----------------- 1 file changed, 270 insertions(+), 209 deletions(-) diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index b70da25..ffc5a06 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -3,6 +3,7 @@ package tensor import ( "fmt" "math" + "sync" "github.com/chewxy/math32" "github.com/pkg/errors" @@ -196,45 +197,52 @@ func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax outerSize *= xShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { - maxInput := xArr[0] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - if xArr[i] > maxInput { - maxInput = xArr[i] + if xArr[i] > maxInput { + maxInput = xArr[i] + } } - } - sumExp := float64(0.0) - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - z := xArr[i] - maxInput - exp := math.Exp(z) + sumExp := float64(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math.Exp(z) - if logSoftMax { - outputArr[i] = z - } else { - outputArr[i] = exp - } + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } - sumExp += exp - } + sumExp += exp + } - if !logSoftMax { - sumExp = 1 / sumExp - } + if !logSoftMax { + sumExp = 1 / sumExp + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - if logSoftMax { - outputArr[i] -= math.Log(sumExp) - } else { - outputArr[i] *= sumExp + if logSoftMax { + outputArr[i] -= math.Log(sumExp) + } else { + outputArr[i] *= sumExp + } } - } + wg.Done() + }(ii, &wg) + } + wg.Wait() } func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { @@ -250,41 +258,51 @@ func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, log outerSize *= outputShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { + wg.Add(1) if logSoftMax { - sum := gradArr[ii*dimSize] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j - - sum += gradArr[i] - } - - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + go func(gradArr, dx []float64, ii int, wg *sync.WaitGroup) { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) - } - } else { - //mul := make([]float64, dimSize) - var sum float64 - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + sum += gradArr[i] + } - //mul[j] = outputArr[i] * gradArr[i] - sum += outputArr[i] * gradArr[i] - } + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - // sum := mul[0] - // for j := 1; j < dimSize; j++ { - // sum += mul[j] - // } + dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) + } + wg.Done() + }(gradArr, dx, ii, &wg) - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - dx[i] = (gradArr[i] - sum) * outputArr[i] - } + } else { + go func(outputArr, gradArr, dx []float64, ii int, wg *sync.WaitGroup) { + //mul := make([]float64, dimSize) + var sum float64 + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] + } + + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + wg.Done() + }(outputArr, gradArr, dx, ii, &wg) } } + wg.Wait() } func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax bool) { @@ -306,50 +324,56 @@ func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax outputArr := getFloat64s(output) xArr := getFloat64s(x) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - inputPart := xArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - maxInput := inputPart[0] - for j := 1; j < dimSize; j++ { - i := j * dimStride + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride - if inputPart[i] > maxInput { - maxInput = inputPart[i] + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } } - } - - sumExp := 0.0 - for j := 0; j < dimSize; j++ { - i := j * dimStride - - exp := math.Exp(inputPart[i] - maxInput) - if !logSoftmax { - outputPart[i] = exp - } + sumExp := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride - sumExp += exp - } + exp := math.Exp(inputPart[i] - maxInput) - if logSoftmax { - sumExp = math.Log(sumExp) - } else { - sumExp = 1 / sumExp - } + if !logSoftmax { + outputPart[i] = exp + } - for j := 0; j < dimSize; j++ { - i := j * dimStride + sumExp += exp + } if logSoftmax { - outputPart[i] = inputPart[i] - maxInput - sumExp + sumExp = math.Log(sumExp) } else { - outputPart[i] *= sumExp + sumExp = 1 / sumExp } - } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) } + wg.Wait() } func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { @@ -372,34 +396,41 @@ func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, lo outputArr := getFloat64s(output) gradArr := getFloat64s(grad) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - gradPart := gradArr[outerIndex*outerStride+innerIndex:] - dxPart := dxArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - sum := 0.0 - for j := 0; j < dimSize; j++ { - i := j * dimStride + sum := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - sum += gradPart[i] - } else { - sum += gradPart[i] * outputPart[i] + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } } - } - for j := 0; j < dimSize; j++ { - i := j * dimStride + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum - } else { - dxPart[i] = outputPart[i] * (gradPart[i] - sum) + if logSoftmax { + dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } } - } + wg.Done() + }(ii, &wg) + } + wg.Wait() } func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) { @@ -413,45 +444,51 @@ func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax outerSize *= xShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { - maxInput := xArr[0] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - if xArr[i] > maxInput { - maxInput = xArr[i] + if xArr[i] > maxInput { + maxInput = xArr[i] + } } - } - sumExp := float32(0.0) - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - z := xArr[i] - maxInput - exp := math32.Exp(z) + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math32.Exp(z) - if logSoftMax { - outputArr[i] = z - } else { - outputArr[i] = exp - } + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } - sumExp += exp - } + sumExp += exp + } - if !logSoftMax { - sumExp = 1 / sumExp - } + if !logSoftMax { + sumExp = 1 / sumExp + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - if logSoftMax { - outputArr[i] -= math32.Log(sumExp) - } else { - outputArr[i] *= sumExp + if logSoftMax { + outputArr[i] -= math32.Log(sumExp) + } else { + outputArr[i] *= sumExp + } } - } + wg.Done() + }(ii, &wg) } + wg.Wait() } func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { @@ -467,42 +504,52 @@ func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, log outerSize *= outputShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { + wg.Add(1) + if logSoftMax { - sum := gradArr[ii*dimSize] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + go func(ii int, wg *sync.WaitGroup) { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - sum += gradArr[i] - } + sum += gradArr[i] + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) - } + dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) + } + wg.Done() + }(ii, &wg) } else { - //mul := make([]float32, dimSize) - var sum float32 - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - - //mul[j] = outputArr[i] * gradArr[i] - sum += outputArr[i] * gradArr[i] - } - - // sum := mul[0] - // for j := 1; j < dimSize; j++ { - // sum += mul[j] - // } - - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - - dx[i] = (gradArr[i] - sum) * outputArr[i] - } + go func(ii int, wg *sync.WaitGroup) { + //mul := make([]float32, dimSize) + var sum float32 + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] + } + + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + wg.Done() + }(ii, &wg) } } + wg.Wait() } func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax bool) { @@ -524,50 +571,57 @@ func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax outputArr := getFloat32s(output) xArr := getFloat32s(x) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) - - inputPart := xArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + wg.Add(1) - maxInput := inputPart[0] - for j := 1; j < dimSize; j++ { - i := j * dimStride - - if inputPart[i] > maxInput { - maxInput = inputPart[i] - } - } + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - sumExp := float32(0.0) - for j := 0; j < dimSize; j++ { - i := j * dimStride + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - exp := math32.Exp(inputPart[i] - maxInput) + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride - if !logSoftmax { - outputPart[i] = exp + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } } - sumExp += exp - } + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - sumExp = math32.Log(sumExp) - } else { - sumExp = 1 / sumExp - } + exp := math32.Exp(inputPart[i] - maxInput) + + if !logSoftmax { + outputPart[i] = exp + } - for j := 0; j < dimSize; j++ { - i := j * dimStride + sumExp += exp + } if logSoftmax { - outputPart[i] = inputPart[i] - maxInput - sumExp + sumExp = math32.Log(sumExp) } else { - outputPart[i] *= sumExp + sumExp = 1 / sumExp } - } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) } + wg.Wait() } func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { @@ -590,32 +644,39 @@ func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, lo outputArr := getFloat32s(output) gradArr := getFloat32s(grad) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) + wg.Add(1) - gradPart := gradArr[outerIndex*outerStride+innerIndex:] - dxPart := dxArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - sum := float32(0.0) - for j := 0; j < dimSize; j++ { - i := j * dimStride + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - if logSoftmax { - sum += gradPart[i] - } else { - sum += gradPart[i] * outputPart[i] + sum := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } } - } - for j := 0; j < dimSize; j++ { - i := j * dimStride + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum - } else { - dxPart[i] = outputPart[i] * (gradPart[i] - sum) + if logSoftmax { + dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } } - } + wg.Done() + }(ii, &wg) } + wg.Wait() } From 0aa7e64df0ebcc80b275f5da368a059fbf631f72 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sat, 23 Oct 2021 05:36:32 +1100 Subject: [PATCH 145/154] Optimizations (#121) * fix a bug in ByIdx * Added comments and documentation to softmax * Renamed the variables in SelectByIndicesB and SelectByIndices for better clarity. * Added example of ByIndices * Removed the allocation in `SoftMax` and `SoftMaxB` as suggested by @dcu * Switch to using `getFloat64s` and `getFloat32s` (new utility func) to reduce allocations ``` Results vs prev: benchmark old ns/op new ns/op delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 2237 2057 -8.05% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 2138 1920 -10.20% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 2112 1798 -14.87% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 2123 1844 -13.14% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 2236 1937 -13.37% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 2305 2040 -11.50% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 2167 1931 -10.89% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 2261 1884 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 2119 2035 -3.96% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 2143 1846 -13.86% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 2212 1821 -17.68% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 2164 1930 -10.81% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 36898948 36137745 -2.06% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 35541861 35019509 -1.47% benchmark old allocs new allocs delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 16 12 -25.00% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 17 13 -23.53% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 17 13 -23.53% benchmark old bytes new bytes delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 664 568 -14.46% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 616 520 -15.58% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 664 568 -14.46% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 616 520 -15.58% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 648 552 -14.81% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 19392926 19392912 -0.00% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 9701448 9701351 -0.00% ``` * `SoftMax` and `SoftMaxB` optimization: Removed unnecessary calls to .Clone() which reduces the number of allocs. Results: ``` benchmark old ns/op new ns/op delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 2057 1619 -21.29% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 1920 1563 -18.59% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 1798 1508 -16.13% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 1844 1575 -14.59% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 1937 1836 -5.21% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 2040 1672 -18.04% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 1931 1704 -11.76% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 1884 1542 -18.15% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 2035 1558 -23.44% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 1846 1626 -11.92% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 1821 1552 -14.77% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 1930 1499 -22.33% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 36137745 36795574 +1.82% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 35019509 34759423 -0.74% benchmark old allocs new allocs delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 12 10 -16.67% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 13 11 -15.38% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 13 11 -15.38% benchmark old bytes new bytes delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 568 528 -7.04% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 520 480 -7.69% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 568 528 -7.04% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 520 480 -7.69% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 552 504 -8.70% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 19392912 19392892 -0.00% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 9701351 9701312 -0.00% ``` * Parallelized the softmax code --- defaultengine_selbyidx.go | 59 +++-- defaultengine_softmax.go | 541 ++++++++++++++++++++++---------------- dense_matop_test.go | 10 +- dense_selbyidx_test.go | 64 +++-- engine.go | 2 +- example_byindices_test.go | 74 ++++++ interfaces.go | 1 + utils.go | 21 ++ 8 files changed, 487 insertions(+), 285 deletions(-) create mode 100644 example_byindices_test.go diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index cdcc318..e0564e6 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -7,23 +7,27 @@ import ( "reflect" ) -func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { - if !b.Shape().IsVectorLike() { - return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape()) +// SelectByIndices selects the values given the in `indices` tensor. +// +// Currently SelectByIndices only supports Dense tensors that do not require the use of iterators. +// Please make a pull request to support tensors that require the use of an iterator to traverse data. +func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if !indices.Shape().IsVectorLike() { + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape()) } - if b.Dtype() != Int { - return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype()) + if indices.Dtype() != Int { + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype()) } // if b is a scalar, then use Slice if a.Shape().IsScalarEquiv() { slices := make([]Slice, a.Shape().Dims()) - slices[axis] = ss(b.Data().([]int)[0]) + slices[axis] = ss(getInts(indices)[0]) return a.Slice(slices...) } expectedShape := a.Shape().Clone() - expectedShape[axis] = b.Shape().TotalSize() + expectedShape[axis] = indices.Shape().TotalSize() var reuse DenseTensor var safe, toReuse, _ bool @@ -36,9 +40,9 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal } if !safe { - if a.Shape()[axis] != b.Shape().TotalSize() { + if a.Shape()[axis] != indices.Shape().TotalSize() { expected := a.Shape().Clone() - expected[axis] = b.Shape().TotalSize() + expected[axis] = indices.Shape().TotalSize() return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape()) } @@ -49,7 +53,7 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator var useIter bool - if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil { + if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil { return nil, errors.Wrapf(err, "StdEng.Add") } @@ -130,39 +134,42 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da } } -// SelectByIndicesB is the backwards function of SelectByIndices. -func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { +// SelectByIndicesB computes the gradient of the result of `SelectByIndices`. +// +// Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators. +// Please make a pull request to support tensors that require the use of an iterator to traverse data. +func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !indices.Shape().IsVectorLike() { - return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape()) + return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape()) } if indices.Dtype() != Int { - return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype()) + return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype()) } // if b is a scalar, then use Slice - if a.Shape().IsScalarEquiv() { - slices := make([]Slice, a.Shape().Dims()) - slices[axis] = ss(b.Data().([]int)[0]) - return a.Slice(slices...) + if input.Shape().IsScalarEquiv() { + slices := make([]Slice, input.Shape().Dims()) + slices[axis] = ss(outGrad.Data().([]int)[0]) + return input.Slice(slices...) } - expectedShape := a.Shape().Clone() + expectedShape := input.Shape().Clone() var reuse DenseTensor var _, toReuse, _ bool - if reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !toReuse && reuse == nil { // create reuse - reuse = New(WithShape(expectedShape...), Of(a.Dtype())) + reuse = New(WithShape(expectedShape...), Of(input.Dtype())) } - typ := a.Dtype().Type + typ := input.Dtype().Type var _, dataB, dataReuse *storage.Header var _, bit, iit Iterator var useIter bool - if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil { + if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil { return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB") } @@ -172,7 +179,7 @@ func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt return } - e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, b.(*Dense).AP, reuse.(*Dense).AP) + e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP) return reuse, nil } @@ -228,8 +235,8 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data for i, idx := range indices { dstCoord[axis] = idx srcCoord[axis] = i - dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...) - start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...) + dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) + start, _ := Ltoi(apB.shape, apB.strides, srcCoord...) for o := 0; o < outer; o++ { dstEnd := dstStart + axStride diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index 1a22675..ffc5a06 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -3,6 +3,7 @@ package tensor import ( "fmt" "math" + "sync" "github.com/chewxy/math32" "github.com/pkg/errors" @@ -18,9 +19,14 @@ func resolveAxis(axis int, dims int) int { return res } +// SoftMax performs the softmax operation on the given tensor. Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. +// +// The softmax function is defined as : +// σ(x) = e^x_i / Σ(e^x_i) func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { axis = resolveAxis(axis, x.Dims()) - expectedShape := x.Shape().Clone() + expectedShape := x.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -52,6 +58,8 @@ func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err return reuse, nil } +// SoftMaxB computes gradient of the input `x`, given the `output = SoftMax(x)` and its associated gradient. Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !output.Shape().Eq(grad.Shape()) { return nil, fmt.Errorf("output and grad shapes don't match") @@ -62,7 +70,7 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal } axis = resolveAxis(axis, output.Dims()) - expectedShape := output.Shape().Clone() + expectedShape := output.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -94,9 +102,13 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal return reuse, nil } +// LogSoftMax performs softmax but in log space. This provides some amount of numerical stabilization. +// Conceptually it is the same as performing a logarithm after applying the softmax function. +// Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { axis = resolveAxis(axis, x.Dims()) - expectedShape := x.Shape().Clone() + expectedShape := x.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -128,6 +140,9 @@ func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, return reuse, nil } +// LogSoftMaxB computes the gradient of the input `x`, given the `output = LogSoftmax(x)` and its associated gradient. +// Currently it expects the tensor to be a Dense tensor. +// Please make a pull request to support sparse tensors. func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { if !output.Shape().Eq(grad.Shape()) { return nil, fmt.Errorf("output and grad shapes don't match") @@ -138,7 +153,7 @@ func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (ret } axis = resolveAxis(axis, output.Dims()) - expectedShape := output.Shape().Clone() + expectedShape := output.Shape() var reuse DenseTensor var safe, toReuse, _ bool @@ -171,8 +186,9 @@ func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (ret } func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax bool) { - outputArr := output.Data().([]float64) - xArr := x.Data().([]float64) + outputArr := getFloat64s(output) + xArr := getFloat64s(x) + xShape := x.Shape() outerSize := 1 @@ -181,51 +197,58 @@ func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax outerSize *= xShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { - maxInput := xArr[0] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - if xArr[i] > maxInput { - maxInput = xArr[i] + if xArr[i] > maxInput { + maxInput = xArr[i] + } } - } - sumExp := float64(0.0) - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - z := xArr[i] - maxInput - exp := math.Exp(z) + sumExp := float64(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math.Exp(z) - if logSoftMax { - outputArr[i] = z - } else { - outputArr[i] = exp - } + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } - sumExp += exp - } + sumExp += exp + } - if !logSoftMax { - sumExp = 1 / sumExp - } + if !logSoftMax { + sumExp = 1 / sumExp + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - if logSoftMax { - outputArr[i] -= math.Log(sumExp) - } else { - outputArr[i] *= sumExp + if logSoftMax { + outputArr[i] -= math.Log(sumExp) + } else { + outputArr[i] *= sumExp + } } - } + wg.Done() + }(ii, &wg) + } + wg.Wait() } func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { - dx := inputGrad.Data().([]float64) - outputArr := output.Data().([]float64) - gradArr := grad.Data().([]float64) + dx := getFloat64s(inputGrad) + outputArr := getFloat64s(output) + gradArr := getFloat64s(grad) outputShape := output.Shape() @@ -235,41 +258,51 @@ func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, log outerSize *= outputShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { + wg.Add(1) if logSoftMax { - sum := gradArr[ii*dimSize] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + go func(gradArr, dx []float64, ii int, wg *sync.WaitGroup) { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - sum += gradArr[i] - } + sum += gradArr[i] + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) - } - } else { - mul := make([]float64, dimSize) + dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) + } + wg.Done() + }(gradArr, dx, ii, &wg) - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - - mul[j] = outputArr[i] * gradArr[i] - } - - sum := mul[0] - for j := 1; j < dimSize; j++ { - sum += mul[j] - } - - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - - dx[i] = (gradArr[i] - sum) * outputArr[i] - } + } else { + go func(outputArr, gradArr, dx []float64, ii int, wg *sync.WaitGroup) { + //mul := make([]float64, dimSize) + var sum float64 + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] + } + + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + wg.Done() + }(outputArr, gradArr, dx, ii, &wg) } } + wg.Wait() } func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax bool) { @@ -288,53 +321,59 @@ func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax dimStride := innerSize outerStride := dimSize * dimStride - outputArr := output.Data().([]float64) - xArr := x.Data().([]float64) + outputArr := getFloat64s(output) + xArr := getFloat64s(x) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - inputPart := xArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - maxInput := inputPart[0] - for j := 1; j < dimSize; j++ { - i := j * dimStride + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride - if inputPart[i] > maxInput { - maxInput = inputPart[i] + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } } - } - sumExp := 0.0 - for j := 0; j < dimSize; j++ { - i := j * dimStride - - exp := math.Exp(inputPart[i] - maxInput) + sumExp := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride - if !logSoftmax { - outputPart[i] = exp - } + exp := math.Exp(inputPart[i] - maxInput) - sumExp += exp - } - - if logSoftmax { - sumExp = math.Log(sumExp) - } else { - sumExp = 1 / sumExp - } + if !logSoftmax { + outputPart[i] = exp + } - for j := 0; j < dimSize; j++ { - i := j * dimStride + sumExp += exp + } if logSoftmax { - outputPart[i] = inputPart[i] - maxInput - sumExp + sumExp = math.Log(sumExp) } else { - outputPart[i] *= sumExp + sumExp = 1 / sumExp } - } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) } + wg.Wait() } func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { @@ -353,43 +392,50 @@ func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, lo dimStride := innerSize outerStride := dimSize * dimStride - dxArr := inputGrad.Data().([]float64) - outputArr := output.Data().([]float64) - gradArr := grad.Data().([]float64) + dxArr := getFloat64s(inputGrad) + outputArr := getFloat64s(output) + gradArr := getFloat64s(grad) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - gradPart := gradArr[outerIndex*outerStride+innerIndex:] - dxPart := dxArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - sum := 0.0 - for j := 0; j < dimSize; j++ { - i := j * dimStride + sum := 0.0 + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - sum += gradPart[i] - } else { - sum += gradPart[i] * outputPart[i] + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } } - } - for j := 0; j < dimSize; j++ { - i := j * dimStride + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum - } else { - dxPart[i] = outputPart[i] * (gradPart[i] - sum) + if logSoftmax { + dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } } - } + wg.Done() + }(ii, &wg) + } + wg.Wait() } func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) { - outputArr := output.Data().([]float32) - xArr := x.Data().([]float32) + outputArr := getFloat32s(output) + xArr := getFloat32s(x) xShape := x.Shape() outerSize := 1 @@ -398,51 +444,57 @@ func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax outerSize *= xShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { - maxInput := xArr[0] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + wg.Add(1) + go func(ii int, wg *sync.WaitGroup) { + maxInput := xArr[0] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - if xArr[i] > maxInput { - maxInput = xArr[i] + if xArr[i] > maxInput { + maxInput = xArr[i] + } } - } - sumExp := float32(0.0) - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - z := xArr[i] - maxInput - exp := math32.Exp(z) + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + z := xArr[i] - maxInput + exp := math32.Exp(z) - if logSoftMax { - outputArr[i] = z - } else { - outputArr[i] = exp - } + if logSoftMax { + outputArr[i] = z + } else { + outputArr[i] = exp + } - sumExp += exp - } + sumExp += exp + } - if !logSoftMax { - sumExp = 1 / sumExp - } + if !logSoftMax { + sumExp = 1 / sumExp + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - if logSoftMax { - outputArr[i] -= math32.Log(sumExp) - } else { - outputArr[i] *= sumExp + if logSoftMax { + outputArr[i] -= math32.Log(sumExp) + } else { + outputArr[i] *= sumExp + } } - } + wg.Done() + }(ii, &wg) } + wg.Wait() } func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { - dx := inputGrad.Data().([]float32) - outputArr := output.Data().([]float32) - gradArr := grad.Data().([]float32) + dx := getFloat32s(inputGrad) + outputArr := getFloat32s(output) + gradArr := getFloat32s(grad) outputShape := output.Shape() @@ -452,41 +504,52 @@ func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, log outerSize *= outputShape[i] } + var wg sync.WaitGroup for ii := 0; ii < outerSize; ii++ { + wg.Add(1) + if logSoftMax { - sum := gradArr[ii*dimSize] - for j := 1; j < dimSize; j++ { - i := ii*dimSize + j + go func(ii int, wg *sync.WaitGroup) { + sum := gradArr[ii*dimSize] + for j := 1; j < dimSize; j++ { + i := ii*dimSize + j - sum += gradArr[i] - } + sum += gradArr[i] + } - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j - dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) - } + dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) + } + wg.Done() + }(ii, &wg) } else { - mul := make([]float32, dimSize) - - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - - mul[j] = outputArr[i] * gradArr[i] - } - - sum := mul[0] - for j := 1; j < dimSize; j++ { - sum += mul[j] - } - - for j := 0; j < dimSize; j++ { - i := ii*dimSize + j - - dx[i] = (gradArr[i] - sum) * outputArr[i] - } + go func(ii int, wg *sync.WaitGroup) { + //mul := make([]float32, dimSize) + var sum float32 + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + //mul[j] = outputArr[i] * gradArr[i] + sum += outputArr[i] * gradArr[i] + } + + // sum := mul[0] + // for j := 1; j < dimSize; j++ { + // sum += mul[j] + // } + + for j := 0; j < dimSize; j++ { + i := ii*dimSize + j + + dx[i] = (gradArr[i] - sum) * outputArr[i] + } + wg.Done() + }(ii, &wg) } } + wg.Wait() } func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax bool) { @@ -505,53 +568,60 @@ func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax dimStride := innerSize outerStride := dimSize * dimStride - outputArr := output.Data().([]float32) - xArr := x.Data().([]float32) + outputArr := getFloat32s(output) + xArr := getFloat32s(x) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) - - inputPart := xArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + wg.Add(1) - maxInput := inputPart[0] - for j := 1; j < dimSize; j++ { - i := j * dimStride + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - if inputPart[i] > maxInput { - maxInput = inputPart[i] - } - } + inputPart := xArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - sumExp := float32(0.0) - for j := 0; j < dimSize; j++ { - i := j * dimStride - - exp := math32.Exp(inputPart[i] - maxInput) + maxInput := inputPart[0] + for j := 1; j < dimSize; j++ { + i := j * dimStride - if !logSoftmax { - outputPart[i] = exp + if inputPart[i] > maxInput { + maxInput = inputPart[i] + } } - sumExp += exp - } + sumExp := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - sumExp = math32.Log(sumExp) - } else { - sumExp = 1 / sumExp - } + exp := math32.Exp(inputPart[i] - maxInput) + + if !logSoftmax { + outputPart[i] = exp + } - for j := 0; j < dimSize; j++ { - i := j * dimStride + sumExp += exp + } if logSoftmax { - outputPart[i] = inputPart[i] - maxInput - sumExp + sumExp = math32.Log(sumExp) } else { - outputPart[i] *= sumExp + sumExp = 1 / sumExp } - } + + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + outputPart[i] = inputPart[i] - maxInput - sumExp + } else { + outputPart[i] *= sumExp + } + } + wg.Done() + }(ii, &wg) } + wg.Wait() } func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { @@ -570,36 +640,43 @@ func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, lo dimStride := innerSize outerStride := dimSize * dimStride - dxArr := inputGrad.Data().([]float32) - outputArr := output.Data().([]float32) - gradArr := grad.Data().([]float32) + dxArr := getFloat32s(inputGrad) + outputArr := getFloat32s(output) + gradArr := getFloat32s(grad) + var wg sync.WaitGroup for ii := 0; ii < innerSize*outerSize; ii++ { - outerIndex, innerIndex := divmod(ii, innerSize) + wg.Add(1) - gradPart := gradArr[outerIndex*outerStride+innerIndex:] - dxPart := dxArr[outerIndex*outerStride+innerIndex:] - outputPart := outputArr[outerIndex*outerStride+innerIndex:] + go func(ii int, wg *sync.WaitGroup) { + outerIndex, innerIndex := divmod(ii, innerSize) - sum := float32(0.0) - for j := 0; j < dimSize; j++ { - i := j * dimStride + gradPart := gradArr[outerIndex*outerStride+innerIndex:] + dxPart := dxArr[outerIndex*outerStride+innerIndex:] + outputPart := outputArr[outerIndex*outerStride+innerIndex:] - if logSoftmax { - sum += gradPart[i] - } else { - sum += gradPart[i] * outputPart[i] + sum := float32(0.0) + for j := 0; j < dimSize; j++ { + i := j * dimStride + + if logSoftmax { + sum += gradPart[i] + } else { + sum += gradPart[i] * outputPart[i] + } } - } - for j := 0; j < dimSize; j++ { - i := j * dimStride + for j := 0; j < dimSize; j++ { + i := j * dimStride - if logSoftmax { - dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum - } else { - dxPart[i] = outputPart[i] * (gradPart[i] - sum) + if logSoftmax { + dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum + } else { + dxPart[i] = outputPart[i] * (gradPart[i] - sum) + } } - } + wg.Done() + }(ii, &wg) } + wg.Wait() } diff --git a/dense_matop_test.go b/dense_matop_test.go index 48c854f..652c71d 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -703,17 +703,19 @@ func TestDense_Narrow(t *testing.T) { for i, tC := range testCases { t.Run(fmt.Sprintf("Example #%d narrow(%v,%d,%d,%v)", i+1, tC.x.Shape(), tC.dim, tC.start, tC.length), func(t *testing.T) { c := assert.New(t) + // t.Logf("X:\n%v", tC.x) y, err := tC.x.Narrow(tC.dim, tC.start, tC.length) c.NoError(err) + // t.Logf("y:\n%v", y) yMat := y.Materialize() c.Equal(tC.expected.Shape(), yMat.Shape()) c.Equal(tC.expected.Data(), yMat.Data()) - - // err = y.Memset(1024) - // c.Nil(err) - // t.Logf("example %d y \n%v\n%v", i+1, y, y.Data()) + + // err = y.Memset(1024) + // c.NoError(err) + // t.Logf("After Memset\nY: %v\nX:\n%v", y, tC.x) }) } } diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go index 86369be..e542133 100644 --- a/dense_selbyidx_test.go +++ b/dense_selbyidx_test.go @@ -19,6 +19,9 @@ type selByIndicesTest struct { } var selByIndicesTests = []selByIndicesTest{ + {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, + Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, + }, {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, @@ -35,7 +38,7 @@ var selByIndicesTests = []selByIndicesTest{ Correct: []int{1, 1}, CorrectShape: Shape{2}}, {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, - {Name: "(2,1) Matrx (colvec)m with (10) indies", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, }, } @@ -60,38 +63,47 @@ var selByIndicesBTests = []struct { CorrectGrad interface{} CorrectGradShape Shape }{ + // Basic + { + CorrectGrad: []float64{1, 1, 1, 1}, + }, + // 3-tensor, axis 0 + { + CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + // 3-tensor, axis 1 { - selByIndicesTest: selByIndicesTests[0], - CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 16, 18, 20, 22, 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0}, - CorrectGradShape: Shape{3, 2, 4}, + CorrectGrad: []float64{0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2}, }, + // 3-tensor, axis 2 { - selByIndicesTest: selByIndicesTests[1], - CorrectGrad: []float64{0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 0, 24, 26, 28, 30, 0, 0, 0, 0, 40, 42, 44, 46}, - CorrectGradShape: Shape{3, 2, 4}, + CorrectGrad: []float64{0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0}, }, + // vector, axis 0 { - selByIndicesTest: selByIndicesTests[2], - CorrectGrad: []float64{0, 2, 0, 0, 0, 10, 0, 0, 0, 18, 0, 0, 0, 26, 0, 0, 0, 34, 0, 0, 0, 42, 0, 0}, - CorrectGradShape: Shape{3, 2, 4}, + CorrectGrad: []int{0, 2, 0, 0, 0}, }, + // vector, axis 1 { - selByIndicesTest: selByIndicesTests[3], - CorrectGrad: []int{0, 2, 0, 0, 0}, - CorrectGradShape: Shape{5}, + CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, }, + // (4,2) Matrix with (10) indices { - selByIndicesTest: selByIndicesTests[5], - CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, - CorrectGradShape: Shape{4, 2}, + CorrectGrad: []float32{2, 2, 4, 4, 4, 4, 0, 0}, }, + // (2, 1) Matrix (colvec) with (10) indices { - selByIndicesTest: selByIndicesTests[6], - CorrectGrad: []float64{0, 10}, - CorrectGradShape: Shape{2, 1}, + CorrectGrad: []float64{0, 10}, }, } +func init() { + for i := range selByIndicesBTests { + selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] + selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape + } +} + func TestDense_SelectByIndicesB(t *testing.T) { assert := assert.New(t) @@ -102,12 +114,20 @@ func TestDense_SelectByIndicesB(t *testing.T) { if checkErr(t, tc.WillErr, err, tc.Name, i) { continue } - grad, err := ByIndicesB(T, ret, indices, tc.Axis) + outGrad := ret.Clone().(*Dense) + switch outGrad.Dtype() { + case Float64: + outGrad.Memset(1.0) + case Float32: + outGrad.Memset(float32(1.0)) + } + + grad, err := ByIndicesB(T, outGrad, indices, tc.Axis) if checkErr(t, tc.WillErr, err, tc.Name, i) { continue } - assert.Equal(tc.CorrectGrad, grad.Data(), "%v", tc.Name) - assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead", tc.Name, tc.CorrectGradShape, grad.Shape()) + assert.Equal(tc.CorrectGrad, grad.Data(), "%v - x:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, T, indices, ret, grad) + assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead.\n\nx:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, tc.CorrectGradShape, grad.Shape(), T, indices, ret, grad) } } diff --git a/engine.go b/engine.go index 5730c60..39e3f04 100644 --- a/engine.go +++ b/engine.go @@ -410,7 +410,7 @@ type InfChecker interface { // ByIndiceser allows for values in tensor `a` to be selected by the indices listed in the `indices` tensor. type ByIndiceser interface { SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) - SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) + SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) } /* Internal interfaces for faster shit */ diff --git a/example_byindices_test.go b/example_byindices_test.go new file mode 100644 index 0000000..7f94781 --- /dev/null +++ b/example_byindices_test.go @@ -0,0 +1,74 @@ +package tensor + +import "fmt" + +func ExampleByIndices() { + a := New(WithShape(2, 2), WithBacking([]float64{ + 100, 200, + 300, 400, + })) + indices := New(WithBacking([]int{1, 1, 1, 0, 1})) + b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1 + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\n", a, indices, b) + + // Output: + // a: + // ⎡100 200⎤ + // ⎣300 400⎦ + // + // indices: [1 1 1 0 1] + // b: + // ⎡300 400⎤ + // ⎢300 400⎥ + // ⎢300 400⎥ + // ⎢100 200⎥ + // ⎣300 400⎦ + +} + +func ExampleByIndicesB() { + a := New(WithShape(2, 2), WithBacking([]float64{ + 100, 200, + 300, 400, + })) + indices := New(WithBacking([]int{1, 1, 1, 0, 1})) + b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1 + if err != nil { + fmt.Println(err) + return + } + + outGrad := b.Clone().(*Dense) + outGrad.Memset(1.0) + + grad, err := ByIndicesB(a, outGrad, indices, 0) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\ngrad:\n%v", a, indices, b, grad) + + // Output: + // a: + // ⎡100 200⎤ + // ⎣300 400⎦ + // + // indices: [1 1 1 0 1] + // b: + // ⎡300 400⎤ + // ⎢300 400⎥ + // ⎢300 400⎥ + // ⎢100 200⎥ + // ⎣300 400⎦ + // + // grad: + // ⎡1 1⎤ + // ⎣4 4⎦ + +} diff --git a/interfaces.go b/interfaces.go index c0fd7e3..e33502f 100644 --- a/interfaces.go +++ b/interfaces.go @@ -144,6 +144,7 @@ type unsafeMem interface { Set(i int, x interface{}) GetF64(i int) float64 GetF32(i int) float32 + Ints() []int Float64s() []float64 Float32s() []float32 Complex64s() []complex64 diff --git a/utils.go b/utils.go index 064c812..2b3aa65 100644 --- a/utils.go +++ b/utils.go @@ -300,6 +300,27 @@ func allones(a []int) bool { return true } +func getFloat64s(a Tensor) []float64 { + if um, ok := a.(unsafeMem); ok { + return um.Float64s() + } + return a.Data().([]float64) +} + +func getFloat32s(a Tensor) []float32 { + if um, ok := a.(unsafeMem); ok { + return um.Float32s() + } + return a.Data().([]float32) +} + +func getInts(a Tensor) []int { + if um, ok := a.(unsafeMem); ok { + return um.Ints() + } + return a.Data().([]int) +} + /* FOR ILLUSTRATIVE PURPOSES */ // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) From 160896e859d27891c5cd86e91a20862824773731 Mon Sep 17 00:00:00 2001 From: chewxy Date: Mon, 21 Mar 2022 15:40:44 +1100 Subject: [PATCH 146/154] Made the Slice type = shapes.Slice --- slice.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/slice.go b/slice.go index f45e9fc..55dc62c 100644 --- a/slice.go +++ b/slice.go @@ -7,12 +7,7 @@ import ( var xxx Slice = ss(1) var _ shapes.Slice = xxx -// A Slice represents a slicing operation for a Tensor. -type Slice interface { - Start() int - End() int - Step() int -} +type Slice = shapes.Slice type rs struct { start, end, step int From e3db7f8e5291aebdb759e5b502d39162670b7f01 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Apr 2022 14:12:11 +1000 Subject: [PATCH 147/154] Bump github.com/gogo/protobuf from 1.3.1 to 1.3.2 (#125) Bumps [github.com/gogo/protobuf](https://github.com/gogo/protobuf) from 1.3.1 to 1.3.2. - [Release notes](https://github.com/gogo/protobuf/releases) - [Commits](https://github.com/gogo/protobuf/compare/v1.3.1...v1.3.2) --- updated-dependencies: - dependency-name: github.com/gogo/protobuf dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index e488a8d..1666dd0 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc github.com/chewxy/hm v1.0.0 github.com/chewxy/math32 v1.0.8 - github.com/gogo/protobuf v1.3.1 + github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.4.3 github.com/google/flatbuffers v1.12.0 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 8d91866..c350f4e 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= -github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -45,7 +45,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= -github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -59,9 +59,12 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -72,36 +75,49 @@ golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86h golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200909081042-eff7692f9009 h1:W0lCpv29Hv0UaM1LXb9QlBHLNP8UFfcKjblhVCWftOM= golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= From c20848f9fe0ce1d1ceff12e529d5d8a10c8e91f7 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Mon, 4 Apr 2022 14:46:19 +1000 Subject: [PATCH 148/154] Fixed a major bug when reusing an operand (#126) * Fixed the main bug where a reuse tensor is the same as one of the operands. This was fixed by means of generating new internal generic functions * Added more tests, which found more issues which is now fixed * Because Pow is such a PITA, I'm removing it from the generated tests --- api_arith_generated_test.go | 10 +- defaultengine_arith.go | 18 +- dense_arith_test.go | 426 ++++++++++++++- errors.go | 2 +- example_dense_arith_test.go | 131 +++++ genlib2/agg1_body.go | 29 +- genlib2/agg2_body.go | 3 +- genlib2/agg3_body.go | 57 +- genlib2/arith_tests.go | 41 +- genlib2/engine.go | 2 + genlib2/generic_arith.go | 24 +- genlib2/internaleng.go | 27 +- internal/execution/eng_arith.go | 552 ++++++++++++++++++- internal/execution/generic_arith_vv.go | 714 +++++++++++++++++++++++++ internal/storage/header.go | 6 + 15 files changed, 1991 insertions(+), 51 deletions(-) diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index 1120fba..ce08af9 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -686,7 +686,7 @@ func TestAddScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Add (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Add (tensor as right, scalar as left) failed: %v", err) } } func TestSubScalar(t *testing.T) { @@ -766,7 +766,7 @@ func TestSubScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Sub (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Sub (tensor as right, scalar as left) failed: %v", err) } } func TestMulScalar(t *testing.T) { @@ -846,7 +846,7 @@ func TestMulScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Mul (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Mul (tensor as right, scalar as left) failed: %v", err) } } func TestDivScalar(t *testing.T) { @@ -901,7 +901,7 @@ func TestDivScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Div (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Div (tensor as right, scalar as left) failed: %v", err) } } func TestPowScalar(t *testing.T) { @@ -956,7 +956,7 @@ func TestPowScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Pow (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Pow (tensor as right, scalar as left) failed: %v", err) } } func TestAddScalar_unsafe(t *testing.T) { diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 918e1ca..131ea33 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -55,8 +55,7 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.AddIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Add(typ, dataReuse, dataB) + err = e.E.AddRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Add(typ, dataA, dataB) @@ -120,8 +119,7 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.SubIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Sub(typ, dataReuse, dataB) + err = e.E.SubRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Sub(typ, dataA, dataB) @@ -185,8 +183,7 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.MulIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Mul(typ, dataReuse, dataB) + err = e.E.MulRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Mul(typ, dataA, dataB) @@ -250,8 +247,7 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.DivIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Div(typ, dataReuse, dataB) + err = e.E.DivRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Div(typ, dataA, dataB) @@ -315,8 +311,7 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.PowIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Pow(typ, dataReuse, dataB) + err = e.E.PowRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Pow(typ, dataA, dataB) @@ -380,8 +375,7 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err err = e.E.ModIncr(typ, dataA, dataB, dataReuse) retVal = reuse case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Mod(typ, dataReuse, dataB) + err = e.E.ModRecv(typ, dataA, dataB, dataReuse) retVal = reuse case !safe: err = e.E.Mod(typ, dataA, dataB) diff --git a/dense_arith_test.go b/dense_arith_test.go index d414dd2..8d791fd 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -329,6 +329,58 @@ func TestDense_Add_reuse(t *testing.T) { t.Errorf("Identity test for Add failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Add(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Adder) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Add(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Add(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Add failed: %v", err) + } + } func TestDense_Sub_reuse(t *testing.T) { inv := func(a *Dense) bool { @@ -361,6 +413,58 @@ func TestDense_Sub_reuse(t *testing.T) { if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Sub(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Suber) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Sub(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Sub(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Sub failed: %v", err) + } + } func TestDense_Mul_reuse(t *testing.T) { iden := func(a *Dense) bool { @@ -394,6 +498,58 @@ func TestDense_Mul_reuse(t *testing.T) { t.Errorf("Identity test for Mul failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Mul(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Muler) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Mul(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Mul(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Mul failed: %v", err) + } + } func TestDense_Div_reuse(t *testing.T) { inv := func(a *Dense) bool { @@ -427,6 +583,58 @@ func TestDense_Div_reuse(t *testing.T) { if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Div failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Div(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Diver) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Div(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Div(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Div failed: %v", err) + } + } func TestDense_Pow_reuse(t *testing.T) { iden := func(a *Dense) bool { @@ -686,7 +894,7 @@ func TestDense_AddScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Add (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Add (tensor as right, scalar as left) failed: %v", err) } } func TestDense_SubScalar(t *testing.T) { @@ -766,7 +974,7 @@ func TestDense_SubScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Sub (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Sub (tensor as right, scalar as left) failed: %v", err) } } func TestDense_MulScalar(t *testing.T) { @@ -846,7 +1054,7 @@ func TestDense_MulScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Mul (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Mul (tensor as right, scalar as left) failed: %v", err) } } func TestDense_DivScalar(t *testing.T) { @@ -901,7 +1109,7 @@ func TestDense_DivScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Div (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Div (tensor as right, scalar as left) failed: %v", err) } } func TestDense_PowScalar(t *testing.T) { @@ -956,7 +1164,7 @@ func TestDense_PowScalar(t *testing.T) { return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for Pow (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for Pow (tensor as right, scalar as left) failed: %v", err) } } func TestDense_AddScalar_unsafe(t *testing.T) { @@ -1284,6 +1492,58 @@ func TestDense_AddScalar_reuse(t *testing.T) { t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Add(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Adder) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Add(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Add(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Add failed: %v", err) + } + } func TestDense_SubScalar_reuse(t *testing.T) { inv1 := func(q *Dense) bool { @@ -1350,6 +1610,58 @@ func TestDense_SubScalar_reuse(t *testing.T) { if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Sub(b) + we, willFailEq := willerr(a, numberTypes, unsignedTypes) + _, ok := a.Engine().(Suber) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Sub(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Sub(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Sub failed: %v", err) + } + } func TestDense_MulScalar_reuse(t *testing.T) { iden1 := func(q *Dense) bool { @@ -1416,6 +1728,58 @@ func TestDense_MulScalar_reuse(t *testing.T) { t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Mul(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Muler) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Mul(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Mul(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Mul failed: %v", err) + } + } func TestDense_DivScalar_reuse(t *testing.T) { inv1 := func(q *Dense) bool { @@ -1451,6 +1815,58 @@ func TestDense_DivScalar_reuse(t *testing.T) { t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) } + mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype() { + return true + } + if !a.Shape().Eq(b.Shape()) { + return true + } + + correct, err := a.Div(b) + we, willFailEq := willerr(a, numberTypes, nil) + _, ok := a.Engine().(Diver) + we = we || !ok + + var ret, reuse *Dense + if reuseA { + ret, err = a.Div(b, WithReuse(a)) + reuse = a + } else { + ret, err = a.Div(b, WithReuse(b)) + reuse = b + } + + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + if reuse != ret { + t.Errorf("Expected reuse to be the same as retVal") + return false + } + + return true + } + if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for Div failed: %v", err) + } + } func TestDense_PowScalar_reuse(t *testing.T) { iden1 := func(q *Dense) bool { diff --git a/errors.go b/errors.go index cd6a297..314c91c 100644 --- a/errors.go +++ b/errors.go @@ -44,7 +44,7 @@ const ( sliceIndexOOB = "Slice index out of bounds: Start: %d, End: %d. Length: %d" broadcastError = "Cannot broadcast together. Resulting shape will be at least (%d, 1). Repeats is (%d, 1)" lenMismatch = "Cannot compare with differing lengths: %d and %d" - typeMismatch = "TypeMismatch: a %T and b %T" + typeMismatch = "TypeMismatch: a %v and b %v" typeclassMismatch = "Typeclass mismatch on %v" shapeMismatch = "Shape mismatch. Expected %v. Got %v" sizeMismatch = "Size Mismatch. %d and %d" diff --git a/example_dense_arith_test.go b/example_dense_arith_test.go index 1ea0c1d..a78fd21 100644 --- a/example_dense_arith_test.go +++ b/example_dense_arith_test.go @@ -121,6 +121,40 @@ func ExampleDense_Add_reuse() { // T3: // ⎡10 12⎤ // ⎣15 17⎦ + +} + +// An optional reuse tensor can also be specified with the WithReuse function option. Passing in an operand would not cause a problem. +func ExampleDense_Add_reuse_operand() { + var T1, T2, T3 *Dense + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Add(T2, WithReuse(T1)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T1: %t\nT3:\n%v\n", T3 == T1, T3) + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Add(T2, WithReuse(T2)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T2: %t\nT3:\n%v\n", T3 == T2, T3) + + // Output: + // Reuse tensor passed in + // ====================== + // T3 == T1: true + // T3: + // ⎡10 12 14⎤ + // ⎢16 18 20⎥ + // ⎣22 24 26⎦ + // + // Reuse tensor passed in + // ====================== + // T3 == T2: true + // T3: + // ⎡10 12 14⎤ + // ⎢16 18 20⎥ + // ⎣22 24 26⎦ + } // Incrementing a tensor is also a function option provided by the package @@ -285,6 +319,38 @@ func ExampleDense_Sub_reuse() { // ⎣ -9 -9⎦ } +// An optional reuse tensor can also be specified with the WithReuse function option. Passing in an operand would not cause a problem. +func ExampleDense_Sub_reuse_operand() { + var T1, T2, T3 *Dense + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Sub(T2, WithReuse(T1)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T1: %t\nT3:\n%v\n", T3 == T1, T3) + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Sub(T2, WithReuse(T2)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T2: %t\nT3:\n%v\n", T3 == T2, T3) + + // Output: + // Reuse tensor passed in + // ====================== + // T3 == T1: true + // T3: + // ⎡-10 -10 -10⎤ + // ⎢-10 -10 -10⎥ + // ⎣-10 -10 -10⎦ + // + // Reuse tensor passed in + // ====================== + // T3 == T2: true + // T3: + // ⎡-10 -10 -10⎤ + // ⎢-10 -10 -10⎥ + // ⎣-10 -10 -10⎦ +} + // Incrementing a tensor is also a function option provided by the package func ExampleDense_Sub_incr() { var T1, T2, T3, Incr, V *Dense @@ -447,6 +513,39 @@ func ExampleDense_Mul_reuse() { // ⎣36 52⎦ } +// An optional reuse tensor can also be specified with the WithReuse function option. Passing in an operand would not cause a problem. +func ExampleDense_Mul_reuse_operand() { + var T1, T2, T3 *Dense + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Mul(T2, WithReuse(T1)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T1: %t\nT3:\n%v\n", T3 == T1, T3) + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Mul(T2, WithReuse(T2)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T2: %t\nT3:\n%v\n", T3 == T2, T3) + + // Output: + // Reuse tensor passed in + // ====================== + // T3 == T1: true + // T3: + // ⎡ 0 11 24⎤ + // ⎢ 39 56 75⎥ + // ⎣ 96 119 144⎦ + // + // Reuse tensor passed in + // ====================== + // T3 == T2: true + // T3: + // ⎡ 0 11 24⎤ + // ⎢ 39 56 75⎥ + // ⎣ 96 119 144⎦ + +} + // Incrementing a tensor is also a function option provided by the package func ExampleDense_Mul_incr() { var T1, T2, T3, Incr, V *Dense @@ -609,6 +708,38 @@ func ExampleDense_Div_reuse() { // ⎣ 0.2 0.3⎦ } +// An optional reuse tensor can also be specified with the WithReuse function option. Passing in an operand would not cause a problem. +func ExampleDense_Div_reuse_operand() { + var T1, T2, T3 *Dense + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Div(T2, WithReuse(T1)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T1: %t\nT3:\n%1.1v\n", T3 == T1, T3) + + T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) + T2 = New(WithBacking(Range(Float64, 10, 19)), WithShape(3, 3)) + T3, _ = T1.Div(T2, WithReuse(T2)) + fmt.Printf("Reuse tensor passed in\n======================\nT3 == T2: %t\nT3:\n%1.1v\n", T3 == T2, T3) + + // Output: + // Reuse tensor passed in + // ====================== + // T3 == T1: true + // T3: + // ⎡ 0 0.09 0.2⎤ + // ⎢ 0.2 0.3 0.3⎥ + // ⎣ 0.4 0.4 0.4⎦ + // + // Reuse tensor passed in + // ====================== + // T3 == T2: true + // T3: + // ⎡ 0 0.09 0.2⎤ + // ⎢ 0.2 0.3 0.3⎥ + // ⎣ 0.4 0.4 0.4⎦ +} + // Incrementing a tensor is also a function option provided by the package func ExampleDense_Div_incr() { var T1, T2, T3, Incr, V *Dense diff --git a/genlib2/agg1_body.go b/genlib2/agg1_body.go index 838eb20..85580d0 100644 --- a/genlib2/agg1_body.go +++ b/genlib2/agg1_body.go @@ -87,7 +87,7 @@ const ( return {{end -}} default: - return errors.Errorf("Unsupported type %v for {{$name}}", t) + return errors.Errorf("Unsupported type %v for {{$name}}Iter", t) } ` @@ -122,7 +122,30 @@ const ( } {{end -}} default: - return errors.Errorf("Unsupported type %v for {{$name}}", t) + return errors.Errorf("Unsupported type %v for {{$name}}IterIncr", t) + } + ` + + eArithRecvRaw = `as :=isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + {{$name := .Name}} + switch t{ + {{range .Kinds -}} + case {{reflectKind .}}: + at := a.{{sliceOf .}} + bt := b.{{sliceOf .}} + rt := recv.{{sliceOf .}} + {{$name}}Recv{{short .}}(at, bt, rt) + return + {{end -}} + default: + return errors.Errorf("Unsupported type %v for {{$name}}Recv", t) } ` @@ -641,6 +664,7 @@ var ( eArithIncr *template.Template eArithIter *template.Template eArithIterIncr *template.Template + eArithRecv *template.Template eMap *template.Template eMapIter *template.Template @@ -674,6 +698,7 @@ func init() { eArithIncr = template.Must(template.New("eArithIncr").Funcs(funcs).Parse(eArithIncrRaw)) eArithIter = template.Must(template.New("eArithIter").Funcs(funcs).Parse(eArithIterRaw)) eArithIterIncr = template.Must(template.New("eArithIterIncr").Funcs(funcs).Parse(eArithIterIncrRaw)) + eArithRecv = template.Must(template.New("eArithRecv").Funcs(funcs).Parse(eArithRecvRaw)) eMap = template.Must(template.New("eMap").Funcs(funcs).Parse(eMapRaw)) eMapIter = template.Must(template.New("eMapIter").Funcs(funcs).Parse(eMapIterRaw)) diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index ff2fe3a..54dd1f2 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -153,8 +153,7 @@ const agg2BodyRaw = `if useIter { retVal = reuse {{if .VV -}} case toReuse: - storage.Copy(typ,dataReuse, dataA) - err = e.E.{{.Name}}(typ, dataReuse, dataB) + err = e.E.{{.Name}}Recv(typ, dataA, dataB, dataReuse) retVal = reuse {{else -}} case toReuse && leftTensor: diff --git a/genlib2/agg3_body.go b/genlib2/agg3_body.go index e7b6592..c780e90 100644 --- a/genlib2/agg3_body.go +++ b/genlib2/agg3_body.go @@ -270,10 +270,65 @@ wt2 := func(a *Dense) bool{ return true } if err := quick.Check(wt2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("WrongTYpe test for {{.Name}} (tensor as right, scalar as left) failed: %v", err) + t.Errorf("WrongType test for {{.Name}} (tensor as right, scalar as left) failed: %v", err) } ` +const denseArithReuseMutationTestRaw = `mut := func(a, b *Dense, reuseA bool) bool { + // req because we're only testing on one kind of tensor/engine combo + a.e = StdEng{} + a.oe = StdEng{} + a.flag = 0 + b.e = StdEng{} + b.oe = StdEng{} + b.flag = 0 + + if a.Dtype() != b.Dtype(){ + return true + } + if !a.Shape().Eq(b.Shape()){ + return true + } + + + + {{template "callVanilla" .}} + we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok + + + + var ret, reuse {{template "retType" .}} + if reuseA { + {{template "call0" .}}, WithReuse(a)) + reuse = a + } else { + {{template "call0" .}}, WithReuse(b)) + reuse = b + } + + + if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + + {{template "funcoptcheck" -}} + + return true +} +if err := quick.Check(mut, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Reuse Mutation test for {{.Name}} failed: %v", err) +} + +` + var ( denseArithBody *template.Template denseArithScalarBody *template.Template diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index 369cc0f..c65a97f 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -7,21 +7,28 @@ import ( ) const ( - APICallVVRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` - APICallVSRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` - APICallSVRaw = `ret, err := {{.Name}}(b, a {{template "funcoptuse"}})` + APICallVVxRaw = `correct, err := {{.Name}}(a, b)` // no funcopt + APICallVVReuseMutRaw = `ret, err = {{.Name}}(a, b` + APICallVVRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` + APICallVSRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` + APICallSVRaw = `ret, err := {{.Name}}(b, a {{template "funcoptuse"}})` APIInvVVRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())` APIInvVSRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())` APIInvSVRaw = `ret, err = {{.Name}}(b, ret, UseUnsafe())` - DenseMethodCallVVRaw = `ret, err := a.{{.Name}}(b {{template "funcoptuse"}})` - DenseMethodCallVSRaw = `ret, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse"}})` - DenseMethodCallSVRaw = `ret, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse"}})` + DenseMethodCallVVxRaw = `correct, err := a.{{.Name}}(b)` // no funcopt + DenseMethodCallVVReuseMutRaw = `ret, err = a.{{.Name}}(b` + DenseMethodCallVVRaw = `ret, err := a.{{.Name}}(b {{template "funcoptuse"}})` + DenseMethodCallVSRaw = `ret, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse"}})` + DenseMethodCallSVRaw = `ret, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse"}})` DenseMethodInvVVRaw = `ret, err = ret.{{.Inv}}(b, UseUnsafe())` DenseMethodInvVSRaw = `ret, err = ret.{{.Inv}}Scalar(b, true, UseUnsafe())` DenseMethodInvSVRaw = `ret, err = ret.{{.Name}}Scalar(b, false, UseUnsafe())` + + APIRetType = `Tensor` + DenseRetType = `*Dense` ) type ArithTest struct { @@ -65,6 +72,10 @@ func (fn *ArithTest) WriteBody(w io.Writer) { fn.writeInv(w) } fn.WriteScalarWrongType(w) + + if fn.FuncOpt == "reuse" && fn.arithOp.Name() != "Pow" { + fn.writeReuseMutate(w) + } } func (fn *ArithTest) canWrite() bool { @@ -143,6 +154,24 @@ func (fn *ArithTest) writeInv(w io.Writer) { t.Execute(w, fn) } + +func (fn *ArithTest) writeReuseMutate(w io.Writer) { + t := template.Must(template.New("Reuse mutation test").Funcs(funcs).Parse(denseArithReuseMutationTestRaw)) + switch fn.lvl { + case API: + return // tmp + case Dense: + template.Must(t.New("callVanilla").Parse(DenseMethodCallVVxRaw)) + template.Must(t.New("retType").Parse(DenseRetType)) + template.Must(t.New("call0").Parse(DenseMethodCallVVReuseMutRaw)) + + } + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + t.Execute(w, fn) +} + func (fn *ArithTest) WriteScalarWrongType(w io.Writer) { if !fn.scalars { return diff --git a/genlib2/engine.go b/genlib2/engine.go index 2ae5b18..f48a0eb 100644 --- a/genlib2/engine.go +++ b/genlib2/engine.go @@ -11,6 +11,7 @@ type EngineArith struct { VecVar string PrepData string TypeClassCheck string + IsCommutative bool VV bool LeftVec bool @@ -97,6 +98,7 @@ func generateStdEngArith(f io.Writer, ak Kinds) { Name: abo.Name(), VV: true, TypeClassCheck: "Number", + IsCommutative: abo.IsCommutative, } methods = append(methods, meth) } diff --git a/genlib2/generic_arith.go b/genlib2/generic_arith.go index fa6fa39..a0ab358 100644 --- a/genlib2/generic_arith.go +++ b/genlib2/generic_arith.go @@ -11,6 +11,7 @@ type GenericVecVecArith struct { TypedBinOp Iter bool Incr bool + WithRecv bool // not many BinOps have this Check TypeClass // can be nil CheckTemplate string } @@ -23,6 +24,8 @@ func (fn *GenericVecVecArith) Name() string { return fmt.Sprintf("%sIter", fn.TypedBinOp.Name()) case !fn.Iter && fn.Incr: return fmt.Sprintf("%sIncr", fn.TypedBinOp.Name()) + case fn.WithRecv: + return fmt.Sprintf("%vRecv", fn.TypedBinOp.Name()) default: return fmt.Sprintf("Vec%s", fn.TypedBinOp.Name()) } @@ -45,6 +48,9 @@ func (fn *GenericVecVecArith) Signature() *Signature { case !fn.Iter && fn.Incr: paramNames = []string{"a", "b", "incr"} paramTemplates = []*template.Template{sliceType, sliceType, sliceType} + case fn.WithRecv: + paramNames = []string{"a", "b", "recv"} + paramTemplates = []*template.Template{sliceType, sliceType, sliceType} default: paramNames = []string{"a", "b"} paramTemplates = []*template.Template{sliceType, sliceType} @@ -97,6 +103,11 @@ func (fn *GenericVecVecArith) WriteBody(w io.Writer) { Right = "b[i]" T = template.Must(T.Parse(genericLoopRaw)) template.Must(T.New("loopbody").Parse(basicIncr)) + case fn.WithRecv: + Range = "recv" + Right = "b[i]" + T = template.Must(T.Parse(genericLoopRaw)) + template.Must(T.New("loopbody").Parse(basicSet)) default: Right = "b[i]" T = template.Must(T.Parse(genericLoopRaw)) @@ -130,7 +141,7 @@ func (fn *GenericVecVecArith) WriteBody(w io.Writer) { func (fn *GenericVecVecArith) Write(w io.Writer) { sig := fn.Signature() - if !fn.Iter && isFloat(fn.Kind()) { + if !fn.Iter && isFloat(fn.Kind()) && !fn.WithRecv { // golinkPragma.Execute(w, fn) w.Write([]byte("func ")) sig.Write(w) @@ -148,7 +159,9 @@ func (fn *GenericVecVecArith) Write(w io.Writer) { switch { case !fn.Iter && fn.Incr: w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]; incr = incr[:len(a)]\n")) - case !fn.Iter && !fn.Incr: + case fn.WithRecv: + w.Write([]byte("{\na = a[:len(recv)]; b = b[:len(recv)]\n")) + case !fn.Iter && !fn.Incr && !fn.WithRecv: w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]\n")) default: w.Write([]byte("{\n")) @@ -390,6 +403,7 @@ func makeGenericVecVecAriths(tbo []TypedBinOp) (retVal []*GenericVecVecArith) { fn.Check = panicsDiv0 fn.CheckTemplate = check0 } + retVal = append(retVal, fn) } @@ -457,6 +471,12 @@ func generateGenericVecVecArith(f io.Writer, ak Kinds) { for _, g := range gen { g.Write(f) } + for _, g := range gen { + g.Incr = false + g.Iter = false + g.WithRecv = true + g.Write(f) + } } func generateGenericMixedArith(f io.Writer, ak Kinds) { diff --git a/genlib2/internaleng.go b/genlib2/internaleng.go index 6d07d32..48a3938 100644 --- a/genlib2/internaleng.go +++ b/genlib2/internaleng.go @@ -10,9 +10,10 @@ import ( type InternalEngArithMethod struct { BinOp - Kinds []reflect.Kind - Incr bool - Iter bool + Kinds []reflect.Kind + Incr bool + Iter bool + WithRecv bool } type eLoopBody struct { @@ -30,6 +31,8 @@ func (fn *InternalEngArithMethod) Name() string { return fmt.Sprintf("%sIncr", fn.BinOp.Name()) case !fn.Incr && fn.Iter: return fmt.Sprintf("%sIter", fn.BinOp.Name()) + case fn.WithRecv: + return fmt.Sprintf("%sRecv", fn.BinOp.Name()) default: return fn.BinOp.Name() } @@ -48,6 +51,9 @@ func (fn *InternalEngArithMethod) Signature() *Signature { case !fn.Iter && fn.Incr: paramNames = []string{"t", "a", "b", "incr"} paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} + case fn.WithRecv: + paramNames = []string{"t", "a", "b", "recv"} + paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} default: paramNames = []string{"t", "a", "b"} paramTemplates = []*template.Template{reflectType, arrayType, arrayType} @@ -72,6 +78,8 @@ func (fn *InternalEngArithMethod) WriteBody(w io.Writer) { T = eArithIncr case fn.Iter && !fn.Incr: T = eArithIter + case fn.WithRecv: + T = eArithRecv default: T = eArith } @@ -107,22 +115,35 @@ func generateEArith(f io.Writer, kinds Kinds) { methods = append(methods, meth) } + // write vanilla for _, meth := range methods { meth.Write(f) meth.Incr = true } + // write incr for _, meth := range methods { meth.Write(f) meth.Incr = false meth.Iter = true } + + // write iter for _, meth := range methods { meth.Write(f) meth.Incr = true } + // write iter incr + for _, meth := range methods { + meth.Write(f) + meth.Incr = false + meth.Iter = false + } + + // write recv for _, meth := range methods { + meth.WithRecv = true meth.Write(f) } } diff --git a/internal/execution/eng_arith.go b/internal/execution/eng_arith.go index f3de110..bc0af43 100644 --- a/internal/execution/eng_arith.go +++ b/internal/execution/eng_arith.go @@ -2851,7 +2851,7 @@ func (e E) AddIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Add", t) + return errors.Errorf("Unsupported type %v for AddIter", t) } } @@ -3057,7 +3057,7 @@ func (e E) SubIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Sub", t) + return errors.Errorf("Unsupported type %v for SubIter", t) } } @@ -3263,7 +3263,7 @@ func (e E) MulIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Mul", t) + return errors.Errorf("Unsupported type %v for MulIter", t) } } @@ -3469,7 +3469,7 @@ func (e E) DivIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Div", t) + return errors.Errorf("Unsupported type %v for DivIter", t) } } @@ -3535,7 +3535,7 @@ func (e E) PowIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Pow", t) + return errors.Errorf("Unsupported type %v for PowIter", t) } } @@ -3713,7 +3713,7 @@ func (e E) ModIter(t reflect.Type, a *storage.Header, b *storage.Header, ait Ite } return default: - return errors.Errorf("Unsupported type %v for Mod", t) + return errors.Errorf("Unsupported type %v for ModIter", t) } } @@ -4013,7 +4013,7 @@ func (e E) AddIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return AddIterIncrStr(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Add", t) + return errors.Errorf("Unsupported type %v for AddIterIncr", t) } } @@ -4294,7 +4294,7 @@ func (e E) SubIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return SubIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Sub", t) + return errors.Errorf("Unsupported type %v for SubIterIncr", t) } } @@ -4575,7 +4575,7 @@ func (e E) MulIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return MulIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Mul", t) + return errors.Errorf("Unsupported type %v for MulIterIncr", t) } } @@ -4856,7 +4856,7 @@ func (e E) DivIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return DivIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Div", t) + return errors.Errorf("Unsupported type %v for DivIterIncr", t) } } @@ -4947,7 +4947,7 @@ func (e E) PowIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return PowIterIncrC128(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Pow", t) + return errors.Errorf("Unsupported type %v for PowIterIncr", t) } } @@ -5190,6 +5190,534 @@ func (e E) ModIterIncr(t reflect.Type, a *storage.Header, b *storage.Header, inc return ModIterIncrF64(at, bt, it, ait, bit, iit) } default: - return errors.Errorf("Unsupported type %v for Mod", t) + return errors.Errorf("Unsupported type %v for ModIterIncr", t) + } +} + +func (e E) AddRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + AddRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + AddRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + AddRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + AddRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + AddRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + AddRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + AddRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + AddRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + AddRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + AddRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + AddRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + AddRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + AddRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + AddRecvC128(at, bt, rt) + return + case String: + at := a.Strings() + bt := b.Strings() + rt := recv.Strings() + AddRecvStr(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for AddRecv", t) + } +} + +func (e E) SubRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + SubRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + SubRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + SubRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + SubRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + SubRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + SubRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + SubRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + SubRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + SubRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + SubRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + SubRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + SubRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + SubRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + SubRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for SubRecv", t) + } +} + +func (e E) MulRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + MulRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + MulRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + MulRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + MulRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + MulRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + MulRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + MulRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + MulRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + MulRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + MulRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + MulRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + MulRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + MulRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + MulRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for MulRecv", t) + } +} + +func (e E) DivRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + DivRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + DivRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + DivRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + DivRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + DivRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + DivRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + DivRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + DivRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + DivRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + DivRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + DivRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + DivRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + DivRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + DivRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for DivRecv", t) + } +} + +func (e E) PowRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + PowRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + PowRecvF64(at, bt, rt) + return + case Complex64: + at := a.Complex64s() + bt := b.Complex64s() + rt := recv.Complex64s() + PowRecvC64(at, bt, rt) + return + case Complex128: + at := a.Complex128s() + bt := b.Complex128s() + rt := recv.Complex128s() + PowRecvC128(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for PowRecv", t) + } +} + +func (e E) ModRecv(t reflect.Type, a *storage.Header, b *storage.Header, recv *storage.Header) (err error) { + as := isScalar(a, t) + bs := isScalar(b, t) + rs := isScalar(recv, t) + + if ((as && !bs) || (bs && !as)) && rs { + return errors.Errorf("Cannot increment on a scalar increment. len(a): %d, len(b) %d", a.TypedLen(t), b.TypedLen(t)) + } + + switch t { + case Int: + at := a.Ints() + bt := b.Ints() + rt := recv.Ints() + ModRecvI(at, bt, rt) + return + case Int8: + at := a.Int8s() + bt := b.Int8s() + rt := recv.Int8s() + ModRecvI8(at, bt, rt) + return + case Int16: + at := a.Int16s() + bt := b.Int16s() + rt := recv.Int16s() + ModRecvI16(at, bt, rt) + return + case Int32: + at := a.Int32s() + bt := b.Int32s() + rt := recv.Int32s() + ModRecvI32(at, bt, rt) + return + case Int64: + at := a.Int64s() + bt := b.Int64s() + rt := recv.Int64s() + ModRecvI64(at, bt, rt) + return + case Uint: + at := a.Uints() + bt := b.Uints() + rt := recv.Uints() + ModRecvU(at, bt, rt) + return + case Uint8: + at := a.Uint8s() + bt := b.Uint8s() + rt := recv.Uint8s() + ModRecvU8(at, bt, rt) + return + case Uint16: + at := a.Uint16s() + bt := b.Uint16s() + rt := recv.Uint16s() + ModRecvU16(at, bt, rt) + return + case Uint32: + at := a.Uint32s() + bt := b.Uint32s() + rt := recv.Uint32s() + ModRecvU32(at, bt, rt) + return + case Uint64: + at := a.Uint64s() + bt := b.Uint64s() + rt := recv.Uint64s() + ModRecvU64(at, bt, rt) + return + case Float32: + at := a.Float32s() + bt := b.Float32s() + rt := recv.Float32s() + ModRecvF32(at, bt, rt) + return + case Float64: + at := a.Float64s() + bt := b.Float64s() + rt := recv.Float64s() + ModRecvF64(at, bt, rt) + return + default: + return errors.Errorf("Unsupported type %v for ModRecv", t) } } diff --git a/internal/execution/generic_arith_vv.go b/internal/execution/generic_arith_vv.go index 26f3772..e2a3c46 100644 --- a/internal/execution/generic_arith_vv.go +++ b/internal/execution/generic_arith_vv.go @@ -4637,3 +4637,717 @@ func ModIterIncrF64(a []float64, b []float64, incr []float64, ait Iterator, bit } return } + +func AddRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func AddRecvStr(a []string, b []string, recv []string) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] + b[i] + } +} + +func SubRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func SubRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] - b[i] + } +} + +func MulRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func MulRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] * b[i] + } +} + +func DivRecvI(a []int, b []int, recv []int) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI8(a []int8, b []int8, recv []int8) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI16(a []int16, b []int16, recv []int16) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI32(a []int32, b []int32, recv []int32) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvI64(a []int64, b []int64, recv []int64) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU(a []uint, b []uint, recv []uint) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU8(a []uint8, b []uint8, recv []uint8) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU16(a []uint16, b []uint16, recv []uint16) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU32(a []uint32, b []uint32, recv []uint32) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvU64(a []uint64, b []uint64, recv []uint64) (err error) { + a = a[:len(recv)] + b = b[:len(recv)] + var errs errorIndices + for i := range recv { + if b[i] == 0 { + errs = append(errs, i) + recv[i] = 0 + continue + } + recv[i] = a[i] / b[i] + } + if err != nil { + return + } + if len(errs) > 0 { + return errs + } + return nil +} + +func DivRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func DivRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func DivRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func DivRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] / b[i] + } +} + +func PowRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math32.Pow(a[i], b[i]) + } +} + +func PowRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math.Pow(a[i], b[i]) + } +} + +func PowRecvC64(a []complex64, b []complex64, recv []complex64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = complex64(cmplx.Pow(complex128(a[i]), complex128(b[i]))) + } +} + +func PowRecvC128(a []complex128, b []complex128, recv []complex128) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = cmplx.Pow(a[i], b[i]) + } +} + +func ModRecvI(a []int, b []int, recv []int) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI8(a []int8, b []int8, recv []int8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI16(a []int16, b []int16, recv []int16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI32(a []int32, b []int32, recv []int32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvI64(a []int64, b []int64, recv []int64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU(a []uint, b []uint, recv []uint) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU8(a []uint8, b []uint8, recv []uint8) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU16(a []uint16, b []uint16, recv []uint16) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU32(a []uint32, b []uint32, recv []uint32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvU64(a []uint64, b []uint64, recv []uint64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = a[i] % b[i] + } +} + +func ModRecvF32(a []float32, b []float32, recv []float32) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math32.Mod(a[i], b[i]) + } +} + +func ModRecvF64(a []float64, b []float64, recv []float64) { + a = a[:len(recv)] + b = b[:len(recv)] + for i := range recv { + recv[i] = math.Mod(a[i], b[i]) + } +} diff --git a/internal/storage/header.go b/internal/storage/header.go index 0e05a1d..93f67e7 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -35,6 +35,12 @@ func CopySliced(t reflect.Type, dst *Header, dstart, dend int, src *Header, ssta return copied / size } +func SwapCopy(a, b *Header) { + for i := range a.Raw { + a.Raw[i], b.Raw[i] = b.Raw[i], a.Raw[i] + } +} + func Fill(t reflect.Type, dst, src *Header) int { dstBA := dst.Raw srcBA := src.Raw From 5911a45530e9c4efeff43e74b885e4f4751d6377 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 4 May 2022 11:17:08 +1000 Subject: [PATCH 149/154] Fixed error messages to have a consistent error message for NYIs --- defaultengine_argmethods.go | 4 +-- defaultengine_linalg.go | 9 +++--- defaultengine_matop_misc.go | 8 ++--- dense.go | 2 +- dense_io.go | 2 +- dense_matop.go | 28 ++++++++++++++--- errors.go | 63 +++++++++++++++++++++++++++++++++++-- scalar.go | 30 +++++++----------- sparse.go | 6 ++-- 9 files changed, 110 insertions(+), 42 deletions(-) diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index f2bbf60..0bb1707 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -13,7 +13,7 @@ func (e StdEng) Argmax(ctx context.Context, t Tensor, axis int) (retVal Tensor, case DenseTensor: return e.argmaxDenseTensor(ctx, tt, axis) default: - return nil, errors.Errorf(typeNYI, "StdEng.Argmax", t) + return nil, nyierr(typeNYI, t) } } @@ -103,7 +103,7 @@ func (e StdEng) Argmin(ctx context.Context, t Tensor, axis int) (retVal Tensor, case DenseTensor: return e.argminDenseTensor(ctx, tt, axis) default: - return nil, errors.Errorf(typeNYI, "StdEng.Argmin", t) + return nil, nyierr(typeNYI, t) } } diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 1f7eaef..ca114fd 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -10,7 +10,7 @@ import ( "gorgonia.org/dtype" ) -// Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error +// Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error func (e StdEng) Trace(ctx context.Context, t Tensor) (retVal interface{}, err error) { if err := handleCtx(ctx); err != nil { return nil, err @@ -482,7 +482,7 @@ func (e StdEng) MatVecMul(ctx context.Context, a, b, prealloc Tensor) (err error var alpha, beta complex128 = complex(1, 0), complex(0, 0) whichblas.Zgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) default: - return errors.Errorf(typeNYI, "matVecMul", bd.Data()) + return nyierr(typeNYI, bd.Data()) } return nil @@ -598,7 +598,8 @@ func (e StdEng) MatMul(ctx context.Context, a, b, prealloc Tensor) (err error) { whichblas.Zgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) } default: - return errors.Errorf(typeNYI, "matMul", ad.Data()) + return nyierr(typeNYI, ad.Data()) + } return } @@ -674,7 +675,7 @@ func (e StdEng) Outer(ctx context.Context, a, b, prealloc Tensor) (err error) { var alpha complex128 = complex(1, 0) whichblas.Zgeru(m, n, alpha, x, incX, y, incY, A, lda) default: - return errors.Errorf(typeNYI, "outer", b.Data()) + return nyierr(typeNYI, b.Data()) } return nil } diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 303c5e2..56641d3 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -33,7 +33,7 @@ func (e StdEng) Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) default: - return nil, errors.Errorf("NYI") + return nil, nyierr(typeNYI, t) } } @@ -59,7 +59,7 @@ func (e StdEng) RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis in } return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) default: - return nil, errors.Errorf("NYI") + return nil, nyierr(typeNYI, t) } } @@ -255,7 +255,7 @@ func (e StdEng) Concat(ctx context.Context, t Tensor, axis int, others ...Tensor } return e.denseConcat(tt, axis, denses) default: - return nil, errors.Errorf("NYI") + return nil, nyierr(typeNYI, t) } } @@ -435,7 +435,7 @@ func (e StdEng) Diag(ctx context.Context, t Tensor) (retVal Tensor, err error) { bdata[i] = adata[i*stride] } default: - return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t) + return nil, nyierr(typeNYI, "Arbitrary-sized .Diag()", t) } return b, nil } diff --git a/dense.go b/dense.go index 39b0f90..1623eee 100644 --- a/dense.go +++ b/dense.go @@ -141,7 +141,7 @@ func (t *Dense) Reshape(dims ...int) error { } if t.viewOf != 0 && t.o.IsNotContiguous() { - return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") + return nyierr(methodNYI, "non-contiguous views") } if !t.old.IsZero() { diff --git a/dense_io.go b/dense_io.go index 374daf0..1ef0d91 100644 --- a/dense_io.go +++ b/dense_io.go @@ -651,7 +651,7 @@ func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{ backing = append(backing, record...) return backing, nil default: - return nil, errors.Errorf(methodNYI, "convFromStrs", to) + return nil, nyierr(typeNYI, to) } } diff --git a/dense_matop.go b/dense_matop.go index e48b5cd..38a32fb 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -197,7 +197,7 @@ func (t *Dense) CopyTo(other *Dense) error { } // TODO: use copyDenseIter - return errors.Errorf(methodNYI, "CopyTo", "views") + return nyierr(methodNYI, "views") } // Slice performs slicing on the *Dense Tensor. It returns a view which shares the same underlying memory as the original *Dense. @@ -236,12 +236,31 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view. // The underlying data is the same. // This method will override ALL the metadata in view. -func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) { +func (t *Dense) SliceInto(view Tensor, slices ...Slice) (retVal View, err error) { + switch view := view.(type) { + case DenseView: + v := view.Dense + if v, err = t.sliceIntoDense(v, slices...); err != nil { + return nil, err + } + return DenseView{v}, nil + + case *Dense: + if view, err = t.sliceIntoDense(view, slices...); err != nil { + return nil, err + } + return DenseView{view}, nil + default: + return nil, nyierr(typeNYI) + } +} + +func (t *Dense) sliceIntoDense(view *Dense, slices ...Slice) (retVal *Dense, err error) { var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { - return + return nil, err } view.AP.zero() @@ -257,8 +276,7 @@ func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) if t.IsMasked() { view.mask = t.mask[ndStart:ndEnd] } - - return DenseView{view}, err + return view, nil } // RollAxis rolls the axis backwards until it lies in the given position. diff --git a/errors.go b/errors.go index cd6a297..2ad399d 100644 --- a/errors.go +++ b/errors.go @@ -1,6 +1,11 @@ package tensor -import "fmt" +import ( + "fmt" + "runtime" + + "github.com/pkg/errors" +) // NoOpError is a useful for operations that have no op. type NoOpError interface { @@ -60,6 +65,58 @@ const ( maskRequired = "Masked array type required for %v" inaccessibleData = "Data in %p inaccessible" - methodNYI = "%q not yet implemented for %v" - typeNYI = "%q not yet implemented for interactions with %T" + // NYI errors + + methodNYI = "%q not yet implemented for %v." + typeNYI = "%q not yet implemented for interactions with %T." + typeNYI2 = "%q (%v) not yet implemented for interactions with %T." + prmsg = "Please make a pull request at github.com/gorgonia/tensor if you wish to contribute a solution" ) + +// nyierr is a convenience function that decorates a NYI error message with additional information. +// +// It assumes that `msg` is either `typeNYI` or `methodNYI`. +func nyierr(msg string, args ...interface{}) error { + var fnName string = "UNKNOWN FUNCTION" + pc, _, _, ok := runtime.Caller(1) + if ok { + fnName = runtime.FuncForPC(pc).Name() + } + + switch len(args) { + case 0: + // no args + case 1: + // the usual + switch msg { + case methodNYI: + // do nothing + case typeNYI: + // do nothing + case typeNYI2: + // this is the wrong message to use, so we revert to typeNYI. + msg = typeNYI + default: + // do nothing + } + case 2: + switch msg { + case methodNYI: + // do nothing + case typeNYI: + // we assume that args[0] is an additional descriptive string. + msg = typeNYI2 + case typeNYI2: + // do nothing + default: + // do nothing + } + default: + } + + // prepend fnName + args = append(args, fnName) + copy(args[1:], args[0:]) + args[0] = fnName + return errors.Errorf(msg, args...) +} diff --git a/scalar.go b/scalar.go index 3a344ae..ee37ba0 100644 --- a/scalar.go +++ b/scalar.go @@ -50,17 +50,14 @@ func (s Scalar) RequiresIterator() bool { return false } func (s Scalar) Iterator() Iterator { return nil } func (s Scalar) DataOrder() DataOrder { return 0 } // TODO -func (s Scalar) Slice(...Slice) (View, error) { return nil, errors.New("Cannot slice a scalar") } -func (s Scalar) At(at ...int) (interface{}, error) { return nil, errors.New("Get a value of a scalar") } -func (s Scalar) SetAt(_ interface{}, _ ...int) error { return errors.New("Cannot set value of scalar") } -func (s Scalar) Reshape(_ ...int) error { return errors.New("Cannot reshape a scalar") } -func (s Scalar) T(_ ...int) error { return errors.New("Cannot transpose a scalar") } -func (s Scalar) UT() {} -func (s Scalar) Transpose() error { return errors.New("Cannot transpose a scalar") } -func (s Scalar) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { - // TODO - return nil, errors.New("Cannot apply ") -} +func (s Scalar) Slice(...Slice) (View, error) { return nil, errors.New("Cannot slice a scalar") } +func (s Scalar) At(at ...int) (interface{}, error) { return nil, errors.New("Get a value of a scalar") } +func (s Scalar) SetAt(_ interface{}, _ ...int) error { return errors.New("Cannot set value of scalar") } +func (s Scalar) Reshape(_ ...int) error { return errors.New("Cannot reshape a scalar") } +func (s Scalar) T(_ ...int) error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) UT() {} +func (s Scalar) Transpose() error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { return nyierr(typeNYI, s) } func (s Scalar) Zero() {} //TODO func (s Scalar) Memset(interface{}) error { return errors.New("Cannot Memset") } @@ -81,13 +78,10 @@ func (s Scalar) IsManuallyManaged() bool { return false } func (s Scalar) Format(t fmt.State, c rune) {} // TODO func (s Scalar) String() string { return fmt.Sprintf("%v", s) } -func (s Scalar) WriteNpy(io.Writer) error { return errors.Errorf(methodNYI, "WriteNpy", "Scalar") } -func (s Scalar) ReadNpy(io.Reader) error { return errors.Errorf(methodNYI, "ReadNypy", "Scalar") } -func (s Scalar) GobEncode() ([]byte, error) { - // TODO - return nil, errors.Errorf(methodNYI, "GobEncode", "Scalar") -} -func (s Scalar) GobDecode([]byte) error { return errors.Errorf(methodNYI, "GobDecode", "Scalar") } // TODO +func (s Scalar) WriteNpy(io.Writer) error { return nyierr(typeNYI, s) } +func (s Scalar) ReadNpy(io.Reader) error { return nyierr(typeNYI, s) } +func (s Scalar) GobEncode() ([]byte, error) { return nil, nyierr(typeNYI, s) } +func (s Scalar) GobDecode([]byte) error { return nyierr(typeNYI, s) } func (s Scalar) standardEngine() StandardEngine { return StdEng{} } func (s Scalar) hdr() *storage.Header { return nil } diff --git a/sparse.go b/sparse.go index 9d4884f..b500db8 100644 --- a/sparse.go +++ b/sparse.go @@ -234,7 +234,7 @@ func (t *CS) T(axes ...int) error { UnsafePermute(axes, []int(t.s)) t.o = t.o.toggleColMajor() t.o = MakeDataOrder(t.o, Transposed) - return errors.Errorf(methodNYI, "T", t) + return nyierr(typeNYI, t) } // UT untransposes the CS @@ -243,9 +243,7 @@ func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } // Transpose is a no-op. The data does not move func (t *CS) Transpose() error { return nil } -func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { - return nil, errors.Errorf(methodNYI, "Apply", t) -} +func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { return nil, nyierr(typeNYI, t) } func (t *CS) Eq(other interface{}) bool { if ot, ok := other.(*CS); ok { From 81d74ca27cc6be3424d4f02870d7aec3678208b9 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 4 May 2022 11:59:30 +1000 Subject: [PATCH 150/154] Fixed a bunch of issues due to merging from origin/master. Fixed a bunch of errors related stufff --- api_matop.go | 25 +++++++++++++++++++++--- defaultengine_selbyidx.go | 2 +- defaultengine_softmax.go | 29 ++++++++++++++++++++++++---- dense_arith_test.go | 16 +++++++-------- dense_io.go | 2 +- dense_linalg_test.go | 2 ++ example_batched_nativeselect_test.go | 2 -- interfaces.go | 7 +++++++ native/iterator_native.go | 1 + native/iterator_native_purego.go | 1 + native/select_native.go | 1 + native/select_native_purego.go | 1 + utils.go | 2 +- 13 files changed, 71 insertions(+), 20 deletions(-) diff --git a/api_matop.go b/api_matop.go index 3957651..0667ce4 100644 --- a/api_matop.go +++ b/api_matop.go @@ -42,6 +42,15 @@ func T(t Tensor, axes ...int) (retVal Tensor, err error) { switch tt := t.(type) { case *Dense: return tt.SafeT(axes...) + case DenseView: + var ret *Dense + if ret, err = tt.SafeT(axes...); err != nil { + return nil, errors.Wrap(err, ".T() off a DenseView") + } + return DenseView{ret}, nil + + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -52,11 +61,20 @@ func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) { case *Dense: var ret *Dense if ret, err = tt.SafeT(axes...); err != nil { - return + return nil, errors.Wrap(err, "Unable to perform .SafeT() on a *Dense") } ret.Transpose() retVal = ret return + case DenseView: + var ret *Dense + if ret, err = tt.SafeT(axes...); err != nil { + return nil, errors.Wrap(err, "Unable to perform .SafeT() on a DenseView") + } + ret.Transpose() + return DenseView{ret}, nil + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -91,7 +109,8 @@ func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) { return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) } return T.Concat(axis, ts...) - + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -114,7 +133,7 @@ func Copy(dst, src Tensor) error { copyDense(dt, st) return nil default: - return errors.Errorf("NYI for Copy %T", src) + return nyierr(typeNYI, src) } panic("Unreachable") } diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index 1202031..e00cee4 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -164,7 +164,7 @@ func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts var reuse DenseTensor var _, toReuse, _ bool var ctx context.Context - if ctx, reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + if ctx, reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if err = handleCtx(ctx); err != nil { diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index ffc5a06..8a7dc3e 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -1,6 +1,7 @@ package tensor import ( + "context" "fmt" "math" "sync" @@ -30,9 +31,14 @@ func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(x.Dtype())) @@ -74,9 +80,14 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(output.Dtype())) @@ -112,9 +123,14 @@ func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(x.Dtype())) @@ -157,9 +173,14 @@ func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (ret var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(output.Dtype())) diff --git a/dense_arith_test.go b/dense_arith_test.go index b96a33c..423fc85 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -343,7 +343,7 @@ func TestDense_Add_reuse(t *testing.T) { } correct, err := a.Add(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -427,7 +427,7 @@ func TestDense_Sub_reuse(t *testing.T) { } correct, err := a.Sub(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -512,7 +512,7 @@ func TestDense_Mul_reuse(t *testing.T) { } correct, err := a.Mul(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -597,7 +597,7 @@ func TestDense_Div_reuse(t *testing.T) { } correct, err := a.Div(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -1498,7 +1498,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { } correct, err := a.Add(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -1616,7 +1616,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { } correct, err := a.Sub(b) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := a.Engine().(Suber) we = we || !ok @@ -1734,7 +1734,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { } correct, err := a.Mul(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -1821,7 +1821,7 @@ func TestDense_DivScalar_reuse(t *testing.T) { } correct, err := a.Div(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok diff --git a/dense_io.go b/dense_io.go index 1ef0d91..374daf0 100644 --- a/dense_io.go +++ b/dense_io.go @@ -651,7 +651,7 @@ func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{ backing = append(backing, record...) return backing, nil default: - return nil, nyierr(typeNYI, to) + return nil, errors.Errorf(methodNYI, "convFromStrs", to) } } diff --git a/dense_linalg_test.go b/dense_linalg_test.go index a9a24dc..17b6fcd 100644 --- a/dense_linalg_test.go +++ b/dense_linalg_test.go @@ -408,11 +408,13 @@ var outerTests = []linalgTest{ []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float32{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, + /* TODO: this test is no longer valid with the new impl of outer // stupids - a or b not vector {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, + */ // stupids - bad incr shape {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, diff --git a/example_batched_nativeselect_test.go b/example_batched_nativeselect_test.go index 7350aa9..cfa128e 100644 --- a/example_batched_nativeselect_test.go +++ b/example_batched_nativeselect_test.go @@ -2,7 +2,6 @@ package tensor import ( "fmt" - "log" ) func ExampleBatchedNativeSelectF64() { @@ -19,7 +18,6 @@ func ExampleBatchedNativeSelectF64() { } fmt.Printf("Is Truncated? %t\n", it.IsTruncated()) - log.Printf("XXX") fmt.Println("Reusing the same iterator for another loop") batchNo = 0 for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { diff --git a/interfaces.go b/interfaces.go index fd970e4..7061997 100644 --- a/interfaces.go +++ b/interfaces.go @@ -72,6 +72,13 @@ type Slicer interface { Slice(...Slice) (View, error) } +// SlicerInto is any tensor that can slice into another tensor. +// The other tensor may already have data allocated in it. +// If that is the case then the slice will be a copy operation. +type SlicerInto interface { + SliceInto(view Tensor, slices ...Slice) (retVal Tensor, err error) +} + // Reslicer is any tensor that can reslice. // To reslice is to reuse the container (*Dense, *CS) etc, but with new `Slice`s applied to it. // diff --git a/native/iterator_native.go b/native/iterator_native.go index 8ebc0e5..1ad0573 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -1,3 +1,4 @@ +//go:build !purego // +build !purego package native diff --git a/native/iterator_native_purego.go b/native/iterator_native_purego.go index 57e03c1..aba1b50 100644 --- a/native/iterator_native_purego.go +++ b/native/iterator_native_purego.go @@ -1,3 +1,4 @@ +//go:build purego // +build purego package native diff --git a/native/select_native.go b/native/select_native.go index f6b2e0e..b048ae9 100644 --- a/native/select_native.go +++ b/native/select_native.go @@ -1,3 +1,4 @@ +//go:build !purego // +build !purego package native diff --git a/native/select_native_purego.go b/native/select_native_purego.go index e2f1e2c..6285fe0 100644 --- a/native/select_native_purego.go +++ b/native/select_native_purego.go @@ -1,3 +1,4 @@ +//go:build purego // +build purego package native diff --git a/utils.go b/utils.go index 9fb455e..426a1dd 100644 --- a/utils.go +++ b/utils.go @@ -301,13 +301,13 @@ func allones(a []int) bool { return true } - // ctxFromEngine gets a context from an engine if it's a contexter. Otherwise it returns a context.Background() func ctxFromEngine(e Engine) context.Context { if c, ok := e.(contexter); ok { return c.Context() } return context.Background() +} func getFloat64s(a Tensor) []float64 { if um, ok := a.(unsafeMem); ok { From 61aec532ebc44bfef34f1965e5dce4803ab3b49a Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 4 May 2022 13:29:06 +1000 Subject: [PATCH 151/154] Added documentation for Transpose and T() --- api_matop.go | 3 +-- example_matop_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 example_matop_test.go diff --git a/api_matop.go b/api_matop.go index 0667ce4..4d98479 100644 --- a/api_matop.go +++ b/api_matop.go @@ -45,10 +45,9 @@ func T(t Tensor, axes ...int) (retVal Tensor, err error) { case DenseView: var ret *Dense if ret, err = tt.SafeT(axes...); err != nil { - return nil, errors.Wrap(err, ".T() off a DenseView") + return nil, errors.Wrap(err, "T() off a DenseView") } return DenseView{ret}, nil - default: return nil, nyierr(typeNYI, t) } diff --git a/example_matop_test.go b/example_matop_test.go new file mode 100644 index 0000000..4c0d4da --- /dev/null +++ b/example_matop_test.go @@ -0,0 +1,59 @@ +package tensor_test + +import ( + "fmt" + + "gorgonia.org/tensor" +) + +func ExampleTranspose() { + t := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]int{1, 2, 3, 4, 5, 6})) + t2, err := tensor.Transpose(t) + if err != nil { + fmt.Printf("ERR: %v\n", err) + } + fmt.Printf("Transpose is a safe operation.\nT:\n%v\nT':\n%v\n", t, t2) + fmt.Printf("The data is changed:\nT : %v\nT': %v", t.Data(), t2.Data()) + + // Output: + // Transpose is a safe operation. + // T: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // + // T': + // ⎡1 4⎤ + // ⎢2 5⎥ + // ⎣3 6⎦ + // + // The data is changed: + // T : [1 2 3 4 5 6] + // T': [1 4 2 5 3 6] + +} + +func ExampleT() { + t := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]int{1, 2, 3, 4, 5, 6})) + t2, err := tensor.T(t) + if err != nil { + fmt.Printf("ERR: %v\n", err) + } + fmt.Printf("T is a safe version of the .T() method\nT:\n%v\nT':\n%v\n", t, t2) + fmt.Printf("The data is unchanged:\nT : %v\nT': %v\n", t.Data(), t2.Data()) + + // Output: + // T is a safe version of the .T() method + // T: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // + // T': + // ⎡1 4⎤ + // ⎢2 5⎥ + // ⎣3 6⎦ + // + // The data is unchanged: + // T : [1 2 3 4 5 6] + // T': [1 2 3 4 5 6] + +} From 1ec5b8943e273929347003cdb20f2f51491ab1e1 Mon Sep 17 00:00:00 2001 From: chewxy Date: Wed, 4 May 2022 14:21:29 +1000 Subject: [PATCH 152/154] Added examples for SliceInto --- dense_matop.go | 6 ++-- example_dense_matop_test.go | 65 +++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/dense_matop.go b/dense_matop.go index 3a04f9c..3a9f005 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -246,8 +246,10 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view. // The underlying data is the same. // This method will override ALL the metadata in view. -func (t *Dense) SliceInto(view Tensor, slices ...Slice) (retVal View, err error) { +func (t *Dense) SliceInto(view Tensor, slices ...Slice) (retVal Tensor, err error) { switch view := view.(type) { + case nil: + return t.Slice(slices...) case DenseView: v := view.Dense if v, err = t.sliceIntoDense(v, slices...); err != nil { @@ -261,7 +263,7 @@ func (t *Dense) SliceInto(view Tensor, slices ...Slice) (retVal View, err error) } return DenseView{view}, nil default: - return nil, nyierr(typeNYI) + return nil, nyierr(typeNYI, view) } } diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 1e81271..c6f50a5 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -31,6 +31,71 @@ func ExampleDense_Slice() { // [1 4] } +func ExampleDense_SliceInto() { + var v Tensor + var err error + T := New(WithBacking(Range(Int, 0, 9)), WithShape(3, 3)) + fmt.Println("SliceInto works with nil values. It simply creates a View.\n==========================================================") + fmt.Printf("T:\n%v\n", T) + + if v, err = T.SliceInto(v, makeRS(0, 2), makeRS(0, 2)); err != nil { + fmt.Printf("ERR %v\n", err) + return + } + fmt.Printf("T[0:2, 0:2]:\n%v\n", v) + + v.Zero() + fmt.Printf("When v is zeroed, T is zeroed too.\n==================================\nv:\n%v\nT:\n%v\n", v, T) + + fmt.Println("Primary use case of SliceInto.\n==============================") + T = New(WithBacking(Range(Int, 0, 9)), WithShape(3, 3)) + fmt.Printf("T:\n%v\nv:\n%v\n", T, v) + if v, err = T.SliceInto(v, makeRS(0, 2), makeRS(0, 2)); err != nil { + fmt.Printf("ERR %v\n", err) + return + } + fmt.Printf("v = T[0:2, 0:2]:\n%v\n", v) + + // Output: + // SliceInto works with nil values. It simply creates a View. + // ========================================================== + // T: + // ⎡0 1 2⎤ + // ⎢3 4 5⎥ + // ⎣6 7 8⎦ + // + // T[0:2, 0:2]: + // ⎡0 1⎤ + // ⎣3 4⎦ + // + // When v is zeroed, T is zeroed too. + // ================================== + // v: + // ⎡0 0⎤ + // ⎣0 0⎦ + // + // T: + // ⎡0 0 0⎤ + // ⎢0 0 5⎥ + // ⎣6 7 8⎦ + // + // Primary use case of SliceInto. + // ============================== + // T: + // ⎡0 1 2⎤ + // ⎢3 4 5⎥ + // ⎣6 7 8⎦ + // + // v: + // ⎡0 0⎤ + // ⎣0 0⎦ + // + // v = T[0:2, 0:2]: + // ⎡0 1⎤ + // ⎣3 4⎦ + +} + // Slicing works on one dimensional arrays too: func ExampleDense_Slice_oneDimension() { var T Tensor From 0fa100d87850a8d5ae219dffad2d5234a0d603d2 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Mon, 18 Jul 2022 10:42:17 +1000 Subject: [PATCH 153/154] Wipsel 128 upgrade asume no moving gc to latest version (#130) * Upgrade asume-no-moving-gc. * Updated versions to run thests for * Fixed so govet won't mistaken the format vebs in the templates for actual format verbs Co-authored-by: wisse Co-authored-by: chewxy --- .github/workflows/.go.yml | 2 +- genlib2/dense_io.go | 4 ++-- go.mod | 14 +++++++++++--- go.sum | 4 ++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/.github/workflows/.go.yml b/.github/workflows/.go.yml index f658e9a..9710e2f 100644 --- a/.github/workflows/.go.yml +++ b/.github/workflows/.go.yml @@ -10,7 +10,7 @@ jobs: test: strategy: matrix: - go: [1.13.x, 1.14.x, 1.15.x] + go: [1.18.x, 1.17.x, 1.16.x, 1.15.x] os: [ubuntu-latest, macos-latest, windows-latest] tags: [avx, sse] allowfail: [false] diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 4a63ddd..0fe010a 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -666,12 +666,12 @@ func generateDenseIO(f io.Writer, generic Kinds) { fmt.Fprintln(f, npyDescRE) fmt.Fprintln(f, rowOrderRE) fmt.Fprintln(f, shapeRE) - fmt.Fprintln(f, writeNpyRaw) + f.Write([]byte(writeNpyRaw)) readNpy.Execute(f, mk) fmt.Fprint(f, "\n") fmt.Fprint(f, "/* CSV SERIALIZATION */\n\n") - fmt.Fprintln(f, writeCSVRaw) + f.Write([]byte(writeCSVRaw)) readCSV.Execute(f, mk) fmt.Fprint(f, "\n") diff --git a/go.mod b/go.mod index 1666dd0..f43f495 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorgonia.org/tensor -go 1.15 +go 1.18 require ( github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc @@ -11,9 +11,17 @@ require ( github.com/google/flatbuffers v1.12.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.6.1 - github.com/xtgo/set v1.0.0 // indirect - go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 + go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 gonum.org/v1/gonum v0.8.2 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/xtgo/set v1.0.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/protobuf v1.25.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/go.sum b/go.sum index c350f4e..524d845 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= -go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 h1:FyBZqvoA/jbNzuAWLQE2kG820zMAkcilx6BMjGbL/E4= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= From 6bc6f461f59911b3cb488f6284d54fbe478e4548 Mon Sep 17 00:00:00 2001 From: chewxy Date: Tue, 26 Sep 2023 11:07:04 +1000 Subject: [PATCH 154/154] Some corrections to the tests of selbyidx. Prepping for the larger change --- defaultengine_selbyidx.go | 2 -- dense_selbyidx_test.go | 40 +++++++++++++++++++-------------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index e00cee4..58e3e42 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -20,7 +20,6 @@ func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (r if indices.Dtype() != Int { return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype()) } - // if b is a scalar, then use Slice if a.Shape().IsScalarEquiv() { slices := make([]Slice, a.Shape().Dims()) @@ -111,7 +110,6 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da for o := 0; o < outer; o++ { end := start + axStride dstEnd := dstStart + retStride - storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) start += prevStride diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go index e542133..98d309a 100644 --- a/dense_selbyidx_test.go +++ b/dense_selbyidx_test.go @@ -19,28 +19,28 @@ type selByIndicesTest struct { } var selByIndicesTests = []selByIndicesTest{ - {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, - Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, - }, - {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, - Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, + // {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, + // Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, + // }, + // {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + // Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, - {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, - Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, + // {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, + // Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, - {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, - Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, + // {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, + // Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, - {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, - Correct: []int{1, 1}, CorrectShape: Shape{2}}, + // {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + // Correct: []int{1, 1}, CorrectShape: Shape{2}}, {Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true, Correct: []int{1, 1}, CorrectShape: Shape{2}}, - {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, - Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, - {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, - Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, - }, + // {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, + // Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, + // {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + // Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, + // }, } func TestDense_SelectByIndices(t *testing.T) { @@ -98,10 +98,10 @@ var selByIndicesBTests = []struct { } func init() { - for i := range selByIndicesBTests { - selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] - selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape - } + // for i := range selByIndicesBTests { + // selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] + // selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape + // } } func TestDense_SelectByIndicesB(t *testing.T) {